summaryrefslogtreecommitdiff
path: root/testsuite/tests/typing-labels/mixin3.ml
blob: 5113eeb6f8943cb8c518c1ed9d38a56c627ea882 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
(* $Id: mixin3.ml 5929 2003-11-19 02:36:58Z garrigue $ *)

(* Full fledge version, using objects to structure code *)

open StdLabels
open MoreLabels

(* Use maps for substitutions and sets for free variables *)

module Subst = Map.Make(struct type t = string let compare = compare end)
module Names = Set.Make(struct type t = string let compare = compare end)

(* To build recursive objects *)

let lazy_fix make =
  let rec obj () = make (lazy (obj ()) : _ Lazy.t) in
  obj ()

let (!!) = Lazy.force

(* The basic operations *)

class type ['a, 'b] ops =
  object
    method free : 'b -> Names.t
    method subst : sub:'a Subst.t -> 'b -> 'a
    method eval : 'b -> 'a
  end

(* Variables are common to lambda and expr *)

type var = [`Var of string]

let var = object (self : ([>var], var) #ops)
  method subst ~sub (`Var s as x) =
    try Subst.find s sub with Not_found -> x
  method free (`Var s) =
    Names.singleton s
  method eval (#var as v) = v
end

(* The lambda language: free variables, substitutions, and evaluation *)

type 'a lambda = [`Var of string | `Abs of string * 'a | `App of 'a * 'a]

let next_id =
  let current = ref 3 in
  fun () -> incr current; !current

let lambda_ops (ops : ('a,'a) #ops Lazy.t) =
  let free = lazy !!ops#free
  and subst = lazy !!ops#subst
  and eval = lazy !!ops#eval in
  object (self : ([> 'a lambda], 'a lambda) #ops)
    method free = function
        #var as x -> var#free x
      | `Abs (s, t) -> Names.remove s (!!free t)
      | `App (t1, t2) -> Names.union (!!free t1) (!!free t2)

    method private map ~f = function
        #var as x -> x
      | `Abs (s, t) as l ->
          let t' = f t in
          if t == t' then l else `Abs(s, t')
      | `App (t1, t2) as l ->
          let t'1 = f t1 and t'2 = f t2 in
          if t'1 == t1 && t'2 == t2 then l else `App (t'1, t'2)

    method subst ~sub = function
        #var as x -> var#subst ~sub x
      | `Abs(s, t) as l ->
          let used = !!free t in
          let used_expr =
            Subst.fold sub ~init:[]
              ~f:(fun ~key ~data acc ->
                if Names.mem s used then data::acc else acc) in
          if List.exists used_expr ~f:(fun t -> Names.mem s (!!free t)) then
            let name = s ^ string_of_int (next_id ()) in
            `Abs(name,
                 !!subst ~sub:(Subst.add ~key:s ~data:(`Var name) sub) t)
          else
            self#map ~f:(!!subst ~sub:(Subst.remove s sub)) l
      | `App _ as l ->
          self#map ~f:(!!subst ~sub) l

    method eval l =
      match self#map ~f:!!eval l with
        `App(`Abs(s,t1), t2) ->
          !!eval (!!subst ~sub:(Subst.add ~key:s ~data:t2 Subst.empty) t1)
      | t -> t
end

(* Operations specialized to lambda *)

let lambda = lazy_fix lambda_ops

(* The expr language of arithmetic expressions *)

type 'a expr =
    [ `Var of string | `Num of int | `Add of 'a * 'a
    | `Neg of 'a | `Mult of 'a * 'a]

let expr_ops (ops : ('a,'a) #ops Lazy.t) =
  let free = lazy !!ops#free
  and subst = lazy !!ops#subst
  and eval = lazy !!ops#eval in
  object (self : ([> 'a expr], 'a expr) #ops)
    method free = function
        #var as x -> var#free x
      | `Num _ -> Names.empty
      | `Add(x, y) -> Names.union (!!free x) (!!free y)
      | `Neg x -> !!free x
      | `Mult(x, y) -> Names.union (!!free x) (!!free y)

    method private map ~f = function
        #var as x -> x
      | `Num _ as x -> x
      | `Add(x, y) as e ->
          let x' = f x and y' = f y in
          if x == x' && y == y' then e
          else `Add(x', y')
      | `Neg x as e ->
          let x' = f x in
          if x == x' then e else `Neg x'
      | `Mult(x, y) as e ->
          let x' = f x and y' = f y in
          if x == x' && y == y' then e
          else `Mult(x', y')

    method subst ~sub = function
        #var as x -> var#subst ~sub x
      | #expr as e -> self#map ~f:(!!subst ~sub) e

    method eval (#expr as e) =
      match self#map ~f:!!eval e with
        `Add(`Num m, `Num n) -> `Num (m+n)
      | `Neg(`Num n) -> `Num (-n)
      | `Mult(`Num m, `Num n) -> `Num (m*n)
      | e -> e
  end

(* Specialized versions *)

let expr = lazy_fix expr_ops

(* The lexpr language, reunion of lambda and expr *)

type 'a lexpr = [ 'a lambda | 'a expr ]

let lexpr_ops (ops : ('a,'a) #ops Lazy.t) =
  let lambda = lambda_ops ops in
  let expr = expr_ops ops in
  object (self : ([> 'a lexpr], 'a lexpr) #ops)
    method free = function
        #lambda as x -> lambda#free x
      | #expr as x -> expr#free x

    method subst ~sub = function
        #lambda as x -> lambda#subst ~sub x
      | #expr as x -> expr#subst ~sub x

    method eval = function
        #lambda as x -> lambda#eval x
      | #expr as x -> expr#eval x
end

let lexpr = lazy_fix lexpr_ops

let rec print = function
  | `Var id -> print_string id
  | `Abs (id, l) -> print_string ("\ " ^ id ^ " . "); print l
  | `App (l1, l2) -> print l1; print_string " "; print l2
  | `Num x -> print_int x
  | `Add (e1, e2) -> print e1; print_string " + "; print e2
  | `Neg e -> print_string "-"; print e
  | `Mult (e1, e2) -> print e1; print_string " * "; print e2

let () =
  let e1 = lambda#eval (`App(`Abs("x",`Var"x"), `Var"y")) in
  let e2 = expr#eval (`Add(`Mult(`Num 3,`Neg(`Num 2)), `Var"x")) in
  let e3 = lexpr#eval (`Add(`App(`Abs("x",`Mult(`Var"x",`Var"x")),`Num 2), `Num 5)) in
  print e1; print_newline ();
  print e2; print_newline ();
  print e3; print_newline ()