Skip to content

Commit

Permalink
implement changes from experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
rajgodse committed Aug 16, 2023
1 parent fdc3df8 commit a3dfee8
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 118 deletions.
69 changes: 20 additions & 49 deletions ocaml/lambda/matching.ml
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,10 @@ let may_compat = MayCompat.compat

and may_compats = MayCompat.compats

(* The free variables in a guarded rhs can be precomputed to be used in an
optimization, or uncomputed, in which case the optimization is not applied.
*)
type guarded_free_variables =
| Precomputed of Ident.Set.t
| Uncomputed

let map_guarded_free_variables ~f = function
| Precomputed free_variables -> Precomputed (f free_variables)
| Uncomputed -> Uncomputed

type rhs =
| Guarded of
{ patch_guarded: patch:lambda -> lambda
; free_variables: guarded_free_variables }
; free_variables: Ident.Set.t }
(* Guarded rhs's must allow for fallthrough if the guard fails.
When translating a guarded rhs, the code to execute on fallthrough must
Expand All @@ -158,14 +147,19 @@ type rhs =
*)
| Unguarded of lambda

let mk_boolean_guarded_rhs ~patch_guarded ~free_variables =
Guarded { patch_guarded; free_variables = Precomputed free_variables }

let mk_pattern_guarded_rhs ~patch_guarded =
Guarded { patch_guarded; free_variables = Uncomputed }
let mk_guarded_rhs ~patch_guarded ~free_variables =
Guarded { patch_guarded; free_variables }

let mk_unguarded_rhs action = Unguarded action

let free_variables_of_rhs = function
| Guarded { free_variables; _ } -> free_variables
| Unguarded lam -> free_variables lam

let unguarded_exn = function
| Unguarded lam -> lam
| Guarded _ -> fatal_error "Matching.unguarded_exn"

let is_guarded = function
| Guarded _ -> true
| Unguarded _ -> false
Expand All @@ -178,13 +172,10 @@ let bind_rhs_with_layout str (var, layout) exp body =
| Lvar var' when Ident.same var var' -> body
| _ ->
let patch_guarded ~patch =
Llet (str, layout, var, exp, patch_guarded ~patch) in
Llet (str, layout, var, exp, patch_guarded ~patch)
in
let free_variables =
map_guarded_free_variables
~f:(fun free ->
Ident.Set.union
(free_variables exp) (Ident.Set.remove var free))
free
Ident.Set.union (free_variables exp) (Ident.Set.remove var free)
in
Guarded { patch_guarded; free_variables }

Expand Down Expand Up @@ -1205,24 +1196,10 @@ let what_is_first_case = what_is_cases ~skip_any:false

let what_is_cases = what_is_cases ~skip_any:true

type pm_free_variables =
| Known of Ident.Set.t
(* Pattern match free variables are known: optimization can be applied *)
| Unknown
(* Pattern match free variables are unknown: optimization cannot be applied *)

let pm_free_variables { cases } =
List.fold_right
(fun (_, act) -> function
| Unknown -> Unknown
| Known free ->
match act with
| Unguarded lam ->
Known (Ident.Set.union free (free_variables lam))
| Guarded { free_variables = Precomputed free_variables } ->
Known (Ident.Set.union free free_variables)
| Guarded { free_variables = Uncomputed } -> Unknown)
cases (Known Ident.Set.empty)
(fun (_, rhs) free -> Ident.Set.union free (free_variables_of_rhs rhs))
cases Ident.Set.empty

(* Basic grouping predicates *)

Expand Down Expand Up @@ -1669,17 +1646,11 @@ and precompile_or ~arg ~arg_sort (cls : Simple.clause list) ors args def k =
fires and determine if the resulting change in code is important.
*)
(* Optimization: discard pattern vars not bound in orpm actions *)
let pm_fv = pm_free_variables orpm in
let patbound_idents =
match pm_free_variables orpm with
(* Give up on the optimization: there is some action not tracking
free variables, so the free variable set is not known. *)
| Unknown -> patbound_idents
(* The free variables set is known: apply the optimization by
filtering out pattern-bound variables unused by the actions. *)
| Known pm_fv ->
List.filter
(fun (id, _, _, _) -> Ident.Set.mem id pm_fv)
patbound_idents
List.filter
(fun (id, _, _, _) -> Ident.Set.mem id pm_fv)
patbound_idents
in
let patbound_action_vars =
List.map
Expand Down
20 changes: 9 additions & 11 deletions ocaml/lambda/matching.mli
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,27 @@ open Debuginfo.Scoped_location
type rhs

(* Creates a guarded rhs.
If a guard fails, a guarded rhs must fallthrough to the remaining cases.
To facilitate this, guarded rhs's are constructed using a continuation.
[mk_pattern_guarded_rhs ~patch_guarded] produces a guarded rhs with a
[mk_guarded_rhs ~patch_guarded ~free_variables] produces a guarded rhs with a
lambda representation given by [patch_guarded ~patch], where [patch] contains
an expression that falls through to the remaining cases.
[mk_boolean_guarded_rhs ~patch_guarded ~free_variables] produces a similar
rhs where [free_variables] contains the free variables of the rhs.
an expression that falls through to the remaining cases and [free_variables]
contains the free variables of the rhs.
*)
val mk_boolean_guarded_rhs:
val mk_guarded_rhs:
patch_guarded:(patch:lambda -> lambda) ->
free_variables:Ident.Set.t ->
rhs

val mk_pattern_guarded_rhs:
patch_guarded:(patch:lambda -> lambda) ->
rhs

(* Creates an unguarded rhs from its lambda representation. *)
val mk_unguarded_rhs: lambda -> rhs

val free_variables_of_rhs : rhs -> Ident.Set.t

val unguarded_exn: rhs -> lambda

(* Entry points to match compiler *)
val for_function:
scopes:scopes ->
Expand Down
147 changes: 89 additions & 58 deletions ocaml/lambda/translcore.ml
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,12 @@ and transl_exp0 ~in_new_scope ~scopes sort e =
~position ~mode (transl_exp ~scopes Sort.for_function funct)
oargs (of_location ~scopes e.exp_loc))
| Texp_match(arg, arg_sort, pat_expr_list, partial) ->
transl_match ~scopes ~arg_sort ~return_sort:sort ~return_type:e.exp_type
~loc:e.exp_loc ~env:e.exp_env ~extra_cases:[] arg pat_expr_list partial
let scrutineel = transl_match_scrutinee ~scopes arg_sort arg in
let pat_rhs_list =
transl_match_cases ~scopes ~return_sort:sort pat_expr_list
in
transl_match ~scopes ~arg_sort ~return_type:e.exp_type ~loc:e.exp_loc
~env:e.exp_env ~extra_cases:[] arg scrutineel pat_rhs_list partial
| Texp_try(body, pat_expr_list) ->
let id = Typecore.name_cases "exn" pat_expr_list in
let return_layout = layout_exp sort e in
Expand Down Expand Up @@ -1032,12 +1036,6 @@ and pure_module m =
and transl_list ~scopes expr_list =
List.map (fun (exp, sort) -> transl_exp ~scopes sort exp) expr_list

and transl_list_with_layout ~scopes expr_list =
List.map (fun (exp, sort) -> transl_exp ~scopes sort exp,
sort,
layout_exp sort exp)
expr_list

(* Will raise if a list element has a non-value layout. *)
and transl_list_with_shape ~scopes expr_list =
let transl_with_shape (e, sort) =
Expand All @@ -1054,17 +1052,25 @@ and transl_rhs ~scopes rhs_sort rhs =
(event_before ~scopes rhs (transl_exp ~scopes rhs_sort rhs))
| Boolean_guarded_rhs { guard = typed_guard; rhs } ->
let guard = transl_exp ~scopes Sort.for_predef_value typed_guard in
let body = event_before ~scopes rhs (transl_exp ~scopes rhs_sort rhs) in
let body =
event_before ~scopes rhs (transl_exp ~scopes rhs_sort rhs)
in
let patch_guarded ~patch =
event_before
~scopes typed_guard (Lifthenelse (guard, body, patch, layout))
in
let free_variables =
Ident.Set.union (free_variables guard) (free_variables body)
in
Matching.mk_boolean_guarded_rhs ~patch_guarded ~free_variables
Matching.mk_guarded_rhs ~patch_guarded ~free_variables
| Pattern_guarded_rhs { scrutinee; scrutinee_sort; cases; partial;
loc; env; rhs_type } ->
let scrutineel =
transl_match_scrutinee ~scopes scrutinee_sort scrutinee
in
let pat_rhs_list =
transl_match_cases ~scopes ~return_sort:rhs_sort cases
in
match partial with
| Partial ->
(* Partial pattern guards may fail to match, so we must construct a
Expand All @@ -1084,21 +1090,24 @@ and transl_rhs ~scopes rhs_sort rhs =
let extra_cases = [ any_pat, Matching.mk_unguarded_rhs patch ] in
event_before ~scopes scrutinee
(transl_match ~scopes ~arg_sort:scrutinee_sort
~return_sort:rhs_sort ~return_type:rhs_type ~loc ~env
~extra_cases scrutinee cases partial)
~return_type:rhs_type ~loc:loc ~env:env ~extra_cases scrutinee
scrutineel pat_rhs_list partial)
in
let free_variables =
free_variables_of_match scrutineel pat_rhs_list
in
Matching.mk_pattern_guarded_rhs ~patch_guarded
Matching.mk_guarded_rhs ~patch_guarded ~free_variables
| Total ->
(* Total pattern guards are equivalent to nested matches. *)
let nested_match =
transl_match ~scopes ~arg_sort:scrutinee_sort ~return_sort:rhs_sort
~return_type:rhs_type ~loc ~env ~extra_cases:[] scrutinee cases
partial
transl_match ~scopes ~arg_sort:scrutinee_sort ~return_type:rhs_type
~loc:loc ~env:env ~extra_cases:[] scrutinee scrutineel
pat_rhs_list partial
in
Matching.mk_unguarded_rhs
(event_before ~scopes scrutinee nested_match)

and transl_case ~scopes rhs_sort {c_lhs; c_rhs} =
and transl_case ~scopes rhs_sort { c_lhs; c_rhs } =
c_lhs, transl_rhs ~scopes rhs_sort c_rhs

and transl_cases ~scopes rhs_sort cases =
Expand All @@ -1107,10 +1116,10 @@ and transl_cases ~scopes rhs_sort cases =
in
List.map (transl_case ~scopes rhs_sort) cases

and transl_case_try ~scopes rhs_sort {c_lhs; c_rhs} =
and transl_case_try ~scopes rhs_sort ({c_lhs; _} as case) =
iter_exn_names Translprim.add_exception_ident c_lhs;
Misc.try_finally
(fun () -> c_lhs, transl_rhs ~scopes rhs_sort c_rhs)
(fun () -> transl_case ~scopes rhs_sort case)
~always:(fun () ->
iter_exn_names Translprim.remove_exception_ident c_lhs)

Expand Down Expand Up @@ -1617,31 +1626,57 @@ and transl_record ~scopes loc env mode fields repres opt_init_expr =
end
end

and transl_match ~scopes ~arg_sort ~return_sort ~return_type ~loc ~env
~extra_cases arg pat_expr_list partial =
and free_variables_of_match scrutineel cases =
let scrutinee_free_variables =
List.fold_left
(fun free arg -> Ident.Set.union free (free_variables arg))
Ident.Set.empty scrutineel
in
List.fold_left
(fun free (pat, rhs) ->
let case_free =
List.fold_left
(fun s id -> Ident.Set.remove id s)
(Matching.free_variables_of_rhs rhs) (pat_bound_idents pat)
in
Ident.Set.union free case_free)
scrutinee_free_variables cases

and transl_match_scrutinee ~scopes scrutinee_sort scrutinee =
let argl =
match scrutinee with
| { exp_desc = Texp_tuple (argl, _) } ->
List.map (fun arg -> arg, Sort.for_tuple_element) argl
| _ -> [ scrutinee, scrutinee_sort ]
in
transl_list ~scopes argl

and transl_match_cases ~scopes ~return_sort cases =
List.filter_map
(fun case -> transl_match_case ~scopes ~return_sort case) cases

and transl_match_case ~scopes ~return_sort ({ c_lhs; c_rhs } as case) =
if is_rhs_unreachable c_rhs then None else
let _, rhs =
match split_pattern c_lhs with
| None, None -> assert false
| Some pv, None -> transl_case ~scopes return_sort { case with c_lhs = pv }
| _, Some pe ->
transl_case_try ~scopes return_sort { case with c_lhs = pe }
in
Some (c_lhs, rhs)

and transl_match ~scopes ~arg_sort ~return_type ~loc ~env ~extra_cases arg
scrutineel pat_rhs_list partial =
let return_layout = layout env loc arg_sort return_type in
let rewrite_case (val_cases, exn_cases, static_handlers as acc)
({ c_lhs; c_rhs } as case) =
if is_rhs_unreachable c_rhs then acc else
let rewrite_case (val_cases, exn_cases, static_handlers)
(c_lhs, (c_rhs : Matching.rhs)) =
let val_pat, exn_pat = split_pattern c_lhs in
match val_pat, exn_pat with
| None, None -> assert false
| Some pv, None ->
let val_case =
transl_case ~scopes return_sort { case with c_lhs = pv }
in
val_case :: val_cases, exn_cases, static_handlers
| None, Some pe ->
let exn_case =
transl_case_try ~scopes return_sort { case with c_lhs = pe }
in
val_cases, exn_case :: exn_cases, static_handlers
| Some pv, None -> (pv, c_rhs) :: val_cases, exn_cases, static_handlers
| None, Some pe -> val_cases, (pe, c_rhs) :: exn_cases, static_handlers
| Some pv, Some pe ->
let rhs_exp =
match c_rhs with
| Simple_rhs rhs -> rhs
| Boolean_guarded_rhs _ | Pattern_guarded_rhs _ -> assert false
in
let lbl = next_raise_count () in
let static_raise ids =
Lstaticraise (lbl, List.map (fun id -> Lvar id) ids)
Expand All @@ -1657,21 +1692,12 @@ and transl_match ~scopes ~arg_sort ~return_sort ~return_type ~loc ~env
in
let vids = List.map Ident.rename ids in
let pv = alpha_pat (List.combine ids vids) pv in
(* Also register the names of the exception so Re-raise happens. *)
iter_exn_names Translprim.add_exception_ident pe;
let rhs =
Misc.try_finally
(fun () -> event_before ~scopes rhs_exp
(transl_exp ~scopes return_sort rhs_exp))
~always:(fun () ->
iter_exn_names Translprim.remove_exception_ident pe)
in
(pv, Matching.mk_unguarded_rhs (static_raise vids)) :: val_cases,
(pe, Matching.mk_unguarded_rhs (static_raise ids)) :: exn_cases,
(lbl, ids_kinds, rhs) :: static_handlers
(lbl, ids_kinds, Matching.unguarded_exn c_rhs) :: static_handlers
in
let val_cases, exn_cases, static_handlers =
let x, y, z = List.fold_left rewrite_case ([], [], []) pat_expr_list in
let x, y, z = List.fold_left rewrite_case ([], [], []) pat_rhs_list in
List.rev_append x extra_cases, List.rev y, List.rev z
in
(* In presence of exception patterns, the code we generate for
Expand Down Expand Up @@ -1708,9 +1734,14 @@ and transl_match ~scopes ~arg_sort ~return_sort ~return_type ~loc ~env
| {exp_desc = Texp_tuple (argl, alloc_mode)}, [] ->
assert (static_handlers = []);
let mode = transl_alloc_mode alloc_mode in
let argl = List.map (fun a -> (a, Sort.for_tuple_element)) argl in
Matching.for_multiple_match ~scopes ~return_layout loc
(transl_list_with_layout ~scopes argl) mode val_cases partial
let argl =
List.map
(fun (exp, lam) ->
lam, Sort.for_tuple_element, layout_exp Sort.for_tuple_element exp)
(List.combine argl scrutineel)
in
Matching.for_multiple_match
~scopes ~return_layout loc argl mode val_cases partial
| {exp_desc = Texp_tuple (argl, alloc_mode)}, _ :: _ ->
let argl = List.map (fun a -> (a, Sort.for_tuple_element)) argl in
let val_ids, lvars =
Expand All @@ -1723,14 +1754,14 @@ and transl_match ~scopes ~arg_sort ~return_sort ~return_type ~loc ~env
|> List.split
in
let mode = transl_alloc_mode alloc_mode in
static_catch (transl_list ~scopes argl) val_ids
static_catch scrutineel val_ids
(Matching.for_multiple_match ~scopes ~return_layout loc lvars mode
val_cases partial)
| arg, [] ->
assert (static_handlers = []);
let arg_layout = layout_exp arg_sort arg in
Matching.for_function ~scopes ~arg_sort ~arg_layout ~return_layout
loc None (transl_exp ~scopes arg_sort arg) val_cases partial
assert (static_handlers = []);
let arg_layout = layout_exp arg_sort arg in
Matching.for_function ~scopes ~arg_sort ~arg_layout ~return_layout loc
None (List.hd scrutineel) val_cases partial
| arg, _ :: _ ->
let val_id = Typecore.name_pattern "val" (List.map fst val_cases) in
let arg_layout = layout_exp arg_sort arg in
Expand Down

0 comments on commit a3dfee8

Please sign in to comment.