Skip to content

Commit

Permalink
fix type inference for recursive let bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
byorgey committed Oct 19, 2024
1 parent e99e5ca commit b40fc1d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/swarm-lang/Swarm/Language/Typecheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ import Data.Set qualified as S
import Data.Text (Text)
import Data.Text qualified as T
import Prettyprinter
import Swarm.Effect.Unify (Unification, UnificationError, (=:=))
import Swarm.Effect.Unify (Unification, UnificationError)
import Swarm.Effect.Unify qualified as U
import Swarm.Effect.Unify.Fast qualified as U
import Swarm.Language.Context hiding (lookup)
Expand Down Expand Up @@ -342,7 +342,7 @@ unify ::
TypeJoin ->
m UType
unify ms j = do
res <- expected =:= actual
res <- expected U.=:= actual
case res of
Left _ -> do
j' <- traverse U.applyBindings j
Expand Down Expand Up @@ -1135,7 +1135,7 @@ check s@(CSyntax l t cs) expected = addLocToTypeErr l $ case t of
traverse_ (adaptToTypeErr l KindErr . checkKind) mxTy
case toU mxTy of
Just xTy -> do
res <- argTy =:= xTy
res <- argTy U.=:= xTy
case res of
-- Generate a special error when the explicit type annotation
-- on a lambda doesn't match the expected type,
Expand Down Expand Up @@ -1179,8 +1179,8 @@ check s@(CSyntax l t cs) expected = addLocToTypeErr l $ case t of
xTy <- fresh
t1' <- withBinding (lvVar x) (Forall [] xTy) $ infer t1
let uty = t1' ^. sType
_ <- xTy =:= uty
upty <- generalize uty
uty' <- unify (Just t1) (joined xTy uty)
upty <- generalize uty'
return ([], upty, t1')
-- An explicit polytype annotation has been provided. Skolemize it and check
-- definition and body under an extended context.
Expand Down
6 changes: 6 additions & 0 deletions test/unit/TestLanguagePipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,12 @@ testLanguagePipeline =
"(\\f. f 3) 2"
"1:11: Type mismatch:\n From context, expected `2` to have a type like `Int -> _`"
)
, testCase
"inferring type of bad recursive function - #2186"
( process
"def bad = \\acc.\\n. if (n <= 0) {fst acc} {bad (fst acc + 1) (n - 1)} end"
"1:1: Type mismatch:\n From context, expected `\\acc. \\n. if (n <= 0) {fst acc} {\n bad (fst acc + 1) (n - 1)\n }` to have type `Int -> Int -> Int`,\n but it actually has a type like `(Int * _) -> Int -> Int`"
)
]
, testGroup
"generalize top-level binds #351 #1501"
Expand Down

0 comments on commit b40fc1d

Please sign in to comment.