Skip to content

Commit

Permalink
feat(tc): Support partial application of procedures
Browse files Browse the repository at this point in the history
Closes #577
  • Loading branch information
jubnzv committed Jan 25, 2023
1 parent 7fe3623 commit d47b34a
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 36 deletions.
5 changes: 4 additions & 1 deletion src/base/Syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,10 @@ module ScillaSyntax (SR : Rep) (ER : Rep) (Lit : ScillaLiteral) = struct
ER.rep SIdentifier.t option
* SR.rep SIdentifier.t
* ER.rep SIdentifier.t list
(** [CallProc(I, P, [A1, ... An])] is a procedure call: [I = P A1 ... An] *)
(** [CallProc(I, P, [A1, ... An])] is a procedure call, when all the
arguments are specified: [I = P A1 ... An]. Otherwise, it is or
partial application of procedure that creates a new local variable
[I] that has function type: [I = P A1 ... An]. *)
| Throw of ER.rep SIdentifier.t option
(** [Throw(I)] represents: [throw I] *)
| GasStmt of SGasCharge.gas_charge
Expand Down
12 changes: 12 additions & 0 deletions src/base/Type.ml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ module type ScillaType = sig
| PolyFun of string * t
(** [PolyFun('A, T)] represents a polymorphic function type where
['A] is a type parameter. For example: [forall 'A. List 'a -> List 'A] *)
| ProcType of string * t list
(** [ProcType(P, Args)] is a type of partial application of procedure
[P] which has formal arguments with types [Args] *)
| Unit (** [Unit] is a unit type *)
| Address of t addr_kind (** [Address(A)] represents address *)
[@@deriving sexp, to_yojson]
Expand Down Expand Up @@ -227,6 +230,7 @@ module MkType (I : ScillaIdentifier) = struct
| ADT of loc TIdentifier.t * t list
| TypeVar of string
| PolyFun of string * t
| ProcType of string * t list
| Unit
| Address of (t addr_kind[@to_yojson fun _ -> `String "Address"])
[@@deriving sexp, to_yojson]
Expand All @@ -243,6 +247,9 @@ module MkType (I : ScillaIdentifier) = struct
in
String.concat ~sep:" " elems
| FunType (at, vt) -> sprintf "%s -> %s" (with_paren at) (recurser vt)
| ProcType (p, args_tys) ->
sprintf "%s (%s)" p
(List.map args_tys ~f:recurser |> String.concat ~sep:", ")
| TypeVar tv -> tv
| PolyFun (tv, bt) -> sprintf "forall %s. %s" tv (recurser bt)
| Unit -> sprintf "()"
Expand Down Expand Up @@ -318,6 +325,9 @@ module MkType (I : ScillaIdentifier) = struct
let ats = subst_type_in_type tvar tp at in
let rts = subst_type_in_type tvar tp rt in
FunType (ats, rts)
| ProcType (p, args_tys) ->
let args_tyss = List.map args_tys ~f:(subst_type_in_type tvar tp) in
ProcType (p, args_tyss)
| TypeVar n -> if String.(tvar = n) then tp else tm
| ADT (s, ts) ->
let ts' = List.map ts ~f:(subst_type_in_type tvar tp) in
Expand All @@ -341,6 +351,8 @@ module MkType (I : ScillaIdentifier) = struct
match t with
| MapType (kt, vt) -> MapType (kt, recursor vt taken)
| FunType (at, rt) -> FunType (recursor at taken, recursor rt taken)
| ProcType (p, args_tys) ->
ProcType (p, List.map args_tys ~f:(fun ty -> recursor ty taken))
| ADT (n, ts) ->
let ts' = List.map ts ~f:(fun w -> recursor w taken) in
ADT (n, ts')
Expand Down
103 changes: 68 additions & 35 deletions src/base/TypeChecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
match t with
| PrimType _ | Unit | TypeVar _ -> 1
| PolyFun (_, t) -> 1 + type_size t
| ProcType (_, args) -> 1 + List.length args
| MapType (t1, t2) | FunType (t1, t2) -> 1 + type_size t1 + type_size t2
| ADT (_, ts) ->
List.fold_left ts ~init:1 ~f:(fun acc t -> acc + type_size t)
Expand All @@ -163,7 +164,8 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
| MapType (_, _)
| FunType (_, _)
| ADT (_, _)
| PolyFun (_, _) ->
| PolyFun (_, _)
| ProcType (_, _) ->
1
| TypeVar n -> if String.(n = tvar) then tp_size else 1
| Address AnyAddr | Address LibAddr | Address CodeAddr -> 1
Expand Down Expand Up @@ -201,6 +203,9 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
else
let%bind res = recurser t' in
pure (PolyFun (arg, res))
| ProcType (pname, args) ->
let%bind args' = mapM args ~f:recurser in
pure (ProcType (pname, args'))
| Address AnyAddr | Address LibAddr | Address CodeAddr -> pure t
| Address (ContrAddr fts) ->
let%bind fts_res =
Expand Down Expand Up @@ -985,8 +990,8 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
@@ add_stmt_to_stmts_env_gas
(TypedSyntax.CreateEvnt typed_i, rep)
checked_stmts
| CallProc (id_opt, p, args) ->
let%bind arg_typs, ret_ty_opt =
| CallProc (id_opt, p, actual_args) ->
let%bind formal_args, ret_ty_opt =
match lookup_proc env p with
| Some (arg_typs, ret_ty_opt) -> pure (arg_typs, ret_ty_opt)
| None ->
Expand All @@ -995,41 +1000,69 @@ module ScillaTypechecker (SR : Rep) (ER : Rep) = struct
~inst:(as_error_string p)
(SR.get_loc (get_rep p)))
in
let%bind typed_args =
let%bind targs, typed_actuals = type_actuals env.pure args in
let%bind targs, typed_actuals = type_actuals env.pure actual_args in
let is_partial_application =
Option.is_some id_opt
&& List.length formal_args > List.length targs
in
if is_partial_application then
let%bind _ =
fromR_TE
@@ proc_type_applies arg_typs targs ~lc:(SR.get_loc rep)
@@ partial_proc_type_applies formal_args targs
~lc:(SR.get_loc rep)
in
pure typed_actuals
in
let%bind typed_id_opt, checked_stmts =
match id_opt with
| None ->
let%bind checked_stmts = type_stmts comp sts get_loc env in
pure @@ (None, checked_stmts)
| Some id -> (
match ret_ty_opt with
| Some ret_ty ->
let typed_id = add_type_to_ident id (mk_qual_tp ret_ty) in
let%bind checked_stmts =
with_extended_env env get_tenv_pure
[ (id, ret_ty) ]
[]
(type_stmts comp sts get_loc)
in
pure @@ (Some typed_id, checked_stmts)
| None ->
fail
(mk_type_error1
~kind:"Procedure does not return a value"
~inst:(as_error_string p)
(SR.get_loc (get_rep p))))
in
pure
@@ add_stmt_to_stmts_env_gas
(TypedSyntax.CallProc (typed_id_opt, p, typed_args), rep)
checked_stmts
let id = Option.value_exn id_opt in
let proc_name =
SIdentifier.Name.as_string (SIdentifier.get_id p)
in
let partial_applied_type = ProcType (proc_name, targs) in
let typed_id =
add_type_to_ident id (mk_qual_tp partial_applied_type)
in
let%bind checked_stmts =
with_extended_env env get_tenv_pure
[ (id, partial_applied_type) ]
[]
(type_stmts comp sts get_loc)
in
pure
@@ add_stmt_to_stmts_env_gas
(TypedSyntax.CallProc (Some typed_id, p, typed_actuals), rep)
checked_stmts
else
let%bind _ =
fromR_TE
@@ proc_type_applies formal_args targs ~lc:(SR.get_loc rep)
in
let%bind typed_id_opt, checked_stmts =
match id_opt with
| None ->
let%bind checked_stmts = type_stmts comp sts get_loc env in
pure @@ (None, checked_stmts)
| Some id -> (
match ret_ty_opt with
| Some ret_ty ->
let typed_id =
add_type_to_ident id (mk_qual_tp ret_ty)
in
let%bind checked_stmts =
with_extended_env env get_tenv_pure
[ (id, ret_ty) ]
[]
(type_stmts comp sts get_loc)
in
pure @@ (Some typed_id, checked_stmts)
| None ->
fail
(mk_type_error1
~kind:"Procedure does not return a value"
~inst:(as_error_string p)
(SR.get_loc (get_rep p))))
in
pure
@@ add_stmt_to_stmts_env_gas
(TypedSyntax.CallProc (typed_id_opt, p, typed_actuals), rep)
checked_stmts
| Iterate (l, p) -> (
let%bind lt =
fromR_TE
Expand Down
7 changes: 7 additions & 0 deletions src/base/TypeUtil.ml
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,13 @@ module TypeUtilities = struct
mk_error1 ~kind:"Incorrect number of arguments to procedure" ?inst:None
lc)
let partial_proc_type_applies ~lc formals actuals =
if List.length formals < List.length actuals then
fail (mk_error1 ~kind:"Extra arguments in procedure call" ?inst:None lc)
else
let formals' = List.sub ~pos:0 ~len:(List.length actuals) formals in
proc_type_applies ~lc formals' actuals
let rec elab_tfun_with_args_no_gas tf args =
match (tf, args) with
| (PolyFun _ as pf), a :: args' ->
Expand Down
8 changes: 8 additions & 0 deletions src/base/TypeUtil.mli
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ module TypeUtilities : sig
TUType.t list ->
(unit list, scilla_error list) result

(* Checks if types of the specified [actuals] arguments are assignable to the
corresponding types of [formals]. *)
val partial_proc_type_applies :
lc:ErrorUtils.loc ->
TUType.t list ->
TUType.t list ->
(unit list, scilla_error list) result

(* Applying a type function without gas charge (for builtins) *)
val elab_tfun_with_args_no_gas :
TUType.t -> TUType.t list -> (TUType.t, scilla_error list) result
Expand Down

0 comments on commit d47b34a

Please sign in to comment.