Home My Page Projects Code Snippets Project Openings SML/NJ
Summary Activity Forums Tracker Lists Tasks Docs Surveys News SCM Files

SCM Repository

[smlnj] Annotation of /sml/branches/num64/compiler/CPS/opt/num64cnv.sml
ViewVC logotype

Annotation of /sml/branches/num64/compiler/CPS/opt/num64cnv.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5164 - (view) (download)

1 : jhr 5110 (* num64cnv.sml
2 :     *
3 :     * COPYRIGHT (c) 2019 The Fellowship of SML/NJ (http://www.smlnj.org)
4 :     * All rights reserved.
5 :     *
6 :     * This module supports the 64-bit int/word types on 32-bit machines
7 :     * by expanding them to pairs of 32-bit words and replacing the primitive
8 :     * operations with 32-bit code.
9 :     *)
10 :    
11 :     structure Num64Cnv : sig
12 :    
13 : jhr 5122 (* eliminate 64-bit literals and operations on a 32-bit machine. This function
14 : jhr 5137 * does not rewrite its argument if either the target is a 64-bit machine or
15 :     * the no 64-bit operations were detected.
16 : jhr 5110 *)
17 : jhr 5137 val elim : CPS.function -> CPS.function
18 : jhr 5110
19 :     end = struct
20 :    
21 :     structure C = CPS
22 :     structure P = C.P
23 :     structure LV = LambdaVar
24 :    
25 :     fun bug s = ErrorMsg.impossible ("Num64Cnv: " ^ s)
26 :    
27 : jhr 5141 fun isNum64Ty (C.NUMt{sz = 64, ...}) = true
28 :     | isNum64Ty _ = false
29 : jhr 5156 val pairTy = C.PTRt C.VPT
30 : jhr 5110 val box32Ty = C.PTRt C.VPT
31 :     val raw32Ty = C.NUMt{sz = 32, tag = false} (* assuming a 32-bit machine *)
32 :     val tagNumTy = C.NUMt{sz = 31, tag = true}
33 :     val si32 = P.INT 32
34 :     val ui32 = P.UINT 32
35 :    
36 :     (* split a 64-bit integer/word literal into two 32-bit unsigned values. We assume
37 :     * that the argument is in the range -2^63 .. 2^64-1, which is the union of the ranges
38 :     * of Int64.int and Word64.word.
39 :     *)
40 :     fun split (n : IntInf.int) = let
41 :     val n = if (n < 0) then 0x10000000000000000 + n else n
42 :     val hi = C.NUM{ival = IntInf.~>>(n, 0w32), ty = {sz = 32, tag = false}}
43 :     val lo = C.NUM{ival = IntInf.andb(n, 0xffffffff), ty = {sz = 32, tag = false}}
44 :     in
45 :     (hi, lo)
46 :     end
47 :    
48 :     (* short names for various CPS constructs *)
49 :     val zero = C.NUM{ival=0, ty={sz = 32, tag = false}}
50 :     val one = C.NUM{ival=1, ty={sz = 32, tag = false}}
51 :     fun num n = C.NUM{ival=n, ty={sz = 32, tag = false}}
52 :     fun tagNum n = C.NUM{ival=n, ty={sz = 31, tag = true}}
53 :     fun uIf (oper, a, b, tr, fl) = (* unsigned conditional *)
54 :     C.BRANCH(P.CMP{oper=oper, kind=ui32}, [a, b], LV.mkLvar(), tr, fl)
55 :     fun sIf (oper, a, b, tr, fl) = (* signed conditional *)
56 :     C.BRANCH(P.CMP{oper=oper, kind=si32}, [a, b], LV.mkLvar(), tr, fl)
57 :     fun ifzero (v, tr, fl) = uIf(P.EQL, v, zero, tr, fl)
58 :     fun pure (rator, args, ty, k) = let
59 :     val x = LV.mkLvar()
60 :     in
61 :     C.PURE(rator, args, x, ty, k(C.VAR x))
62 :     end
63 :     fun pure_arith32 (oper, args, k) =
64 :     pure (P.PURE_ARITH{oper=oper, kind=ui32}, args, raw32Ty, k)
65 :     fun taggedArith (oper, args, k) =
66 :     pure (P.PURE_ARITH{oper=oper, kind=P.UINT 31}, args, tagNumTy, k)
67 :     fun iarith32 (oper, args, k) = let
68 :     val x = LV.mkLvar()
69 :     in
70 :     C.ARITH(P.IARITH{oper=oper, sz=32}, args, x, raw32Ty, k(C.VAR x))
71 :     end
72 :    
73 :     (* bitwise equivalence *)
74 :     fun bitEquiv (a, b, k) =
75 :     pure_arith32 (P.XORB, [a, b], fn a_xor_b =>
76 :     pure_arith32 (P.NOTB, [a_xor_b], k))
77 :    
78 : jhr 5141 (* make an application that will be substituted for a primop; we examine
79 :     * the structure of the continuation expression to avoid creating an
80 :     * eta-redex.
81 :     *)
82 :     fun mkApplyWithReturn (f, args, res, resTy, ce) = let
83 :     fun mkReturn () = let
84 :     val rk = LV.mkLvar()
85 :     in
86 :     C.FIX(
87 :     [(C.CONT, rk, [res], [resTy], ce)],
88 :     C.APP(f, C.VAR rk :: args))
89 :     end
90 :     in
91 :     case ce
92 :     of C.APP(C.VAR g, [C.VAR arg]) => if (arg = res)
93 :     then C.APP(f, C.VAR g :: args)
94 :     else mkReturn ()
95 :     | _ => mkReturn ()
96 :     (* end case *)
97 :     end
98 :    
99 : jhr 5110 (* bind a continuation around a cexp to avoid code duplication; `res` is the variable
100 :     * to use as a parameter for the code `cexp` (we assume that it is a wrapped 64-bit
101 :     * integer)
102 :     *)
103 :     fun join (res, cexp, k) = let
104 :     val fnId = LV.mkLvar()
105 :     in
106 :     C.FIX([(C.CONT, fnId, [res], [pairTy], cexp)],
107 :     k (fn v => C.APP(C.VAR fnId, [v])))
108 :     end
109 :    
110 :     (* given two 32-bit values that comprise a 64-bit number and a continuation `k`,
111 :     * make code to create the 64-bit object
112 :     *)
113 :     fun to64 (hi, lo, k) = let
114 :     val pair = LV.mkLvar()
115 :     in
116 : jhr 5156 C.RECORD(C.RK_RAWBLOCK, [(hi, C.OFFp 0), (lo, C.OFFp 0)],
117 : jhr 5157 pair, k(C.VAR pair))
118 : jhr 5110 end
119 :    
120 :     (* given a 64-bit object and a continuation `k`, make code to unpackage the value into
121 :     * two 32-bit values, which are passed to `k`.
122 :     *)
123 :     fun from64 (n, k) = let
124 :     val hi = LV.mkLvar()
125 :     val lo = LV.mkLvar()
126 :     in
127 : jhr 5156 C.SELECT(0, n, hi, raw32Ty,
128 :     C.SELECT(1, n, lo, raw32Ty,
129 : jhr 5157 k (C.VAR hi, C.VAR lo)))
130 : jhr 5110 end
131 :    
132 :     (* given a 64-bit object and a continuation `k`, make code to unpackage the low 32 word,
133 :     * which is passed to `k`.
134 :     *)
135 :     fun getLo32 (n, k) = let
136 :     val lo = LV.mkLvar()
137 :     in
138 : jhr 5157 C.SELECT(1, n, lo, raw32Ty, k(C.VAR lo))
139 : jhr 5110 end
140 :    
141 :     (* given a 64-bit object and a continuation `k`, make code to unpackage the high 32 word,
142 :     * which is passed to `k`.
143 :     *)
144 :     fun getHi32 (n, k) = let
145 :     val hi = LV.mkLvar()
146 :     in
147 : jhr 5157 C.SELECT(0, n, hi, raw32Ty, k(C.VAR hi))
148 : jhr 5110 end
149 :    
150 :     (* split a 32-bit value into two 16-bit values *)
151 :     fun split32 (n, k) =
152 :     pure_arith32(P.RSHIFT, [n, num 16], fn hi =>
153 :     pure_arith32(P.ANDB, [n, num 0xffff], fn lo =>
154 :     k (hi, lo)))
155 :    
156 :     (***** Word64 primitive operations *****)
157 :    
158 :     (*
159 :     * fun add2 ((hi1, lo1), (hi2, lo2)) = let
160 :     * val hi = hi1 + hi2
161 :     * val lo = lo1 + lo2
162 :     * (* from "Hacker's Delight": c = ((lo1 & lo2) | ((lo1 | lo2) & ¬lo)) >> 31 *)
163 :     * val carry = (((lo1 ++ lo2) & ~~lo) ++ (lo1 & lo2)) >> 0w31
164 :     * val hi = hi + carry
165 :     * in
166 :     * (hi, lo)
167 :     * end
168 :     *)
169 :     fun w64Add (n1, n2, res, cexp) = join (res, cexp, fn k =>
170 :     from64 (n1, fn (hi1, lo1) =>
171 :     from64 (n2, fn (hi2, lo2) =>
172 :     pure_arith32(P.ADD, [hi1, hi2], fn hi =>
173 :     pure_arith32(P.ADD, [lo1, lo2], fn lo =>
174 :     pure_arith32(P.ORB, [lo1, lo2], fn lo1_or_lo2 =>
175 :     pure_arith32(P.NOTB, [lo], fn not_lo =>
176 :     pure_arith32(P.ANDB, [lo1_or_lo2, not_lo], fn tmp1 =>
177 :     pure_arith32(P.ANDB, [lo1, lo2], fn lo1_and_lo2 =>
178 :     pure_arith32(P.ORB, [lo1_and_lo2, tmp1], fn tmp2 =>
179 :     pure_arith32(P.RSHIFT, [tmp2, tagNum 31], fn carry =>
180 :     pure_arith32(P.ADD, [hi, carry], fn hi =>
181 :     to64(hi, lo, k)))))))))))))
182 :    
183 :     (*
184 :     * fun sub ((hi1, lo1), (hi2, lo2)) = let
185 :     * val hi = hi1 - hi2 - b
186 :     * val lo = lo1 - lo2
187 :     * (* from "Hacker's Delight": b = ((¬lo1 & lo2) | ((lo1 ≡ lo2) & lo)) >> 31 *)
188 :     * val b = (((lo1 ^= lo2) & lo) ++ (~~lo1 & lo2)) >> 0w31
189 :     * val hi = hi - b
190 :     * in
191 :     * (hi, lo)
192 :     * end
193 :     *)
194 :     fun w64Sub (n1, n2, res, cexp) = join (res, cexp, fn k =>
195 :     from64 (n1, fn (hi1, lo1) =>
196 :     from64 (n2, fn (hi2, lo2) =>
197 :     pure_arith32(P.SUB, [hi1, hi2], fn hi =>
198 :     pure_arith32(P.SUB, [lo1, lo2], fn lo =>
199 :     bitEquiv(lo1, lo2, fn lo1_eqv_lo2 =>
200 :     pure_arith32(P.ANDB, [lo1_eqv_lo2, lo], fn tmp1 =>
201 :     pure_arith32(P.NOTB, [lo1], fn not_lo1 =>
202 :     pure_arith32(P.ANDB, [not_lo1, lo2], fn tmp2 =>
203 :     pure_arith32(P.ORB, [tmp1, tmp2], fn tmp3 =>
204 :     pure_arith32(P.RSHIFT, [tmp3, tagNum 31], fn borrow =>
205 :     pure_arith32(P.SUB, [hi, borrow], fn hi =>
206 :     to64(hi, lo, k)))))))))))))
207 :    
208 :     (*
209 :     * fun orb (W64(hi1, lo1), W64(hi2, lo2)) = W64(hi1 ++ hi2, lo1 ++ lo2)
210 :     *)
211 :     fun w64Orb (n1, n2, res, cexp) = join (res, cexp, fn k =>
212 :     from64 (n1, fn (hi1, lo1) =>
213 :     from64 (n2, fn (hi2, lo2) =>
214 :     pure_arith32(P.ORB, [lo1, lo2], fn lo =>
215 :     pure_arith32(P.ORB, [hi1, hi2], fn hi =>
216 :     to64 (hi, lo, k))))))
217 :    
218 :     (*
219 :     * fun xorb (W64(hi1, lo1), W64(hi2, lo2)) = W64(hi1 ^^ hi2, lo1^^ lo2)
220 :     *)
221 :     fun w64Xorb (n1, n2, res, cexp) = join (res, cexp, fn k =>
222 :     from64 (n1, fn (hi1, lo1) =>
223 :     from64 (n2, fn (hi2, lo2) =>
224 :     pure_arith32(P.XORB, [lo1, lo2], fn lo =>
225 :     pure_arith32(P.XORB, [hi1, hi2], fn hi =>
226 :     to64 (hi, lo, k))))))
227 :    
228 :     (*
229 :     * fun andb (W64(hi1, lo1), W64(hi2, lo2)) = W64(hi1 & hi2, lo1 & lo2)
230 :     *)
231 :     fun w64Andb (n1, n2, res, cexp) = join (res, cexp, fn k =>
232 :     from64 (n1, fn (hi1, lo1) =>
233 :     from64 (n2, fn (hi2, lo2) =>
234 :     pure_arith32(P.ANDB, [lo1, lo2], fn lo =>
235 :     pure_arith32(P.ANDB, [hi1, hi2], fn hi =>
236 :     to64 (hi, lo, k))))))
237 :    
238 :     (*
239 :     * fun notb (W64(hi, lo)) = W64(Word32.notb hi, Word32.notb lo)
240 :     *)
241 :     fun w64Notb (n, res, cexp) = join (res, cexp, fn k =>
242 :     from64 (n, fn (hi, lo) =>
243 :     pure_arith32(P.NOTB, [hi], fn hi' =>
244 :     pure_arith32(P.NOTB, [lo], fn lo' =>
245 :     to64 (hi', lo', k)))))
246 :    
247 :     (*
248 :     fun neg (hi, 0w0) = (W32.~ hi, 0w0)
249 :     | neg (hi, lo) = (W32.notb hi, W32.~ lo)
250 :     *)
251 :     fun w64Neg (n, res, cexp) = join (res, cexp, fn k =>
252 :     from64(n, fn (hi, lo) =>
253 :     ifzero(lo,
254 :     pure_arith32(P.NEG, [hi], fn hi' => to64 (hi', zero, k)),
255 :     (* else *)
256 :     pure_arith32(P.NOTB, [hi], fn hi' =>
257 :     pure_arith32(P.NEG, [lo], fn lo' =>
258 :     to64 (hi', lo', k))))))
259 :    
260 :     (* logical shift-right, where we know that amt < 0w64
261 :     *
262 :     * fun w63RShiftL ((hi, lo), amt) =
263 :     * if (amt < 32)
264 :     * then let
265 :     * val hi' = (hi >> amt)
266 :     * val lo' = (lo >> amt) | (hi << (0w32 - amt))
267 :     * in
268 :     * (hi', lo')
269 :     * end
270 :     * else (0, (hi >> (amt - 0w32)))
271 :     *
272 :     * Note, that while there is a branch-free version of this, it does not work
273 :     * on the x86 architecture, which uses mod-32 shift amounts.
274 :     *)
275 :     fun w64RShiftL (n, amt, res, cexp) = join (res, cexp, fn k =>
276 :     from64(n, fn (hi, lo) =>
277 :     sIf(P.LT, amt, tagNum 32,
278 :     pure_arith32(P.RSHIFTL, [hi, amt], fn hi' =>
279 :     pure_arith32(P.RSHIFTL, [lo, amt], fn tmp1 =>
280 :     taggedArith(P.SUB, [tagNum 32, amt], fn tmp2 =>
281 :     pure_arith32(P.LSHIFT, [hi, tmp2], fn tmp3 =>
282 :     pure_arith32(P.ORB, [tmp1, tmp2], fn lo' =>
283 :     to64(hi', lo', k)))))),
284 :     (* else *)
285 :     taggedArith(P.SUB, [amt, tagNum 32], fn tmp4 =>
286 :     pure_arith32(P.RSHIFTL, [hi, tmp4], fn lo' =>
287 :     to64(zero, lo', k))))))
288 :    
289 :     (*arithmetic shift-right, where we know that amt < 0w64
290 :     *
291 :     * fun w63RShift ((hi, lo), amt) =
292 :     * if (amt < 32)
293 :     * then let
294 :     * val hi' = (hi ~>> amt)
295 :     * val lo' = (lo >> amt) | (hi << (0w32 - amt))
296 :     * in
297 :     * (hi', lo')
298 :     * end
299 :     * else (hi ~>> 0w31, (hi ~>> (amt - 0w32)))
300 :     *
301 :     * Note, that while there is a branch-free version of this, it does not work
302 :     * on the x86 architecture, which uses mod-32 shift amounts.
303 :     *)
304 :     fun w64RShift (n, amt, res, cexp) = join (res, cexp, fn k =>
305 :     from64(n, fn (hi, lo) =>
306 :     sIf(P.LT, amt, tagNum 32,
307 :     pure_arith32(P.RSHIFT, [hi, amt], fn hi' =>
308 :     pure_arith32(P.RSHIFT, [lo, amt], fn tmp1 =>
309 :     taggedArith(P.SUB, [tagNum 32, amt], fn tmp2 =>
310 :     pure_arith32(P.LSHIFT, [hi, tmp2], fn tmp3 =>
311 :     pure_arith32(P.ORB, [tmp1, tmp2], fn lo' =>
312 :     to64(hi', lo', k)))))),
313 :     (* else *)
314 :     pure_arith32(P.RSHIFT, [hi, tagNum 31], fn hi' =>
315 :     taggedArith(P.SUB, [amt, tagNum 32], fn tmp4 =>
316 :     pure_arith32(P.RSHIFTL, [hi, tmp4], fn lo' =>
317 :     to64(hi', lo', k)))))))
318 :    
319 :     (* shift-left, where we know that amt < 0w64
320 :     *
321 :     * fun w64LShift ((hi, lo), amt) =
322 :     * if (amt < 0w32)
323 :     * then let
324 :     * val hi' = (hi << amt) | (lo >> (0w32 - amt))
325 :     * val lo' = (lo << amt)
326 :     * in
327 :     * (hi', lo')
328 :     * end
329 :     * else (lo << (amt - 0w32), 0)
330 :     *
331 :     * Note, that while there is a branch-free version of this, it does not work
332 :     * on the x86 architecture, which uses mod-32 shift amounts.
333 :     *)
334 :     fun w64LShift (n, amt, res, cexp) = join (res, cexp, fn k =>
335 :     from64(n, fn (hi, lo) =>
336 :     sIf(P.LT, amt, tagNum 32,
337 :     pure_arith32(P.LSHIFT, [hi, amt], fn tmp1 =>
338 :     taggedArith(P.SUB, [tagNum 32, amt], fn tmp2 =>
339 :     pure_arith32(P.RSHIFTL, [lo, tmp2], fn tmp3 =>
340 :     pure_arith32(P.ORB, [tmp1, tmp2], fn hi' =>
341 :     pure_arith32(P.LSHIFT, [lo, amt], fn lo' =>
342 :     to64(hi', lo', k)))))),
343 :     (* else *)
344 :     taggedArith(P.SUB, [amt, tagNum 32], fn tmp4 =>
345 :     pure_arith32(P.LSHIFT, [lo, tmp4], fn hi' =>
346 :     to64(hi', zero, k))))))
347 :    
348 :     (*
349 :     * fun w64Eql ((hi1, lo1), (hi2, lo2)) =
350 :     * (W32.orb(W32.xorb(hi1, hi2), W32.xorb(lo1, lo2)) = 0)
351 :     *)
352 :     fun w64Eql (n1, n2, tr, fl) =
353 :     from64(n1, fn (hi1, lo1) =>
354 :     from64(n2, fn (hi2, lo2) =>
355 :     pure_arith32(P.XORB, [hi1, hi2], fn hi' =>
356 :     pure_arith32(P.XORB, [lo1, lo2], fn lo' =>
357 :     pure_arith32(P.ORB, [hi', lo'], fn res =>
358 :     ifzero (res, tr, fl))))))
359 :    
360 :     (* the basic pattern for comparisons is
361 :     * fun cmp ((hi1, lo1), (hi2, lo2)) =
362 : jhr 5164 * cmpHi(hi1, hi2) orelse ((hi1 = hi2) andalso cmpLo(lo1, lo2))
363 : jhr 5110 *)
364 :     local
365 : jhr 5164 fun w64Cmp (cmpHi, cmpLo) (n1, n2, tr, fl) = let
366 : jhr 5110 (* continuations for the branches so that we can avoid code duplication *)
367 :     val trFnId = LV.mkLvar()
368 :     val tr' = C.APP(C.VAR trFnId, [])
369 :     val flFnId = LV.mkLvar()
370 :     val fl' = C.APP(C.VAR flFnId, [])
371 :     in
372 : jhr 5125 (* NOTE: closure conversion requires that there only be one continuation
373 :     * function per FIX!
374 :     *)
375 :     C.FIX([(C.CONT, trFnId, [], [], tr)],
376 :     C.FIX([(C.CONT, flFnId, [], [], fl)],
377 : jhr 5110 (* (hi1 < hi2) orelse ((hi1 = hi2) andalso (lo1 < lo2)) *)
378 :     getHi32(n1, fn hi1 =>
379 :     getHi32(n2, fn hi2 =>
380 : jhr 5164 uIf(cmpHi, hi1, hi2,
381 : jhr 5110 tr',
382 :     uIf(P.EQL, hi1, hi2,
383 :     getLo32(n1, fn lo1 =>
384 :     getLo32(n2, fn lo2 =>
385 : jhr 5164 uIf(cmpLo, lo1, lo2, tr', fl'))),
386 : jhr 5125 fl'))))))
387 : jhr 5110 end
388 :     in
389 : jhr 5164 val w64Less = w64Cmp (P.LT, P.LT)
390 :     val w64LessEq = w64Cmp (P.LT, P.LTE)
391 :     val w64Greater = w64Cmp (P.GT, P.GT)
392 :     val w64GreaterEq = w64Cmp (P.GT, P.GTE)
393 : jhr 5110 end (* local *)
394 :    
395 :     (***** Int64 primitive operations *****)
396 :    
397 :     (*
398 :     fun add64 ((hi1, lo1), (hi2, lo2)) = let
399 :     val lo = lo1 + lo2
400 :     (* from "Hacker's Delight": c = ((lo1 & lo2) | ((lo1 | lo2) & ¬lo)) >> 31 *)
401 :     val carry = ((lo1 & lo2) ++ ((lo1 ++ lo2) & not lo)) >> 0w31
402 :     (* we add the carry to the smaller hi component before add them; this
403 :     * check is needed to get Overflow right in the edge cases
404 :     *)
405 :     val hi = if InLine.int32_le(hi1, hi2)
406 :     then InLine.int32_add(InLine.int32_add(hi1, c), hi2)
407 :     else InLine.int32_add(InLine.int32_add(hi2, c), hi1)
408 :     in
409 :     (hi, lo)
410 :     end
411 :     *)
412 :     fun i64Add (n1, n2, res, cexp) = let
413 :     val hi = LV.mkLvar()
414 :     in
415 :     join (res, cexp, fn k =>
416 :     from64(n1, fn (hi1, lo1) =>
417 :     from64(n2, fn (hi2, lo2) =>
418 :     pure_arith32(P.ADD, [lo1, lo2], fn lo =>
419 :     pure_arith32(P.ORB, [lo1, lo2], fn lo1_or_lo2 =>
420 :     pure_arith32(P.NOTB, [lo], fn not_lo =>
421 :     pure_arith32(P.ANDB, [lo1_or_lo2, not_lo], fn tmp1 =>
422 :     pure_arith32(P.ANDB, [lo1, lo2], fn lo1_and_lo2 =>
423 :     pure_arith32(P.ORB, [lo1_and_lo2, tmp1], fn tmp2 =>
424 :     pure_arith32(P.RSHIFT, [tmp2, tagNum 31], fn carry =>
425 :     join (hi,
426 :     to64(C.VAR hi, lo, k),
427 :     fn k' =>
428 :     sIf(P.LTE, hi1, hi2,
429 :     iarith32(P.IADD, [hi1, carry], fn tmp1 =>
430 :     iarith32(P.IADD, [tmp1, hi2], k')),
431 :     (* else *)
432 :     iarith32(P.IADD, [hi2, carry], fn tmp2 =>
433 :     iarith32(P.IADD, [tmp2, hi1], k'))))))))))))))
434 :     end
435 :    
436 :     (*
437 :     fun sub64 ((hi1, lo1), (hi2, lo2)) = let
438 :     val lo = lo1 - lo2
439 :     (* from "Hacker's Delight": b = ((¬lo1 & lo2) | ((lo1 ≡ lo2) & lo)) >> 31 *)
440 :     val b = ((InLine.word32_notb lo1 & lo2) ++ ((lo1 ^= lo2) & lo)) >> 0w31
441 :     (* we need this test to get Overflow right in the edge cases *)
442 :     val hi = if InLine.int32_le(hi1, hi2)
443 :     then InLine.int32_sub(InLine.int32_sub(hi1, hi2), b)
444 :     else InLine.int32_sub(InLine.int32_sub(hi1, b), hi2)
445 :     in
446 :     (hi, lo)
447 :     end
448 :     *)
449 :     fun i64Sub (n1, n2, res, cexp) = let
450 :     val hi = LV.mkLvar()
451 :     in
452 :     join (res, cexp, fn k =>
453 :     from64(n1, fn (hi1, lo1) =>
454 :     from64(n2, fn (hi2, lo2) =>
455 :     pure_arith32(P.SUB, [lo1, lo2], fn lo =>
456 :     bitEquiv(lo1, lo2, fn lo1_eqv_lo2 =>
457 :     pure_arith32(P.ANDB, [lo1_eqv_lo2, lo], fn tmp1 =>
458 :     pure_arith32(P.NOTB, [lo1], fn not_lo1 =>
459 :     pure_arith32(P.ANDB, [not_lo1, lo2], fn tmp2 =>
460 :     pure_arith32(P.ORB, [tmp1, tmp2], fn tmp3 =>
461 :     pure_arith32(P.RSHIFT, [tmp3, tagNum 31], fn borrow =>
462 :     join (hi,
463 :     to64(C.VAR hi, lo, k),
464 :     fn k' =>
465 :     sIf(P.LTE, hi1, hi2,
466 :     iarith32(P.IADD, [hi1, hi2], fn tmp1 =>
467 :     iarith32(P.IADD, [tmp1, borrow], k')),
468 :     (* else *)
469 :     iarith32(P.IADD, [hi1, borrow], fn tmp2 =>
470 :     iarith32(P.IADD, [tmp2, hi2], k'))))))))))))))
471 :     end
472 :    
473 :     (*
474 :     * fun neg (hi, 0w0) = (I32.~ hi, 0w0)
475 :     * | neg (hi, lo) = (W32.notb hi, W32.~ lo)
476 :     *)
477 :     fun i64Neg (n, res, cexp) = join (res, cexp, fn k =>
478 :     from64(n, fn (hi, lo) =>
479 :     ifzero(lo,
480 :     iarith32(P.INEG, [hi], fn hi' => to64 (hi', zero, k)),
481 :     (* else *)
482 :     pure_arith32(P.NOTB, [hi], fn hi' =>
483 :     pure_arith32(P.NEG, [lo], fn lo' =>
484 :     to64 (hi', lo', k))))))
485 :    
486 :     val i64Eql = w64Eql
487 :    
488 :     (* the basic pattern for comparisons is
489 :     * fun cmp ((hi1, lo1), (hi2, lo2)) =
490 : jhr 5164 * cmpHi(hi1, hi2) orelse ((hi1 = hi2) andalso cmpLo(lo1, lo2))
491 : jhr 5110 *)
492 :     local
493 : jhr 5164 fun i64Cmp (cmpHi, cmpLo) (n1, n2, tr, fl) = let
494 : jhr 5110 (* continuations for the branches so that we can avoid code duplication *)
495 :     val trFnId = LV.mkLvar()
496 :     val tr' = C.APP(C.VAR trFnId, [])
497 :     val flFnId = LV.mkLvar()
498 :     val fl' = C.APP(C.VAR flFnId, [])
499 :     in
500 : jhr 5141 (* NOTE: closure conversion requires that there only be one continuation
501 :     * function per FIX!
502 :     *)
503 :     C.FIX([(C.CONT, trFnId, [], [], tr)],
504 :     C.FIX([(C.CONT, flFnId, [], [], fl)],
505 : jhr 5110 (* (hi1 < hi2) orelse ((hi1 = hi2) andalso (lo1 < lo2)) *)
506 :     getHi32(n1, fn hi1 =>
507 :     getHi32(n2, fn hi2 =>
508 : jhr 5164 sIf(cmpHi, hi1, hi2,
509 : jhr 5110 tr',
510 : jhr 5164 sIf(P.EQL, hi1, hi2,
511 : jhr 5110 getLo32(n1, fn lo1 =>
512 :     getLo32(n2, fn lo2 =>
513 : jhr 5164 sIf(cmpLo, lo1, lo2, tr', fl'))),
514 : jhr 5141 fl'))))))
515 : jhr 5110 end
516 :     in
517 : jhr 5164 val i64Less = i64Cmp (P.LT, P.LT)
518 :     val i64LessEq = i64Cmp (P.LT, P.LTE)
519 :     val i64Greater = i64Cmp (P.GT, P.GT)
520 :     val i64GreaterEq = i64Cmp (P.GT, P.GTE)
521 : jhr 5110 end (* local *)
522 :    
523 : jhr 5139 (***** conversions *****)
524 : jhr 5110
525 : jhr 5139 (* signed conversion from 64-bit word with test for overflow *)
526 : jhr 5141 fun test64To (toSz, [x, f], res, resTy, ce) =
527 :     if (toSz <= Target.defaultIntSz)
528 :     then let (* need extra conversion from 32-bits to fromSz *)
529 :     val rk = LV.mkLvar()
530 :     val v = LV.mkLvar()
531 :     val ce' = C.ARITH(P.TEST{from=32, to=toSz}, [C.VAR v], res, resTy, ce)
532 :     in
533 :     C.FIX([(C.CONT, rk, [v], [raw32Ty], ce')], C.APP (f, [C.VAR rk, x]))
534 :     end
535 :     else mkApplyWithReturn (f, [x], res, resTy, ce)
536 : jhr 5111
537 : jhr 5139 (* unsigned conversion from 64-bit word with test for overflow *)
538 : jhr 5141 fun testu64To (toSz, [x, f], res, resTy, ce) =
539 :     if (toSz <= Target.defaultIntSz)
540 :     then let (* need extra conversion from 32-bits to fromSz *)
541 :     val rk = LV.mkLvar()
542 :     val v = LV.mkLvar()
543 :     val ce' = C.ARITH(P.TESTU{from=32, to=toSz}, [C.VAR v], res, resTy, ce)
544 :     in
545 :     C.FIX([(C.CONT, rk, [v], [raw32Ty], ce')], C.APP (f, [C.VAR rk, x]))
546 :     end
547 :     else mkApplyWithReturn (f, [x], res, resTy, ce)
548 : jhr 5111
549 : jhr 5139 (* truncate a 64-bit number to a size <= 32 bit number *)
550 :     fun trunc64To (toSz, n, res, ce) = join (res, ce, fn k =>
551 :     getLo32 (n, if (toSz = 32)
552 :     then k
553 :     else (fn lo => pure(P.TRUNC{from=32, to=toSz}, [lo], tagNumTy, k))))
554 : jhr 5111
555 : jhr 5158 (* copy (zero-extend) a number to a 64-bit representation *)
556 :     fun copy64From (64, n, res, ce) = join (res, ce, fn k => k n)
557 :     | copy64From (fromSz, n, res, ce) = join (res, ce, fn k => if (fromSz = 32)
558 : jhr 5139 then to64 (zero, n, k)
559 :     else pure(P.COPY{from=fromSz, to=32}, [n], raw32Ty, fn lo =>
560 :     to64 (zero, lo, k)))
561 : jhr 5111
562 : jhr 5139 (* sign-extend a number to a 64-bit representation, where fromSz <= 32 *)
563 :     fun extend64From (fromSz, n, res, ce) = join (res, ce, fn k => if (fromSz = 32)
564 :     then pure_arith32(P.RSHIFT, [n, tagNum 31], fn hi =>
565 :     to64 (hi, n, k))
566 :     else pure(P.EXTEND{from=fromSz, to=32}, [n], raw32Ty, fn lo =>
567 :     pure_arith32(P.RSHIFT, [lo, tagNum 31], fn hi =>
568 :     to64 (hi, lo, k))))
569 : jhr 5111
570 : jhr 5122 (***** other functions *****)
571 :    
572 :     fun wrap64 (v, res, cexp) = join (res, cexp, fn k => k v)
573 : jhr 5125 fun unwrap64 (v, res, cexp) = join (res, cexp, fn k => k v)
574 : jhr 5122
575 :    
576 : jhr 5110 (***** main function *****)
577 :    
578 :     (* check if an expression needs rewriting *)
579 :     fun needsRewrite func = let
580 : jhr 5156 fun chkTy (C.NUMt{sz=64, ...}) = true
581 :     | chkTy _ = false
582 : jhr 5110 fun chkValue (C.NUM{ival, ty={sz=64, ...}}) = true
583 :     | chkValue _ = false
584 : jhr 5137 fun chkValues [] = false
585 :     | chkValues (v::vs) = chkValue v orelse chkValues vs
586 : jhr 5110 fun chkExp (C.RECORD(_, vs, _, e)) =
587 :     List.exists (chkValue o #1) vs orelse chkExp e
588 :     | chkExp (C.SELECT(_, v, _, _, e)) = chkValue v orelse chkExp e
589 :     | chkExp (C.OFFSET(_, v, _, e)) = chkValue v orelse chkExp e
590 : jhr 5137 | chkExp (C.APP(_, vs)) = chkValues vs
591 : jhr 5110 | chkExp (C.FIX(fns, e)) = List.exists chkFun fns orelse chkExp e
592 :     | chkExp (C.SWITCH(v, _, es)) = chkValue v orelse List.exists chkExp es
593 :     | chkExp (C.BRANCH(P.CMP{kind=P.INT 64, ...}, _, _, _, _)) = true
594 :     | chkExp (C.BRANCH(P.CMP{kind=P.UINT 64, ...}, _, _, _, _)) = true
595 : jhr 5137 | chkExp (C.BRANCH(_, vs, _, e1, e2)) =
596 :     chkValues vs orelse chkExp e1 orelse chkExp e2
597 : jhr 5141 (* QUESTION: what about RAWUPDATE and RAWSTORE? *)
598 : jhr 5137 | chkExp (C.SETTER(_, vs, e)) = chkValues vs orelse chkExp e
599 :     | chkExp (C.LOOKER(_, vs, _, _, e)) = chkValues vs orelse chkExp e
600 : jhr 5110 | chkExp (C.ARITH(P.IARITH{sz=64, ...}, _, _, _, _)) = true
601 : jhr 5139 | chkExp (C.ARITH(P.TEST{from=64, ...}, _, _, _, _)) = true
602 :     | chkExp (C.ARITH(P.TESTU{from=64, ...}, _, _, _, _)) = true
603 : jhr 5137 | chkExp (C.ARITH(_, vs, _, _, e)) = chkValues vs orelse chkExp e
604 : jhr 5110 | chkExp (C.PURE(P.PURE_ARITH{kind=P.UINT 64, ...}, _, _, _, _)) = true
605 : jhr 5139 | chkExp (C.PURE(P.COPY{to=64, ...}, _, _, _, _)) = true
606 :     | chkExp (C.PURE(P.EXTEND{to=64, ...}, _, _, _, _)) = true
607 :     | chkExp (C.PURE(P.TRUNC{from=64, ...}, _, _, _, _)) = true
608 : jhr 5122 | chkExp (C.PURE(P.WRAP(P.INT 64), _, _, _, _)) = true
609 :     | chkExp (C.PURE(P.UNWRAP(P.INT 64), _, _, _, _)) = true
610 : jhr 5137 | chkExp (C.PURE(_, vs, _, _, e)) = chkValues vs orelse chkExp e
611 :     | chkExp (C.RCC(_, _, _, vs, _, e)) = chkValues vs orelse chkExp e
612 : jhr 5156 and chkFun (_, _, _, tys, e) = List.exists chkTy tys orelse chkExp e
613 : jhr 5110 in
614 :     (not Target.is64) andalso (chkFun func)
615 :     end
616 :    
617 : jhr 5141 (* we replace occurrences of the 64-bit number type with "pointer to pair" *)
618 :     fun cvtTy (C.NUMt{sz=64, ...}) = pairTy
619 :     | cvtTy ty = ty
620 :    
621 : jhr 5110 fun elim cfun = let
622 :     fun value (C.NUM{ival, ty={sz=64, ...}}, k) = let
623 :     val (hi, lo) = split ival
624 :     in
625 :     to64 (hi, lo, k)
626 :     end
627 :     | value (v, k) = k v
628 :     and values (vl, k) = let
629 :     fun f ([], vl') = k (List.rev vl')
630 :     | f (C.NUM{ival, ty={sz=64, ...}}::vs, vl') = let
631 :     val (hi, lo) = split ival
632 :     in
633 :     to64 (hi, lo, fn v => f (vs, v::vl'))
634 :     end
635 :     | f (v::vs, vl') = f (vs, v::vl')
636 :     in
637 :     f (vl, [])
638 :     end
639 :     fun cexp (C.RECORD (rk, xl, v, e)) = let
640 :     fun f ([], args') = C.RECORD (rk, List.rev args', v, cexp e)
641 :     | f ((C.NUM{ival, ty={sz=64, ...}}, offp)::args, args') = let
642 :     val (hi, lo) = split ival
643 :     in
644 :     to64 (hi, lo, fn v => f (args, (v, offp)::args'))
645 :     end
646 :     | f (arg::args, args') = f (args, arg::args')
647 :     in
648 :     f (xl, [])
649 :     end
650 : jhr 5141 | cexp (C.SELECT(i, x, v, t, e)) = C.SELECT(i, x, v, cvtTy t, cexp e)
651 : jhr 5110 | cexp (C.OFFSET(i, v, x, e)) = C.OFFSET(i, v, x, cexp e)
652 :     | cexp (C.APP(f, xl)) = values (xl, fn xl' => C.APP (f, xl'))
653 :     | cexp (C.FIX(fl, e)) = C.FIX(List.map function fl, cexp e)
654 :     | cexp (C.SWITCH(x, v, el)) =
655 :     value (x, fn x' => C.SWITCH(x', v, List.map cexp el))
656 :     | cexp (C.BRANCH(P.CMP{oper, kind=P.INT 64}, args, _, e1, e2)) =
657 :     values (args, fn args' => (case (oper, args')
658 :     of (P.GT, [a, b]) => i64Greater(a, b, cexp e1, cexp e2)
659 :     | (P.GTE, [a, b]) => i64GreaterEq(a, b, cexp e1, cexp e2)
660 :     | (P.LT, [a, b]) => i64Less(a, b, cexp e1, cexp e2)
661 :     | (P.LTE, [a, b]) => i64LessEq(a, b, cexp e1, cexp e2)
662 :     | (P.EQL, [a, b]) => i64Eql(a, b, cexp e1, cexp e2)
663 :     | (P.NEQ, [a, b]) => i64Eql(a, b, cexp e2, cexp e1)
664 :     | _ => bug "impossible BRANCH; INT 64"
665 :     (* end case *)))
666 :     | cexp (C.BRANCH(P.CMP{oper, kind=P.UINT 64}, args, _, e1, e2)) =
667 :     values (args, fn args' => (case (oper, args')
668 :     of (P.GT, [a, b]) => w64Greater(a, b, cexp e1, cexp e2)
669 :     | (P.GTE, [a, b]) => w64GreaterEq(a, b, cexp e1, cexp e2)
670 :     | (P.LT, [a, b]) => w64Less(a, b, cexp e1, cexp e2)
671 :     | (P.LTE, [a, b]) => w64LessEq(a, b, cexp e1, cexp e2)
672 :     | (P.EQL, [a, b]) => w64Eql(a, b, cexp e1, cexp e2)
673 :     | (P.NEQ, [a, b]) => w64Eql(a, b, cexp e2, cexp e1)
674 :     | _ => bug "impossible BRANCH; UINT 64"
675 :     (* end case *)))
676 :     | cexp (C.BRANCH(rator, args, v, e1, e2)) =
677 : jhr 5125 values (args, fn args' => C.BRANCH(rator, args', v, cexp e1, cexp e2))
678 : jhr 5110 | cexp (C.SETTER(rator, xl, e)) =
679 :     values (xl, fn xl' => C.SETTER (rator, xl', cexp e))
680 :     | cexp (C.LOOKER (rator, xl, v, ty, e)) =
681 : jhr 5141 values (xl, fn xl' => C.LOOKER (rator, xl', v, cvtTy ty, cexp e))
682 : jhr 5110 | cexp (C.ARITH(P.IARITH{oper, sz=64}, args, res, _, e)) =
683 :     values (args, fn args' => (case (oper, args')
684 :     of (P.IADD, [a, b]) => i64Add(a, b, res, cexp e)
685 :     | (P.ISUB, [a, b]) => i64Sub(a, b, res, cexp e)
686 :     | (P.IMUL, [a, b, f]) => mkApply(f, [a, b], res, e)
687 :     | (P.IDIV, [a, b, f]) => mkApply(f, [a, b], res, e)
688 :     | (P.IMOD, [a, b, f]) => mkApply(f, [a, b], res, e)
689 :     | (P.IQUOT, [a, b, f]) => mkApply(f, [a, b], res, e)
690 :     | (P.IREM , [a, b, f]) => mkApply(f, [a, b], res, e)
691 : jhr 5125 | (P.INEG, [a]) => i64Neg(a, res, cexp e)
692 : jhr 5110 | _ => bug "impossible IARITH; sz=64"
693 :     (* end case *)))
694 : jhr 5139 | cexp (C.ARITH(P.TEST{from=64, to}, args, res, cty, e)) =
695 :     values (args, fn args' => test64To(to, args', res, cty, cexp e))
696 :     | cexp (C.ARITH(P.TESTU{from=64, to}, args, res, cty, e)) =
697 :     values (args, fn args' => testu64To(to, args', res, cty, cexp e))
698 : jhr 5110 | cexp (C.ARITH(rator, args, res, ty, e)) =
699 : jhr 5139 values (args, fn args' => C.ARITH(rator, args', res, ty, cexp e))
700 : jhr 5110 | cexp (C.PURE(P.PURE_ARITH{oper, kind=P.UINT 64}, args, res, _, e)) =
701 :     values (args, fn args' => (case (oper, args')
702 :     of (P.ADD, [a, b]) => w64Add(a, b, res, cexp e)
703 :     | (P.SUB, [a, b]) => w64Sub(a, b, res, cexp e)
704 :     | (P.MUL, [a, b, f]) => mkApply(f, [a, b], res, e)
705 :     | (P.QUOT, [a, b, f]) => mkApply(f, [a, b], res, e)
706 :     | (P.REM , [a, b, f]) => mkApply(f, [a, b], res, e)
707 : jhr 5125 | (P.NEG, [a]) => i64Neg(a, res, cexp e)
708 : jhr 5110 | (P.ORB, [a, b]) => w64Orb(a, b, res, cexp e)
709 :     | (P.XORB, [a, b]) => w64Xorb(a, b, res, cexp e)
710 :     | (P.ANDB, [a, b]) => w64Andb(a, b, res, cexp e)
711 : jhr 5125 | (P.NOTB, [a]) => w64Notb(a, res, cexp e)
712 : jhr 5110 | (P.RSHIFT, [a, b]) => w64RShift(a, b, res, cexp e)
713 :     | (P.RSHIFTL, [a, b]) => w64RShiftL(a, b, res, cexp e)
714 :     | (P.LSHIFT, [a, b]) => w64LShift(a, b, res, cexp e)
715 :     | _ => bug "impossible PURE_ARITH; UINT 64"
716 :     (* end case *)))
717 : jhr 5139 | cexp (C.PURE(P.TRUNC{from=64, to}, [a], res, _, e)) =
718 :     value (a, fn x => trunc64To(to, x, res, cexp e))
719 :     | cexp (C.PURE(P.COPY{from, to=64}, [a], res, _, e)) =
720 :     value (a, fn x => copy64From(from, x, res, cexp e))
721 :     | cexp (C.PURE(P.EXTEND{from, to=64}, [a], res, _, e)) =
722 :     value (a, fn x => extend64From(from, x, res, cexp e))
723 : jhr 5122 | cexp (C.PURE(P.WRAP(P.INT 64), [a], res, _, e)) =
724 :     value (a, fn x => wrap64 (x, res, cexp e))
725 : jhr 5125 | cexp (C.PURE(P.UNWRAP(P.INT 64), [a], res, _, e)) =
726 :     value (a, fn x => unwrap64 (x, res, cexp e))
727 : jhr 5110 | cexp (C.PURE(rator, args, res, ty, e)) =
728 : jhr 5141 values (args, fn args => C.PURE(rator, args, res, cvtTy ty, cexp e))
729 : jhr 5110 | cexp (C.RCC(rk, cf, proto, args, res, e)) =
730 : jhr 5137 values (args, fn args => C.RCC(rk, cf, proto, args, res, cexp e))
731 : jhr 5110 (* make an application of the function `f`, where `exp` is the continuation
732 :     * of the original primop and we assume the result type is a pair of
733 :     * 32-bit integers.
734 :     *)
735 : jhr 5141 and mkApply (f, args, res, exp) =
736 :     values (args, fn args' =>
737 :     mkApplyWithReturn (f, args', res, pairTy, cexp exp))
738 : jhr 5110 and function (fk, f, params, tys, body) =
739 : jhr 5141 (fk, f, params, List.map cvtTy tys, cexp body)
740 : jhr 5110 in
741 :     if needsRewrite cfun
742 : jhr 5137 then function cfun
743 :     else cfun
744 : jhr 5110 end (* elim *)
745 :    
746 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0