diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 1fec7bc30..873d93c6c 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -76,6 +76,18 @@ pub enum InferExtensionError { /// The incompatible solution that we found was already there actual: ExtensionSet, }, + #[error("Solved extensions {expected} at {expected_loc:?} and {actual} at {actual_loc:?} should be equal.")] + /// A version of the above with info about which nodes failed to unify + MismatchedConcreteWithLocations { + /// Where the solution we want to insert came from + expected_loc: (Node, Direction), + /// The solution we were trying to insert for this meta + expected: ExtensionSet, + /// Which node we're trying to add a solution for + actual_loc: (Node, Direction), + /// The incompatible solution that we found was already there + actual: ExtensionSet, + }, /// A variable went unsolved that wasn't related to a parameter #[error("Unsolved variable at location {:?}", location)] Unsolved { @@ -314,7 +326,10 @@ impl UnificationContext { } } - /// Try to turn mismatches into `ExtensionError` when possible + /// When trying to unify two metas, check if they both correspond to + /// different ends of the same wire. If so, return an `ExtensionError`. + /// Otherwise check whether they both correspond to *some* location on the + /// graph and include that info the otherwise generic `MismatchedConcrete`. fn report_mismatch( &self, m1: Meta, @@ -375,10 +390,19 @@ impl UnificationContext { } else { None }; - err.unwrap_or(InferExtensionError::MismatchedConcrete { - expected: rs1, - actual: rs2, - }) + if let (Some(loc1), Some(loc2)) = (loc1, loc2) { + err.unwrap_or(InferExtensionError::MismatchedConcreteWithLocations { + expected_loc: *loc1, + expected: rs1, + actual_loc: *loc2, + actual: rs2, + }) + } else { + err.unwrap_or(InferExtensionError::MismatchedConcrete { + expected: rs1, + actual: rs2, + }) + } } /// Take a group of equal metas and merge them into a new, single meta. diff --git a/src/types.rs b/src/types.rs index a86aef65f..164213171 100644 --- a/src/types.rs +++ b/src/types.rs @@ -95,31 +95,36 @@ pub(crate) fn least_upper_bound(mut tags: impl Iterator) -> Ty } #[derive(Clone, PartialEq, Debug, Eq, derive_more::Display, Serialize, Deserialize)] +#[serde(tag = "s")] /// Representation of a Sum type. /// Either store the types of the variants, or in the special (but common) case /// of a "simple predicate" (sum over empty tuples), store only the size of the predicate. enum SumType { - #[display(fmt = "SimplePredicate({})", "_0")] - Simple(u8), - General(TypeRow), + #[display(fmt = "SimplePredicate({})", "size")] + Simple { + size: u8, + }, + General { + row: TypeRow, + }, } impl SumType { fn new(types: impl Into) -> Self { let row: TypeRow = types.into(); - let len = row.len(); + let len: usize = row.len(); if len <= (u8::MAX as usize) && row.iter().all(|t| *t == Type::UNIT) { - Self::Simple(len as u8) + Self::Simple { size: len as u8 } } else { - Self::General(row) + Self::General { row } } } fn get_variant(&self, tag: usize) -> Option<&Type> { match self { - SumType::Simple(size) if tag < (*size as usize) => Some(Type::UNIT_REF), - SumType::General(row) => row.get(tag), + SumType::Simple { size } if tag < (*size as usize) => Some(Type::UNIT_REF), + SumType::General { row } => row.get(tag), _ => None, } } @@ -128,8 +133,8 @@ impl SumType { impl From for Type { fn from(sum: SumType) -> Type { match sum { - SumType::Simple(size) => Type::new_simple_predicate(size), - SumType::General(types) => Type::new_sum(types), + SumType::Simple { size } => Type::new_simple_predicate(size), + SumType::General { row } => Type::new_sum(row), } } } @@ -148,9 +153,9 @@ impl TypeEnum { fn least_upper_bound(&self) -> TypeBound { match self { TypeEnum::Prim(p) => p.bound(), - TypeEnum::Sum(SumType::Simple(_)) => TypeBound::Eq, - TypeEnum::Sum(SumType::General(ts)) => { - least_upper_bound(ts.iter().map(Type::least_upper_bound)) + TypeEnum::Sum(SumType::Simple { size: _ }) => TypeBound::Eq, + TypeEnum::Sum(SumType::General { row }) => { + least_upper_bound(row.iter().map(Type::least_upper_bound)) } TypeEnum::Tuple(ts) => least_upper_bound(ts.iter().map(Type::least_upper_bound)), } @@ -238,7 +243,7 @@ impl Type { /// New simple predicate with empty Tuple variants pub const fn new_simple_predicate(size: u8) -> Self { // should be the only way to avoid going through SumType::new - Self(TypeEnum::Sum(SumType::Simple(size)), TypeBound::Eq) + Self(TypeEnum::Sum(SumType::Simple { size }), TypeBound::Eq) } /// Report the least upper TypeBound, if there is one. @@ -260,10 +265,10 @@ impl Type { // There is no need to check the components against the bound, // that is guaranteed by construction (even for deserialization) match &self.0 { - TypeEnum::Tuple(row) | TypeEnum::Sum(SumType::General(row)) => { + TypeEnum::Tuple(row) | TypeEnum::Sum(SumType::General { row }) => { row.iter().try_for_each(|t| t.validate(extension_registry)) } - TypeEnum::Sum(SumType::Simple(_)) => Ok(()), // No leaves there + TypeEnum::Sum(SumType::Simple { .. }) => Ok(()), // No leaves there TypeEnum::Prim(PrimType::Alias(_)) => Ok(()), TypeEnum::Prim(PrimType::Extension(custy)) => custy.validate(extension_registry), TypeEnum::Prim(PrimType::Function(ft)) => ft @@ -349,7 +354,7 @@ pub(crate) mod test { assert_eq!(pred1, pred2); - let pred_direct = SumType::Simple(2); + let pred_direct = SumType::Simple { size: 2 }; assert_eq!(pred1, pred_direct.into()) } }