Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'improve' type traversal #132046

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions compiler/rustc_infer/src/infer/canonical/query_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ use rustc_index::{Idx, IndexVec};
use rustc_middle::arena::ArenaAllocatable;
use rustc_middle::mir::ConstraintCategory;
use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::traverse::AlwaysTraversable;
use rustc_middle::ty::{self, BoundVar, GenericArg, GenericArgKind, Ty, TyCtxt};
use rustc_middle::{bug, span_bug};
use rustc_type_ir::traverse::OptTryFoldWith;
use tracing::{debug, instrument};

use crate::infer::canonical::instantiate::{CanonicalExt, instantiate_value};
Expand Down Expand Up @@ -60,7 +62,7 @@ impl<'tcx> InferCtxt<'tcx> {
fulfill_cx: &mut dyn TraitEngine<'tcx, ScrubbedTraitError<'tcx>>,
) -> Result<CanonicalQueryResponse<'tcx, T>, NoSolution>
where
T: Debug + TypeFoldable<TyCtxt<'tcx>>,
T: OptTryFoldWith<TyCtxt<'tcx>>,
Canonical<'tcx, QueryResponse<'tcx, T>>: ArenaAllocatable<'tcx>,
{
let query_response = self.make_query_response(inference_vars, answer, fulfill_cx)?;
Expand Down Expand Up @@ -107,7 +109,7 @@ impl<'tcx> InferCtxt<'tcx> {
fulfill_cx: &mut dyn TraitEngine<'tcx, ScrubbedTraitError<'tcx>>,
) -> Result<QueryResponse<'tcx, T>, NoSolution>
where
T: Debug + TypeFoldable<TyCtxt<'tcx>>,
T: OptTryFoldWith<TyCtxt<'tcx>>,
{
let tcx = self.tcx;

Expand Down Expand Up @@ -243,7 +245,7 @@ impl<'tcx> InferCtxt<'tcx> {
output_query_region_constraints: &mut QueryRegionConstraints<'tcx>,
) -> InferResult<'tcx, R>
where
R: Debug + TypeFoldable<TyCtxt<'tcx>>,
R: OptTryFoldWith<TyCtxt<'tcx>>,
{
let InferOk { value: result_args, mut obligations } = self
.query_response_instantiation_guess(
Expand Down Expand Up @@ -326,8 +328,11 @@ impl<'tcx> InferCtxt<'tcx> {
.map(|p_c| instantiate_value(self.tcx, &result_args, p_c.clone())),
);

let user_result: R =
query_response.instantiate_projected(self.tcx, &result_args, |q_r| q_r.value.clone());
let user_result: R = query_response
.instantiate_projected(self.tcx, &result_args, |q_r| {
AlwaysTraversable(q_r.value.clone())
})
.0;

Ok(InferOk { value: user_result, obligations })
}
Expand Down Expand Up @@ -396,7 +401,7 @@ impl<'tcx> InferCtxt<'tcx> {
query_response: &Canonical<'tcx, QueryResponse<'tcx, R>>,
) -> InferResult<'tcx, CanonicalVarValues<'tcx>>
where
R: Debug + TypeFoldable<TyCtxt<'tcx>>,
R: OptTryFoldWith<TyCtxt<'tcx>>,
{
// For each new universe created in the query result that did
// not appear in the original query, create a local
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_infer/src/traits/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rustc_ast_ir::try_visit;
use rustc_middle::ty::fold::{FallibleTypeFolder, TypeFoldable};
use rustc_middle::ty::visit::{TypeVisitable, TypeVisitor};
use rustc_middle::ty::{self, TyCtxt};
use rustc_type_ir::traverse::{ImportantTypeTraversal, TypeTraversable};

use crate::traits;
use crate::traits::project::Normalized;
Expand Down Expand Up @@ -55,6 +56,11 @@ impl<'tcx, O: TypeFoldable<TyCtxt<'tcx>>> TypeFoldable<TyCtxt<'tcx>>
}
}

impl<'tcx, O: TypeVisitable<TyCtxt<'tcx>>> TypeTraversable<TyCtxt<'tcx>>
for traits::Obligation<'tcx, O>
{
type Kind = ImportantTypeTraversal;
}
impl<'tcx, O: TypeVisitable<TyCtxt<'tcx>>> TypeVisitable<TyCtxt<'tcx>>
for traits::Obligation<'tcx, O>
{
Expand Down
25 changes: 4 additions & 21 deletions compiler/rustc_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod diagnostics;
mod extension;
mod hash_stable;
mod lift;
mod noop_type_traversable;
mod query;
mod serialize;
mod symbols;
Expand Down Expand Up @@ -81,27 +82,9 @@ decl_derive!([TyDecodable] => serialize::type_decodable_derive);
decl_derive!([TyEncodable] => serialize::type_encodable_derive);
decl_derive!([MetadataDecodable] => serialize::meta_decodable_derive);
decl_derive!([MetadataEncodable] => serialize::meta_encodable_derive);
decl_derive!(
[TypeFoldable, attributes(type_foldable)] =>
/// Derives `TypeFoldable` for the annotated `struct` or `enum` (`union` is not supported).
///
/// The fold will produce a value of the same struct or enum variant as the input, with
/// each field respectively folded using the `TypeFoldable` implementation for its type.
/// However, if a field of a struct or an enum variant is annotated with
/// `#[type_foldable(identity)]` then that field will retain its incumbent value (and its
/// type is not required to implement `TypeFoldable`).
type_foldable::type_foldable_derive
);
decl_derive!(
[TypeVisitable, attributes(type_visitable)] =>
/// Derives `TypeVisitable` for the annotated `struct` or `enum` (`union` is not supported).
///
/// Each field of the struct or enum variant will be visited in definition order, using the
/// `TypeVisitable` implementation for its type. However, if a field of a struct or an enum
/// variant is annotated with `#[type_visitable(ignore)]` then that field will not be
/// visited (and its type is not required to implement `TypeVisitable`).
type_visitable::type_visitable_derive
);
decl_derive!([NoopTypeTraversable] => noop_type_traversable::noop_type_traversable_derive);
decl_derive!([TypeVisitable] => type_visitable::type_visitable_derive);
decl_derive!([TypeFoldable] => type_foldable::type_foldable_derive);
decl_derive!([Lift, attributes(lift)] => lift::lift_derive);
decl_derive!(
[Diagnostic, attributes(
Expand Down
39 changes: 39 additions & 0 deletions compiler/rustc_macros/src/noop_type_traversable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use quote::quote;
use syn::parse_quote;

pub(super) fn noop_type_traversable_derive(
mut s: synstructure::Structure<'_>,
) -> proc_macro2::TokenStream {
if let syn::Data::Union(_) = s.ast().data {
panic!("cannot derive on union")
}

s.underscore_const(true);

if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
s.add_impl_generic(parse_quote! { 'tcx });
}

s.add_bounds(synstructure::AddBounds::None);
let mut where_clauses = None;
s.add_trait_bounds(
&parse_quote!(
::rustc_middle::ty::traverse::TypeTraversable<
::rustc_middle::ty::TyCtxt<'tcx>,
Kind = ::rustc_middle::ty::traverse::NoopTypeTraversal,
>
),
&mut where_clauses,
synstructure::AddBounds::Fields,
);
for pred in where_clauses.into_iter().flat_map(|c| c.predicates) {
s.add_where_predicate(pred);
}

s.bound_impl(
quote!(::rustc_middle::ty::traverse::TypeTraversable<::rustc_middle::ty::TyCtxt<'tcx>>),
quote! {
type Kind = ::rustc_middle::ty::traverse::NoopTypeTraversal;
},
)
}
39 changes: 15 additions & 24 deletions compiler/rustc_macros/src/type_foldable.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use quote::{ToTokens, quote};
use quote::quote;
use syn::parse_quote;

pub(super) fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
Expand All @@ -12,34 +12,25 @@ pub(super) fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_m
s.add_impl_generic(parse_quote! { 'tcx });
}

s.add_bounds(synstructure::AddBounds::Generics);
s.add_bounds(synstructure::AddBounds::None);
let mut where_clauses = None;
s.add_trait_bounds(
&parse_quote!(::rustc_type_ir::traverse::OptTryFoldWith<::rustc_middle::ty::TyCtxt<'tcx>>),
&mut where_clauses,
synstructure::AddBounds::Fields,
);
s.add_where_predicate(parse_quote! { Self: std::fmt::Debug + Clone });
for pred in where_clauses.into_iter().flat_map(|c| c.predicates) {
s.add_where_predicate(pred);
}

s.bind_with(|_| synstructure::BindStyle::Move);
let body_fold = s.each_variant(|vi| {
let bindings = vi.bindings();
vi.construct(|_, index| {
let bind = &bindings[index];

let mut fixed = false;

// retain value of fields with #[type_foldable(identity)]
bind.ast().attrs.iter().for_each(|x| {
if !x.path().is_ident("type_foldable") {
return;
}
let _ = x.parse_nested_meta(|nested| {
if nested.path.is_ident("identity") {
fixed = true;
}
Ok(())
});
});

if fixed {
bind.to_token_stream()
} else {
quote! {
::rustc_middle::ty::fold::TypeFoldable::try_fold_with(#bind, __folder)?
}
quote! {
::rustc_middle::ty::traverse::OptTryFoldWith::mk_try_fold_with()(#bind, __folder)?
}
})
});
Expand Down
54 changes: 31 additions & 23 deletions compiler/rustc_macros/src/type_visitable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,30 @@ pub(super) fn type_visitable_derive(

s.underscore_const(true);

// ignore fields with #[type_visitable(ignore)]
s.filter(|bi| {
let mut ignored = false;

bi.ast().attrs.iter().for_each(|attr| {
if !attr.path().is_ident("type_visitable") {
return;
}
let _ = attr.parse_nested_meta(|nested| {
if nested.path.is_ident("ignore") {
ignored = true;
}
Ok(())
});
});

!ignored
});

if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
s.add_impl_generic(parse_quote! { 'tcx });
}

s.add_bounds(synstructure::AddBounds::Generics);
s.add_bounds(synstructure::AddBounds::None);
let mut where_clauses = None;
s.add_trait_bounds(
&parse_quote!(
::rustc_middle::ty::traverse::OptVisitWith::<::rustc_middle::ty::TyCtxt<'tcx>>
),
&mut where_clauses,
synstructure::AddBounds::Fields,
);
s.add_where_predicate(parse_quote! { Self: std::fmt::Debug + Clone });
for pred in where_clauses.into_iter().flat_map(|c| c.predicates) {
s.add_where_predicate(pred);
}

let impl_traversable_s = s.clone();

let body_visit = s.each(|bind| {
quote! {
match ::rustc_ast_ir::visit::VisitorResult::branch(
::rustc_middle::ty::visit::TypeVisitable::visit_with(#bind, __visitor)
::rustc_middle::ty::traverse::OptVisitWith::mk_visit_with()(#bind, __visitor)
) {
::core::ops::ControlFlow::Continue(()) => {},
::core::ops::ControlFlow::Break(r) => {
Expand All @@ -48,7 +44,7 @@ pub(super) fn type_visitable_derive(
});
s.bind_with(|_| synstructure::BindStyle::Move);

s.bound_impl(
let visitable_impl = s.bound_impl(
quote!(::rustc_middle::ty::visit::TypeVisitable<::rustc_middle::ty::TyCtxt<'tcx>>),
quote! {
fn visit_with<__V: ::rustc_middle::ty::visit::TypeVisitor<::rustc_middle::ty::TyCtxt<'tcx>>>(
Expand All @@ -59,5 +55,17 @@ pub(super) fn type_visitable_derive(
<__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output()
}
},
)
);

let traversable_impl = impl_traversable_s.bound_impl(
quote!(::rustc_middle::ty::traverse::TypeTraversable<::rustc_middle::ty::TyCtxt<'tcx>>),
quote! {
type Kind = ::rustc_middle::ty::traverse::ImportantTypeTraversal;
},
);

quote! {
#visitable_impl
#traversable_impl
}
}
8 changes: 5 additions & 3 deletions compiler/rustc_middle/src/hir/place.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use rustc_hir::HirId;
use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};
use rustc_macros::{
HashStable, NoopTypeTraversable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable,
};
use rustc_target::abi::{FieldIdx, VariantIdx};

use crate::ty;
use crate::ty::Ty;

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable)]
#[derive(TypeFoldable, TypeVisitable)]
#[derive(NoopTypeTraversable)]
pub enum PlaceBase {
/// A temporary variable.
Rvalue,
Expand All @@ -19,7 +21,7 @@ pub enum PlaceBase {
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable)]
#[derive(TypeFoldable, TypeVisitable)]
#[derive(NoopTypeTraversable)]
pub enum ProjectionKind {
/// A dereference of a pointer, reference or `Box<T>` of the given type.
Deref,
Expand Down
28 changes: 2 additions & 26 deletions compiler/rustc_middle/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,32 +73,8 @@ macro_rules! TrivialLiftImpls {
macro_rules! TrivialTypeTraversalImpls {
($($ty:ty),+ $(,)?) => {
$(
impl<'tcx> $crate::ty::fold::TypeFoldable<$crate::ty::TyCtxt<'tcx>> for $ty {
fn try_fold_with<F: $crate::ty::fold::FallibleTypeFolder<$crate::ty::TyCtxt<'tcx>>>(
self,
_: &mut F,
) -> ::std::result::Result<Self, F::Error> {
Ok(self)
}

#[inline]
fn fold_with<F: $crate::ty::fold::TypeFolder<$crate::ty::TyCtxt<'tcx>>>(
self,
_: &mut F,
) -> Self {
self
}
}

impl<'tcx> $crate::ty::visit::TypeVisitable<$crate::ty::TyCtxt<'tcx>> for $ty {
#[inline]
fn visit_with<F: $crate::ty::visit::TypeVisitor<$crate::ty::TyCtxt<'tcx>>>(
&self,
_: &mut F)
-> F::Result
{
<F::Result as ::rustc_ast_ir::visit::VisitorResult>::output()
}
impl<'tcx> $crate::ty::traverse::TypeTraversable<$crate::ty::TyCtxt<'tcx>> for $ty {
type Kind = $crate::ty::traverse::NoopTypeTraversal;
}
)+
};
Expand Down
2 changes: 0 additions & 2 deletions compiler/rustc_middle/src/mir/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ pub struct CoroutineLayout<'tcx> {
/// Which saved locals are storage-live at the same time. Locals that do not
/// have conflicts with each other are allowed to overlap in the computed
/// layout.
#[type_foldable(identity)]
#[type_visitable(ignore)]
pub storage_conflicts: BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
}

Expand Down
Loading
Loading