From d47b34a388cb6409c3ec4230744816d3387cfc3a Mon Sep 17 00:00:00 2001 From: Georgiy Komarov Date: Wed, 25 Jan 2023 16:00:43 +0700 Subject: [PATCH] feat(tc): Support partial application of procedures Closes #577 --- src/base/Syntax.ml | 5 +- src/base/Type.ml | 12 +++++ src/base/TypeChecker.ml | 103 ++++++++++++++++++++++++++-------------- src/base/TypeUtil.ml | 7 +++ src/base/TypeUtil.mli | 8 ++++ 5 files changed, 99 insertions(+), 36 deletions(-) diff --git a/src/base/Syntax.ml b/src/base/Syntax.ml index acba45c8a..728e6bb7e 100644 --- a/src/base/Syntax.ml +++ b/src/base/Syntax.ml @@ -374,7 +374,10 @@ module ScillaSyntax (SR : Rep) (ER : Rep) (Lit : ScillaLiteral) = struct ER.rep SIdentifier.t option * SR.rep SIdentifier.t * ER.rep SIdentifier.t list - (** [CallProc(I, P, [A1, ... An])] is a procedure call: [I = P A1 ... An] *) + (** [CallProc(I, P, [A1, ... An])] is a procedure call, when all the + arguments are specified: [I = P A1 ... An]. Otherwise, it is or + partial application of procedure that creates a new local variable + [I] that has function type: [I = P A1 ... An]. *) | Throw of ER.rep SIdentifier.t option (** [Throw(I)] represents: [throw I] *) | GasStmt of SGasCharge.gas_charge diff --git a/src/base/Type.ml b/src/base/Type.ml index a110d3b08..f2f8dc6ac 100644 --- a/src/base/Type.ml +++ b/src/base/Type.ml @@ -139,6 +139,9 @@ module type ScillaType = sig | PolyFun of string * t (** [PolyFun('A, T)] represents a polymorphic function type where ['A] is a type parameter. For example: [forall 'A. List 'a -> List 'A] *) + | ProcType of string * t list + (** [ProcType(P, Args)] is a type of partial application of procedure + [P] which has formal arguments with types [Args] *) | Unit (** [Unit] is a unit type *) | Address of t addr_kind (** [Address(A)] represents address *) [@@deriving sexp, to_yojson] @@ -227,6 +230,7 @@ module MkType (I : ScillaIdentifier) = struct | ADT of loc TIdentifier.t * t list | TypeVar of string | PolyFun of string * t + | ProcType of string * t list | Unit | Address of (t addr_kind[@to_yojson fun _ -> `String "Address"]) [@@deriving sexp, to_yojson] @@ -243,6 +247,9 @@ module MkType (I : ScillaIdentifier) = struct in String.concat ~sep:" " elems | FunType (at, vt) -> sprintf "%s -> %s" (with_paren at) (recurser vt) + | ProcType (p, args_tys) -> + sprintf "%s (%s)" p + (List.map args_tys ~f:recurser |> String.concat ~sep:", ") | TypeVar tv -> tv | PolyFun (tv, bt) -> sprintf "forall %s. %s" tv (recurser bt) | Unit -> sprintf "()" @@ -318,6 +325,9 @@ module MkType (I : ScillaIdentifier) = struct let ats = subst_type_in_type tvar tp at in let rts = subst_type_in_type tvar tp rt in FunType (ats, rts) + | ProcType (p, args_tys) -> + let args_tyss = List.map args_tys ~f:(subst_type_in_type tvar tp) in + ProcType (p, args_tyss) | TypeVar n -> if String.(tvar = n) then tp else tm | ADT (s, ts) -> let ts' = List.map ts ~f:(subst_type_in_type tvar tp) in @@ -341,6 +351,8 @@ module MkType (I : ScillaIdentifier) = struct match t with | MapType (kt, vt) -> MapType (kt, recursor vt taken) | FunType (at, rt) -> FunType (recursor at taken, recursor rt taken) + | ProcType (p, args_tys) -> + ProcType (p, List.map args_tys ~f:(fun ty -> recursor ty taken)) | ADT (n, ts) -> let ts' = List.map ts ~f:(fun w -> recursor w taken) in ADT (n, ts') diff --git a/src/base/TypeChecker.ml b/src/base/TypeChecker.ml index b0adcbe68..3b7431c14 100644 --- a/src/base/TypeChecker.ml +++ b/src/base/TypeChecker.ml @@ -147,6 +147,7 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct match t with | PrimType _ | Unit | TypeVar _ -> 1 | PolyFun (_, t) -> 1 + type_size t + | ProcType (_, args) -> 1 + List.length args | MapType (t1, t2) | FunType (t1, t2) -> 1 + type_size t1 + type_size t2 | ADT (_, ts) -> List.fold_left ts ~init:1 ~f:(fun acc t -> acc + type_size t) @@ -163,7 +164,8 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct | MapType (_, _) | FunType (_, _) | ADT (_, _) - | PolyFun (_, _) -> + | PolyFun (_, _) + | ProcType (_, _) -> 1 | TypeVar n -> if String.(n = tvar) then tp_size else 1 | Address AnyAddr | Address LibAddr | Address CodeAddr -> 1 @@ -201,6 +203,9 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct else let%bind res = recurser t' in pure (PolyFun (arg, res)) + | ProcType (pname, args) -> + let%bind args' = mapM args ~f:recurser in + pure (ProcType (pname, args')) | Address AnyAddr | Address LibAddr | Address CodeAddr -> pure t | Address (ContrAddr fts) -> let%bind fts_res = @@ -985,8 +990,8 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct @@ add_stmt_to_stmts_env_gas (TypedSyntax.CreateEvnt typed_i, rep) checked_stmts - | CallProc (id_opt, p, args) -> - let%bind arg_typs, ret_ty_opt = + | CallProc (id_opt, p, actual_args) -> + let%bind formal_args, ret_ty_opt = match lookup_proc env p with | Some (arg_typs, ret_ty_opt) -> pure (arg_typs, ret_ty_opt) | None -> @@ -995,41 +1000,69 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct ~inst:(as_error_string p) (SR.get_loc (get_rep p))) in - let%bind typed_args = - let%bind targs, typed_actuals = type_actuals env.pure args in + let%bind targs, typed_actuals = type_actuals env.pure actual_args in + let is_partial_application = + Option.is_some id_opt + && List.length formal_args > List.length targs + in + if is_partial_application then let%bind _ = fromR_TE - @@ proc_type_applies arg_typs targs ~lc:(SR.get_loc rep) + @@ partial_proc_type_applies formal_args targs + ~lc:(SR.get_loc rep) in - pure typed_actuals - in - let%bind typed_id_opt, checked_stmts = - match id_opt with - | None -> - let%bind checked_stmts = type_stmts comp sts get_loc env in - pure @@ (None, checked_stmts) - | Some id -> ( - match ret_ty_opt with - | Some ret_ty -> - let typed_id = add_type_to_ident id (mk_qual_tp ret_ty) in - let%bind checked_stmts = - with_extended_env env get_tenv_pure - [ (id, ret_ty) ] - [] - (type_stmts comp sts get_loc) - in - pure @@ (Some typed_id, checked_stmts) - | None -> - fail - (mk_type_error1 - ~kind:"Procedure does not return a value" - ~inst:(as_error_string p) - (SR.get_loc (get_rep p)))) - in - pure - @@ add_stmt_to_stmts_env_gas - (TypedSyntax.CallProc (typed_id_opt, p, typed_args), rep) - checked_stmts + let id = Option.value_exn id_opt in + let proc_name = + SIdentifier.Name.as_string (SIdentifier.get_id p) + in + let partial_applied_type = ProcType (proc_name, targs) in + let typed_id = + add_type_to_ident id (mk_qual_tp partial_applied_type) + in + let%bind checked_stmts = + with_extended_env env get_tenv_pure + [ (id, partial_applied_type) ] + [] + (type_stmts comp sts get_loc) + in + pure + @@ add_stmt_to_stmts_env_gas + (TypedSyntax.CallProc (Some typed_id, p, typed_actuals), rep) + checked_stmts + else + let%bind _ = + fromR_TE + @@ proc_type_applies formal_args targs ~lc:(SR.get_loc rep) + in + let%bind typed_id_opt, checked_stmts = + match id_opt with + | None -> + let%bind checked_stmts = type_stmts comp sts get_loc env in + pure @@ (None, checked_stmts) + | Some id -> ( + match ret_ty_opt with + | Some ret_ty -> + let typed_id = + add_type_to_ident id (mk_qual_tp ret_ty) + in + let%bind checked_stmts = + with_extended_env env get_tenv_pure + [ (id, ret_ty) ] + [] + (type_stmts comp sts get_loc) + in + pure @@ (Some typed_id, checked_stmts) + | None -> + fail + (mk_type_error1 + ~kind:"Procedure does not return a value" + ~inst:(as_error_string p) + (SR.get_loc (get_rep p)))) + in + pure + @@ add_stmt_to_stmts_env_gas + (TypedSyntax.CallProc (typed_id_opt, p, typed_actuals), rep) + checked_stmts | Iterate (l, p) -> ( let%bind lt = fromR_TE diff --git a/src/base/TypeUtil.ml b/src/base/TypeUtil.ml index 7574438a9..215581179 100644 --- a/src/base/TypeUtil.ml +++ b/src/base/TypeUtil.ml @@ -542,6 +542,13 @@ module TypeUtilities = struct mk_error1 ~kind:"Incorrect number of arguments to procedure" ?inst:None lc) + let partial_proc_type_applies ~lc formals actuals = + if List.length formals < List.length actuals then + fail (mk_error1 ~kind:"Extra arguments in procedure call" ?inst:None lc) + else + let formals' = List.sub ~pos:0 ~len:(List.length actuals) formals in + proc_type_applies ~lc formals' actuals + let rec elab_tfun_with_args_no_gas tf args = match (tf, args) with | (PolyFun _ as pf), a :: args' -> diff --git a/src/base/TypeUtil.mli b/src/base/TypeUtil.mli index 27a711a61..8307fbd8f 100644 --- a/src/base/TypeUtil.mli +++ b/src/base/TypeUtil.mli @@ -200,6 +200,14 @@ module TypeUtilities : sig TUType.t list -> (unit list, scilla_error list) result + (* Checks if types of the specified [actuals] arguments are assignable to the + corresponding types of [formals]. *) + val partial_proc_type_applies : + lc:ErrorUtils.loc -> + TUType.t list -> + TUType.t list -> + (unit list, scilla_error list) result + (* Applying a type function without gas charge (for builtins) *) val elab_tfun_with_args_no_gas : TUType.t -> TUType.t list -> (TUType.t, scilla_error list) result