summaryrefslogtreecommitdiff
path: root/lambda
diff options
context:
space:
mode:
authorGabriel Scherer <gabriel.scherer@gmail.com>2020-03-29 15:02:21 +0200
committerGabriel Scherer <gabriel.scherer@gmail.com>2021-11-02 15:42:55 +0100
commite9397f6605beb57d8ccc08df1b3f928b40e9435e (patch)
tree91aac6b5b1553f437399bcd7a3e8ebd3f9bd676e /lambda
parent90dd724a3e75d51d826addfb87868816ae308696 (diff)
downloadocaml-e9397f6605beb57d8ccc08df1b3f928b40e9435e.tar.gz
TMC: generalize `Choice.t` to use binding operators
Diffstat (limited to 'lambda')
-rw-r--r--lambda/tmc.ml449
1 files changed, 246 insertions, 203 deletions
diff --git a/lambda/tmc.ml b/lambda/tmc.ml
index a80a85d76b..a3cf667365 100644
--- a/lambda/tmc.ml
+++ b/lambda/tmc.ml
@@ -1,15 +1,5 @@
open Lambda
-open struct
- let combine_upto short long =
- let prefix, rest = Misc.Stdlib.List.split_at (List.length short) long in
- List.combine short prefix, rest
-
- let option_of_list = function
- | [] -> None, []
- | x::xs -> Some x, xs
-end
-
(** TMC (Tail Modulo Cons) is a code transformation that
rewrites transformed functions in destination-passing-style, in
such a way that certain calls that were not in tail position in the
@@ -106,9 +96,17 @@ module Dps = struct
(** Create a new destination-passing-style term which is simply
setting the destination with the given [v], hence "returning"
it. *)
- let return (v : lambda): lambda t =
- fun ~tail:_ ~dst ->
- assign_to_dst dst v
+ let return (v : lambda): lambda t = fun ~tail:_ ~dst ->
+ assign_to_dst dst v
+
+ let map (f : 'a -> 'b) (dps : 'a t) : 'b t = fun ~tail ~dst ->
+ f @@ dps ~tail ~dst
+
+ let pair (fa : 'a t) (fb : 'b t) : ('a * 'b) t = fun ~tail ~dst ->
+ (fa ~tail ~dst, fb ~tail ~dst)
+
+ let unit : unit t = fun ~tail:_ ~dst:_ ->
+ ()
end
(** The TMC transformation requires information flows in two opposite
@@ -130,16 +128,26 @@ end
2. Code-production operators that have contextual information
to transform a "code choice" into the final code.
+
+ The code-production choices for a single term have type [lambda Choice.t];
+ using a parametrized type ['a Choice.t] is useful to represent
+ simultaneous choices over several subterms; for example
+ [(lambda * lambda) Choice.t] makes a choice for a pair of terms,
+ for example the [then] and [else] cases of a conditional. With
+ this parameter, ['a Choice.t] has an applicative structure, which
+ is useful to write the actual code transformation in the {!choice}
+ function.
*)
module Choice = struct
- type t = {
- dps : lambda Dps.t;
- direct : unit -> lambda;
+ type 'a t = {
+ dps : 'a Dps.t;
+ direct : unit -> 'a;
has_tmc_calls : bool;
}
(**
- A [Choice.t] represents code that may be written in destination-passing style
- if its usage context allows it. More precisely:
+ An ['a Choice.t] represents code that may be written
+ in destination-passing style if its usage context allows it.
+ More precisely:
- If the surrounding context is already in destination-passing
style, it has a destination available, we should produce the
@@ -158,25 +166,64 @@ module Choice = struct
position and are rewritten into tailcalls in the [dps] version.
*)
- let return v = {
+ let return (v : lambda) : lambda t = {
dps = Dps.return v;
direct = (fun () -> v);
has_tmc_calls = false;
}
- let direct (c : t) : lambda =
+ let map f s = {
+ dps = Dps.map f s.dps;
+ direct = (fun () -> f (s.direct ()));
+ has_tmc_calls = s.has_tmc_calls;
+ }
+ (** Apply function [f] to the transformed term. *)
+
+ let direct (c : lambda t) : lambda =
c.direct ()
- let dps (c : t) ~tail ~dst =
+ let dps (c : lambda t) ~tail ~dst =
c.dps ~tail:tail ~dst:dst
+ let pair ((c1, c2) : 'a t * 'b t) : ('a * 'b) t = {
+ dps = Dps.pair c1.dps c2.dps;
+ direct = (fun () -> (c1.direct (), c2.direct ()));
+ has_tmc_calls =
+ c1.has_tmc_calls || c2.has_tmc_calls;
+ }
+
+ let unit = {
+ dps = Dps.unit;
+ direct = (fun () -> ());
+ has_tmc_calls = false;
+ }
+
+ module Syntax = struct
+ let (let+) a f = map f a
+ let (and+) a1 a2 = pair (a1, a2)
+ end
+ open Syntax
+
+ let option (c : 'a t option) : 'a option t =
+ match c with
+ | None -> let+ () = unit in None
+ | Some c -> let+ v = c in Some v
+
+ let rec list (c : 'a t list) : 'a list t =
+ match c with
+ | [] -> let+ () = unit in []
+ | c :: cs ->
+ let+ v = c
+ and+ vs = list cs
+ in v :: vs
+
(** Finds the first [Choice.t] in a list that [has_tmc_calls] *)
- type zipper = {
- rev_before : lambda list;
- choice : t;
- after: t list
+ type 'a zipper = {
+ rev_before : 'a list;
+ choice : 'a t;
+ after: 'a t list
}
- let find_tmc_calls : t list -> (zipper, lambda list) result =
+ let find_tmc_calls : 'a t list -> ('a zipper, 'a list) result =
let rec find rev_before = function
| [] -> Error (List.rev rev_before)
| choice :: after ->
@@ -186,6 +233,8 @@ module Choice = struct
in find []
end
+open Choice.Syntax
+
type context = {
specialized: specialized Ident.Map.t;
}
@@ -212,116 +261,93 @@ let declare_binding ctx (var, def) =
let rec choice ctx t =
let rec choice ctx t =
- begin[@warning "-8"]
- (*FIXME: allows non-exhaustive pattern matching;
- use an overkill functor-based solution instead? *)
- match t with
- | (Lvar _ | Lconst _ | Lfunction _ | Lsend _
- | Lassign _ | Lfor _ | Lwhile _) ->
- let t = traverse ctx t in
- Choice.return t
-
- (* [choice_prim] handles most primitives, but the important case of construction
- [Lprim(Pmakeblock(...), ...)] is handled by [choice_makeblock] *)
- | Lprim (prim, primargs, loc) ->
- choice_prim ctx prim primargs loc
-
- (* [choice_apply] handles applications, in particular tail-calls which
- generate Set choices at the leaves *)
- | Lapply apply ->
- choice_apply ctx apply
- (* other cases use the [lift] helper that takes the sub-terms in tail
- position and the context around them, and generates a choice for
- the whole term from choices for the tail subterms. *)
- | Lsequence (l1, l2) ->
- let l1 = traverse ctx l1 in
- lift ctx [l2] @@ fun [l2] ->
- Lsequence (l1, l2)
- | Lifthenelse (l1, l2, l3) ->
- let l1 = traverse ctx l1 in
- lift ctx [l2; l3]
- (fun [l2; l3] -> Lifthenelse (l1, l2, l3))
- | Llet (lk, vk, var, def, body) ->
- (* non-recursive bindings are not specialized *)
- let def = traverse ctx def in
- lift ctx [body] @@ fun [body] ->
- Llet (lk, vk, var, def, body)
- | Lletrec (bindings, body) ->
- let ctx, bindings = traverse_letrec ctx bindings in
- lift ctx [body] @@ fun [body] ->
- Lletrec(bindings, body)
- | Lswitch (l1, sw, loc) ->
- let l1 = traverse ctx l1 in
- let consts_lhs, consts_rhs = List.split sw.sw_consts in
- let blocks_lhs, blocks_rhs = List.split sw.sw_blocks in
- let failaction = Option.to_list sw.sw_failaction in
- lift ctx (consts_rhs @ blocks_rhs @ failaction)
- (fun li ->
- let consts, li = combine_upto consts_lhs li in
- let blocks, li = combine_upto blocks_lhs li in
- let fail, li = option_of_list li in
- assert (li = []);
- let sw =
- { sw with
- sw_consts = consts;
- sw_blocks = blocks;
- sw_failaction = fail;
- }
- in
- Lswitch (l1, sw, loc))
- | Lstringswitch (l1, ls, lo, loc) ->
- let l1 = traverse ctx l1 in
- let cases_lhs, cases_rhs = List.split ls in
- let failaction = Option.to_list lo in
- lift ctx (cases_rhs @ failaction)
- (fun li ->
- let cases, li = combine_upto cases_lhs li in
- let fail, li = option_of_list li in
- assert (li = []);
- Lstringswitch (l1, cases, fail, loc))
- | Lstaticraise (id, ls) ->
- let ls = List.map (traverse ctx) ls in
- Choice.return (Lstaticraise (id, ls))
- | Ltrywith (l1, id, l2) ->
- (* in [try l1 with id -> l2], the term [l1] is
- not in tail-call position (after it returns
- we need to remove the exception handler),
- so it is not transformed here *)
- let l1 = traverse ctx l1 in
- lift ctx [l2]
- (fun [l2] -> Ltrywith (l1, id, l2))
- | Lstaticcatch (l1, ids, l2) ->
- (* In [static-catch l1 with ids -> l2],
- the term [l1] is in fact in tail-position *)
- lift ctx [l1; l2]
- (fun [l1; l2] -> Lstaticcatch (l1, ids, l2))
- | Levent (lam, lev) ->
- lift ctx [lam]
- (fun [lam] -> Levent (lam, lev))
- | Lifused (x, lam) ->
- lift ctx [lam]
- (fun [lam] -> Lifused (x, lam))
- end
-
- (* [lift ctx tail_terms context] optimizes a term of the form
- C[t1,..,tn] where the t1,..,tn are subterms of the multi-context C
- that are all in tail position.
-
- It works by recursively compiling each t1..tn into the corresponding choice.
- If they are all Return, we Return the overall context;
- otherwise there is at least one tail-term
- that is Set (would benefit from TMC), so we Set.
- *)
- and lift ctx tail_terms context =
- let choices = List.map (choice ctx) tail_terms in
- {
- Choice.dps = (fun ~tail ~dst ->
- context (List.map (Choice.dps ~tail ~dst) choices));
- direct = (fun () ->
- context (List.map Choice.direct choices));
- has_tmc_calls =
- List.exists (fun choice -> choice.Choice.has_tmc_calls) choices;
- }
+ match t with
+ | (Lvar _ | Lmutvar _ | Lconst _ | Lfunction _ | Lsend _
+ | Lassign _ | Lfor _ | Lwhile _) ->
+ let t = traverse ctx t in
+ Choice.return t
+
+ (* [choice_prim] handles most primitives, but the important case of construction
+ [Lprim(Pmakeblock(...), ...)] is handled by [choice_makeblock] *)
+ | Lprim (prim, primargs, loc) ->
+ choice_prim ctx prim primargs loc
+
+ (* [choice_apply] handles applications, in particular tail-calls which
+ generate Set choices at the leaves *)
+ | Lapply apply ->
+ choice_apply ctx apply
+ (* other cases use the [lift] helper that takes the sub-terms in tail
+ position and the context around them, and generates a choice for
+ the whole term from choices for the tail subterms. *)
+ | Lsequence (l1, l2) ->
+ let l1 = traverse ctx l1 in
+ let+ l2 = choice ctx l2 in
+ Lsequence (l1, l2)
+ | Lifthenelse (l1, l2, l3) ->
+ let l1 = traverse ctx l1 in
+ let+ (l2, l3) = choice_pair ctx (l2, l3) in
+ Lifthenelse (l1, l2, l3)
+ | Lmutlet (vk, var, def, body) ->
+ (* non-recursive bindings are not specialized *)
+ let def = traverse ctx def in
+ let+ body = choice ctx body in
+ Lmutlet (vk, var, def, body)
+ | Llet (lk, vk, var, def, body) ->
+ (* non-recursive bindings are not specialized *)
+ let def = traverse ctx def in
+ let+ body = choice ctx body in
+ Llet (lk, vk, var, def, body)
+ | Lletrec (bindings, body) ->
+ let ctx, bindings = traverse_letrec ctx bindings in
+ let+ body = choice ctx body in
+ Lletrec(bindings, body)
+ | Lswitch (l1, sw, loc) ->
+ (* decompose *)
+ let consts_lhs, consts_rhs = List.split sw.sw_consts in
+ let blocks_lhs, blocks_rhs = List.split sw.sw_blocks in
+ (* transform *)
+ let l1 = traverse ctx l1 in
+ let+ consts_rhs = choice_list ctx consts_rhs
+ and+ blocks_rhs = choice_list ctx blocks_rhs
+ and+ sw_failaction = choice_option ctx sw.sw_failaction in
+ (* rebuild *)
+ let sw_consts = List.combine consts_lhs consts_rhs in
+ let sw_blocks = List.combine blocks_lhs blocks_rhs in
+ let sw = { sw with sw_consts; sw_blocks; sw_failaction; } in
+ Lswitch (l1, sw, loc)
+ | Lstringswitch (l1, cases, fail, loc) ->
+ (* decompose *)
+ let cases_lhs, cases_rhs = List.split cases in
+ (* transform *)
+ let l1 = traverse ctx l1 in
+ let+ cases_rhs = choice_list ctx cases_rhs
+ and+ fail = choice_option ctx fail in
+ (* rebuild *)
+ let cases = List.combine cases_lhs cases_rhs in
+ Lstringswitch (l1, cases, fail, loc)
+ | Lstaticraise (id, ls) ->
+ let ls = traverse_list ctx ls in
+ Choice.return (Lstaticraise (id, ls))
+ | Ltrywith (l1, id, l2) ->
+ (* in [try l1 with id -> l2], the term [l1] is
+ not in tail-call position (after it returns
+ we need to remove the exception handler),
+ so it is not transformed here *)
+ let l1 = traverse ctx l1 in
+ let+ l2 = choice ctx l2 in
+ Ltrywith (l1, id, l2)
+ | Lstaticcatch (l1, ids, l2) ->
+ (* In [static-catch l1 with ids -> l2],
+ the term [l1] is in fact in tail-position *)
+ let+ l1 = choice ctx l1
+ and+ l2 = choice ctx l2 in
+ Lstaticcatch (l1, ids, l2)
+ | Levent (lam, lev) ->
+ let+ lam = choice ctx lam in
+ Levent (lam, lev)
+ | Lifused (x, lam) ->
+ let+ lam = choice ctx lam in
+ Lifused (x, lam)
and choice_apply ctx apply =
let exception No_tmc in
@@ -403,73 +429,87 @@ let rec choice ctx t =
}
and choice_prim ctx prim primargs loc =
- begin [@warning "-8"] (* see choice *)
- match prim with
- (* The important case is the construction case *)
- | Pmakeblock (tag, flag, shape) ->
- choice_makeblock ctx (tag, flag, shape) primargs loc
-
- (* Some primitives have arguments in tail-position *)
- | Popaque ->
- let [l1] = primargs in
- lift ctx [l1] (fun [l1] -> Lprim (Popaque, [l1], loc))
- | (Psequand | Psequor) as shortcutop ->
- let [l1; l2] = primargs in
- lift ctx [l2]
- (fun [l2] -> Lprim (shortcutop, [l1; l2], loc))
-
- (* cases we don't handle yet *)
- | (Pmakearray _ | Pduparray _) ->
- failwith "TODO: we don't handle array indices as destinations yet"
-
- | Pduprecord _ ->
- failwith "TODO"
-
- (* in common cases we just Return *)
- | Pbytes_to_string | Pbytes_of_string
- | Pgetglobal _ | Psetglobal _
- | Pfield _ | Pfield_computed
- | Psetfield _ | Psetfield_computed _
- | Pfloatfield _ | Psetfloatfield _
- | Pccall _
- | Praise _
- | Pnot
- | Pnegint | Paddint | Psubint | Pmulint | Pdivint _ | Pmodint _
- | Pandint | Porint | Pxorint
- | Plslint | Plsrint | Pasrint
- | Pintcomp _
- | Poffsetint _ | Poffsetref _
- | Pintoffloat | Pfloatofint
- | Pnegfloat | Pabsfloat
- | Paddfloat | Psubfloat | Pmulfloat | Pdivfloat
- | Pfloatcomp _
- | Pstringlength | Pstringrefu | Pstringrefs
- | Pbyteslength | Pbytesrefu | Pbytessetu | Pbytesrefs | Pbytessets
- | Parraylength _ | Parrayrefu _ | Parraysetu _ | Parrayrefs _ | Parraysets _
- | Pisint | Pisout
-
- (* operations returning boxed values could be considered constructions someday *)
- | Pbintofint _ | Pintofbint _
- | Pcvtbint _
- | Pnegbint _
- | Paddbint _ | Psubbint _ | Pmulbint _ | Pdivbint _ | Pmodbint _
- | Pandbint _ | Porbint _ | Pxorbint _ | Plslbint _ | Plsrbint _ | Pasrbint _
- | Pbintcomp _
-
- (* more common cases... *)
- | Pbigarrayref _ | Pbigarrayset _
- | Pbigarraydim _
- | Pstring_load_16 _ | Pstring_load_32 _ | Pstring_load_64 _
- | Pbytes_load_16 _ | Pbytes_load_32 _ | Pbytes_load_64 _
- | Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_64 _
- | Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_64 _
- | Pbigstring_set_16 _ | Pbigstring_set_32 _ | Pbigstring_set_64 _ | Pctconst _
- | Pbswap16
- | Pbbswap _
- | Pint_as_pointer
- ->
- Choice.return (Lprim (prim, primargs, loc))
- end
+ match prim with
+ (* The important case is the construction case *)
+ | Pmakeblock (tag, flag, shape) ->
+ choice_makeblock ctx (tag, flag, shape) primargs loc
+
+ (* Some primitives have arguments in tail-position *)
+ | Popaque ->
+ let l1 = match primargs with
+ | [l1] -> l1
+ | _ -> invalid_arg "choice_prim" in
+ let+ l1 = choice ctx l1 in
+ Lprim (Popaque, [l1], loc)
+ | (Psequand | Psequor) as shortcutop ->
+ let l1, l2 = match primargs with
+ | [l1; l2] -> l1, l2
+ | _ -> invalid_arg "choice_prim" in
+ let l1 = traverse ctx l1 in
+ let+ l2 = choice ctx l2 in
+ Lprim (shortcutop, [l1; l2], loc)
+
+ (* in common cases we just return *)
+ | Pbytes_to_string | Pbytes_of_string
+ | Pgetglobal _ | Psetglobal _
+ | Pfield _ | Pfield_computed
+ | Psetfield _ | Psetfield_computed _
+ | Pfloatfield _ | Psetfloatfield _
+ | Pccall _
+ | Praise _
+ | Pnot
+ | Pnegint | Paddint | Psubint | Pmulint | Pdivint _ | Pmodint _
+ | Pandint | Porint | Pxorint
+ | Plslint | Plsrint | Pasrint
+ | Pintcomp _
+ | Poffsetint _ | Poffsetref _
+ | Pintoffloat | Pfloatofint
+ | Pnegfloat | Pabsfloat
+ | Paddfloat | Psubfloat | Pmulfloat | Pdivfloat
+ | Pfloatcomp _
+ | Pstringlength | Pstringrefu | Pstringrefs
+ | Pbyteslength | Pbytesrefu | Pbytessetu | Pbytesrefs | Pbytessets
+ | Parraylength _ | Parrayrefu _ | Parraysetu _ | Parrayrefs _ | Parraysets _
+ | Pisint | Pisout
+ | Pignore
+ | Pcompare_ints | Pcompare_floats | Pcompare_bints _
+
+ (* we don't handle array indices as destinations yet *)
+ | (Pmakearray _ | Pduparray _)
+
+ (* we don't handle { foo with x = ...; y = recursive-call } *)
+ | Pduprecord _
+
+ (* operations returning boxed values could be considered constructions someday *)
+ | Pbintofint _ | Pintofbint _
+ | Pcvtbint _
+ | Pnegbint _
+ | Paddbint _ | Psubbint _ | Pmulbint _ | Pdivbint _ | Pmodbint _
+ | Pandbint _ | Porbint _ | Pxorbint _ | Plslbint _ | Plsrbint _ | Pasrbint _
+ | Pbintcomp _
+
+ (* more common cases... *)
+ | Pbigarrayref _ | Pbigarrayset _
+ | Pbigarraydim _
+ | Pstring_load_16 _ | Pstring_load_32 _ | Pstring_load_64 _
+ | Pbytes_load_16 _ | Pbytes_load_32 _ | Pbytes_load_64 _
+ | Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_64 _
+ | Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_64 _
+ | Pbigstring_set_16 _ | Pbigstring_set_32 _ | Pbigstring_set_64 _ | Pctconst _
+ | Pbswap16
+ | Pbbswap _
+ | Pint_as_pointer
+ ->
+ let primargs = traverse_list ctx primargs in
+ Choice.return (Lprim (prim, primargs, loc))
+
+ and choice_list ctx terms =
+ Choice.list (List.map (choice ctx) terms)
+ and choice_pair ctx (t1, t2) =
+ Choice.pair (choice ctx t1, choice ctx t2)
+ and choice_option ctx t =
+ Choice.option (Option.map (choice ctx) t)
+
in choice ctx t
and traverse ctx = function
@@ -506,6 +546,9 @@ and traverse_binding ctx (var, def) =
let dps_var = special.dps_id in
[(var, direct); (dps_var, dps)]
+and traverse_list ctx terms =
+ List.map (traverse ctx) terms
+
let rewrite t =
let ctx = { specialized = Ident.Map.empty } in
traverse ctx t