diff options
author | Gabriel Scherer <gabriel.scherer@gmail.com> | 2020-03-29 15:02:21 +0200 |
---|---|---|
committer | Gabriel Scherer <gabriel.scherer@gmail.com> | 2021-11-02 15:42:55 +0100 |
commit | e9397f6605beb57d8ccc08df1b3f928b40e9435e (patch) | |
tree | 91aac6b5b1553f437399bcd7a3e8ebd3f9bd676e /lambda | |
parent | 90dd724a3e75d51d826addfb87868816ae308696 (diff) | |
download | ocaml-e9397f6605beb57d8ccc08df1b3f928b40e9435e.tar.gz |
TMC: generalize `Choice.t` to use binding operators
Diffstat (limited to 'lambda')
-rw-r--r-- | lambda/tmc.ml | 449 |
1 files changed, 246 insertions, 203 deletions
diff --git a/lambda/tmc.ml b/lambda/tmc.ml index a80a85d76b..a3cf667365 100644 --- a/lambda/tmc.ml +++ b/lambda/tmc.ml @@ -1,15 +1,5 @@ open Lambda -open struct - let combine_upto short long = - let prefix, rest = Misc.Stdlib.List.split_at (List.length short) long in - List.combine short prefix, rest - - let option_of_list = function - | [] -> None, [] - | x::xs -> Some x, xs -end - (** TMC (Tail Modulo Cons) is a code transformation that rewrites transformed functions in destination-passing-style, in such a way that certain calls that were not in tail position in the @@ -106,9 +96,17 @@ module Dps = struct (** Create a new destination-passing-style term which is simply setting the destination with the given [v], hence "returning" it. *) - let return (v : lambda): lambda t = - fun ~tail:_ ~dst -> - assign_to_dst dst v + let return (v : lambda): lambda t = fun ~tail:_ ~dst -> + assign_to_dst dst v + + let map (f : 'a -> 'b) (dps : 'a t) : 'b t = fun ~tail ~dst -> + f @@ dps ~tail ~dst + + let pair (fa : 'a t) (fb : 'b t) : ('a * 'b) t = fun ~tail ~dst -> + (fa ~tail ~dst, fb ~tail ~dst) + + let unit : unit t = fun ~tail:_ ~dst:_ -> + () end (** The TMC transformation requires information flows in two opposite @@ -130,16 +128,26 @@ end 2. Code-production operators that have contextual information to transform a "code choice" into the final code. + + The code-production choices for a single term have type [lambda Choice.t]; + using a parametrized type ['a Choice.t] is useful to represent + simultaneous choices over several subterms; for example + [(lambda * lambda) Choice.t] makes a choice for a pair of terms, + for example the [then] and [else] cases of a conditional. With + this parameter, ['a Choice.t] has an applicative structure, which + is useful to write the actual code transformation in the {!choice} + function. *) module Choice = struct - type t = { - dps : lambda Dps.t; - direct : unit -> lambda; + type 'a t = { + dps : 'a Dps.t; + direct : unit -> 'a; has_tmc_calls : bool; } (** - A [Choice.t] represents code that may be written in destination-passing style - if its usage context allows it. More precisely: + An ['a Choice.t] represents code that may be written + in destination-passing style if its usage context allows it. + More precisely: - If the surrounding context is already in destination-passing style, it has a destination available, we should produce the @@ -158,25 +166,64 @@ module Choice = struct position and are rewritten into tailcalls in the [dps] version. *) - let return v = { + let return (v : lambda) : lambda t = { dps = Dps.return v; direct = (fun () -> v); has_tmc_calls = false; } - let direct (c : t) : lambda = + let map f s = { + dps = Dps.map f s.dps; + direct = (fun () -> f (s.direct ())); + has_tmc_calls = s.has_tmc_calls; + } + (** Apply function [f] to the transformed term. *) + + let direct (c : lambda t) : lambda = c.direct () - let dps (c : t) ~tail ~dst = + let dps (c : lambda t) ~tail ~dst = c.dps ~tail:tail ~dst:dst + let pair ((c1, c2) : 'a t * 'b t) : ('a * 'b) t = { + dps = Dps.pair c1.dps c2.dps; + direct = (fun () -> (c1.direct (), c2.direct ())); + has_tmc_calls = + c1.has_tmc_calls || c2.has_tmc_calls; + } + + let unit = { + dps = Dps.unit; + direct = (fun () -> ()); + has_tmc_calls = false; + } + + module Syntax = struct + let (let+) a f = map f a + let (and+) a1 a2 = pair (a1, a2) + end + open Syntax + + let option (c : 'a t option) : 'a option t = + match c with + | None -> let+ () = unit in None + | Some c -> let+ v = c in Some v + + let rec list (c : 'a t list) : 'a list t = + match c with + | [] -> let+ () = unit in [] + | c :: cs -> + let+ v = c + and+ vs = list cs + in v :: vs + (** Finds the first [Choice.t] in a list that [has_tmc_calls] *) - type zipper = { - rev_before : lambda list; - choice : t; - after: t list + type 'a zipper = { + rev_before : 'a list; + choice : 'a t; + after: 'a t list } - let find_tmc_calls : t list -> (zipper, lambda list) result = + let find_tmc_calls : 'a t list -> ('a zipper, 'a list) result = let rec find rev_before = function | [] -> Error (List.rev rev_before) | choice :: after -> @@ -186,6 +233,8 @@ module Choice = struct in find [] end +open Choice.Syntax + type context = { specialized: specialized Ident.Map.t; } @@ -212,116 +261,93 @@ let declare_binding ctx (var, def) = let rec choice ctx t = let rec choice ctx t = - begin[@warning "-8"] - (*FIXME: allows non-exhaustive pattern matching; - use an overkill functor-based solution instead? *) - match t with - | (Lvar _ | Lconst _ | Lfunction _ | Lsend _ - | Lassign _ | Lfor _ | Lwhile _) -> - let t = traverse ctx t in - Choice.return t - - (* [choice_prim] handles most primitives, but the important case of construction - [Lprim(Pmakeblock(...), ...)] is handled by [choice_makeblock] *) - | Lprim (prim, primargs, loc) -> - choice_prim ctx prim primargs loc - - (* [choice_apply] handles applications, in particular tail-calls which - generate Set choices at the leaves *) - | Lapply apply -> - choice_apply ctx apply - (* other cases use the [lift] helper that takes the sub-terms in tail - position and the context around them, and generates a choice for - the whole term from choices for the tail subterms. *) - | Lsequence (l1, l2) -> - let l1 = traverse ctx l1 in - lift ctx [l2] @@ fun [l2] -> - Lsequence (l1, l2) - | Lifthenelse (l1, l2, l3) -> - let l1 = traverse ctx l1 in - lift ctx [l2; l3] - (fun [l2; l3] -> Lifthenelse (l1, l2, l3)) - | Llet (lk, vk, var, def, body) -> - (* non-recursive bindings are not specialized *) - let def = traverse ctx def in - lift ctx [body] @@ fun [body] -> - Llet (lk, vk, var, def, body) - | Lletrec (bindings, body) -> - let ctx, bindings = traverse_letrec ctx bindings in - lift ctx [body] @@ fun [body] -> - Lletrec(bindings, body) - | Lswitch (l1, sw, loc) -> - let l1 = traverse ctx l1 in - let consts_lhs, consts_rhs = List.split sw.sw_consts in - let blocks_lhs, blocks_rhs = List.split sw.sw_blocks in - let failaction = Option.to_list sw.sw_failaction in - lift ctx (consts_rhs @ blocks_rhs @ failaction) - (fun li -> - let consts, li = combine_upto consts_lhs li in - let blocks, li = combine_upto blocks_lhs li in - let fail, li = option_of_list li in - assert (li = []); - let sw = - { sw with - sw_consts = consts; - sw_blocks = blocks; - sw_failaction = fail; - } - in - Lswitch (l1, sw, loc)) - | Lstringswitch (l1, ls, lo, loc) -> - let l1 = traverse ctx l1 in - let cases_lhs, cases_rhs = List.split ls in - let failaction = Option.to_list lo in - lift ctx (cases_rhs @ failaction) - (fun li -> - let cases, li = combine_upto cases_lhs li in - let fail, li = option_of_list li in - assert (li = []); - Lstringswitch (l1, cases, fail, loc)) - | Lstaticraise (id, ls) -> - let ls = List.map (traverse ctx) ls in - Choice.return (Lstaticraise (id, ls)) - | Ltrywith (l1, id, l2) -> - (* in [try l1 with id -> l2], the term [l1] is - not in tail-call position (after it returns - we need to remove the exception handler), - so it is not transformed here *) - let l1 = traverse ctx l1 in - lift ctx [l2] - (fun [l2] -> Ltrywith (l1, id, l2)) - | Lstaticcatch (l1, ids, l2) -> - (* In [static-catch l1 with ids -> l2], - the term [l1] is in fact in tail-position *) - lift ctx [l1; l2] - (fun [l1; l2] -> Lstaticcatch (l1, ids, l2)) - | Levent (lam, lev) -> - lift ctx [lam] - (fun [lam] -> Levent (lam, lev)) - | Lifused (x, lam) -> - lift ctx [lam] - (fun [lam] -> Lifused (x, lam)) - end - - (* [lift ctx tail_terms context] optimizes a term of the form - C[t1,..,tn] where the t1,..,tn are subterms of the multi-context C - that are all in tail position. - - It works by recursively compiling each t1..tn into the corresponding choice. - If they are all Return, we Return the overall context; - otherwise there is at least one tail-term - that is Set (would benefit from TMC), so we Set. - *) - and lift ctx tail_terms context = - let choices = List.map (choice ctx) tail_terms in - { - Choice.dps = (fun ~tail ~dst -> - context (List.map (Choice.dps ~tail ~dst) choices)); - direct = (fun () -> - context (List.map Choice.direct choices)); - has_tmc_calls = - List.exists (fun choice -> choice.Choice.has_tmc_calls) choices; - } + match t with + | (Lvar _ | Lmutvar _ | Lconst _ | Lfunction _ | Lsend _ + | Lassign _ | Lfor _ | Lwhile _) -> + let t = traverse ctx t in + Choice.return t + + (* [choice_prim] handles most primitives, but the important case of construction + [Lprim(Pmakeblock(...), ...)] is handled by [choice_makeblock] *) + | Lprim (prim, primargs, loc) -> + choice_prim ctx prim primargs loc + + (* [choice_apply] handles applications, in particular tail-calls which + generate Set choices at the leaves *) + | Lapply apply -> + choice_apply ctx apply + (* other cases use the [lift] helper that takes the sub-terms in tail + position and the context around them, and generates a choice for + the whole term from choices for the tail subterms. *) + | Lsequence (l1, l2) -> + let l1 = traverse ctx l1 in + let+ l2 = choice ctx l2 in + Lsequence (l1, l2) + | Lifthenelse (l1, l2, l3) -> + let l1 = traverse ctx l1 in + let+ (l2, l3) = choice_pair ctx (l2, l3) in + Lifthenelse (l1, l2, l3) + | Lmutlet (vk, var, def, body) -> + (* non-recursive bindings are not specialized *) + let def = traverse ctx def in + let+ body = choice ctx body in + Lmutlet (vk, var, def, body) + | Llet (lk, vk, var, def, body) -> + (* non-recursive bindings are not specialized *) + let def = traverse ctx def in + let+ body = choice ctx body in + Llet (lk, vk, var, def, body) + | Lletrec (bindings, body) -> + let ctx, bindings = traverse_letrec ctx bindings in + let+ body = choice ctx body in + Lletrec(bindings, body) + | Lswitch (l1, sw, loc) -> + (* decompose *) + let consts_lhs, consts_rhs = List.split sw.sw_consts in + let blocks_lhs, blocks_rhs = List.split sw.sw_blocks in + (* transform *) + let l1 = traverse ctx l1 in + let+ consts_rhs = choice_list ctx consts_rhs + and+ blocks_rhs = choice_list ctx blocks_rhs + and+ sw_failaction = choice_option ctx sw.sw_failaction in + (* rebuild *) + let sw_consts = List.combine consts_lhs consts_rhs in + let sw_blocks = List.combine blocks_lhs blocks_rhs in + let sw = { sw with sw_consts; sw_blocks; sw_failaction; } in + Lswitch (l1, sw, loc) + | Lstringswitch (l1, cases, fail, loc) -> + (* decompose *) + let cases_lhs, cases_rhs = List.split cases in + (* transform *) + let l1 = traverse ctx l1 in + let+ cases_rhs = choice_list ctx cases_rhs + and+ fail = choice_option ctx fail in + (* rebuild *) + let cases = List.combine cases_lhs cases_rhs in + Lstringswitch (l1, cases, fail, loc) + | Lstaticraise (id, ls) -> + let ls = traverse_list ctx ls in + Choice.return (Lstaticraise (id, ls)) + | Ltrywith (l1, id, l2) -> + (* in [try l1 with id -> l2], the term [l1] is + not in tail-call position (after it returns + we need to remove the exception handler), + so it is not transformed here *) + let l1 = traverse ctx l1 in + let+ l2 = choice ctx l2 in + Ltrywith (l1, id, l2) + | Lstaticcatch (l1, ids, l2) -> + (* In [static-catch l1 with ids -> l2], + the term [l1] is in fact in tail-position *) + let+ l1 = choice ctx l1 + and+ l2 = choice ctx l2 in + Lstaticcatch (l1, ids, l2) + | Levent (lam, lev) -> + let+ lam = choice ctx lam in + Levent (lam, lev) + | Lifused (x, lam) -> + let+ lam = choice ctx lam in + Lifused (x, lam) and choice_apply ctx apply = let exception No_tmc in @@ -403,73 +429,87 @@ let rec choice ctx t = } and choice_prim ctx prim primargs loc = - begin [@warning "-8"] (* see choice *) - match prim with - (* The important case is the construction case *) - | Pmakeblock (tag, flag, shape) -> - choice_makeblock ctx (tag, flag, shape) primargs loc - - (* Some primitives have arguments in tail-position *) - | Popaque -> - let [l1] = primargs in - lift ctx [l1] (fun [l1] -> Lprim (Popaque, [l1], loc)) - | (Psequand | Psequor) as shortcutop -> - let [l1; l2] = primargs in - lift ctx [l2] - (fun [l2] -> Lprim (shortcutop, [l1; l2], loc)) - - (* cases we don't handle yet *) - | (Pmakearray _ | Pduparray _) -> - failwith "TODO: we don't handle array indices as destinations yet" - - | Pduprecord _ -> - failwith "TODO" - - (* in common cases we just Return *) - | Pbytes_to_string | Pbytes_of_string - | Pgetglobal _ | Psetglobal _ - | Pfield _ | Pfield_computed - | Psetfield _ | Psetfield_computed _ - | Pfloatfield _ | Psetfloatfield _ - | Pccall _ - | Praise _ - | Pnot - | Pnegint | Paddint | Psubint | Pmulint | Pdivint _ | Pmodint _ - | Pandint | Porint | Pxorint - | Plslint | Plsrint | Pasrint - | Pintcomp _ - | Poffsetint _ | Poffsetref _ - | Pintoffloat | Pfloatofint - | Pnegfloat | Pabsfloat - | Paddfloat | Psubfloat | Pmulfloat | Pdivfloat - | Pfloatcomp _ - | Pstringlength | Pstringrefu | Pstringrefs - | Pbyteslength | Pbytesrefu | Pbytessetu | Pbytesrefs | Pbytessets - | Parraylength _ | Parrayrefu _ | Parraysetu _ | Parrayrefs _ | Parraysets _ - | Pisint | Pisout - - (* operations returning boxed values could be considered constructions someday *) - | Pbintofint _ | Pintofbint _ - | Pcvtbint _ - | Pnegbint _ - | Paddbint _ | Psubbint _ | Pmulbint _ | Pdivbint _ | Pmodbint _ - | Pandbint _ | Porbint _ | Pxorbint _ | Plslbint _ | Plsrbint _ | Pasrbint _ - | Pbintcomp _ - - (* more common cases... *) - | Pbigarrayref _ | Pbigarrayset _ - | Pbigarraydim _ - | Pstring_load_16 _ | Pstring_load_32 _ | Pstring_load_64 _ - | Pbytes_load_16 _ | Pbytes_load_32 _ | Pbytes_load_64 _ - | Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_64 _ - | Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_64 _ - | Pbigstring_set_16 _ | Pbigstring_set_32 _ | Pbigstring_set_64 _ | Pctconst _ - | Pbswap16 - | Pbbswap _ - | Pint_as_pointer - -> - Choice.return (Lprim (prim, primargs, loc)) - end + match prim with + (* The important case is the construction case *) + | Pmakeblock (tag, flag, shape) -> + choice_makeblock ctx (tag, flag, shape) primargs loc + + (* Some primitives have arguments in tail-position *) + | Popaque -> + let l1 = match primargs with + | [l1] -> l1 + | _ -> invalid_arg "choice_prim" in + let+ l1 = choice ctx l1 in + Lprim (Popaque, [l1], loc) + | (Psequand | Psequor) as shortcutop -> + let l1, l2 = match primargs with + | [l1; l2] -> l1, l2 + | _ -> invalid_arg "choice_prim" in + let l1 = traverse ctx l1 in + let+ l2 = choice ctx l2 in + Lprim (shortcutop, [l1; l2], loc) + + (* in common cases we just return *) + | Pbytes_to_string | Pbytes_of_string + | Pgetglobal _ | Psetglobal _ + | Pfield _ | Pfield_computed + | Psetfield _ | Psetfield_computed _ + | Pfloatfield _ | Psetfloatfield _ + | Pccall _ + | Praise _ + | Pnot + | Pnegint | Paddint | Psubint | Pmulint | Pdivint _ | Pmodint _ + | Pandint | Porint | Pxorint + | Plslint | Plsrint | Pasrint + | Pintcomp _ + | Poffsetint _ | Poffsetref _ + | Pintoffloat | Pfloatofint + | Pnegfloat | Pabsfloat + | Paddfloat | Psubfloat | Pmulfloat | Pdivfloat + | Pfloatcomp _ + | Pstringlength | Pstringrefu | Pstringrefs + | Pbyteslength | Pbytesrefu | Pbytessetu | Pbytesrefs | Pbytessets + | Parraylength _ | Parrayrefu _ | Parraysetu _ | Parrayrefs _ | Parraysets _ + | Pisint | Pisout + | Pignore + | Pcompare_ints | Pcompare_floats | Pcompare_bints _ + + (* we don't handle array indices as destinations yet *) + | (Pmakearray _ | Pduparray _) + + (* we don't handle { foo with x = ...; y = recursive-call } *) + | Pduprecord _ + + (* operations returning boxed values could be considered constructions someday *) + | Pbintofint _ | Pintofbint _ + | Pcvtbint _ + | Pnegbint _ + | Paddbint _ | Psubbint _ | Pmulbint _ | Pdivbint _ | Pmodbint _ + | Pandbint _ | Porbint _ | Pxorbint _ | Plslbint _ | Plsrbint _ | Pasrbint _ + | Pbintcomp _ + + (* more common cases... *) + | Pbigarrayref _ | Pbigarrayset _ + | Pbigarraydim _ + | Pstring_load_16 _ | Pstring_load_32 _ | Pstring_load_64 _ + | Pbytes_load_16 _ | Pbytes_load_32 _ | Pbytes_load_64 _ + | Pbytes_set_16 _ | Pbytes_set_32 _ | Pbytes_set_64 _ + | Pbigstring_load_16 _ | Pbigstring_load_32 _ | Pbigstring_load_64 _ + | Pbigstring_set_16 _ | Pbigstring_set_32 _ | Pbigstring_set_64 _ | Pctconst _ + | Pbswap16 + | Pbbswap _ + | Pint_as_pointer + -> + let primargs = traverse_list ctx primargs in + Choice.return (Lprim (prim, primargs, loc)) + + and choice_list ctx terms = + Choice.list (List.map (choice ctx) terms) + and choice_pair ctx (t1, t2) = + Choice.pair (choice ctx t1, choice ctx t2) + and choice_option ctx t = + Choice.option (Option.map (choice ctx) t) + in choice ctx t and traverse ctx = function @@ -506,6 +546,9 @@ and traverse_binding ctx (var, def) = let dps_var = special.dps_id in [(var, direct); (dps_var, dps)] +and traverse_list ctx terms = + List.map (traverse ctx) terms + let rewrite t = let ctx = { specialized = Ident.Map.empty } in traverse ctx t |