diff --git a/compiler/rustc_infer/src/infer/canonical/query_response.rs b/compiler/rustc_infer/src/infer/canonical/query_response.rs index 1d3d32ef74979..cdc3d50c4a84e 100644 --- a/compiler/rustc_infer/src/infer/canonical/query_response.rs +++ b/compiler/rustc_infer/src/infer/canonical/query_response.rs @@ -15,8 +15,10 @@ use rustc_index::{Idx, IndexVec}; use rustc_middle::arena::ArenaAllocatable; use rustc_middle::mir::ConstraintCategory; use rustc_middle::ty::fold::TypeFoldable; +use rustc_middle::ty::traverse::AlwaysTraversable; use rustc_middle::ty::{self, BoundVar, GenericArg, GenericArgKind, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; +use rustc_type_ir::traverse::OptTryFoldWith; use tracing::{debug, instrument}; use crate::infer::canonical::instantiate::{CanonicalExt, instantiate_value}; @@ -60,7 +62,7 @@ impl<'tcx> InferCtxt<'tcx> { fulfill_cx: &mut dyn TraitEngine<'tcx, ScrubbedTraitError<'tcx>>, ) -> Result, NoSolution> where - T: Debug + TypeFoldable>, + T: OptTryFoldWith>, Canonical<'tcx, QueryResponse<'tcx, T>>: ArenaAllocatable<'tcx>, { let query_response = self.make_query_response(inference_vars, answer, fulfill_cx)?; @@ -107,7 +109,7 @@ impl<'tcx> InferCtxt<'tcx> { fulfill_cx: &mut dyn TraitEngine<'tcx, ScrubbedTraitError<'tcx>>, ) -> Result, NoSolution> where - T: Debug + TypeFoldable>, + T: OptTryFoldWith>, { let tcx = self.tcx; @@ -243,7 +245,7 @@ impl<'tcx> InferCtxt<'tcx> { output_query_region_constraints: &mut QueryRegionConstraints<'tcx>, ) -> InferResult<'tcx, R> where - R: Debug + TypeFoldable>, + R: OptTryFoldWith>, { let InferOk { value: result_args, mut obligations } = self .query_response_instantiation_guess( @@ -326,8 +328,11 @@ impl<'tcx> InferCtxt<'tcx> { .map(|p_c| instantiate_value(self.tcx, &result_args, p_c.clone())), ); - let user_result: R = - query_response.instantiate_projected(self.tcx, &result_args, |q_r| q_r.value.clone()); + let user_result: R = query_response + .instantiate_projected(self.tcx, &result_args, |q_r| { + AlwaysTraversable(q_r.value.clone()) + }) + .0; Ok(InferOk { value: user_result, obligations }) } @@ -396,7 +401,7 @@ impl<'tcx> InferCtxt<'tcx> { query_response: &Canonical<'tcx, QueryResponse<'tcx, R>>, ) -> InferResult<'tcx, CanonicalVarValues<'tcx>> where - R: Debug + TypeFoldable>, + R: OptTryFoldWith>, { // For each new universe created in the query result that did // not appear in the original query, create a local diff --git a/compiler/rustc_infer/src/traits/structural_impls.rs b/compiler/rustc_infer/src/traits/structural_impls.rs index 31f585c0c9edd..e185feba26db9 100644 --- a/compiler/rustc_infer/src/traits/structural_impls.rs +++ b/compiler/rustc_infer/src/traits/structural_impls.rs @@ -4,6 +4,7 @@ use rustc_ast_ir::try_visit; use rustc_middle::ty::fold::{FallibleTypeFolder, TypeFoldable}; use rustc_middle::ty::visit::{TypeVisitable, TypeVisitor}; use rustc_middle::ty::{self, TyCtxt}; +use rustc_type_ir::traverse::{ImportantTypeTraversal, TypeTraversable}; use crate::traits; use crate::traits::project::Normalized; @@ -55,6 +56,11 @@ impl<'tcx, O: TypeFoldable>> TypeFoldable> } } +impl<'tcx, O: TypeVisitable>> TypeTraversable> + for traits::Obligation<'tcx, O> +{ + type Kind = ImportantTypeTraversal; +} impl<'tcx, O: TypeVisitable>> TypeVisitable> for traits::Obligation<'tcx, O> { diff --git a/compiler/rustc_macros/src/lib.rs b/compiler/rustc_macros/src/lib.rs index f46c795b9565c..b3e491c651a93 100644 --- a/compiler/rustc_macros/src/lib.rs +++ b/compiler/rustc_macros/src/lib.rs @@ -17,6 +17,7 @@ mod diagnostics; mod extension; mod hash_stable; mod lift; +mod noop_type_traversable; mod query; mod serialize; mod symbols; @@ -81,27 +82,9 @@ decl_derive!([TyDecodable] => serialize::type_decodable_derive); decl_derive!([TyEncodable] => serialize::type_encodable_derive); decl_derive!([MetadataDecodable] => serialize::meta_decodable_derive); decl_derive!([MetadataEncodable] => serialize::meta_encodable_derive); -decl_derive!( - [TypeFoldable, attributes(type_foldable)] => - /// Derives `TypeFoldable` for the annotated `struct` or `enum` (`union` is not supported). - /// - /// The fold will produce a value of the same struct or enum variant as the input, with - /// each field respectively folded using the `TypeFoldable` implementation for its type. - /// However, if a field of a struct or an enum variant is annotated with - /// `#[type_foldable(identity)]` then that field will retain its incumbent value (and its - /// type is not required to implement `TypeFoldable`). - type_foldable::type_foldable_derive -); -decl_derive!( - [TypeVisitable, attributes(type_visitable)] => - /// Derives `TypeVisitable` for the annotated `struct` or `enum` (`union` is not supported). - /// - /// Each field of the struct or enum variant will be visited in definition order, using the - /// `TypeVisitable` implementation for its type. However, if a field of a struct or an enum - /// variant is annotated with `#[type_visitable(ignore)]` then that field will not be - /// visited (and its type is not required to implement `TypeVisitable`). - type_visitable::type_visitable_derive -); +decl_derive!([NoopTypeTraversable] => noop_type_traversable::noop_type_traversable_derive); +decl_derive!([TypeVisitable] => type_visitable::type_visitable_derive); +decl_derive!([TypeFoldable] => type_foldable::type_foldable_derive); decl_derive!([Lift, attributes(lift)] => lift::lift_derive); decl_derive!( [Diagnostic, attributes( diff --git a/compiler/rustc_macros/src/noop_type_traversable.rs b/compiler/rustc_macros/src/noop_type_traversable.rs new file mode 100644 index 0000000000000..eeb4bd2bc5863 --- /dev/null +++ b/compiler/rustc_macros/src/noop_type_traversable.rs @@ -0,0 +1,39 @@ +use quote::quote; +use syn::parse_quote; + +pub(super) fn noop_type_traversable_derive( + mut s: synstructure::Structure<'_>, +) -> proc_macro2::TokenStream { + if let syn::Data::Union(_) = s.ast().data { + panic!("cannot derive on union") + } + + s.underscore_const(true); + + if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") { + s.add_impl_generic(parse_quote! { 'tcx }); + } + + s.add_bounds(synstructure::AddBounds::None); + let mut where_clauses = None; + s.add_trait_bounds( + &parse_quote!( + ::rustc_middle::ty::traverse::TypeTraversable< + ::rustc_middle::ty::TyCtxt<'tcx>, + Kind = ::rustc_middle::ty::traverse::NoopTypeTraversal, + > + ), + &mut where_clauses, + synstructure::AddBounds::Fields, + ); + for pred in where_clauses.into_iter().flat_map(|c| c.predicates) { + s.add_where_predicate(pred); + } + + s.bound_impl( + quote!(::rustc_middle::ty::traverse::TypeTraversable<::rustc_middle::ty::TyCtxt<'tcx>>), + quote! { + type Kind = ::rustc_middle::ty::traverse::NoopTypeTraversal; + }, + ) +} diff --git a/compiler/rustc_macros/src/type_foldable.rs b/compiler/rustc_macros/src/type_foldable.rs index bc3b82c2893fa..afea1089a9401 100644 --- a/compiler/rustc_macros/src/type_foldable.rs +++ b/compiler/rustc_macros/src/type_foldable.rs @@ -1,4 +1,4 @@ -use quote::{ToTokens, quote}; +use quote::quote; use syn::parse_quote; pub(super) fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream { @@ -12,34 +12,25 @@ pub(super) fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_m s.add_impl_generic(parse_quote! { 'tcx }); } - s.add_bounds(synstructure::AddBounds::Generics); + s.add_bounds(synstructure::AddBounds::None); + let mut where_clauses = None; + s.add_trait_bounds( + &parse_quote!(::rustc_type_ir::traverse::OptTryFoldWith<::rustc_middle::ty::TyCtxt<'tcx>>), + &mut where_clauses, + synstructure::AddBounds::Fields, + ); + s.add_where_predicate(parse_quote! { Self: std::fmt::Debug + Clone }); + for pred in where_clauses.into_iter().flat_map(|c| c.predicates) { + s.add_where_predicate(pred); + } + s.bind_with(|_| synstructure::BindStyle::Move); let body_fold = s.each_variant(|vi| { let bindings = vi.bindings(); vi.construct(|_, index| { let bind = &bindings[index]; - - let mut fixed = false; - - // retain value of fields with #[type_foldable(identity)] - bind.ast().attrs.iter().for_each(|x| { - if !x.path().is_ident("type_foldable") { - return; - } - let _ = x.parse_nested_meta(|nested| { - if nested.path.is_ident("identity") { - fixed = true; - } - Ok(()) - }); - }); - - if fixed { - bind.to_token_stream() - } else { - quote! { - ::rustc_middle::ty::fold::TypeFoldable::try_fold_with(#bind, __folder)? - } + quote! { + ::rustc_middle::ty::traverse::OptTryFoldWith::mk_try_fold_with()(#bind, __folder)? } }) }); diff --git a/compiler/rustc_macros/src/type_visitable.rs b/compiler/rustc_macros/src/type_visitable.rs index 527ca26c0eb10..24850dbdb4071 100644 --- a/compiler/rustc_macros/src/type_visitable.rs +++ b/compiler/rustc_macros/src/type_visitable.rs @@ -10,34 +10,30 @@ pub(super) fn type_visitable_derive( s.underscore_const(true); - // ignore fields with #[type_visitable(ignore)] - s.filter(|bi| { - let mut ignored = false; - - bi.ast().attrs.iter().for_each(|attr| { - if !attr.path().is_ident("type_visitable") { - return; - } - let _ = attr.parse_nested_meta(|nested| { - if nested.path.is_ident("ignore") { - ignored = true; - } - Ok(()) - }); - }); - - !ignored - }); - if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") { s.add_impl_generic(parse_quote! { 'tcx }); } - s.add_bounds(synstructure::AddBounds::Generics); + s.add_bounds(synstructure::AddBounds::None); + let mut where_clauses = None; + s.add_trait_bounds( + &parse_quote!( + ::rustc_middle::ty::traverse::OptVisitWith::<::rustc_middle::ty::TyCtxt<'tcx>> + ), + &mut where_clauses, + synstructure::AddBounds::Fields, + ); + s.add_where_predicate(parse_quote! { Self: std::fmt::Debug + Clone }); + for pred in where_clauses.into_iter().flat_map(|c| c.predicates) { + s.add_where_predicate(pred); + } + + let impl_traversable_s = s.clone(); + let body_visit = s.each(|bind| { quote! { match ::rustc_ast_ir::visit::VisitorResult::branch( - ::rustc_middle::ty::visit::TypeVisitable::visit_with(#bind, __visitor) + ::rustc_middle::ty::traverse::OptVisitWith::mk_visit_with()(#bind, __visitor) ) { ::core::ops::ControlFlow::Continue(()) => {}, ::core::ops::ControlFlow::Break(r) => { @@ -48,7 +44,7 @@ pub(super) fn type_visitable_derive( }); s.bind_with(|_| synstructure::BindStyle::Move); - s.bound_impl( + let visitable_impl = s.bound_impl( quote!(::rustc_middle::ty::visit::TypeVisitable<::rustc_middle::ty::TyCtxt<'tcx>>), quote! { fn visit_with<__V: ::rustc_middle::ty::visit::TypeVisitor<::rustc_middle::ty::TyCtxt<'tcx>>>( @@ -59,5 +55,17 @@ pub(super) fn type_visitable_derive( <__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output() } }, - ) + ); + + let traversable_impl = impl_traversable_s.bound_impl( + quote!(::rustc_middle::ty::traverse::TypeTraversable<::rustc_middle::ty::TyCtxt<'tcx>>), + quote! { + type Kind = ::rustc_middle::ty::traverse::ImportantTypeTraversal; + }, + ); + + quote! { + #visitable_impl + #traversable_impl + } } diff --git a/compiler/rustc_middle/src/hir/place.rs b/compiler/rustc_middle/src/hir/place.rs index 4c7af0bc3726d..ceaa74156cd98 100644 --- a/compiler/rustc_middle/src/hir/place.rs +++ b/compiler/rustc_middle/src/hir/place.rs @@ -1,12 +1,14 @@ use rustc_hir::HirId; -use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable}; +use rustc_macros::{ + HashStable, NoopTypeTraversable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable, +}; use rustc_target::abi::{FieldIdx, VariantIdx}; use crate::ty; use crate::ty::Ty; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable)] -#[derive(TypeFoldable, TypeVisitable)] +#[derive(NoopTypeTraversable)] pub enum PlaceBase { /// A temporary variable. Rvalue, @@ -19,7 +21,7 @@ pub enum PlaceBase { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable)] -#[derive(TypeFoldable, TypeVisitable)] +#[derive(NoopTypeTraversable)] pub enum ProjectionKind { /// A dereference of a pointer, reference or `Box` of the given type. Deref, diff --git a/compiler/rustc_middle/src/macros.rs b/compiler/rustc_middle/src/macros.rs index 39816c17b985f..8e3427e941dd1 100644 --- a/compiler/rustc_middle/src/macros.rs +++ b/compiler/rustc_middle/src/macros.rs @@ -73,32 +73,8 @@ macro_rules! TrivialLiftImpls { macro_rules! TrivialTypeTraversalImpls { ($($ty:ty),+ $(,)?) => { $( - impl<'tcx> $crate::ty::fold::TypeFoldable<$crate::ty::TyCtxt<'tcx>> for $ty { - fn try_fold_with>>( - self, - _: &mut F, - ) -> ::std::result::Result { - Ok(self) - } - - #[inline] - fn fold_with>>( - self, - _: &mut F, - ) -> Self { - self - } - } - - impl<'tcx> $crate::ty::visit::TypeVisitable<$crate::ty::TyCtxt<'tcx>> for $ty { - #[inline] - fn visit_with>>( - &self, - _: &mut F) - -> F::Result - { - ::output() - } + impl<'tcx> $crate::ty::traverse::TypeTraversable<$crate::ty::TyCtxt<'tcx>> for $ty { + type Kind = $crate::ty::traverse::NoopTypeTraversal; } )+ }; diff --git a/compiler/rustc_middle/src/mir/query.rs b/compiler/rustc_middle/src/mir/query.rs index 70331214ac5a8..088c041c4b2be 100644 --- a/compiler/rustc_middle/src/mir/query.rs +++ b/compiler/rustc_middle/src/mir/query.rs @@ -57,8 +57,6 @@ pub struct CoroutineLayout<'tcx> { /// Which saved locals are storage-live at the same time. Locals that do not /// have conflicts with each other are allowed to overlap in the computed /// layout. - #[type_foldable(identity)] - #[type_visitable(ignore)] pub storage_conflicts: BitMatrix, } diff --git a/compiler/rustc_middle/src/mir/type_foldable.rs b/compiler/rustc_middle/src/mir/type_foldable.rs index b798f0788007f..178b6767ce8ca 100644 --- a/compiler/rustc_middle/src/mir/type_foldable.rs +++ b/compiler/rustc_middle/src/mir/type_foldable.rs @@ -1,7 +1,5 @@ //! `TypeFoldable` implementations for MIR types - -use rustc_ast::InlineAsmTemplatePiece; -use rustc_hir::def_id::LocalDefId; +use rustc_index::bit_set::BitMatrix; use super::*; @@ -20,6 +18,7 @@ TrivialTypeTraversalImpls! { SwitchTargets, CoroutineKind, CoroutineSavedLocal, + BitMatrix, } TrivialTypeTraversalImpls! { @@ -27,33 +26,6 @@ TrivialTypeTraversalImpls! { NullOp<'tcx>, } -impl<'tcx> TypeFoldable> for &'tcx [InlineAsmTemplatePiece] { - fn try_fold_with>>( - self, - _folder: &mut F, - ) -> Result { - Ok(self) - } -} - -impl<'tcx> TypeFoldable> for &'tcx [Span] { - fn try_fold_with>>( - self, - _folder: &mut F, - ) -> Result { - Ok(self) - } -} - -impl<'tcx> TypeFoldable> for &'tcx ty::List { - fn try_fold_with>>( - self, - _folder: &mut F, - ) -> Result { - Ok(self) - } -} - impl<'tcx> TypeFoldable> for &'tcx ty::List> { fn try_fold_with>>( self, diff --git a/compiler/rustc_middle/src/thir.rs b/compiler/rustc_middle/src/thir.rs index fe865b8a51508..9173a7518a4b5 100644 --- a/compiler/rustc_middle/src/thir.rs +++ b/compiler/rustc_middle/src/thir.rs @@ -17,7 +17,7 @@ use rustc_hir as hir; use rustc_hir::def_id::DefId; use rustc_hir::{BindingMode, ByRef, HirId, MatchSource, RangeEnd}; use rustc_index::{IndexVec, newtype_index}; -use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeVisitable}; +use rustc_macros::{HashStable, NoopTypeTraversable, TyDecodable, TyEncodable, TypeVisitable}; use rustc_middle::middle::region; use rustc_middle::mir::interpret::AllocId; use rustc_middle::mir::{self, BinOp, BorrowKind, FakeReadCause, UnOp}; @@ -234,7 +234,8 @@ pub enum StmtKind<'tcx> { }, } -#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, HashStable, TyEncodable, TyDecodable)] +#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash, HashStable)] +#[derive(NoopTypeTraversable, TyEncodable, TyDecodable)] pub struct LocalVarId(pub HirId); /// A THIR expression. @@ -739,9 +740,7 @@ pub enum PatKind<'tcx> { /// `x`, `ref x`, `x @ P`, etc. Binding { name: Symbol, - #[type_visitable(ignore)] mode: BindingMode, - #[type_visitable(ignore)] var: LocalVarId, ty: Ty<'tcx>, subpattern: Option>>, @@ -844,7 +843,6 @@ pub struct PatRange<'tcx> { pub lo: PatRangeBoundary<'tcx>, /// Must not be `NegInfinity`. pub hi: PatRangeBoundary<'tcx>, - #[type_visitable(ignore)] pub end: RangeEnd, pub ty: Ty<'tcx>, } diff --git a/compiler/rustc_middle/src/traits/solve.rs b/compiler/rustc_middle/src/traits/solve.rs index f659bf8125a0e..e538c3839320a 100644 --- a/compiler/rustc_middle/src/traits/solve.rs +++ b/compiler/rustc_middle/src/traits/solve.rs @@ -1,8 +1,9 @@ use rustc_ast_ir::try_visit; use rustc_data_structures::intern::Interned; use rustc_macros::HashStable; -use rustc_type_ir as ir; pub use rustc_type_ir::solve::*; +use rustc_type_ir::traverse::{ImportantTypeTraversal, TypeTraversable}; +use rustc_type_ir::{self as ir}; use crate::ty::{ self, FallibleTypeFolder, TyCtxt, TypeFoldable, TypeFolder, TypeVisitable, TypeVisitor, @@ -72,6 +73,9 @@ impl<'tcx> TypeFoldable> for ExternalConstraints<'tcx> { } } +impl<'tcx> TypeTraversable> for ExternalConstraints<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for ExternalConstraints<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { try_visit!(self.region_constraints.visit_with(visitor)); @@ -106,6 +110,9 @@ impl<'tcx> TypeFoldable> for PredefinedOpaques<'tcx> { } } +impl<'tcx> TypeTraversable> for PredefinedOpaques<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for PredefinedOpaques<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { self.opaque_types.visit_with(visitor) diff --git a/compiler/rustc_middle/src/ty/adt.rs b/compiler/rustc_middle/src/ty/adt.rs index 3322a2643d7d3..c32ba563816f4 100644 --- a/compiler/rustc_middle/src/ty/adt.rs +++ b/compiler/rustc_middle/src/ty/adt.rs @@ -13,7 +13,7 @@ use rustc_hir::def::{CtorKind, DefKind, Res}; use rustc_hir::def_id::DefId; use rustc_hir::{self as hir, LangItem}; use rustc_index::{IndexSlice, IndexVec}; -use rustc_macros::{HashStable, TyDecodable, TyEncodable}; +use rustc_macros::{HashStable, NoopTypeTraversable, TyDecodable, TyEncodable}; use rustc_query_system::ich::StableHashingContext; use rustc_session::DataTypeKind; use rustc_span::symbol::sym; @@ -168,7 +168,7 @@ impl<'a> HashStable> for AdtDefData { } } -#[derive(Copy, Clone, PartialEq, Eq, Hash, HashStable)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, HashStable, NoopTypeTraversable)] #[rustc_pass_by_value] pub struct AdtDef<'tcx>(pub Interned<'tcx, AdtDefData>); diff --git a/compiler/rustc_middle/src/ty/consts/valtree.rs b/compiler/rustc_middle/src/ty/consts/valtree.rs index 9f9bf41c3355a..e3bad806d153f 100644 --- a/compiler/rustc_middle/src/ty/consts/valtree.rs +++ b/compiler/rustc_middle/src/ty/consts/valtree.rs @@ -1,11 +1,11 @@ -use rustc_macros::{HashStable, TyDecodable, TyEncodable}; +use rustc_macros::{HashStable, NoopTypeTraversable, TyDecodable, TyEncodable}; use super::ScalarInt; use crate::mir::interpret::Scalar; use crate::ty::{self, Ty, TyCtxt}; #[derive(Copy, Clone, Debug, Hash, TyEncodable, TyDecodable, Eq, PartialEq)] -#[derive(HashStable)] +#[derive(HashStable, NoopTypeTraversable)] /// This datastructure is used to represent the value of constants used in the type system. /// /// We explicitly choose a different datastructure from the way values are processed within diff --git a/compiler/rustc_middle/src/ty/error.rs b/compiler/rustc_middle/src/ty/error.rs index b02eff3bfd6a3..c49824bb418cd 100644 --- a/compiler/rustc_middle/src/ty/error.rs +++ b/compiler/rustc_middle/src/ty/error.rs @@ -35,9 +35,6 @@ impl<'tcx> TypeError<'tcx> { TypeError::CyclicTy(_) => "cyclic type of infinite size".into(), TypeError::CyclicConst(_) => "encountered a self-referencing constant".into(), TypeError::Mismatch => "types differ".into(), - TypeError::ConstnessMismatch(values) => { - format!("expected {} bound, found {} bound", values.expected, values.found).into() - } TypeError::PolarityMismatch(values) => { format!("expected {} polarity, found {} polarity", values.expected, values.found) .into() diff --git a/compiler/rustc_middle/src/ty/generic_args.rs b/compiler/rustc_middle/src/ty/generic_args.rs index daf1362e25c1f..7f6edb0a61c76 100644 --- a/compiler/rustc_middle/src/ty/generic_args.rs +++ b/compiler/rustc_middle/src/ty/generic_args.rs @@ -14,6 +14,7 @@ use rustc_hir::def_id::DefId; use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable, extension}; use rustc_serialize::{Decodable, Encodable}; use rustc_type_ir::WithCachedTypeInfo; +use rustc_type_ir::traverse::{ImportantTypeTraversal, TypeTraversable}; use smallvec::SmallVec; use crate::ty::codec::{TyDecoder, TyEncoder}; @@ -329,6 +330,9 @@ impl<'tcx> TypeFoldable> for GenericArg<'tcx> { } } +impl<'tcx> TypeTraversable> for GenericArg<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for GenericArg<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { match self.unpack() { @@ -642,6 +646,9 @@ impl<'tcx> TypeFoldable> for &'tcx ty::List> { } } +impl<'tcx, T: TypeTraversable>> TypeTraversable> for &'tcx ty::List { + type Kind = T::Kind; +} impl<'tcx, T: TypeVisitable>> TypeVisitable> for &'tcx ty::List { #[inline] fn visit_with>>(&self, visitor: &mut V) -> V::Result { diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index ed24fcc7eb88a..9e653e0e20a42 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -55,8 +55,10 @@ pub use rustc_type_ir::ConstKind::{ Placeholder as PlaceholderCt, Unevaluated, Value, }; pub use rustc_type_ir::relate::VarianceDiagInfo; +use rustc_type_ir::traverse::TypeTraversable; pub use rustc_type_ir::*; use tracing::{debug, instrument}; +use traverse::ImportantTypeTraversal; pub use vtable::*; use {rustc_ast as ast, rustc_attr as attr, rustc_hir as hir}; @@ -547,6 +549,9 @@ impl<'tcx> TypeFoldable> for Term<'tcx> { } } +impl<'tcx> TypeTraversable> for Term<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for Term<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { match self.unpack() { @@ -1031,15 +1036,18 @@ impl<'tcx> TypeFoldable> for ParamEnv<'tcx> { ) -> Result { Ok(ParamEnv::new( self.caller_bounds().try_fold_with(folder)?, - self.reveal().try_fold_with(folder)?, + self.reveal().noop_try_fold_with(folder)?, )) } } +impl<'tcx> TypeTraversable> for ParamEnv<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for ParamEnv<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { try_visit!(self.caller_bounds().visit_with(visitor)); - self.reveal().visit_with(visitor) + self.reveal().noop_visit_with(visitor) } } diff --git a/compiler/rustc_middle/src/ty/relate.rs b/compiler/rustc_middle/src/ty/relate.rs index 4c7bcb1bf2e88..504a3c8a6d832 100644 --- a/compiler/rustc_middle/src/ty/relate.rs +++ b/compiler/rustc_middle/src/ty/relate.rs @@ -1,7 +1,5 @@ use std::iter; -use rustc_hir as hir; -use rustc_target::spec::abi; pub use rustc_type_ir::relate::*; use crate::ty::error::{ExpectedFound, TypeError}; @@ -121,26 +119,6 @@ impl<'tcx> Relate> for &'tcx ty::List Relate> for hir::Safety { - fn relate>>( - _relation: &mut R, - a: hir::Safety, - b: hir::Safety, - ) -> RelateResult<'tcx, hir::Safety> { - if a != b { Err(TypeError::SafetyMismatch(ExpectedFound::new(true, a, b))) } else { Ok(a) } - } -} - -impl<'tcx> Relate> for abi::Abi { - fn relate>>( - _relation: &mut R, - a: abi::Abi, - b: abi::Abi, - ) -> RelateResult<'tcx, abi::Abi> { - if a == b { Ok(a) } else { Err(TypeError::AbiMismatch(ExpectedFound::new(true, a, b))) } - } -} - impl<'tcx> Relate> for ty::GenericArgsRef<'tcx> { fn relate>>( relation: &mut R, diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs index cd9ff9b60d859..9c743f06c6c78 100644 --- a/compiler/rustc_middle/src/ty/structural_impls.rs +++ b/compiler/rustc_middle/src/ty/structural_impls.rs @@ -11,6 +11,7 @@ use rustc_hir::def::Namespace; use rustc_span::source_map::Spanned; use rustc_target::abi::TyAndLayout; use rustc_type_ir::ConstKind; +use rustc_type_ir::traverse::{ImportantTypeTraversal, TypeTraversable}; use super::print::PrettyPrinter; use super::{GenericArg, GenericArgKind, Pattern, Region}; @@ -18,7 +19,7 @@ use crate::mir::interpret; use crate::ty::fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable}; use crate::ty::print::{FmtPrinter, Printer, with_no_trimmed_paths}; use crate::ty::visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor}; -use crate::ty::{self, InferConst, Lift, Term, TermKind, Ty, TyCtxt}; +use crate::ty::{self, Lift, Term, TermKind, Ty, TyCtxt}; impl fmt::Debug for ty::TraitDef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -208,55 +209,56 @@ impl<'tcx> fmt::Debug for Region<'tcx> { // For things for which the type library provides traversal implementations // for all Interners, we only need to provide a Lift implementation: TrivialLiftImpls! { - (), - bool, - usize, - u64, + (), + bool, + usize, + u64, + crate::ty::ParamConst, } // For some things about which the type library does not know, or does not // provide any traversal implementations, we need to provide a traversal // implementation (only for TyCtxt<'_> interners). TrivialTypeTraversalImpls! { - ::rustc_target::abi::FieldIdx, - ::rustc_target::abi::VariantIdx, - crate::middle::region::Scope, - ::rustc_ast::InlineAsmOptions, - ::rustc_ast::InlineAsmTemplatePiece, - ::rustc_ast::NodeId, - ::rustc_span::symbol::Symbol, - ::rustc_hir::def::Res, - ::rustc_hir::def_id::LocalDefId, - ::rustc_hir::ByRef, - ::rustc_hir::HirId, - ::rustc_hir::MatchSource, - ::rustc_target::asm::InlineAsmRegOrRegClass, - crate::mir::coverage::BlockMarkerId, - crate::mir::coverage::CounterId, - crate::mir::coverage::ExpressionId, - crate::mir::coverage::ConditionId, - crate::mir::Local, - crate::mir::Promoted, - crate::traits::Reveal, - crate::ty::adjustment::AutoBorrowMutability, - crate::ty::AdtKind, - crate::ty::BoundRegion, - // Including `BoundRegionKind` is a *bit* dubious, but direct - // references to bound region appear in `ty::Error`, and aren't - // really meant to be folded. In general, we can only fold a fully - // general `Region`. - crate::ty::BoundRegionKind, - crate::ty::AssocItem, - crate::ty::AssocKind, - crate::ty::Placeholder, - crate::ty::Placeholder, - crate::ty::Placeholder, +::rustc_target::abi::FieldIdx, +::rustc_target::abi::VariantIdx, +crate::middle::region::Scope, +::rustc_ast::InlineAsmOptions, +::rustc_ast::InlineAsmTemplatePiece, +::rustc_ast::NodeId, +::rustc_ast::ast::BindingMode, +::rustc_span::symbol::Symbol, +::rustc_hir::def::Res, +::rustc_hir::def_id::LocalDefId, +::rustc_hir::ByRef, +::rustc_hir::HirId, +::rustc_hir::RangeEnd, +::rustc_hir::MatchSource, +::rustc_target::asm::InlineAsmRegOrRegClass, +crate::mir::coverage::BlockMarkerId, +crate::mir::coverage::CounterId, +crate::mir::coverage::ExpressionId, +crate::mir::coverage::ConditionId, +crate::mir::Local, +crate::mir::Promoted, +crate::ty::adjustment::AutoBorrowMutability, +crate::ty::AdtKind, +crate::ty::BoundRegion, +// Including `BoundRegionKind` is a *bit* dubious, but direct +// references to bound region appear in `ty::Error`, and aren't +// really meant to be folded. In general, we can only fold a fully +// general `Region`. +crate::ty::BoundRegionKind, +crate::ty::AssocItem, +crate::ty::AssocKind, +crate::ty::Placeholder, +crate::ty::Placeholder, +crate::ty::Placeholder,} +TrivialTypeTraversalImpls! { crate::ty::LateParamRegion, crate::ty::adjustment::PointerCoercion, ::rustc_span::Span, ::rustc_span::symbol::Ident, - ty::BoundVar, - ty::ValTree<'tcx>, } // For some things about which the type library does not know, or does not // provide any traversal implementations, we need to provide a traversal @@ -267,7 +269,6 @@ TrivialTypeTraversalAndLiftImpls! { ::rustc_hir::Safety, ::rustc_target::spec::abi::Abi, crate::ty::ClosureKind, - crate::ty::ParamConst, crate::ty::ParamTy, crate::ty::instance::ReifyReason, interpret::AllocId, @@ -302,12 +303,6 @@ impl<'a, 'tcx> Lift> for Term<'a> { /////////////////////////////////////////////////////////////////////////// // Traversal implementations. -impl<'tcx> TypeVisitable> for ty::AdtDef<'tcx> { - fn visit_with>>(&self, _visitor: &mut V) -> V::Result { - V::Result::output() - } -} - impl<'tcx> TypeFoldable> for &'tcx ty::List> { fn try_fold_with>>( self, @@ -336,6 +331,9 @@ impl<'tcx> TypeFoldable> for Pattern<'tcx> { } } +impl<'tcx> TypeTraversable> for Pattern<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for Pattern<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { (**self).visit_with(visitor) @@ -351,6 +349,9 @@ impl<'tcx> TypeFoldable> for Ty<'tcx> { } } +impl<'tcx> TypeTraversable> for Ty<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for Ty<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_ty(*self) @@ -467,6 +468,9 @@ impl<'tcx> TypeFoldable> for ty::Region<'tcx> { } } +impl<'tcx> TypeTraversable> for ty::Region<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for ty::Region<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_region(*self) @@ -492,12 +496,18 @@ impl<'tcx> TypeFoldable> for ty::Clause<'tcx> { } } +impl<'tcx> TypeTraversable> for ty::Predicate<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for ty::Predicate<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_predicate(*self) } } +impl<'tcx> TypeTraversable> for ty::Clause<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for ty::Clause<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_predicate(self.as_predicate()) @@ -520,6 +530,9 @@ impl<'tcx> TypeSuperVisitable> for ty::Predicate<'tcx> { } } +impl<'tcx> TypeTraversable> for ty::Clauses<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for ty::Clauses<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_clauses(self) @@ -550,6 +563,9 @@ impl<'tcx> TypeFoldable> for ty::Const<'tcx> { } } +impl<'tcx> TypeTraversable> for ty::Const<'tcx> { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for ty::Const<'tcx> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_const(*self) @@ -562,15 +578,15 @@ impl<'tcx> TypeSuperFoldable> for ty::Const<'tcx> { folder: &mut F, ) -> Result { let kind = match self.kind() { - ConstKind::Param(p) => ConstKind::Param(p.try_fold_with(folder)?), - ConstKind::Infer(i) => ConstKind::Infer(i.try_fold_with(folder)?), + ConstKind::Param(p) => ConstKind::Param(p.noop_try_fold_with(folder)?), + ConstKind::Infer(i) => ConstKind::Infer(i.noop_try_fold_with(folder)?), ConstKind::Bound(d, b) => { - ConstKind::Bound(d.try_fold_with(folder)?, b.try_fold_with(folder)?) + ConstKind::Bound(d.noop_try_fold_with(folder)?, b.noop_try_fold_with(folder)?) } - ConstKind::Placeholder(p) => ConstKind::Placeholder(p.try_fold_with(folder)?), + ConstKind::Placeholder(p) => ConstKind::Placeholder(p.noop_try_fold_with(folder)?), ConstKind::Unevaluated(uv) => ConstKind::Unevaluated(uv.try_fold_with(folder)?), ConstKind::Value(t, v) => { - ConstKind::Value(t.try_fold_with(folder)?, v.try_fold_with(folder)?) + ConstKind::Value(t.try_fold_with(folder)?, v.noop_try_fold_with(folder)?) } ConstKind::Error(e) => ConstKind::Error(e.try_fold_with(folder)?), ConstKind::Expr(e) => ConstKind::Expr(e.try_fold_with(folder)?), @@ -582,17 +598,17 @@ impl<'tcx> TypeSuperFoldable> for ty::Const<'tcx> { impl<'tcx> TypeSuperVisitable> for ty::Const<'tcx> { fn super_visit_with>>(&self, visitor: &mut V) -> V::Result { match self.kind() { - ConstKind::Param(p) => p.visit_with(visitor), - ConstKind::Infer(i) => i.visit_with(visitor), + ConstKind::Param(p) => p.noop_visit_with(visitor), + ConstKind::Infer(i) => i.noop_visit_with(visitor), ConstKind::Bound(d, b) => { - try_visit!(d.visit_with(visitor)); - b.visit_with(visitor) + try_visit!(d.noop_visit_with(visitor)); + b.noop_visit_with(visitor) } - ConstKind::Placeholder(p) => p.visit_with(visitor), + ConstKind::Placeholder(p) => p.noop_visit_with(visitor), ConstKind::Unevaluated(uv) => uv.visit_with(visitor), ConstKind::Value(t, v) => { try_visit!(t.visit_with(visitor)); - v.visit_with(visitor) + v.noop_visit_with(visitor) } ConstKind::Error(e) => e.visit_with(visitor), ConstKind::Expr(e) => e.visit_with(visitor), @@ -600,6 +616,9 @@ impl<'tcx> TypeSuperVisitable> for ty::Const<'tcx> { } } +impl<'tcx> TypeTraversable> for rustc_span::ErrorGuaranteed { + type Kind = ImportantTypeTraversal; +} impl<'tcx> TypeVisitable> for rustc_span::ErrorGuaranteed { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_error(*self) @@ -615,33 +634,26 @@ impl<'tcx> TypeFoldable> for rustc_span::ErrorGuaranteed { } } -impl<'tcx> TypeFoldable> for InferConst { - fn try_fold_with>>( - self, - _folder: &mut F, - ) -> Result { - Ok(self) - } +impl<'tcx> TypeTraversable> for TyAndLayout<'tcx, Ty<'tcx>> { + type Kind = ImportantTypeTraversal; } - -impl<'tcx> TypeVisitable> for InferConst { - fn visit_with>>(&self, _visitor: &mut V) -> V::Result { - V::Result::output() - } -} - impl<'tcx> TypeVisitable> for TyAndLayout<'tcx, Ty<'tcx>> { fn visit_with>>(&self, visitor: &mut V) -> V::Result { visitor.visit_ty(self.ty) } } +impl<'tcx, T: TypeVisitable> + Debug + Clone> TypeTraversable> + for Spanned +{ + type Kind = ImportantTypeTraversal; +} impl<'tcx, T: TypeVisitable> + Debug + Clone> TypeVisitable> for Spanned { fn visit_with>>(&self, visitor: &mut V) -> V::Result { try_visit!(self.node.visit_with(visitor)); - self.span.visit_with(visitor) + self.span.noop_visit_with(visitor) } } @@ -654,7 +666,7 @@ impl<'tcx, T: TypeFoldable> + Debug + Clone> TypeFoldable Result { Ok(Spanned { node: self.node.try_fold_with(folder)?, - span: self.span.try_fold_with(folder)?, + span: self.span.noop_try_fold_with(folder)?, }) } } diff --git a/compiler/rustc_middle/src/ty/sty.rs b/compiler/rustc_middle/src/ty/sty.rs index 74de378c4d78e..16d5dc03dcb17 100644 --- a/compiler/rustc_middle/src/ty/sty.rs +++ b/compiler/rustc_middle/src/ty/sty.rs @@ -13,7 +13,9 @@ use rustc_errors::{ErrorGuaranteed, MultiSpan}; use rustc_hir as hir; use rustc_hir::LangItem; use rustc_hir::def_id::DefId; -use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeFoldable, extension}; +use rustc_macros::{ + HashStable, NoopTypeTraversable, TyDecodable, TyEncodable, TypeFoldable, extension, +}; use rustc_span::symbol::{Symbol, sym}; use rustc_span::{DUMMY_SP, Span}; use rustc_target::abi::{FIRST_VARIANT, FieldIdx, VariantIdx}; @@ -323,7 +325,7 @@ impl<'tcx> ParamTy { } #[derive(Copy, Clone, Hash, TyEncodable, TyDecodable, Eq, PartialEq, Ord, PartialOrd)] -#[derive(HashStable)] +#[derive(HashStable, NoopTypeTraversable)] pub struct ParamConst { pub index: u32, pub name: Symbol, diff --git a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs index 71ce0cce77224..72743f440149e 100644 --- a/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs +++ b/compiler/rustc_next_trait_solver/src/solve/eval_ctxt/canonical.rs @@ -15,6 +15,7 @@ use rustc_index::IndexVec; use rustc_type_ir::fold::TypeFoldable; use rustc_type_ir::inherent::*; use rustc_type_ir::relate::solver_relating::RelateExt; +use rustc_type_ir::traverse::OptTryFoldWith; use rustc_type_ir::{self as ty, Canonical, CanonicalVarValues, InferCtxtLike, Interner}; use tracing::{debug, instrument, trace}; @@ -426,7 +427,7 @@ pub(in crate::solve) fn make_canonical_state( where D: SolverDelegate, I: Interner, - T: TypeFoldable, + T: OptTryFoldWith, { let var_values = CanonicalVarValues { var_values: delegate.cx().mk_args(var_values) }; let state = inspect::State { var_values, data }; @@ -441,7 +442,7 @@ where // FIXME: needs to be pub to be accessed by downstream // `rustc_trait_selection::solve::inspect::analyse`. -pub fn instantiate_canonical_state>( +pub fn instantiate_canonical_state>( delegate: &D, span: D::Span, param_env: I::ParamEnv, diff --git a/compiler/rustc_trait_selection/src/infer.rs b/compiler/rustc_trait_selection/src/infer.rs index bacb3b1b1b861..ee7037f799454 100644 --- a/compiler/rustc_trait_selection/src/infer.rs +++ b/compiler/rustc_trait_selection/src/infer.rs @@ -1,5 +1,3 @@ -use std::fmt::Debug; - use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; pub use rustc_infer::infer::*; @@ -11,6 +9,7 @@ use rustc_middle::infer::canonical::{ use rustc_middle::traits::query::NoSolution; use rustc_middle::ty::{self, GenericArg, Ty, TyCtxt, TypeFoldable, TypeVisitableExt, Upcast}; use rustc_span::DUMMY_SP; +use rustc_type_ir::traverse::OptTryFoldWith; use tracing::instrument; use crate::infer::at::ToTrace; @@ -139,7 +138,7 @@ impl<'tcx> InferCtxtBuilder<'tcx> { ) -> Result, NoSolution> where K: TypeFoldable>, - R: Debug + TypeFoldable>, + R: OptTryFoldWith>, Canonical<'tcx, QueryResponse<'tcx, R>>: ArenaAllocatable<'tcx>, { let (infcx, key, canonical_inference_vars) = diff --git a/compiler/rustc_trait_selection/src/traits/engine.rs b/compiler/rustc_trait_selection/src/traits/engine.rs index 5e270b62b0081..5b2115cb61d1a 100644 --- a/compiler/rustc_trait_selection/src/traits/engine.rs +++ b/compiler/rustc_trait_selection/src/traits/engine.rs @@ -1,5 +1,4 @@ use std::cell::RefCell; -use std::fmt::Debug; use rustc_data_structures::fx::FxIndexSet; use rustc_errors::ErrorGuaranteed; @@ -15,6 +14,7 @@ use rustc_macros::extension; use rustc_middle::arena::ArenaAllocatable; use rustc_middle::traits::query::NoSolution; use rustc_middle::ty::error::TypeError; +use rustc_middle::ty::traverse::OptTryFoldWith; use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable, Upcast, Variance}; use rustc_type_ir::relate::Relate; @@ -259,7 +259,7 @@ impl<'tcx> ObligationCtxt<'_, 'tcx, ScrubbedTraitError<'tcx>> { answer: T, ) -> Result, NoSolution> where - T: Debug + TypeFoldable>, + T: OptTryFoldWith>, Canonical<'tcx, QueryResponse<'tcx, T>>: ArenaAllocatable<'tcx>, { self.infcx.make_canonicalized_query_response( diff --git a/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs b/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs index 18010603286d1..4c6d4fb8a161a 100644 --- a/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs +++ b/compiler/rustc_trait_selection/src/traits/query/type_op/custom.rs @@ -3,8 +3,10 @@ use std::fmt; use rustc_errors::ErrorGuaranteed; use rustc_infer::infer::region_constraints::RegionConstraintData; use rustc_middle::traits::query::NoSolution; -use rustc_middle::ty::{TyCtxt, TypeFoldable}; +use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::traverse::AlwaysTraversable; use rustc_span::Span; +use rustc_type_ir::traverse::OptTryFoldWith; use tracing::info; use crate::infer::InferCtxt; @@ -29,7 +31,7 @@ impl CustomTypeOp { impl<'tcx, F, R> super::TypeOp<'tcx> for CustomTypeOp where F: FnOnce(&ObligationCtxt<'_, 'tcx>) -> Result, - R: fmt::Debug + TypeFoldable>, + R: fmt::Debug + OptTryFoldWith>, { type Output = R; /// We can't do any custom error reporting for `CustomTypeOp`, so @@ -67,7 +69,7 @@ pub fn scrape_region_constraints<'tcx, Op, R>( span: Span, ) -> Result<(TypeOpOutput<'tcx, Op>, RegionConstraintData<'tcx>), ErrorGuaranteed> where - R: TypeFoldable>, + R: OptTryFoldWith>, Op: super::TypeOp<'tcx, Output = R>, { // During NLL, we expect that nobody will register region @@ -97,7 +99,7 @@ where })?; // Next trait solver performs operations locally, and normalize goals should resolve vars. - let value = infcx.resolve_vars_if_possible(value); + let value = infcx.resolve_vars_if_possible(AlwaysTraversable(value)).0; let region_obligations = infcx.take_registered_region_obligations(); let region_constraint_data = infcx.take_and_reset_region_constraints(); diff --git a/compiler/rustc_trait_selection/src/traits/query/type_op/mod.rs b/compiler/rustc_trait_selection/src/traits/query/type_op/mod.rs index a618d96ce9507..66d349ff51823 100644 --- a/compiler/rustc_trait_selection/src/traits/query/type_op/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/query/type_op/mod.rs @@ -6,6 +6,7 @@ use rustc_middle::traits::query::NoSolution; use rustc_middle::ty::fold::TypeFoldable; use rustc_middle::ty::{ParamEnvAnd, TyCtxt}; use rustc_span::Span; +use rustc_type_ir::traverse::OptTryFoldWith; use crate::infer::canonical::{ CanonicalQueryInput, CanonicalQueryResponse, Certainty, OriginalQueryValues, @@ -62,7 +63,7 @@ pub struct TypeOpOutput<'tcx, Op: TypeOp<'tcx>> { /// /// [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html pub trait QueryTypeOp<'tcx>: fmt::Debug + Copy + TypeFoldable> + 'tcx { - type QueryResponse: TypeFoldable>; + type QueryResponse: OptTryFoldWith>; /// Give query the option for a simple fast path that never /// actually hits the tcx cache lookup etc. Return `Some(r)` with diff --git a/compiler/rustc_type_ir/src/binder.rs b/compiler/rustc_type_ir/src/binder.rs index f20beb797500e..40fba75cc90a1 100644 --- a/compiler/rustc_type_ir/src/binder.rs +++ b/compiler/rustc_type_ir/src/binder.rs @@ -14,6 +14,7 @@ use crate::data_structures::SsoHashSet; use crate::fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable}; use crate::inherent::*; use crate::lift::Lift; +use crate::traverse::{ImportantTypeTraversal, OptVisitWith, TypeTraversable}; use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor}; use crate::{self as ty, Interner}; @@ -125,6 +126,9 @@ impl> TypeFoldable for Binder { } } +impl> TypeTraversable for Binder { + type Kind = ImportantTypeTraversal; +} impl> TypeVisitable for Binder { fn visit_with>(&self, visitor: &mut V) -> V::Result { visitor.visit_binder(self) @@ -182,14 +186,14 @@ impl Binder { Binder { value: &self.value, bound_vars: self.bound_vars } } - pub fn map_bound_ref>(&self, f: F) -> Binder + pub fn map_bound_ref>(&self, f: F) -> Binder where F: FnOnce(&T) -> U, { self.as_ref().map_bound(f) } - pub fn map_bound>(self, f: F) -> Binder + pub fn map_bound>(self, f: F) -> Binder where F: FnOnce(T) -> U, { @@ -197,7 +201,7 @@ impl Binder { let value = f(value); if cfg!(debug_assertions) { let mut validator = ValidateBoundVars::new(bound_vars); - value.visit_with(&mut validator); + OptVisitWith::mk_visit_with()(&value, &mut validator); } Binder { value, bound_vars } } diff --git a/compiler/rustc_type_ir/src/const_kind.rs b/compiler/rustc_type_ir/src/const_kind.rs index 7a8c612057fa2..29b2b67b632bb 100644 --- a/compiler/rustc_type_ir/src/const_kind.rs +++ b/compiler/rustc_type_ir/src/const_kind.rs @@ -5,7 +5,9 @@ use derive_where::derive_where; use rustc_data_structures::stable_hasher::{HashStable, StableHasher}; #[cfg(feature = "nightly")] use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable}; -use rustc_type_ir_macros::{Lift_Generic, TypeFoldable_Generic, TypeVisitable_Generic}; +use rustc_type_ir_macros::{ + Lift_Generic, NoopTypeTraversable_Generic, TypeFoldable_Generic, TypeVisitable_Generic, +}; use crate::{self as ty, DebruijnIndex, Interner}; @@ -77,6 +79,7 @@ impl UnevaluatedConst { rustc_index::newtype_index! { /// A **`const`** **v**ariable **ID**. + #[derive(NoopTypeTraversable_Generic)] #[encodable] #[orderable] #[debug_format = "?{}c"] @@ -92,6 +95,7 @@ rustc_index::newtype_index! { /// relate an effect variable with a normal one, we would ICE, which can catch bugs /// where we are not correctly using the effect var for an effect param. Fallback /// is also implemented on top of having separate effect and normal const variables. + #[derive(NoopTypeTraversable_Generic)] #[encodable] #[orderable] #[debug_format = "?{}e"] @@ -100,7 +104,7 @@ rustc_index::newtype_index! { } /// An inference variable for a const, for use in const generics. -#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)] +#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash, NoopTypeTraversable_Generic)] #[cfg_attr(feature = "nightly", derive(TyEncodable, TyDecodable))] pub enum InferConst { /// Infer the value of the const. diff --git a/compiler/rustc_type_ir/src/error.rs b/compiler/rustc_type_ir/src/error.rs index 8a6d37b7d23f6..39672977c008f 100644 --- a/compiler/rustc_type_ir/src/error.rs +++ b/compiler/rustc_type_ir/src/error.rs @@ -27,7 +27,6 @@ impl ExpectedFound { #[cfg_attr(feature = "nightly", rustc_pass_by_value)] pub enum TypeError { Mismatch, - ConstnessMismatch(ExpectedFound), PolarityMismatch(ExpectedFound), SafetyMismatch(ExpectedFound), AbiMismatch(ExpectedFound), @@ -73,9 +72,9 @@ impl TypeError { pub fn must_include_note(self) -> bool { use self::TypeError::*; match self { - CyclicTy(_) | CyclicConst(_) | SafetyMismatch(_) | ConstnessMismatch(_) - | PolarityMismatch(_) | Mismatch | AbiMismatch(_) | FixedArraySize(_) - | ArgumentSorts(..) | Sorts(_) | VariadicMismatch(_) | TargetFeatureCast(_) => false, + CyclicTy(_) | CyclicConst(_) | SafetyMismatch(_) | PolarityMismatch(_) | Mismatch + | AbiMismatch(_) | FixedArraySize(_) | ArgumentSorts(..) | Sorts(_) + | VariadicMismatch(_) | TargetFeatureCast(_) => false, Mutability | ArgumentMutability(_) diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs index f7875bb515270..90f85e1f2f20c 100644 --- a/compiler/rustc_type_ir/src/inherent.rs +++ b/compiler/rustc_type_ir/src/inherent.rs @@ -12,6 +12,7 @@ use crate::elaborate::Elaboratable; use crate::fold::{TypeFoldable, TypeSuperFoldable}; use crate::relate::Relate; use crate::solve::Reveal; +use crate::traverse::{NoopTypeTraversal, TypeTraversable}; use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable}; use crate::{self as ty, CollectAndApply, Interner, UpcastFrom}; @@ -208,14 +209,18 @@ pub trait Tys>: fn output(self) -> I::Ty; } -pub trait Abi>: Copy + Debug + Hash + Eq + Relate { +pub trait Abi>: + Copy + Hash + Eq + TypeTraversable +{ fn rust() -> Self; /// Whether this ABI is `extern "Rust"`. fn is_rust(self) -> bool; } -pub trait Safety>: Copy + Debug + Hash + Eq + Relate { +pub trait Safety>: + Copy + Hash + Eq + TypeTraversable +{ fn safe() -> Self; fn is_safe(self) -> bool; @@ -545,7 +550,9 @@ pub trait Features: Copy { fn associated_const_equality(self) -> bool; } -pub trait DefId: Copy + Debug + Hash + Eq + TypeFoldable { +pub trait DefId: + Copy + Debug + Hash + Eq + TypeTraversable +{ fn is_local(self) -> bool; fn as_local(self) -> Option; @@ -565,7 +572,9 @@ pub trait BoundExistentialPredicates: ) -> impl IntoIterator>>; } -pub trait Span: Copy + Debug + Hash + Eq + TypeFoldable { +pub trait Span: + Copy + Debug + Hash + Eq + TypeTraversable +{ fn dummy() -> Self; } diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 4184e9e313ff7..ceb50e196d28f 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -14,6 +14,7 @@ use crate::relate::Relate; use crate::solve::{ CanonicalInput, ExternalConstraintsData, PredefinedOpaquesData, QueryResult, SolverMode, }; +use crate::traverse::{NoopTypeTraversal, TypeTraversable}; use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable}; use crate::{self as ty, search_graph}; @@ -33,7 +34,12 @@ pub trait Interner: + IrPrint> { type DefId: DefId; - type LocalDefId: Copy + Debug + Hash + Eq + Into + TypeFoldable; + type LocalDefId: Copy + + Debug + + Hash + + Eq + + Into + + TypeTraversable; type Span: Span; type GenericArgs: GenericArgs; @@ -60,7 +66,7 @@ pub trait Interner: + Hash + Default + Eq - + TypeVisitable + + TypeTraversable + SliceLike; type CanonicalVars: Copy diff --git a/compiler/rustc_type_ir/src/lib.rs b/compiler/rustc_type_ir/src/lib.rs index 9e6d1f424ba49..81d4c74853792 100644 --- a/compiler/rustc_type_ir/src/lib.rs +++ b/compiler/rustc_type_ir/src/lib.rs @@ -19,14 +19,13 @@ use rustc_macros::{Decodable, Encodable, HashStable_NoContext}; // These modules are `pub` since they are not glob-imported. #[macro_use] -pub mod visit; +pub mod traverse; #[cfg(feature = "nightly")] pub mod codec; pub mod data_structures; pub mod elaborate; pub mod error; pub mod fast_reject; -pub mod fold; #[cfg_attr(feature = "nightly", rustc_diagnostic_item = "type_ir_inherent")] pub mod inherent; pub mod ir_print; @@ -76,6 +75,8 @@ pub use opaque_ty::*; pub use predicate::*; pub use predicate_kind::*; pub use region_kind::*; +use rustc_type_ir_macros::NoopTypeTraversable_Generic; +pub use traverse::{fold, visit}; pub use ty_info::*; pub use ty_kind::*; pub use upcast::*; @@ -379,6 +380,7 @@ impl Default for UniverseIndex { rustc_index::newtype_index! { #[cfg_attr(feature = "nightly", derive(HashStable_NoContext))] + #[derive(NoopTypeTraversable_Generic)] #[encodable] #[orderable] #[debug_format = "{}"] diff --git a/compiler/rustc_type_ir/src/macros.rs b/compiler/rustc_type_ir/src/macros.rs index aae5aeb5fb363..92cd9322a4252 100644 --- a/compiler/rustc_type_ir/src/macros.rs +++ b/compiler/rustc_type_ir/src/macros.rs @@ -3,32 +3,8 @@ macro_rules! TrivialTypeTraversalImpls { ($($ty:ty,)+) => { $( - impl $crate::fold::TypeFoldable for $ty { - fn try_fold_with>( - self, - _: &mut F, - ) -> ::std::result::Result { - Ok(self) - } - - #[inline] - fn fold_with>( - self, - _: &mut F, - ) -> Self { - self - } - } - - impl $crate::visit::TypeVisitable for $ty { - #[inline] - fn visit_with>( - &self, - _: &mut F) - -> F::Result - { - ::output() - } + impl $crate::traverse::TypeTraversable for $ty { + type Kind = $crate::traverse::NoopTypeTraversal; } )+ }; diff --git a/compiler/rustc_type_ir/src/relate.rs b/compiler/rustc_type_ir/src/relate.rs index a0b93064694e3..ad17911830b3d 100644 --- a/compiler/rustc_type_ir/src/relate.rs +++ b/compiler/rustc_type_ir/src/relate.rs @@ -174,12 +174,17 @@ impl Relate for ty::FnSig { ExpectedFound::new(true, a, b) })); } - let safety = relation.relate(a.safety, b.safety)?; - let abi = relation.relate(a.abi, b.abi)?; + + if a.safety != b.safety { + return Err(TypeError::SafetyMismatch(ExpectedFound::new(true, a.safety, b.safety))); + } + + if a.abi != b.abi { + return Err(TypeError::AbiMismatch(ExpectedFound::new(true, a.abi, b.abi))); + }; let a_inputs = a.inputs(); let b_inputs = b.inputs(); - if a_inputs.len() != b_inputs.len() { return Err(TypeError::ArgCount); } @@ -212,26 +217,12 @@ impl Relate for ty::FnSig { Ok(ty::FnSig { inputs_and_output: cx.mk_type_list_from_iter(inputs_and_output)?, c_variadic: a.c_variadic, - safety, - abi, + safety: a.safety, + abi: a.abi, }) } } -impl Relate for ty::BoundConstness { - fn relate>( - _relation: &mut R, - a: ty::BoundConstness, - b: ty::BoundConstness, - ) -> RelateResult { - if a != b { - Err(TypeError::ConstnessMismatch(ExpectedFound::new(true, a, b))) - } else { - Ok(a) - } - } -} - impl Relate for ty::AliasTy { fn relate>( relation: &mut R, @@ -659,29 +650,18 @@ impl> Relate for ty::Binder { } } -impl Relate for ty::PredicatePolarity { - fn relate>( - _relation: &mut R, - a: ty::PredicatePolarity, - b: ty::PredicatePolarity, - ) -> RelateResult { - if a != b { - Err(TypeError::PolarityMismatch(ExpectedFound::new(true, a, b))) - } else { - Ok(a) - } - } -} - impl Relate for ty::TraitPredicate { fn relate>( relation: &mut R, a: ty::TraitPredicate, b: ty::TraitPredicate, ) -> RelateResult> { - Ok(ty::TraitPredicate { - trait_ref: relation.relate(a.trait_ref, b.trait_ref)?, - polarity: relation.relate(a.polarity, b.polarity)?, - }) + let trait_ref = relation.relate(a.trait_ref, b.trait_ref)?; + if a.polarity != b.polarity { + return Err(TypeError::PolarityMismatch(ExpectedFound::new( + true, a.polarity, b.polarity, + ))); + } + Ok(ty::TraitPredicate { trait_ref, polarity: a.polarity }) } } diff --git a/compiler/rustc_type_ir/src/solve/mod.rs b/compiler/rustc_type_ir/src/solve/mod.rs index b3f8390bbf062..a8a4475a19e5c 100644 --- a/compiler/rustc_type_ir/src/solve/mod.rs +++ b/compiler/rustc_type_ir/src/solve/mod.rs @@ -6,13 +6,15 @@ use std::hash::Hash; use derive_where::derive_where; #[cfg(feature = "nightly")] use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable}; -use rustc_type_ir_macros::{Lift_Generic, TypeFoldable_Generic, TypeVisitable_Generic}; +use rustc_type_ir_macros::{ + Lift_Generic, NoopTypeTraversable_Generic, TypeFoldable_Generic, TypeVisitable_Generic, +}; use crate::{self as ty, Canonical, CanonicalVarValues, Interner, Upcast}; /// Depending on the stage of compilation, we want projection to be /// more or less conservative. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, NoopTypeTraversable_Generic)] #[cfg_attr(feature = "nightly", derive(TyDecodable, TyEncodable, HashStable_NoContext))] pub enum Reveal { /// At type-checking time, we refuse to project any associated diff --git a/compiler/rustc_type_ir/src/fold.rs b/compiler/rustc_type_ir/src/traverse/fold.rs similarity index 92% rename from compiler/rustc_type_ir/src/fold.rs rename to compiler/rustc_type_ir/src/traverse/fold.rs index 8209d6f5fe3b1..7c50633ccab68 100644 --- a/compiler/rustc_type_ir/src/fold.rs +++ b/compiler/rustc_type_ir/src/traverse/fold.rs @@ -51,6 +51,7 @@ use rustc_index::{Idx, IndexVec}; use thin_vec::ThinVec; use tracing::instrument; +use super::OptTryFoldWith; use crate::data_structures::Lrc; use crate::inherent::*; use crate::visit::{TypeVisitable, TypeVisitableExt as _}; @@ -234,9 +235,12 @@ where /////////////////////////////////////////////////////////////////////////// // Traversal implementations. -impl, U: TypeFoldable> TypeFoldable for (T, U) { +impl, U: OptTryFoldWith> TypeFoldable for (T, U) { fn try_fold_with>(self, folder: &mut F) -> Result<(T, U), F::Error> { - Ok((self.0.try_fold_with(folder)?, self.1.try_fold_with(folder)?)) + Ok(( + OptTryFoldWith::mk_try_fold_with()(self.0, folder)?, + OptTryFoldWith::mk_try_fold_with()(self.1, folder)?, + )) } } @@ -248,17 +252,17 @@ impl, B: TypeFoldable, C: TypeFoldable> Ty folder: &mut F, ) -> Result<(A, B, C), F::Error> { Ok(( - self.0.try_fold_with(folder)?, - self.1.try_fold_with(folder)?, - self.2.try_fold_with(folder)?, + OptTryFoldWith::mk_try_fold_with()(self.0, folder)?, + OptTryFoldWith::mk_try_fold_with()(self.1, folder)?, + OptTryFoldWith::mk_try_fold_with()(self.2, folder)?, )) } } -impl> TypeFoldable for Option { +impl> TypeFoldable for Option { fn try_fold_with>(self, folder: &mut F) -> Result { Ok(match self { - Some(v) => Some(v.try_fold_with(folder)?), + Some(v) => Some(OptTryFoldWith::mk_try_fold_with()(v, folder)?), None => None, }) } @@ -310,32 +314,32 @@ impl> TypeFoldable for Lrc { } } -impl> TypeFoldable for Box { +impl> TypeFoldable for Box { fn try_fold_with>(mut self, folder: &mut F) -> Result { - *self = (*self).try_fold_with(folder)?; + *self = OptTryFoldWith::mk_try_fold_with()(*self, folder)?; Ok(self) } } -impl> TypeFoldable for Vec { +impl> TypeFoldable for Vec { fn try_fold_with>(self, folder: &mut F) -> Result { - self.into_iter().map(|t| t.try_fold_with(folder)).collect() + self.into_iter().map(|t| OptTryFoldWith::mk_try_fold_with()(t, folder)).collect() } } -impl> TypeFoldable for ThinVec { +impl> TypeFoldable for ThinVec { fn try_fold_with>(self, folder: &mut F) -> Result { - self.into_iter().map(|t| t.try_fold_with(folder)).collect() + self.into_iter().map(|t| OptTryFoldWith::mk_try_fold_with()(t, folder)).collect() } } -impl> TypeFoldable for Box<[T]> { +impl> TypeFoldable for Box<[T]> { fn try_fold_with>(self, folder: &mut F) -> Result { Vec::from(self).try_fold_with(folder).map(Vec::into_boxed_slice) } } -impl, Ix: Idx> TypeFoldable for IndexVec { +impl, Ix: Idx> TypeFoldable for IndexVec { fn try_fold_with>(self, folder: &mut F) -> Result { self.raw.try_fold_with(folder).map(IndexVec::from_raw) } diff --git a/compiler/rustc_type_ir/src/traverse/mod.rs b/compiler/rustc_type_ir/src/traverse/mod.rs new file mode 100644 index 0000000000000..b732bc11925fb --- /dev/null +++ b/compiler/rustc_type_ir/src/traverse/mod.rs @@ -0,0 +1,133 @@ +//! A visiting traversal mechanism for complex data structures that contain type +//! information. See the documentation of the [visit] and [fold] modules for more +//! details. + +#[macro_use] +pub mod visit; +pub mod fold; + +use std::fmt; + +use fold::{FallibleTypeFolder, TypeFoldable}; +use rustc_ast_ir::visit::VisitorResult; +use rustc_type_ir_macros::{TypeFoldable_Generic, TypeVisitable_Generic}; +use visit::{TypeVisitable, TypeVisitor}; + +use crate::Interner; + +#[derive(Debug, Clone, TypeVisitable_Generic, TypeFoldable_Generic)] +pub struct AlwaysTraversable(pub T); + +/// A trait which allows the compiler to reason about the disjointness +/// of `TypeVisitable` and `NoopTypeTraversable`. +/// +/// This trait has a blanket impls for everything that implements `TypeVisitable` +/// while requiring a manual impl for all types whose traversal is a noop. +pub trait TypeTraversable: fmt::Debug + Clone { + type Kind; + + #[inline(always)] + fn noop_visit_with>(&self, _: &mut V) -> V::Result + where + Self: TypeTraversable, + { + V::Result::output() + } + + #[inline(always)] + fn noop_try_fold_with>(self, _: &mut F) -> Result + where + Self: TypeTraversable, + { + Ok(self) + } +} +pub struct ImportantTypeTraversal; +pub struct NoopTypeTraversal; + +pub trait OptVisitWith: TypeTraversable { + fn mk_visit_with>() -> fn(&Self, &mut V) -> V::Result; +} + +impl OptVisitWith for T +where + I: Interner, + T: TypeTraversable + Clone + OptVisitWithHelper, +{ + #[inline(always)] + fn mk_visit_with>() -> fn(&Self, &mut V) -> V::Result { + Self::mk_visit_with_helper() + } +} + +trait OptVisitWithHelper { + fn mk_visit_with_helper>() -> fn(&Self, &mut V) -> V::Result; +} + +impl OptVisitWithHelper for T +where + I: Interner, + T: TypeVisitable, +{ + #[inline(always)] + fn mk_visit_with_helper>() -> fn(&Self, &mut V) -> V::Result { + Self::visit_with + } +} + +/// While this is implemented for all `T`, it is only useable via `OptVisitWith` if +/// `T` implements `TypeTraversable`. +impl OptVisitWithHelper for T +where + I: Interner, +{ + #[inline(always)] + fn mk_visit_with_helper>() -> fn(&Self, &mut V) -> V::Result { + |_, _| V::Result::output() + } +} + +pub trait OptTryFoldWith: OptVisitWith + Sized { + fn mk_try_fold_with>() -> fn(Self, &mut F) -> Result; +} + +impl OptTryFoldWith for T +where + I: Interner, + T: OptVisitWith + OptTryFoldWithHelper, +{ + #[inline(always)] + fn mk_try_fold_with>() -> fn(Self, &mut F) -> Result { + Self::mk_try_fold_with_helper() + } +} + +pub trait OptTryFoldWithHelper: Sized { + fn mk_try_fold_with_helper>() + -> fn(Self, &mut F) -> Result; +} + +impl OptTryFoldWithHelper for T +where + I: Interner, + T: TypeFoldable, +{ + #[inline(always)] + fn mk_try_fold_with_helper>() + -> fn(Self, &mut F) -> Result { + Self::try_fold_with + } +} + +/// While this is implemented for all `T`, it is only useable via `OptTryFoldWith` if +/// `T` implements `TypeTraversable`. +impl OptTryFoldWithHelper for T +where + I: Interner, +{ + #[inline(always)] + fn mk_try_fold_with_helper>() + -> fn(Self, &mut F) -> Result { + |this, _| Ok(this) + } +} diff --git a/compiler/rustc_type_ir/src/visit.rs b/compiler/rustc_type_ir/src/traverse/visit.rs similarity index 84% rename from compiler/rustc_type_ir/src/visit.rs rename to compiler/rustc_type_ir/src/traverse/visit.rs index 71c3646498b9f..ebba04cfc0407 100644 --- a/compiler/rustc_type_ir/src/visit.rs +++ b/compiler/rustc_type_ir/src/traverse/visit.rs @@ -41,14 +41,14 @@ //! - u.visit_with(visitor) //! ``` -use std::fmt; use std::ops::ControlFlow; +use rustc_ast_ir::try_visit; use rustc_ast_ir::visit::VisitorResult; -use rustc_ast_ir::{try_visit, walk_visitable_list}; use rustc_index::{Idx, IndexVec}; use thin_vec::ThinVec; +use super::{ImportantTypeTraversal, OptVisitWith, TypeTraversable}; use crate::data_structures::Lrc; use crate::inherent::*; use crate::{self as ty, Interner, TypeFlags}; @@ -58,7 +58,7 @@ use crate::{self as ty, Interner, TypeFlags}; /// /// To implement this conveniently, use the derive macro located in /// `rustc_macros`. -pub trait TypeVisitable: fmt::Debug + Clone { +pub trait TypeVisitable: TypeTraversable { /// The entry point for visiting. To visit a value `t` with a visitor `v` /// call: `t.visit_with(v)`. /// @@ -131,87 +131,128 @@ pub trait TypeVisitor: Sized { /////////////////////////////////////////////////////////////////////////// // Traversal implementations. -impl, U: TypeVisitable> TypeVisitable for (T, U) { +impl, U: OptVisitWith> TypeTraversable for (T, U) { + type Kind = ImportantTypeTraversal; +} +impl, U: OptVisitWith> TypeVisitable for (T, U) { fn visit_with>(&self, visitor: &mut V) -> V::Result { - try_visit!(self.0.visit_with(visitor)); - self.1.visit_with(visitor) + try_visit!(OptVisitWith::mk_visit_with()(&self.0, visitor)); + OptVisitWith::mk_visit_with()(&self.1, visitor) } } -impl, B: TypeVisitable, C: TypeVisitable> TypeVisitable +impl, B: OptVisitWith, C: OptVisitWith> TypeTraversable + for (A, B, C) +{ + type Kind = ImportantTypeTraversal; +} +impl, B: OptVisitWith, C: OptVisitWith> TypeVisitable for (A, B, C) { fn visit_with>(&self, visitor: &mut V) -> V::Result { - try_visit!(self.0.visit_with(visitor)); - try_visit!(self.1.visit_with(visitor)); - self.2.visit_with(visitor) + try_visit!(OptVisitWith::mk_visit_with()(&self.0, visitor)); + try_visit!(OptVisitWith::mk_visit_with()(&self.1, visitor)); + OptVisitWith::mk_visit_with()(&self.2, visitor) } } -impl> TypeVisitable for Option { +impl> TypeTraversable for Option { + type Kind = ImportantTypeTraversal; +} +impl> TypeVisitable for Option { fn visit_with>(&self, visitor: &mut V) -> V::Result { match self { - Some(v) => v.visit_with(visitor), + Some(v) => OptVisitWith::mk_visit_with()(v, visitor), None => V::Result::output(), } } } -impl, E: TypeVisitable> TypeVisitable for Result { +impl, E: OptVisitWith> TypeTraversable for Result { + type Kind = ImportantTypeTraversal; +} +impl, E: OptVisitWith> TypeVisitable for Result { fn visit_with>(&self, visitor: &mut V) -> V::Result { match self { - Ok(v) => v.visit_with(visitor), - Err(e) => e.visit_with(visitor), + Ok(v) => OptVisitWith::mk_visit_with()(v, visitor), + Err(e) => OptVisitWith::mk_visit_with()(e, visitor), } } } -impl> TypeVisitable for Lrc { +impl> TypeTraversable for Lrc { + type Kind = ImportantTypeTraversal; +} +impl> TypeVisitable for Lrc { fn visit_with>(&self, visitor: &mut V) -> V::Result { - (**self).visit_with(visitor) + OptVisitWith::mk_visit_with()(&**self, visitor) } } - -impl> TypeVisitable for Box { +impl> TypeTraversable for Box { + type Kind = ImportantTypeTraversal; +} +impl> TypeVisitable for Box { fn visit_with>(&self, visitor: &mut V) -> V::Result { - (**self).visit_with(visitor) + OptVisitWith::mk_visit_with()(&**self, visitor) } } -impl> TypeVisitable for Vec { +impl> TypeTraversable for Vec { + type Kind = ImportantTypeTraversal; +} +impl> TypeVisitable for Vec { fn visit_with>(&self, visitor: &mut V) -> V::Result { - walk_visitable_list!(visitor, self.iter()); + for elem in self.iter() { + try_visit!(OptVisitWith::mk_visit_with()(elem, visitor)); + } V::Result::output() } } -impl> TypeVisitable for ThinVec { +impl> TypeTraversable for ThinVec { + type Kind = ImportantTypeTraversal; +} +impl> TypeVisitable for ThinVec { fn visit_with>(&self, visitor: &mut V) -> V::Result { - walk_visitable_list!(visitor, self.iter()); + for elem in self.iter() { + try_visit!(OptVisitWith::mk_visit_with()(elem, visitor)); + } V::Result::output() } } -// `TypeFoldable` isn't impl'd for `&[T]`. It doesn't make sense in the general -// case, because we can't return a new slice. But note that there are a couple -// of trivial impls of `TypeFoldable` for specific slice types elsewhere. +impl> TypeTraversable for &[T] { + type Kind = T::Kind; +} impl> TypeVisitable for &[T] { fn visit_with>(&self, visitor: &mut V) -> V::Result { - walk_visitable_list!(visitor, self.iter()); + for elem in self.iter() { + try_visit!(elem.visit_with(visitor)); + } V::Result::output() } } -impl> TypeVisitable for Box<[T]> { +impl> TypeTraversable for Box<[T]> { + type Kind = ImportantTypeTraversal; +} +impl> TypeVisitable for Box<[T]> { fn visit_with>(&self, visitor: &mut V) -> V::Result { - walk_visitable_list!(visitor, self.iter()); + for elem in self.iter() { + try_visit!(OptVisitWith::mk_visit_with()(elem, visitor)); + } V::Result::output() } } -impl, Ix: Idx> TypeVisitable for IndexVec { +impl, Ix: Idx> TypeTraversable for IndexVec { + type Kind = ImportantTypeTraversal; +} +impl, Ix: Idx> TypeVisitable for IndexVec { fn visit_with>(&self, visitor: &mut V) -> V::Result { - walk_visitable_list!(visitor, self.iter()); + for elem in self.iter() { + try_visit!(OptVisitWith::mk_visit_with()(elem, visitor)); + } V::Result::output() } } diff --git a/compiler/rustc_type_ir/src/ty_kind.rs b/compiler/rustc_type_ir/src/ty_kind.rs index b7f6ef4ffbb9a..3f98417e4bb8e 100644 --- a/compiler/rustc_type_ir/src/ty_kind.rs +++ b/compiler/rustc_type_ir/src/ty_kind.rs @@ -8,7 +8,9 @@ use rustc_data_structures::stable_hasher::{HashStable, StableHasher}; use rustc_data_structures::unify::{NoError, UnifyKey, UnifyValue}; #[cfg(feature = "nightly")] use rustc_macros::{Decodable, Encodable, HashStable_NoContext, TyDecodable, TyEncodable}; -use rustc_type_ir_macros::{Lift_Generic, TypeFoldable_Generic, TypeVisitable_Generic}; +use rustc_type_ir_macros::{ + Lift_Generic, NoopTypeTraversable_Generic, TypeFoldable_Generic, TypeVisitable_Generic, +}; use self::TyKind::*; pub use self::closure::*; @@ -1012,7 +1014,7 @@ impl ty::Binder> { #[derive_where(Clone, Copy, Debug, PartialEq, Eq, Hash; I: Interner)] #[cfg_attr(feature = "nightly", derive(TyEncodable, TyDecodable, HashStable_NoContext))] -#[derive(TypeVisitable_Generic, TypeFoldable_Generic, Lift_Generic)] +#[derive(NoopTypeTraversable_Generic, Lift_Generic)] pub struct FnHeader { pub c_variadic: bool, pub safety: I::Safety, diff --git a/compiler/rustc_type_ir_macros/src/lib.rs b/compiler/rustc_type_ir_macros/src/lib.rs index 1a0a2479f6f07..74bb22bbbfe36 100644 --- a/compiler/rustc_type_ir_macros/src/lib.rs +++ b/compiler/rustc_type_ir_macros/src/lib.rs @@ -3,15 +3,96 @@ use syn::parse_quote; use syn::visit_mut::VisitMut; use synstructure::decl_derive; -decl_derive!( - [TypeFoldable_Generic] => type_foldable_derive -); -decl_derive!( - [TypeVisitable_Generic] => type_visitable_derive -); -decl_derive!( - [Lift_Generic] => lift_derive -); +decl_derive!([NoopTypeTraversable_Generic] => noop_type_traversable_derive); +decl_derive!([TypeFoldable_Generic] => type_foldable_derive); +decl_derive!([TypeVisitable_Generic] => type_visitable_derive); +decl_derive!([Lift_Generic] => lift_derive); + +fn noop_type_traversable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream { + if let syn::Data::Union(_) = s.ast().data { + panic!("cannot derive on union") + } + + if !s.ast().generics.type_params().any(|ty| ty.ident == "I") { + s.add_impl_generic(parse_quote! { I }); + } + + s.add_bounds(synstructure::AddBounds::None); + let mut where_clauses = None; + s.add_trait_bounds( + &parse_quote!(::rustc_type_ir::traverse::TypeTraversable), + &mut where_clauses, + synstructure::AddBounds::Fields, + ); + s.add_where_predicate(parse_quote! { I: Interner }); + for pred in where_clauses.into_iter().flat_map(|c| c.predicates) { + s.add_where_predicate(pred); + } + + s.bound_impl(quote!(::rustc_type_ir::traverse::TypeTraversable), quote! { + type Kind = ::rustc_type_ir::traverse::NoopTypeTraversal; + }) +} + +fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream { + if let syn::Data::Union(_) = s.ast().data { + panic!("cannot derive on union") + } + + if !s.ast().generics.type_params().any(|ty| ty.ident == "I") { + s.add_impl_generic(parse_quote! { I }); + } + + s.add_bounds(synstructure::AddBounds::None); + let mut where_clauses = None; + s.add_trait_bounds( + &parse_quote!(::rustc_type_ir::traverse::OptVisitWith), + &mut where_clauses, + synstructure::AddBounds::Fields, + ); + s.add_where_predicate(parse_quote! { I: Interner }); + for pred in where_clauses.into_iter().flat_map(|c| c.predicates) { + s.add_where_predicate(pred); + } + + let impl_traversable_s = s.clone(); + + let body_visit = s.each(|bind| { + quote! { + match ::rustc_ast_ir::visit::VisitorResult::branch( + ::rustc_type_ir::traverse::OptVisitWith::mk_visit_with()(#bind, __visitor) + ) { + ::core::ops::ControlFlow::Continue(()) => {}, + ::core::ops::ControlFlow::Break(r) => { + return ::rustc_ast_ir::visit::VisitorResult::from_residual(r); + }, + } + } + }); + s.bind_with(|_| synstructure::BindStyle::Move); + + let visitable_impl = s.bound_impl(quote!(::rustc_type_ir::visit::TypeVisitable), quote! { + fn visit_with<__V: ::rustc_type_ir::visit::TypeVisitor>( + &self, + __visitor: &mut __V + ) -> __V::Result { + match *self { #body_visit } + <__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output() + } + }); + + let traversable_impl = impl_traversable_s.bound_impl( + quote!(::rustc_type_ir::traverse::TypeTraversable), + quote! { + type Kind = ::rustc_type_ir::traverse::ImportantTypeTraversal; + }, + ); + + quote! { + #visitable_impl + #traversable_impl + } +} fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream { if let syn::Data::Union(_) = s.ast().data { @@ -22,15 +103,25 @@ fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::Toke s.add_impl_generic(parse_quote! { I }); } + s.add_bounds(synstructure::AddBounds::None); + let mut where_clauses = None; + s.add_trait_bounds( + &parse_quote!(::rustc_type_ir::traverse::OptTryFoldWith), + &mut where_clauses, + synstructure::AddBounds::Fields, + ); s.add_where_predicate(parse_quote! { I: Interner }); - s.add_bounds(synstructure::AddBounds::Fields); + for pred in where_clauses.into_iter().flat_map(|c| c.predicates) { + s.add_where_predicate(pred); + } + s.bind_with(|_| synstructure::BindStyle::Move); let body_fold = s.each_variant(|vi| { let bindings = vi.bindings(); vi.construct(|_, index| { let bind = &bindings[index]; quote! { - ::rustc_type_ir::fold::TypeFoldable::try_fold_with(#bind, __folder)? + ::rustc_type_ir::traverse::OptTryFoldWith::mk_try_fold_with()(#bind, __folder)? } }) }); @@ -113,39 +204,3 @@ fn lift(mut ty: syn::Type) -> syn::Type { ty } - -fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream { - if let syn::Data::Union(_) = s.ast().data { - panic!("cannot derive on union") - } - - if !s.ast().generics.type_params().any(|ty| ty.ident == "I") { - s.add_impl_generic(parse_quote! { I }); - } - - s.add_where_predicate(parse_quote! { I: Interner }); - s.add_bounds(synstructure::AddBounds::Fields); - let body_visit = s.each(|bind| { - quote! { - match ::rustc_ast_ir::visit::VisitorResult::branch( - ::rustc_type_ir::visit::TypeVisitable::visit_with(#bind, __visitor) - ) { - ::core::ops::ControlFlow::Continue(()) => {}, - ::core::ops::ControlFlow::Break(r) => { - return ::rustc_ast_ir::visit::VisitorResult::from_residual(r); - }, - } - } - }); - s.bind_with(|_| synstructure::BindStyle::Move); - - s.bound_impl(quote!(::rustc_type_ir::visit::TypeVisitable), quote! { - fn visit_with<__V: ::rustc_type_ir::visit::TypeVisitor>( - &self, - __visitor: &mut __V - ) -> __V::Result { - match *self { #body_visit } - <__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output() - } - }) -}