Skip to content

Commit

Permalink
Auto merge of rust-lang#132046 - lcnr:trivial-type-visitable, r=<try>
Browse files Browse the repository at this point in the history
'improve' type traversal

questionable :3

r? `@ghost`
  • Loading branch information
bors committed Oct 23, 2024
2 parents 86d69c7 + 7bfb0ec commit f592dd1
Show file tree
Hide file tree
Showing 39 changed files with 647 additions and 438 deletions.
17 changes: 11 additions & 6 deletions compiler/rustc_infer/src/infer/canonical/query_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ use rustc_index::{Idx, IndexVec};
use rustc_middle::arena::ArenaAllocatable;
use rustc_middle::mir::ConstraintCategory;
use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::traverse::AlwaysTraversable;
use rustc_middle::ty::{self, BoundVar, GenericArg, GenericArgKind, Ty, TyCtxt};
use rustc_middle::{bug, span_bug};
use rustc_type_ir::traverse::OptTryFoldWith;
use tracing::{debug, instrument};

use crate::infer::canonical::instantiate::{CanonicalExt, instantiate_value};
Expand Down Expand Up @@ -60,7 +62,7 @@ impl<'tcx> InferCtxt<'tcx> {
fulfill_cx: &mut dyn TraitEngine<'tcx, ScrubbedTraitError<'tcx>>,
) -> Result<CanonicalQueryResponse<'tcx, T>, NoSolution>
where
T: Debug + TypeFoldable<TyCtxt<'tcx>>,
T: OptTryFoldWith<TyCtxt<'tcx>>,
Canonical<'tcx, QueryResponse<'tcx, T>>: ArenaAllocatable<'tcx>,
{
let query_response = self.make_query_response(inference_vars, answer, fulfill_cx)?;
Expand Down Expand Up @@ -107,7 +109,7 @@ impl<'tcx> InferCtxt<'tcx> {
fulfill_cx: &mut dyn TraitEngine<'tcx, ScrubbedTraitError<'tcx>>,
) -> Result<QueryResponse<'tcx, T>, NoSolution>
where
T: Debug + TypeFoldable<TyCtxt<'tcx>>,
T: OptTryFoldWith<TyCtxt<'tcx>>,
{
let tcx = self.tcx;

Expand Down Expand Up @@ -243,7 +245,7 @@ impl<'tcx> InferCtxt<'tcx> {
output_query_region_constraints: &mut QueryRegionConstraints<'tcx>,
) -> InferResult<'tcx, R>
where
R: Debug + TypeFoldable<TyCtxt<'tcx>>,
R: OptTryFoldWith<TyCtxt<'tcx>>,
{
let InferOk { value: result_args, mut obligations } = self
.query_response_instantiation_guess(
Expand Down Expand Up @@ -326,8 +328,11 @@ impl<'tcx> InferCtxt<'tcx> {
.map(|p_c| instantiate_value(self.tcx, &result_args, p_c.clone())),
);

let user_result: R =
query_response.instantiate_projected(self.tcx, &result_args, |q_r| q_r.value.clone());
let user_result: R = query_response
.instantiate_projected(self.tcx, &result_args, |q_r| {
AlwaysTraversable(q_r.value.clone())
})
.0;

Ok(InferOk { value: user_result, obligations })
}
Expand Down Expand Up @@ -396,7 +401,7 @@ impl<'tcx> InferCtxt<'tcx> {
query_response: &Canonical<'tcx, QueryResponse<'tcx, R>>,
) -> InferResult<'tcx, CanonicalVarValues<'tcx>>
where
R: Debug + TypeFoldable<TyCtxt<'tcx>>,
R: OptTryFoldWith<TyCtxt<'tcx>>,
{
// For each new universe created in the query result that did
// not appear in the original query, create a local
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_infer/src/traits/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rustc_ast_ir::try_visit;
use rustc_middle::ty::fold::{FallibleTypeFolder, TypeFoldable};
use rustc_middle::ty::visit::{TypeVisitable, TypeVisitor};
use rustc_middle::ty::{self, TyCtxt};
use rustc_type_ir::traverse::{ImportantTypeTraversal, TypeTraversable};

use crate::traits;
use crate::traits::project::Normalized;
Expand Down Expand Up @@ -55,6 +56,11 @@ impl<'tcx, O: TypeFoldable<TyCtxt<'tcx>>> TypeFoldable<TyCtxt<'tcx>>
}
}

impl<'tcx, O: TypeVisitable<TyCtxt<'tcx>>> TypeTraversable<TyCtxt<'tcx>>
for traits::Obligation<'tcx, O>
{
type Kind = ImportantTypeTraversal;
}
impl<'tcx, O: TypeVisitable<TyCtxt<'tcx>>> TypeVisitable<TyCtxt<'tcx>>
for traits::Obligation<'tcx, O>
{
Expand Down
25 changes: 4 additions & 21 deletions compiler/rustc_macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod diagnostics;
mod extension;
mod hash_stable;
mod lift;
mod noop_type_traversable;
mod query;
mod serialize;
mod symbols;
Expand Down Expand Up @@ -81,27 +82,9 @@ decl_derive!([TyDecodable] => serialize::type_decodable_derive);
decl_derive!([TyEncodable] => serialize::type_encodable_derive);
decl_derive!([MetadataDecodable] => serialize::meta_decodable_derive);
decl_derive!([MetadataEncodable] => serialize::meta_encodable_derive);
decl_derive!(
[TypeFoldable, attributes(type_foldable)] =>
/// Derives `TypeFoldable` for the annotated `struct` or `enum` (`union` is not supported).
///
/// The fold will produce a value of the same struct or enum variant as the input, with
/// each field respectively folded using the `TypeFoldable` implementation for its type.
/// However, if a field of a struct or an enum variant is annotated with
/// `#[type_foldable(identity)]` then that field will retain its incumbent value (and its
/// type is not required to implement `TypeFoldable`).
type_foldable::type_foldable_derive
);
decl_derive!(
[TypeVisitable, attributes(type_visitable)] =>
/// Derives `TypeVisitable` for the annotated `struct` or `enum` (`union` is not supported).
///
/// Each field of the struct or enum variant will be visited in definition order, using the
/// `TypeVisitable` implementation for its type. However, if a field of a struct or an enum
/// variant is annotated with `#[type_visitable(ignore)]` then that field will not be
/// visited (and its type is not required to implement `TypeVisitable`).
type_visitable::type_visitable_derive
);
decl_derive!([NoopTypeTraversable] => noop_type_traversable::noop_type_traversable_derive);
decl_derive!([TypeVisitable] => type_visitable::type_visitable_derive);
decl_derive!([TypeFoldable] => type_foldable::type_foldable_derive);
decl_derive!([Lift, attributes(lift)] => lift::lift_derive);
decl_derive!(
[Diagnostic, attributes(
Expand Down
39 changes: 39 additions & 0 deletions compiler/rustc_macros/src/noop_type_traversable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use quote::quote;
use syn::parse_quote;

pub(super) fn noop_type_traversable_derive(
mut s: synstructure::Structure<'_>,
) -> proc_macro2::TokenStream {
if let syn::Data::Union(_) = s.ast().data {
panic!("cannot derive on union")
}

s.underscore_const(true);

if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
s.add_impl_generic(parse_quote! { 'tcx });
}

s.add_bounds(synstructure::AddBounds::None);
let mut where_clauses = None;
s.add_trait_bounds(
&parse_quote!(
::rustc_middle::ty::traverse::TypeTraversable<
::rustc_middle::ty::TyCtxt<'tcx>,
Kind = ::rustc_middle::ty::traverse::NoopTypeTraversal,
>
),
&mut where_clauses,
synstructure::AddBounds::Fields,
);
for pred in where_clauses.into_iter().flat_map(|c| c.predicates) {
s.add_where_predicate(pred);
}

s.bound_impl(
quote!(::rustc_middle::ty::traverse::TypeTraversable<::rustc_middle::ty::TyCtxt<'tcx>>),
quote! {
type Kind = ::rustc_middle::ty::traverse::NoopTypeTraversal;
},
)
}
39 changes: 15 additions & 24 deletions compiler/rustc_macros/src/type_foldable.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use quote::{ToTokens, quote};
use quote::quote;
use syn::parse_quote;

pub(super) fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream {
Expand All @@ -12,34 +12,25 @@ pub(super) fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_m
s.add_impl_generic(parse_quote! { 'tcx });
}

s.add_bounds(synstructure::AddBounds::Generics);
s.add_bounds(synstructure::AddBounds::None);
let mut where_clauses = None;
s.add_trait_bounds(
&parse_quote!(::rustc_type_ir::traverse::OptTryFoldWith<::rustc_middle::ty::TyCtxt<'tcx>>),
&mut where_clauses,
synstructure::AddBounds::Fields,
);
s.add_where_predicate(parse_quote! { Self: std::fmt::Debug + Clone });
for pred in where_clauses.into_iter().flat_map(|c| c.predicates) {
s.add_where_predicate(pred);
}

s.bind_with(|_| synstructure::BindStyle::Move);
let body_fold = s.each_variant(|vi| {
let bindings = vi.bindings();
vi.construct(|_, index| {
let bind = &bindings[index];

let mut fixed = false;

// retain value of fields with #[type_foldable(identity)]
bind.ast().attrs.iter().for_each(|x| {
if !x.path().is_ident("type_foldable") {
return;
}
let _ = x.parse_nested_meta(|nested| {
if nested.path.is_ident("identity") {
fixed = true;
}
Ok(())
});
});

if fixed {
bind.to_token_stream()
} else {
quote! {
::rustc_middle::ty::fold::TypeFoldable::try_fold_with(#bind, __folder)?
}
quote! {
::rustc_middle::ty::traverse::OptTryFoldWith::mk_try_fold_with()(#bind, __folder)?
}
})
});
Expand Down
54 changes: 31 additions & 23 deletions compiler/rustc_macros/src/type_visitable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,30 @@ pub(super) fn type_visitable_derive(

s.underscore_const(true);

// ignore fields with #[type_visitable(ignore)]
s.filter(|bi| {
let mut ignored = false;

bi.ast().attrs.iter().for_each(|attr| {
if !attr.path().is_ident("type_visitable") {
return;
}
let _ = attr.parse_nested_meta(|nested| {
if nested.path.is_ident("ignore") {
ignored = true;
}
Ok(())
});
});

!ignored
});

if !s.ast().generics.lifetimes().any(|lt| lt.lifetime.ident == "tcx") {
s.add_impl_generic(parse_quote! { 'tcx });
}

s.add_bounds(synstructure::AddBounds::Generics);
s.add_bounds(synstructure::AddBounds::None);
let mut where_clauses = None;
s.add_trait_bounds(
&parse_quote!(
::rustc_middle::ty::traverse::OptVisitWith::<::rustc_middle::ty::TyCtxt<'tcx>>
),
&mut where_clauses,
synstructure::AddBounds::Fields,
);
s.add_where_predicate(parse_quote! { Self: std::fmt::Debug + Clone });
for pred in where_clauses.into_iter().flat_map(|c| c.predicates) {
s.add_where_predicate(pred);
}

let impl_traversable_s = s.clone();

let body_visit = s.each(|bind| {
quote! {
match ::rustc_ast_ir::visit::VisitorResult::branch(
::rustc_middle::ty::visit::TypeVisitable::visit_with(#bind, __visitor)
::rustc_middle::ty::traverse::OptVisitWith::mk_visit_with()(#bind, __visitor)
) {
::core::ops::ControlFlow::Continue(()) => {},
::core::ops::ControlFlow::Break(r) => {
Expand All @@ -48,7 +44,7 @@ pub(super) fn type_visitable_derive(
});
s.bind_with(|_| synstructure::BindStyle::Move);

s.bound_impl(
let visitable_impl = s.bound_impl(
quote!(::rustc_middle::ty::visit::TypeVisitable<::rustc_middle::ty::TyCtxt<'tcx>>),
quote! {
fn visit_with<__V: ::rustc_middle::ty::visit::TypeVisitor<::rustc_middle::ty::TyCtxt<'tcx>>>(
Expand All @@ -59,5 +55,17 @@ pub(super) fn type_visitable_derive(
<__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output()
}
},
)
);

let traversable_impl = impl_traversable_s.bound_impl(
quote!(::rustc_middle::ty::traverse::TypeTraversable<::rustc_middle::ty::TyCtxt<'tcx>>),
quote! {
type Kind = ::rustc_middle::ty::traverse::ImportantTypeTraversal;
},
);

quote! {
#visitable_impl
#traversable_impl
}
}
8 changes: 5 additions & 3 deletions compiler/rustc_middle/src/hir/place.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use rustc_hir::HirId;
use rustc_macros::{HashStable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable};
use rustc_macros::{
HashStable, NoopTypeTraversable, TyDecodable, TyEncodable, TypeFoldable, TypeVisitable,
};
use rustc_target::abi::{FieldIdx, VariantIdx};

use crate::ty;
use crate::ty::Ty;

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable)]
#[derive(TypeFoldable, TypeVisitable)]
#[derive(NoopTypeTraversable)]
pub enum PlaceBase {
/// A temporary variable.
Rvalue,
Expand All @@ -19,7 +21,7 @@ pub enum PlaceBase {
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable)]
#[derive(TypeFoldable, TypeVisitable)]
#[derive(NoopTypeTraversable)]
pub enum ProjectionKind {
/// A dereference of a pointer, reference or `Box<T>` of the given type.
Deref,
Expand Down
28 changes: 2 additions & 26 deletions compiler/rustc_middle/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,32 +73,8 @@ macro_rules! TrivialLiftImpls {
macro_rules! TrivialTypeTraversalImpls {
($($ty:ty),+ $(,)?) => {
$(
impl<'tcx> $crate::ty::fold::TypeFoldable<$crate::ty::TyCtxt<'tcx>> for $ty {
fn try_fold_with<F: $crate::ty::fold::FallibleTypeFolder<$crate::ty::TyCtxt<'tcx>>>(
self,
_: &mut F,
) -> ::std::result::Result<Self, F::Error> {
Ok(self)
}

#[inline]
fn fold_with<F: $crate::ty::fold::TypeFolder<$crate::ty::TyCtxt<'tcx>>>(
self,
_: &mut F,
) -> Self {
self
}
}

impl<'tcx> $crate::ty::visit::TypeVisitable<$crate::ty::TyCtxt<'tcx>> for $ty {
#[inline]
fn visit_with<F: $crate::ty::visit::TypeVisitor<$crate::ty::TyCtxt<'tcx>>>(
&self,
_: &mut F)
-> F::Result
{
<F::Result as ::rustc_ast_ir::visit::VisitorResult>::output()
}
impl<'tcx> $crate::ty::traverse::TypeTraversable<$crate::ty::TyCtxt<'tcx>> for $ty {
type Kind = $crate::ty::traverse::NoopTypeTraversal;
}
)+
};
Expand Down
2 changes: 0 additions & 2 deletions compiler/rustc_middle/src/mir/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ pub struct CoroutineLayout<'tcx> {
/// Which saved locals are storage-live at the same time. Locals that do not
/// have conflicts with each other are allowed to overlap in the computed
/// layout.
#[type_foldable(identity)]
#[type_visitable(ignore)]
pub storage_conflicts: BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
}

Expand Down
Loading

0 comments on commit f592dd1

Please sign in to comment.