Skip to content

Commit

Permalink
Datatype alpha equivalence 3.x (#20025)
Browse files Browse the repository at this point in the history
* Check for alpha equivalence in data types (#20016)

* Add two tests for Phantom type, fix error messages

* lint
  • Loading branch information
dylant-da authored Sep 27, 2024
1 parent d5dda36 commit 0ccfc47
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ data UnwarnableError
| EForbiddenNewImplementation !TypeConName !TypeConName
| EUpgradeDependenciesFormACycle ![(PackageId, PackageMetadata)]
| EUpgradeMultiplePackagesWithSameNameAndVersion !PackageName !RawPackageVersion ![PackageId]
| EUpgradeDifferentParamsCount !UpgradedRecordOrigin
| EUpgradeDifferentParamsKinds !UpgradedRecordOrigin
deriving (Show)

data WarnableError
Expand Down Expand Up @@ -687,6 +689,8 @@ instance Pretty UnwarnableError where
where
pprintDep (pkgId, meta) = pPrint pkgId <> "(" <> pPrint (packageName meta) <> ", " <> pPrint (packageVersion meta) <> ")"
EUpgradeMultiplePackagesWithSameNameAndVersion name version ids -> "Multiple packages with name " <> pPrint name <> " and version " <> pPrint (show version) <> ": " <> hcat (L.intersperse ", " (map pPrint ids))
EUpgradeDifferentParamsCount origin -> "The upgraded " <> pPrint origin <> " has changed the number of type variables it has."
EUpgradeDifferentParamsKinds origin -> "The upgraded " <> pPrint origin <> " has changed the kind of one of its type variables."

instance Pretty UpgradedRecordOrigin where
pPrint = \case
Expand Down
39 changes: 22 additions & 17 deletions sdk/compiler/daml-lf-tools/src/DA/Daml/LF/TypeChecker/Upgrade.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ module DA.Daml.LF.TypeChecker.Upgrade (
) where

import Control.Monad (unless, forM, forM_, when)
import Control.Monad.Extra (unlessM)
import Control.Monad.Reader (withReaderT, ask)
import Control.Monad.Reader.Class (asks)
import Control.Lens hiding (Context)
Expand Down Expand Up @@ -593,8 +594,8 @@ checkTemplate module_ template = do
(existingChoices, _existingNew) <- checkDeleted (EUpgradeMissingChoice . NM.name) $ NM.toHashMap . tplChoices <$> template
forM_ existingChoices $ \choice -> do
withContextF present' (ContextTemplate (_present module_) (_present template) (TPChoice (_present choice))) $ do
checkType (fmap chcReturnType choice)
(EUpgradeChoiceChangedReturnType (NM.name (_present choice)))
unlessM (isUpgradedTypeNoTypeVars (fmap chcReturnType choice)) $
throwWithContextF present' (EUpgradeChoiceChangedReturnType (NM.name (_present choice)))

whenDifferent "controllers" (extractFuncFromFuncThisArg . chcControllers) choice $
\mismatches -> warnWithContextF present' (WChoiceChangedControllers (NM.name (_present choice)) mismatches)
Expand Down Expand Up @@ -642,7 +643,7 @@ checkTemplate module_ template = do
let tplKey = Upgrading pastKey presentKey

-- Key type must be a valid upgrade
iset <- isUpgradedType (fmap tplKeyType tplKey)
iset <- isUpgradedTypeNoTypeVars (fmap tplKeyType tplKey)
when (not iset) $
diagnosticWithContextF present' (EUpgradeTemplateChangedKeyType (NM.name (_present template)))

Expand Down Expand Up @@ -758,15 +759,21 @@ checkTemplate module_ template = do

checkDefDataType :: UpgradedRecordOrigin -> Upgrading LF.DefDataType -> TcUpgradeM ()
checkDefDataType origin datatype = do
let params = dataParams <$> datatype
paramsLengthMatch = foldU (==) (length <$> params)
allKindsMatch = foldU (==) (map snd <$> params)
when (not paramsLengthMatch) $ throwWithContextF present' $ EUpgradeDifferentParamsCount origin
when (not allKindsMatch) $ throwWithContextF present' $ EUpgradeDifferentParamsKinds origin
let paramNames = unsafeZipUpgrading (map fst <$> params)
case fmap dataCons datatype of
Upgrading { _past = DataRecord _past, _present = DataRecord _present } ->
checkFields origin (Upgrading {..})
checkFields origin paramNames (Upgrading {..})
Upgrading { _past = DataVariant _past, _present = DataVariant _present } -> do
let upgrade = Upgrading{..}
(existing, _new) <- checkDeleted (\_ -> EUpgradeVariantRemovedVariant origin) (fmap HMS.fromList upgrade)
when (not $ and $ foldU (zipWith (==)) $ fmap (map fst) upgrade) $
throwWithContextF present' (EUpgradeVariantVariantsOrderChanged origin)
different <- filterHashMapM (fmap not . isUpgradedType) existing
different <- filterHashMapM (fmap not . isUpgradedType paramNames) existing
when (not (null different)) $
throwWithContextF present' $ EUpgradeVariantChangedVariantType origin
Upgrading { _past = DataEnum _past, _present = DataEnum _present } -> do
Expand All @@ -786,11 +793,11 @@ filterHashMapM :: (Applicative m) => (a -> m Bool) -> HMS.HashMap k a -> m (HMS.
filterHashMapM pred t =
fmap fst . HMS.filter snd <$> traverse (\x -> (x,) <$> pred x) t

checkFields :: UpgradedRecordOrigin -> Upgrading [(FieldName, Type)] -> TcUpgradeM ()
checkFields origin fields = do
checkFields :: UpgradedRecordOrigin -> [Upgrading TypeVarName] -> Upgrading [(FieldName, Type)] -> TcUpgradeM ()
checkFields origin paramNames fields = do
(existing, new) <- checkDeleted (\_ -> EUpgradeRecordFieldsMissing origin) (fmap HMS.fromList fields)
-- If a field from the upgraded package has had its type changed
different <- filterHashMapM (fmap not . isUpgradedType) existing
different <- filterHashMapM (fmap not . isUpgradedType paramNames) existing
when (not (HMS.null different)) $
throwWithContextF present' (EUpgradeRecordFieldsExistingChanged origin)
when (not (all newFieldOptionalType new)) $
Expand All @@ -807,12 +814,6 @@ checkFields origin fields = do
newFieldOptionalType (TOptional _) = True
newFieldOptionalType _ = False

-- Check type upgradability
checkType :: SomeErrorOrWarning e => Upgrading Type -> e -> TcUpgradeM ()
checkType type_ err = do
sameType <- isUpgradedType type_
unless sameType $ diagnosticWithContextF present' err

checkQualName :: Alpha.IsSomeName a => DepsMap -> Upgrading (Qualified a) -> [Alpha.Mismatch UpgradeMismatchReason]
checkQualName deps name =
let namesAreSame = foldU Alpha.alphaEq' (fmap removePkgId name)
Expand Down Expand Up @@ -870,14 +871,18 @@ checkQualName deps name =
ifMismatch (isNothing pastDepLookup) (CouldNotFindPackageForPastIdentifier (Left (_past pkgId)))
`Alpha.andMismatches` ifMismatch (isNothing presentDepLookup) (CouldNotFindPackageForPastIdentifier (Left (_present pkgId)))

isUpgradedType :: Upgrading Type -> TcUpgradeM Bool
isUpgradedType type_ = do
isUpgradedTypeNoTypeVars :: Upgrading Type -> TcUpgradeM Bool
isUpgradedTypeNoTypeVars = isUpgradedType []

isUpgradedType :: [Upgrading TypeVarName] -> Upgrading Type -> TcUpgradeM Bool
isUpgradedType varNames type_ = do
expandedTypes <- runGammaUnderUpgrades (expandTypeSynonyms <$> type_)
alphaEnv <- upgradingAlphaEnv
let alphaEnvWithVars = foldl' (flip (foldU Alpha.bindTypeVar)) alphaEnv varNames
-- NOTE: The warning messages generated by alphaType' via checkQualName
-- below are only designed for the expression warnings and won't make sense
-- if used to describe issues with types.
pure $ null $ foldU (Alpha.alphaType' alphaEnv) expandedTypes
pure $ null $ foldU (Alpha.alphaType' alphaEnvWithVars) expandedTypes

upgradingAlphaEnv :: TcUpgradeM (Alpha.AlphaEnv UpgradeMismatchReason)
upgradingAlphaEnv = do
Expand Down
4 changes: 4 additions & 0 deletions sdk/compiler/damlc/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ da_haskell_test(
"//test-common:upgrades-CannotUpgradeView-files",
"//test-common:upgrades-FailWhenATopLevelEnumChangesChangesTheOrderOfItsVariants-files",
"//test-common:upgrades-FailWhenATopLevelVariantChangesChangesTheOrderOfItsVariants-files",
"//test-common:upgrades-FailWhenParamCountChanges-files",
"//test-common:upgrades-FailWhenParamKindChanges-files",
"//test-common:upgrades-FailsOnlyInModuleNotInReexports-files",
"//test-common:upgrades-FailsWhenATopLevelRecordAddsANonOptionalField-files",
"//test-common:upgrades-FailsWhenATopLevelRecordAddsAnOptionalFieldBeforeTheEnd-files",
Expand Down Expand Up @@ -489,6 +491,8 @@ da_haskell_test(
"//test-common:upgrades-MissingTemplate-files",
"//test-common:upgrades-RecordFieldsNewNonOptional-files",
"//test-common:upgrades-SucceedWhenATopLevelEnumAddsAField-files",
"//test-common:upgrades-SucceedWhenParamNameChanges-files",
"//test-common:upgrades-SucceedWhenPhantomParamBecomesUsed-files",
"//test-common:upgrades-SucceedsWhenATopLevelEnumChanges-files",
"//test-common:upgrades-SucceedsWhenATopLevelRecordAddsAnOptionalFieldAtTheEnd-files",
"//test-common:upgrades-SucceedsWhenATopLevelRecordAddsAnOptionalFieldAtTheEnd-v1.dar",
Expand Down
28 changes: 28 additions & 0 deletions sdk/compiler/damlc/tests/src/DA/Test/DamlcUpgrades.hs
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,34 @@ tests damlc =
"my-package"
"0.0.2"
LF.version2_dev
, test
"FailWhenParamCountChanges"
(FailWithError "\ESC\\[0;91merror type checking data type Main.MyStruct:\n The upgraded data type MyStruct has changed the number of type variables it has.")
versionDefault
NoDependencies
False
True
, test
"FailWhenParamKindChanges"
(FailWithError "\ESC\\[0;91merror type checking data type Main.MyStruct:\n The upgraded data type MyStruct has changed the kind of one of its type variables.")
versionDefault
NoDependencies
False
True
, test
"SucceedWhenParamNameChanges"
Succeed
versionDefault
NoDependencies
False
True
, test
"SucceedWhenPhantomParamBecomesUsed"
Succeed
versionDefault
NoDependencies
False
True
]
)
where
Expand Down
4 changes: 4 additions & 0 deletions sdk/test-common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,10 @@ da_scala_dar_resources_library(
{"data_dependencies": ["//test-common:upgrades-WarnsWhenExpressionChangesPackageId-dep-name2.dar"]},
),
("SucceedsWhenUpgradingLFVersionWithoutExpressionWarning", {}, {}),
("FailWhenParamCountChanges", {}, {}),
("FailWhenParamKindChanges", {}, {}),
("SucceedWhenParamNameChanges", {}, {}),
("SucceedWhenPhantomParamBecomesUsed", {}, {}),
]
]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Copyright (c) 2024 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0

module Main where

data MyStruct a = MyStruct { field1 : a }

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Copyright (c) 2024 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0

module Main where

data MyStruct a b = MyStruct { field1 : a }

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Copyright (c) 2024 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0

module Main where

data MyStruct a f = MyStruct { field1 : a }

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Copyright (c) 2024 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0

module Main where

data MyStruct a f = MyStruct { field1 : a, field2 : f a }

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Copyright (c) 2024 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0

module Main where

data MyStruct1 a = MyStruct1 { field1 : a }
data MyStruct2 a b = MyStruct2 { field1 : a }
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Copyright (c) 2024 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0

module Main where

data MyStruct1 b = MyStruct1 { field1 : b }
data MyStruct2 b a = MyStruct2 { field1 : b, field2 : Optional a, field3: Optional b }

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Copyright (c) 2024 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0

module Main where

data MyStruct a b = MyStruct { field1 : a }


Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-- Copyright (c) 2024 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0

module Main where

data MyStruct a b = MyStruct { field1 : a, field2 : Optional b }

0 comments on commit 0ccfc47

Please sign in to comment.