Skip to content

Commit

Permalink
Merge 'origin/main' into new/validate_types, update SumType
Browse files Browse the repository at this point in the history
  • Loading branch information
acl-cqc committed Aug 29, 2023
2 parents bea1d94 + b58b582 commit c2b3b7c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 22 deletions.
34 changes: 29 additions & 5 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ pub enum InferExtensionError {
/// The incompatible solution that we found was already there
actual: ExtensionSet,
},
#[error("Solved extensions {expected} at {expected_loc:?} and {actual} at {actual_loc:?} should be equal.")]
/// A version of the above with info about which nodes failed to unify
MismatchedConcreteWithLocations {
/// Where the solution we want to insert came from
expected_loc: (Node, Direction),
/// The solution we were trying to insert for this meta
expected: ExtensionSet,
/// Which node we're trying to add a solution for
actual_loc: (Node, Direction),
/// The incompatible solution that we found was already there
actual: ExtensionSet,
},
/// A variable went unsolved that wasn't related to a parameter
#[error("Unsolved variable at location {:?}", location)]
Unsolved {
Expand Down Expand Up @@ -314,7 +326,10 @@ impl UnificationContext {
}
}

/// Try to turn mismatches into `ExtensionError` when possible
/// When trying to unify two metas, check if they both correspond to
/// different ends of the same wire. If so, return an `ExtensionError`.
/// Otherwise check whether they both correspond to *some* location on the
/// graph and include that info the otherwise generic `MismatchedConcrete`.
fn report_mismatch(
&self,
m1: Meta,
Expand Down Expand Up @@ -375,10 +390,19 @@ impl UnificationContext {
} else {
None
};
err.unwrap_or(InferExtensionError::MismatchedConcrete {
expected: rs1,
actual: rs2,
})
if let (Some(loc1), Some(loc2)) = (loc1, loc2) {
err.unwrap_or(InferExtensionError::MismatchedConcreteWithLocations {
expected_loc: *loc1,
expected: rs1,
actual_loc: *loc2,
actual: rs2,
})
} else {
err.unwrap_or(InferExtensionError::MismatchedConcrete {
expected: rs1,
actual: rs2,
})
}
}

/// Take a group of equal metas and merge them into a new, single meta.
Expand Down
39 changes: 22 additions & 17 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,31 +95,36 @@ pub(crate) fn least_upper_bound(mut tags: impl Iterator<Item = TypeBound>) -> Ty
}

#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display, Serialize, Deserialize)]
#[serde(tag = "s")]
/// Representation of a Sum type.
/// Either store the types of the variants, or in the special (but common) case
/// of a "simple predicate" (sum over empty tuples), store only the size of the predicate.
enum SumType {
#[display(fmt = "SimplePredicate({})", "_0")]
Simple(u8),
General(TypeRow),
#[display(fmt = "SimplePredicate({})", "size")]
Simple {
size: u8,
},
General {
row: TypeRow,
},
}

impl SumType {
fn new(types: impl Into<TypeRow>) -> Self {
let row: TypeRow = types.into();

let len = row.len();
let len: usize = row.len();
if len <= (u8::MAX as usize) && row.iter().all(|t| *t == Type::UNIT) {
Self::Simple(len as u8)
Self::Simple { size: len as u8 }
} else {
Self::General(row)
Self::General { row }
}
}

fn get_variant(&self, tag: usize) -> Option<&Type> {
match self {
SumType::Simple(size) if tag < (*size as usize) => Some(Type::UNIT_REF),
SumType::General(row) => row.get(tag),
SumType::Simple { size } if tag < (*size as usize) => Some(Type::UNIT_REF),
SumType::General { row } => row.get(tag),
_ => None,
}
}
Expand All @@ -128,8 +133,8 @@ impl SumType {
impl From<SumType> for Type {
fn from(sum: SumType) -> Type {
match sum {
SumType::Simple(size) => Type::new_simple_predicate(size),
SumType::General(types) => Type::new_sum(types),
SumType::Simple { size } => Type::new_simple_predicate(size),
SumType::General { row } => Type::new_sum(row),
}
}
}
Expand All @@ -148,9 +153,9 @@ impl TypeEnum {
fn least_upper_bound(&self) -> TypeBound {
match self {
TypeEnum::Prim(p) => p.bound(),
TypeEnum::Sum(SumType::Simple(_)) => TypeBound::Eq,
TypeEnum::Sum(SumType::General(ts)) => {
least_upper_bound(ts.iter().map(Type::least_upper_bound))
TypeEnum::Sum(SumType::Simple { size: _ }) => TypeBound::Eq,
TypeEnum::Sum(SumType::General { row }) => {
least_upper_bound(row.iter().map(Type::least_upper_bound))
}
TypeEnum::Tuple(ts) => least_upper_bound(ts.iter().map(Type::least_upper_bound)),
}
Expand Down Expand Up @@ -238,7 +243,7 @@ impl Type {
/// New simple predicate with empty Tuple variants
pub const fn new_simple_predicate(size: u8) -> Self {
// should be the only way to avoid going through SumType::new
Self(TypeEnum::Sum(SumType::Simple(size)), TypeBound::Eq)
Self(TypeEnum::Sum(SumType::Simple { size }), TypeBound::Eq)
}

/// Report the least upper TypeBound, if there is one.
Expand All @@ -260,10 +265,10 @@ impl Type {
// There is no need to check the components against the bound,
// that is guaranteed by construction (even for deserialization)
match &self.0 {
TypeEnum::Tuple(row) | TypeEnum::Sum(SumType::General(row)) => {
TypeEnum::Tuple(row) | TypeEnum::Sum(SumType::General { row }) => {
row.iter().try_for_each(|t| t.validate(extension_registry))
}
TypeEnum::Sum(SumType::Simple(_)) => Ok(()), // No leaves there
TypeEnum::Sum(SumType::Simple { .. }) => Ok(()), // No leaves there
TypeEnum::Prim(PrimType::Alias(_)) => Ok(()),
TypeEnum::Prim(PrimType::Extension(custy)) => custy.validate(extension_registry),
TypeEnum::Prim(PrimType::Function(ft)) => ft
Expand Down Expand Up @@ -349,7 +354,7 @@ pub(crate) mod test {

assert_eq!(pred1, pred2);

let pred_direct = SumType::Simple(2);
let pred_direct = SumType::Simple { size: 2 };
assert_eq!(pred1, pred_direct.into())
}
}

0 comments on commit c2b3b7c

Please sign in to comment.