From b40fc1d15e5fc320e1aea5aee646b640c6751676 Mon Sep 17 00:00:00 2001 From: Brent Yorgey Date: Sat, 19 Oct 2024 09:54:46 -0500 Subject: [PATCH] fix type inference for recursive let bindings --- src/swarm-lang/Swarm/Language/Typecheck.hs | 10 +++++----- test/unit/TestLanguagePipeline.hs | 6 ++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/swarm-lang/Swarm/Language/Typecheck.hs b/src/swarm-lang/Swarm/Language/Typecheck.hs index df33e1b1a..2f2784fbd 100644 --- a/src/swarm-lang/Swarm/Language/Typecheck.hs +++ b/src/swarm-lang/Swarm/Language/Typecheck.hs @@ -69,7 +69,7 @@ import Data.Set qualified as S import Data.Text (Text) import Data.Text qualified as T import Prettyprinter -import Swarm.Effect.Unify (Unification, UnificationError, (=:=)) +import Swarm.Effect.Unify (Unification, UnificationError) import Swarm.Effect.Unify qualified as U import Swarm.Effect.Unify.Fast qualified as U import Swarm.Language.Context hiding (lookup) @@ -342,7 +342,7 @@ unify :: TypeJoin -> m UType unify ms j = do - res <- expected =:= actual + res <- expected U.=:= actual case res of Left _ -> do j' <- traverse U.applyBindings j @@ -1135,7 +1135,7 @@ check s@(CSyntax l t cs) expected = addLocToTypeErr l $ case t of traverse_ (adaptToTypeErr l KindErr . checkKind) mxTy case toU mxTy of Just xTy -> do - res <- argTy =:= xTy + res <- argTy U.=:= xTy case res of -- Generate a special error when the explicit type annotation -- on a lambda doesn't match the expected type, @@ -1179,8 +1179,8 @@ check s@(CSyntax l t cs) expected = addLocToTypeErr l $ case t of xTy <- fresh t1' <- withBinding (lvVar x) (Forall [] xTy) $ infer t1 let uty = t1' ^. sType - _ <- xTy =:= uty - upty <- generalize uty + uty' <- unify (Just t1) (joined xTy uty) + upty <- generalize uty' return ([], upty, t1') -- An explicit polytype annotation has been provided. Skolemize it and check -- definition and body under an extended context. diff --git a/test/unit/TestLanguagePipeline.hs b/test/unit/TestLanguagePipeline.hs index 1bb61d388..6eb772614 100644 --- a/test/unit/TestLanguagePipeline.hs +++ b/test/unit/TestLanguagePipeline.hs @@ -556,6 +556,12 @@ testLanguagePipeline = "(\\f. f 3) 2" "1:11: Type mismatch:\n From context, expected `2` to have a type like `Int -> _`" ) + , testCase + "inferring type of bad recursive function - #2186" + ( process + "def bad = \\acc.\\n. if (n <= 0) {fst acc} {bad (fst acc + 1) (n - 1)} end" + "1:1: Type mismatch:\n From context, expected `\\acc. \\n. if (n <= 0) {fst acc} {\n bad (fst acc + 1) (n - 1)\n }` to have type `Int -> Int -> Int`,\n but it actually has a type like `(Int * _) -> Int -> Int`" + ) ] , testGroup "generalize top-level binds #351 #1501"