1 |
(* COPYRIGHT (c) 1996 Bell Laboratories *) |
(* COPYRIGHT (c) 1998 YALE FLINT PROJECT *) |
2 |
(* equal.sml *) |
(* equal.sml *) |
3 |
|
|
4 |
signature EQUAL = |
signature EQUAL = |
8 |
* Constructing generic equality functions; the current version will |
* Constructing generic equality functions; the current version will |
9 |
* use runtime polyequal function to deal with abstract types. (ZHONG) |
* use runtime polyequal function to deal with abstract types. (ZHONG) |
10 |
*) |
*) |
11 |
val equal : LambdaVar.lvar * LambdaVar.lvar * LtyDef.tyc -> Lambda.lexp |
val equal_branch : FLINT.primop * FLINT.value list * FLINT.lexp * FLINT.lexp |
12 |
|
-> FLINT.lexp |
13 |
val debugging : bool ref |
val debugging : bool ref |
14 |
|
|
15 |
end (* signature EQUAL *) |
end (* signature EQUAL *) |
18 |
structure Equal : EQUAL = |
structure Equal : EQUAL = |
19 |
struct |
struct |
20 |
|
|
21 |
local structure DA = Access |
local structure BT = BasicTypes |
|
structure BT = BasicTypes |
|
22 |
structure LT = LtyExtern |
structure LT = LtyExtern |
23 |
structure PT = PrimTyc |
structure PT = PrimTyc |
24 |
structure PO = PrimOp |
structure PO = PrimOp |
25 |
structure PP = PrettyPrint |
structure PP = PrettyPrint |
26 |
open Lambda |
structure FU = FlintUtil |
27 |
|
open FLINT |
28 |
in |
in |
29 |
|
|
30 |
val debugging = ref false |
val debugging = ref false |
31 |
fun bug msg = ErrorMsg.impossible("Equal: "^msg) |
fun bug msg = ErrorMsg.impossible("Equal: "^msg) |
32 |
val say = Control.Print.say |
val say = Control.Print.say |
33 |
|
val mkv = LambdaVar.mkLvar |
34 |
|
val ident = fn x => x |
35 |
|
|
36 |
|
|
37 |
val (trueDcon', falseDcon') = |
val (trueDcon', falseDcon') = |
38 |
let val lt = LT.ltc_parrow(LT.ltc_unit, LT.ltc_bool) |
let val lt = LT.ltc_arrow(LT.ffc_rrflint, [LT.ltc_unit], [LT.ltc_bool]) |
39 |
fun h (Types.DATACON{name, rep, ...}) = (name, rep, lt) |
fun h (Types.DATACON{name, rep, ...}) = (name, rep, lt) |
40 |
in (h BT.trueDcon, h BT.falseDcon) |
in (h BT.trueDcon, h BT.falseDcon) |
41 |
end |
end |
42 |
|
|
43 |
val tcEqv = LT.tc_eqv |
val tcEqv = LT.tc_eqv |
44 |
|
|
|
(* |
|
|
* MAJOR CLEANUP REQUIRED ! The function mkv is currently directly taken |
|
|
* from the LambdaVar module; I think it should be taken from the |
|
|
* "compInfo". Similarly, should we replace all mkLvar in the backend |
|
|
* with the mkv in "compInfo" ? (ZHONG) |
|
|
*) |
|
|
val mkv = LambdaVar.mkLvar |
|
|
|
|
|
val ident = fn x => x |
|
|
fun split(SVAL v) = (v, ident) |
|
|
| split x = let val v = mkv() |
|
|
in (VAR v, fn z => LET(v, x, z)) |
|
|
end |
|
|
|
|
|
fun APPg(e1, e2) = |
|
|
let val (v1, h1) = split e1 |
|
|
val (v2, h2) = split e2 |
|
|
in h1(h2(APP(v1, v2))) |
|
|
end |
|
|
|
|
|
fun RECORDg es = |
|
|
let fun f ([], vs, hdr) = hdr(RECORD (rev vs)) |
|
|
| f (e::r, vs, hdr) = |
|
|
let val (v, h) = split e |
|
|
in f(r, v::vs, hdr o h) |
|
|
end |
|
|
in f(es, [], ident) |
|
|
end |
|
45 |
|
|
46 |
fun SWITCHg(e, csig, ces, oe) = |
fun boolLexp b = |
47 |
let val (v, h) = split e |
let val v = mkv() and w = mkv() |
48 |
in h(SWITCH(v, csig, ces, oe)) |
val dc = if b then trueDcon' else falseDcon' |
49 |
|
in RECORD(FU.rk_tuple, [], v, CON(dc, [], VAR v, w, RET[VAR w])) |
50 |
end |
end |
51 |
|
|
52 |
fun CONg(dc, ts, e) = |
fun trueLexp () = boolLexp true |
53 |
let val (v, h) = split e |
fun falseLexp () = boolLexp false |
|
in h(CON(dc, ts, v)) |
|
|
end |
|
|
|
|
|
val (trueLexp, falseLexp) = |
|
|
let val unitLexp = RECORD [] |
|
|
in (CONg (trueDcon', [], unitLexp), CONg (falseDcon', [], unitLexp)) |
|
|
end |
|
54 |
|
|
55 |
exception Poly |
exception Poly |
56 |
|
|
58 |
* Commonly-used Lambda Types * |
* Commonly-used Lambda Types * |
59 |
****************************************************************************) |
****************************************************************************) |
60 |
|
|
61 |
val boolty = LT.ltc_bool |
(** assumptions: typed created here will be reprocessed in wrapping.sml *) |
62 |
fun eqLty lt = LT.ltc_arw(LT.ltc_tuple [lt, lt], boolty) |
fun eqLty lt = LT.ltc_arrow(LT.ffc_rrflint, [lt, lt], [LT.ltc_bool]) |
63 |
|
fun eqTy tc = eqLty(LT.ltc_tyc tc) |
64 |
|
|
65 |
val inteqty = eqLty (LT.ltc_int) |
val inteqty = eqLty (LT.ltc_int) |
66 |
val int32eqty = eqLty (LT.ltc_int32) |
val int32eqty = eqLty (LT.ltc_int32) |
67 |
val booleqty = eqLty (LT.ltc_bool) |
val booleqty = eqLty (LT.ltc_bool) |
68 |
val realeqty = eqLty (LT.ltc_real) |
val realeqty = eqLty (LT.ltc_real) |
69 |
|
|
70 |
fun eqTy tc = eqLty(LT.ltc_tyc tc) |
datatype resKind |
71 |
fun ptrEq(p, tc) = PRIM(p, eqTy tc, []) |
= VBIND of value |
72 |
fun prim(p, lt) = PRIM(p, lt, []) |
| PBIND of primop |
73 |
|
| EBIND of lexp |
74 |
|
|
75 |
|
fun ptrEq(p, tc) = PBIND (NONE, p, eqTy tc, []) |
76 |
|
fun prim(p, lt) = PBIND (NONE, p, lt, []) |
77 |
|
|
78 |
fun isRef tc = |
fun isRef tc = |
79 |
if LT.tcp_app tc then |
if LT.tcp_app tc then |
86 |
end) |
end) |
87 |
else false |
else false |
88 |
|
|
89 |
exception Notfound |
fun branch(PBIND p, vs, e1, e2) = BRANCH(p, vs, e1, e2) |
90 |
|
| branch(VBIND v, vs, e1, e2) = |
91 |
|
let val x = mkv() |
92 |
|
in LET([x], APP(v, vs), |
93 |
|
SWITCH(VAR x, BT.boolsign, |
94 |
|
[(DATAcon(trueDcon', [], mkv()), e1), |
95 |
|
(DATAcon(falseDcon', [], mkv()), e2)], NONE)) |
96 |
|
end |
97 |
|
| branch(EBIND e, vs, e1, e2) = |
98 |
|
let val x = mkv() |
99 |
|
in LET([x], e, branch(VBIND (VAR x), vs, e1, e2)) |
100 |
|
end |
101 |
|
|
102 |
(**************************************************************************** |
(**************************************************************************** |
103 |
* equal --- the equality function generator * |
* equal --- the equality function generator * |
104 |
****************************************************************************) |
****************************************************************************) |
105 |
|
exception Notfound |
106 |
|
|
107 |
fun equal (peqv, seqv, tc) = |
fun equal (peqv, seqv, tc) = |
108 |
let val cache : (tyc * lvar * lexp ref) list ref = ref nil |
let |
109 |
|
|
110 |
|
val cache : (tyc * lvar * (fundec option ref)) list ref = ref nil |
111 |
|
(* lexp ref is used for recursions ? *) |
112 |
|
|
113 |
fun enter tc = |
fun enter tc = |
114 |
let val v = mkv() |
let val v = mkv() |
115 |
val r = ref (SVAL(VAR v)) |
val r = ref NONE |
116 |
in cache := (tc, v, r) :: !cache; (VAR v, r) |
in cache := (tc, v, r) :: !cache; (v, r) |
117 |
end |
end |
118 |
|
(* the order of cache is relevant; the hdr may use the tail *) |
119 |
|
|
120 |
fun find tc = |
fun find tc = |
121 |
let fun f ((t,v,e)::r) = if tcEqv(tc,t) then VAR v else f r |
let fun f ((t,v,e)::r) = if tcEqv(tc,t) then VBIND(VAR v) else f r |
122 |
| f [] = (if !debugging |
| f [] = (if !debugging |
123 |
then say "equal.sml-find-notfound\n" else (); |
then say "equal.sml-find-notfound\n" else (); |
124 |
raise Notfound) |
raise Notfound) |
130 |
else if tcEqv(tc,LT.tcc_int32) then prim(PO.IEQL,int32eqty) |
else if tcEqv(tc,LT.tcc_int32) then prim(PO.IEQL,int32eqty) |
131 |
else if tcEqv(tc,LT.tcc_bool) then prim(PO.IEQL,booleqty) |
else if tcEqv(tc,LT.tcc_bool) then prim(PO.IEQL,booleqty) |
132 |
else if tcEqv(tc,LT.tcc_real) then prim(PO.FEQLd,realeqty) |
else if tcEqv(tc,LT.tcc_real) then prim(PO.FEQLd,realeqty) |
133 |
else if tcEqv(tc,LT.tcc_string) then (VAR seqv) |
else if tcEqv(tc,LT.tcc_string) then VBIND (VAR seqv) |
134 |
else if isRef(tc) then ptrEq(PO.PTREQL, tc) |
else if isRef(tc) then ptrEq(PO.PTREQL, tc) |
135 |
else raise Poly |
else raise Poly |
136 |
|
|
137 |
|
val fkfun = FK_FUN{isrec=NONE, known=false, fixed=LT.ffc_rrflint, inline=true} |
138 |
|
|
139 |
fun test(tc, 0) = raise Poly |
fun test(tc, 0) = raise Poly |
140 |
| test(tc, depth) = |
| test(tc, depth) = |
141 |
if LT.tcp_tuple tc then |
if LT.tcp_tuple tc then |
142 |
(let val ts = LT.tcd_tuple tc |
(let val ts = LT.tcd_tuple tc |
143 |
in (find tc handle Notfound => |
in (find tc handle Notfound => |
144 |
let val v = mkv() and x=mkv() and y=mkv() |
let val x=mkv() and y=mkv() |
145 |
val (eqv, patch) = enter tc |
val (v, patch) = enter tc |
|
fun loop(n, [tx]) = |
|
|
APPg(SVAL (test(tx, depth)), |
|
|
RECORDg[SELECT(n, VAR x), |
|
|
SELECT(n, VAR y)]) |
|
|
|
|
|
| loop(n, tx::r) = |
|
|
SWITCHg(loop(n,[tx]), BT.boolsign, |
|
|
[(DATAcon(trueDcon'), loop(n+1,r)), |
|
|
(DATAcon(falseDcon'), falseLexp)], |
|
|
NONE) |
|
146 |
|
|
147 |
| loop(_,nil) = trueLexp |
fun loop(n, tx::r) = |
148 |
|
let val a = mkv() and b = mkv() |
149 |
|
in SELECT(VAR x, n, a, |
150 |
|
SELECT(VAR y, n, b, |
151 |
|
branch(test(tx, depth), [VAR a, VAR b], |
152 |
|
loop(n+1, r), falseLexp()))) |
153 |
|
end |
154 |
|
| loop(_, []) = trueLexp() |
155 |
|
|
156 |
val lt = LT.ltc_tyc tc |
val lt = LT.ltc_tyc tc |
157 |
in patch := FN(v, LT.ltc_tuple [lt,lt], |
in patch := SOME (fkfun, v, [(x, lt), (y, lt)], loop(0, ts)); |
158 |
LET(x, SELECT(0, VAR v), |
VBIND(VAR v) |
|
LET(y, SELECT(1, VAR v), |
|
|
loop(0, ts)))); |
|
|
eqv |
|
159 |
end) |
end) |
160 |
end) |
end) |
161 |
else atomeq tc |
else atomeq tc |
162 |
|
|
163 |
val body = SVAL(test(tc, 10)) |
val body = test(tc, 10) |
164 |
val fl = !cache |
val fl = !cache |
165 |
|
|
166 |
in |
in |
167 |
(case fl |
(case fl |
168 |
of [] => body |
of [] => body |
169 |
| _ => let fun g ((tc, v, e), (vs, ts, es)) = |
| _ => let fun g ((tc, f, store), e) = |
170 |
(v::vs, (eqTy tc)::ts, (!e)::es) |
(case !store |
171 |
val (vs, ts, es) = foldr g ([], [], []) fl |
of NONE => e |
172 |
in FIX(vs, ts, es, body) |
| SOME fd => FIX([fd], e)) |
173 |
|
in case body |
174 |
|
of PBIND _ => bug "unexpected PBIND in equal" |
175 |
|
| VBIND u => EBIND(foldr g (RET[u]) fl) |
176 |
|
| EBIND e => EBIND(foldr g e fl) |
177 |
end) |
end) |
178 |
end handle Poly => (TAPP(VAR peqv, [tc])) |
|
179 |
|
end handle Poly => EBIND(TAPP(VAR peqv, [tc])) |
180 |
|
|
181 |
|
|
182 |
|
fun equal_branch ((d, p, lt, ts), vs, e1, e2) = |
183 |
|
(case (d, p, ts) |
184 |
|
of (SOME{default=pv, table=[(_,sv)]}, PO.POLYEQL, [tc]) => |
185 |
|
branch(equal(pv, sv, tc), vs, e1, e2) |
186 |
|
| _ => bug "unexpected case in equal_branch") |
187 |
|
|
188 |
end (* toplevel local *) |
end (* toplevel local *) |
189 |
end (* structure Equal *) |
end (* structure Equal *) |