Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/translate/translate.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/translate/translate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5574 - (view) (download)

1 : jhr 3471 (* translate.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 : jhr 3476 * Translate Simple-AST code into the HighIR representation. This translation is based on the
9 : jhr 3471 * algorithm described in
10 :     *
11 :     * Single-pass generation of static single assignment form for structured languages
12 :     * ACM TOPLAS, Nov. 1994
13 :     * by Brandis and MossenBock.
14 :     *)
15 :    
16 :     structure Translate : sig
17 :    
18 : jhr 3476 val translate : Simple.program -> HighIR.program
19 : jhr 3471
20 :     end = struct
21 :    
22 :     structure S = Simple
23 :     structure Ty = SimpleTypes
24 : jhr 3485 structure SV = SimpleVar
25 :     structure VMap = SV.Map
26 :     structure VSet = SV.Set
27 : jhr 3476 structure IR = HighIR
28 : jhr 3471 structure Op = HighOps
29 : jhr 3476 structure DstTy = HighTypes
30 :     structure Census = HighCensus
31 : jhr 3493 structure Inp = Inputs
32 : jhr 3471
33 :     val cvtTy = TranslateTy.tr
34 :    
35 : jhr 5286 (* code contexts *)
36 :     datatype context = Method | GlobalUpdate | Other
37 :    
38 : jhr 3471 (* maps from SimpleAST variables to the current corresponding SSA variable *)
39 : jhr 5286 datatype env = E of context * IR.var VMap.map
40 : jhr 3471
41 : jhr 5564 (* mapping from differentiable field functions to their translated definitions *)
42 :     local
43 :     val {getFn : SimpleFunc.t -> IR.func_def, setFn, ...} =
44 :     SimpleFunc.newProp (fn f => raise Fail(concat[
45 :     "no binding for field function '", SimpleFunc.uniqueNameOf f, "'"
46 :     ]))
47 :     in
48 :     val getFieldFnDef = getFn
49 :     val setFieldFnDef = setFn
50 :     end (* local *)
51 :    
52 : jhr 3471 (* +DEBUG *)
53 : jhr 5286 fun prEnv (prefix, E(_, env)) = let
54 : jhr 3471 val wid = ref 0
55 :     fun pr s = (print s; wid := !wid + size s)
56 :     fun nl () = if (!wid > 0) then (print "\n"; wid := 0) else ()
57 :     fun prElem (src, dst) = let
58 :     val s = String.concat [
59 : jhr 3485 " ", SV.uniqueNameOf src, "->", IR.Var.toString dst
60 : jhr 3471 ]
61 :     in
62 :     pr s;
63 :     if (!wid >= 100) then (nl(); pr " ") else ()
64 :     end
65 :     in
66 :     pr prefix; pr " ENV: {"; nl(); pr " ";
67 :     VMap.appi prElem env;
68 :     nl(); pr "}"; nl()
69 :     end
70 :     (* -DEBUG *)
71 :    
72 : jhr 3846 (* a property to map Simple variables to High IR state variables. We need this to support
73 :     * reading the state of other strands.
74 :     *)
75 :     val {getFn=getStateVar, ...} = let
76 : jhr 4317 fun newSVar x = IR.StateVar.new (
77 :     SV.kindOf x = SV.StrandOutputVar,
78 :     SV.nameOf x, cvtTy(SV.typeOf x),
79 :     AnalyzeSimple.varyingStateVar x,
80 :     AnalyzeSimple.sharedStateVar x)
81 :     in
82 :     SV.newProp newSVar
83 :     end
84 : jhr 3846
85 : jhr 5286 fun emptyEnv cxt = E(cxt, VMap.empty)
86 :    
87 :     fun lookup (E(_, vMap)) x = (case VMap.find (vMap, x)
88 : jhr 3471 of SOME x' => x'
89 :     | NONE => raise Fail(concat[
90 : jhr 4505 "no binding for ", SV.kindToString(SV.kindOf x), " ",
91 : jhr 4628 SV.uniqueNameOf x, " in environment"
92 : jhr 3471 ])
93 :     (* end case *))
94 :    
95 : jhr 5286 fun find (E(_, vMap), x) = VMap.find(vMap, x)
96 :    
97 :     fun insert (E(cxt, vMap), x, x') = E(cxt, VMap.insert(vMap, x, x'))
98 :    
99 :     fun context (E(_, vMap), cxt) = E(cxt, vMap)
100 :    
101 :     fun inMethod (E(Method, _)) = true
102 :     | inMethod _ = false
103 :    
104 :     fun inGlobalUpdate (E(GlobalUpdate, _)) = true
105 :     | inGlobalUpdate _ = false
106 :    
107 : jhr 3471 (* create a new instance of a variable *)
108 : jhr 3485 fun newVar x = IR.Var.new (SV.nameOf x, cvtTy(SV.typeOf x))
109 : jhr 3471
110 : jhr 3506 (* is a Simple AST variable mapped to an IR.global_var? *)
111 :     fun isGlobalVar x = (case SV.kindOf x
112 : jhr 4317 of SV.ConstVar => true
113 :     | SV.InputVar => true
114 :     | SV.GlobalVar => true
115 :     | _ => false
116 :     (* end case *))
117 : jhr 3506
118 :     (* convert a global and cache the result in a property *)
119 :     local
120 :     fun new x = let
121 : jhr 4317 val kind = (case SV.kindOf x
122 :     of SV.ConstVar => IR.ConstVar
123 :     | SV.InputVar => IR.InputVar
124 :     | SV.GlobalVar => IR.GlobalVar
125 :     | k => raise Fail(concat[
126 :     "global variable ", SV.uniqueNameOf x,
127 :     " has kind ", SV.kindToString k
128 :     ])
129 :     (* end case *))
130 :     in
131 :     IR.GlobalVar.new(
132 :     kind, AnalyzeSimple.updatedGlobal x, SV.nameOf x, cvtTy(SV.typeOf x))
133 :     end
134 : jhr 3506 in
135 :     val {getFn = cvtGlobalVar, ...} = SV.newProp new
136 : jhr 4163 end (* local *)
137 : jhr 3506
138 : jhr 4163 (* convert a function variable and cache the result in a property *)
139 :     local
140 :     fun new f = let
141 : jhr 4317 val (resTy, paramTys) = SimpleFunc.typeOf f
142 :     in
143 :     IR.Func.new(SimpleFunc.nameOf f, cvtTy resTy, List.map cvtTy paramTys)
144 :     end
145 : jhr 4163 in
146 :     val {getFn = cvtFuncVar, ...} = SimpleFunc.newProp new
147 :     end (* local *)
148 :    
149 : jhr 3471 (* generate fresh SSA variables and add them to the environment *)
150 :     fun freshVars (env, xs) = let
151 :     fun cvtVar (x, (env, xs)) = let
152 :     val x' = newVar x
153 :     in
154 : jhr 5286 (insert(env, x, x'), x'::xs)
155 : jhr 3471 end
156 :     val (env, xs) = List.foldl cvtVar (env, []) xs
157 :     in
158 :     (env, List.rev xs)
159 :     end
160 :    
161 :     (* a pending-join node tracks the phi nodes needed to join the assignments
162 :     * that flow into the join node.
163 :     *)
164 :     datatype join = JOIN of {
165 :     env : env, (* the environment that was current at the conditional *)
166 :     (* associated with this node. *)
167 :     arity : int ref, (* actual number of predecessors *)
168 : jhr 3476 nd : IR.node, (* the CFG node for this pending join *)
169 :     phiMap : (IR.var * IR.var list) VMap.map ref,
170 : jhr 3471 (* a mapping from Simple AST variables that are assigned *)
171 :     (* to their phi nodes. *)
172 :     predKill : bool array (* killed predecessor edges (because of DIE or STABILIZE *)
173 :     }
174 :    
175 :     (* a stack of pending joins. The first component specifies the path index of the current
176 :     * path to the join.
177 :     *)
178 :     type pending_joins = (int * join) list
179 :    
180 : jhr 3504 (* create a new pending-join node for a conditional *)
181 : jhr 3471 fun newJoin (env, arity) = JOIN{
182 : jhr 3476 env = env, arity = ref arity, nd = IR.Node.mkJOIN [], phiMap = ref VMap.empty,
183 : jhr 3471 predKill = Array.array(arity, false)
184 :     }
185 :    
186 : jhr 3504 (* create a new pending-join node for a loop *)
187 : jhr 3502 fun newForeach (env, x, xs, phiVars) = let
188 : jhr 4368 (* for each assigned variable y in the body of the loop, we will need a phi node
189 :     * y'' = PHI(y', y''')
190 :     * where y' is the binding of y coming into the loop and y''' is the binding of y
191 :     * at the end of the loop body. Since we don't know what y''' is at this point, we
192 :     * just use y''.
193 :     *)
194 : jhr 5286 fun doVar (y, (env', phiMap)) = let
195 :     val y' = lookup env y
196 :     val y'' = newVar y
197 : jhr 5564 in
198 :     (insert(env', y, y''), VMap.insert(phiMap, y, (y'', [y', y''])))
199 :     end
200 : jhr 4317 val (env', phiMap) = List.foldl doVar (env, VMap.empty) phiVars
201 :     in
202 :     JOIN{
203 :     env = env',
204 :     arity = ref 2,
205 :     nd = IR.Node.mkFOREACH(x, xs),
206 :     phiMap = ref phiMap,
207 :     predKill = Array.array(2, false)
208 :     }
209 :     end
210 : jhr 3500
211 : jhr 4168 (* record that a path to the top join in the stack has been killed because of RETURN,
212 :     * DIE or STABILIZE
213 :     *)
214 : jhr 3471 fun killPath ((i, JOIN{arity, predKill, ...}) :: _) = (
215 :     arity := !arity - 1;
216 :     Array.update (predKill, i, true))
217 :     | killPath _ = ()
218 :    
219 : jhr 3476 (* record an assignment to the IR variable dstVar (corresponding to the Simple AST variable
220 : jhr 3471 * srcVar) in the current pending-join node. The predIndex specifies which path into the
221 :     * JOIN node this assignment occurs on.
222 :     *)
223 : jhr 4505 fun recordAssign ((predIndex, JOIN{env, phiMap, predKill, nd, ...})::_, srcVar, dstVar) = let
224 : jhr 3471 val arity = Array.length predKill (* the original arity before any killPath calls *)
225 :     val m = !phiMap
226 :     in
227 : jhr 5286 case find (env, srcVar)
228 : jhr 3471 of NONE => () (* local temporary *)
229 :     | SOME dstVar' => (case VMap.find (m, srcVar)
230 :     of NONE => let
231 :     val lhs = newVar srcVar
232 :     val rhs = List.tabulate (arity, fn i => if (i = predIndex) then dstVar else dstVar')
233 :     in
234 : jhr 4391 (**
235 : jhr 3485 print(concat["recordAssign: ", SV.uniqueNameOf srcVar, " --> ", IR.Var.toString lhs,
236 : jhr 3476 " @ ", IR.Node.toString nd, "\n"]);
237 : jhr 4391 **)
238 : jhr 3471 phiMap := VMap.insert (m, srcVar, (lhs, rhs))
239 :     end
240 :     | SOME(lhs, rhs) => let
241 :     fun update (i, l as x::r) = if (i = predIndex)
242 :     then dstVar::r
243 :     else x::update(i+1, r)
244 :     | update _ = raise Fail "invalid predecessor index"
245 :     in
246 :     phiMap := VMap.insert (m, srcVar, (lhs, update(0, rhs)))
247 :     end
248 :     (* end case *))
249 :     (* end case *)
250 :     end
251 : jhr 4505 | recordAssign ([], _, _) = ()
252 : jhr 3471
253 :     (* complete a pending join operation by filling in the phi nodes from the phi map and
254 :     * updating the environment.
255 :     *)
256 :     fun commitJoin (joinStk, JOIN{env, arity, nd, phiMap, predKill}) = let
257 : jhr 4317 val (preds, phis, mask) = (case IR.Node.kind nd
258 :     of IR.JOIN{preds, phis, mask, ...} => (!preds, phis, mask)
259 :     | IR.FOREACH{pred, bodyExit, phis, mask, ...} => ([!pred, !bodyExit], phis, mask)
260 :     | _ => raise Fail "invalid JOIN node"
261 :     (* end case *))
262 : jhr 3471 (* update the predKill array based on reachability *)
263 :     val _ = let
264 :     fun update (_, []) = ()
265 :     | update (i, nd::nds) = (
266 : jhr 3476 if IR.Node.isReachable nd then ()
267 : jhr 3471 else if Array.sub(predKill, i) then ()
268 :     else (arity := !arity-1; Array.update(predKill, i, true));
269 :     update (i+1, nds))
270 :     in
271 : jhr 3509 update (0, preds)
272 : jhr 3471 end
273 :     (* compute the predecessor mask *)
274 :     val mask' = Array.foldr (op ::) [] predKill
275 :     in
276 :     mask := mask';
277 : jhr 4628 if (!arity = 0)
278 :     then env (* all incoming edges are fake *)
279 :     else if (!arity < Array.length predKill)
280 : jhr 4387 then let
281 : jhr 4628 (* filter out variables that correspond to fake preds from the RHS of a phi *)
282 :     fun filterPhiRHS xs = let
283 :     fun f ([], _, xs') = List.rev xs'
284 :     | f (x::xs, i, xs') = if Array.sub(predKill, i)
285 :     then f (xs, i+1, NONE :: xs')
286 :     else f (xs, i+1, (SOME x) :: xs')
287 :     in
288 :     f (xs, 0, [])
289 :     end
290 :     fun doVar (srcVar, phi as (dstVar, srcVars), (env, phis)) = (
291 : jhr 3471 (*
292 : jhr 3485 print(concat["doVar (", SV.uniqueNameOf srcVar, ", ", IR.phiToString phi, ", _) @ ", IR.Node.toString nd, "\n"]);
293 : jhr 3471 *)
294 : jhr 4628 recordAssign (joinStk, srcVar, dstVar);
295 : jhr 5286 (insert (env, srcVar, dstVar), (dstVar, filterPhiRHS srcVars)::phis))
296 : jhr 4628 val (env, phis') = VMap.foldli doVar (env, []) (!phiMap)
297 :     in
298 :     phis := phis';
299 :     env
300 :     end
301 :     else let
302 :     fun doVar (srcVar, phi as (dstVar, xs), (env, phis)) = let
303 :     val xs = List.map SOME xs
304 :     in
305 : jhr 3471 (*
306 : jhr 3485 print(concat["doVar (", SV.uniqueNameOf srcVar, ", ", IR.phiToString phi, ", _) @ ", IR.Node.toString nd, "\n"]);
307 : jhr 3471 *)
308 : jhr 4628 recordAssign (joinStk, srcVar, dstVar);
309 :     IR.Var.setBinding (dstVar, IR.VB_PHI xs);
310 : jhr 5286 (insert (env, srcVar, dstVar), (dstVar, xs)::phis)
311 : jhr 4628 end
312 :     val (env, phis') = VMap.foldli doVar (env, []) (!phiMap)
313 :     in
314 :     phis := phis';
315 :     env
316 :     end
317 : jhr 3471 end
318 :    
319 : jhr 5570 fun gather(IR.ND{kind,...}) = (case kind
320 :     of IR.ASSIGN{stm, pred, ...} => (IR.ASSGN stm) :: gather(!pred)
321 :     | IR.ENTRY _ => []
322 :     (* end case *))
323 :    
324 :     fun tensorSize v = (case IR.Var.ty v
325 :     of DstTy.TensorTy alpha => alpha
326 :     | _ => raise Fail "Type is a not a tensor"
327 :     (* end case *))
328 :    
329 : jhr 3471 (* expression translation *)
330 :     fun cvtExp (env : env, lhs, exp) = (case exp
331 : jhr 3476 of S.E_Var x => [IR.ASSGN(lhs, IR.VAR(lookup env x))]
332 :     | S.E_Lit lit => [IR.ASSGN(lhs, IR.LIT lit)]
333 : jhr 4317 | S.E_Kernel h => [IR.ASSGN(lhs, IR.OP(Op.Kernel(h, 0), []))]
334 :     | S.E_Select(x, fld) => [IR.ASSGN(lhs, IR.STATE(SOME(lookup env x), getStateVar fld))]
335 : jhr 4163 | S.E_Apply(f, args) =>
336 : jhr 4317 [IR.ASSGN(lhs, IR.APPLY(cvtFuncVar f, List.map (lookup env) args))]
337 : jhr 3471 | S.E_Prim(f, tyArgs, args, ty) => let
338 :     val args' = List.map (lookup env) args
339 :     in
340 :     TranslateBasis.translate (lhs, f, tyArgs, args')
341 :     end
342 : jhr 4163 | S.E_Tensor(args, _) =>
343 : jhr 4317 [IR.ASSGN(lhs, IR.CONS(List.map (lookup env) args, IR.Var.ty lhs))]
344 : jhr 5574 | S.E_Field(args, Ty.T_Field{diff, dim, shape}) => let
345 :     val rator = MkOperators.concatField(dim, List.tl shape, List.length(args))
346 :     val ein = IR.EINAPP(rator, List.map (lookup env) args)
347 :     in
348 :     [IR.ASSGN(lhs, ein)]
349 :     end
350 : jhr 3476 | S.E_Seq(args, _) => [IR.ASSGN(lhs, IR.SEQ(List.map (lookup env) args, IR.Var.ty lhs))]
351 : jhr 4317 | S.E_Tuple xs => raise Fail "FIXME: E_Tuple"
352 :     | S.E_Project(x, i) => raise Fail "FIXME: E_Project"
353 : cchiw 4155 | S.E_Slice(x, indices, ty as Ty.T_Field{diff, dim, shape}) => let
354 :     val x = lookup env x
355 :     (* extract the integer indices from the mask *)
356 :     val args' = List.mapPartial Fn.id indices
357 :     val mask' = List.map Option.isSome indices
358 :     val rator = MkOperators.sliceF(mask', args', shape, dim)
359 :     val ein = IR.EINAPP(rator, [x])
360 :     in
361 :     [IR.ASSGN(lhs, ein)]
362 :     end
363 : jhr 3471 | S.E_Slice(x, indices, ty) => let
364 :     val x = lookup env x
365 : jhr 4317 (* check the indices to the slice. There are two special cases: if all of the indices
366 :     * are NONE, then the result is just the original tensor; and if all of the indices
367 :     * are SOME ix, then the result is scalar so we use TensorIndex.
368 :     *)
369 : cchiw 3991 fun chkIndices ([], _, true, idxs) = IR.OP(Op.TensorIndex(IR.Var.ty x, rev idxs), [x])
370 : jhr 4317 | chkIndices ([], true, _, _) = IR.VAR x (* all axes *)
371 :     | chkIndices (NONE :: r, true, _, _) = chkIndices (r, true, false, [])
372 :     | chkIndices (SOME ix :: r, _, true, idxs) = chkIndices (r, false, true, ix::idxs)
373 :     | chkIndices _ = let
374 :     (* extract the integer indices from the mask *)
375 :     val args' = List.mapPartial Fn.id indices
376 :     val mask' = List.map Option.isSome indices
377 :     val rator = (case (IR.Var.ty lhs, IR.Var.ty x, ty)
378 :     of (DstTy.TensorTy rstTy, DstTy.TensorTy argTy, _) =>
379 :     MkOperators.sliceT (mask', args', rstTy, argTy)
380 :     | (_, _, Ty.T_Field{diff, dim, shape}) =>
381 :     MkOperators.sliceF(mask', args', shape, dim)
382 :     | (_, _, _ ) => raise Fail "unsupported type"
383 :     (* end case *))
384 :     in
385 :     IR.EINAPP(rator, [x])
386 :     end
387 : jhr 3471 in
388 : jhr 4317 [IR.ASSGN(lhs, chkIndices (indices, true, true, []))]
389 : jhr 3471 end
390 :     | S.E_Coerce{srcTy, dstTy, x} => (case (srcTy, dstTy)
391 :     of (Ty.T_Int, Ty.T_Tensor _) =>
392 : jhr 3476 [IR.ASSGN(lhs, IR.OP(Op.IntToReal, [lookup env x]))]
393 : jhr 3485 | (Ty.T_Sequence(ty, SOME n), Ty.T_Sequence(_, NONE)) =>
394 : jhr 3476 [IR.ASSGN(lhs, IR.OP(Op.MkDynamic(cvtTy ty, n), [lookup env x]))]
395 : jhr 3471 | (Ty.T_Field _, Ty.T_Field _) =>
396 :     (* change in continuity is a no-op *)
397 : jhr 3476 [IR.ASSGN(lhs, IR.VAR(lookup env x))]
398 : jhr 4317 | (Ty.T_Kernel, Ty.T_Kernel) =>
399 : jhr 4207 (* change in continuity is a no-op *)
400 :     [IR.ASSGN(lhs, IR.VAR(lookup env x))]
401 : jhr 3471 | _ => raise Fail(concat[
402 :     "unsupported type coercion: ", Ty.toString srcTy,
403 :     " ==> ", Ty.toString dstTy
404 :     ])
405 :     (* end case *))
406 : jhr 4317 | S.E_BorderCtl(ctl, img) => let
407 :     val img = lookup env img
408 :     val DstTy.ImageTy info = IR.Var.ty img
409 :     val (rator, args) = (case ctl
410 :     of BorderCtl.Default x => (Op.BorderCtlDefault info, [lookup env x, img])
411 :     | BorderCtl.Clamp => (Op.BorderCtlClamp info, [img])
412 :     | BorderCtl.Mirror => (Op.BorderCtlMirror info, [img])
413 :     | BorderCtl.Wrap => (Op.BorderCtlWrap info, [img])
414 :     (* end case *))
415 :     in
416 :     [IR.ASSGN(lhs, IR.OP(rator, args))]
417 :     end
418 : jhr 3476 | S.E_LoadSeq(ty, nrrd) => [IR.ASSGN(lhs, IR.OP(Op.LoadSeq(cvtTy ty, nrrd), []))]
419 : jhr 4043 | S.E_LoadImage(_, nrrd, info) =>
420 : jhr 4317 [IR.ASSGN(lhs, IR.OP(Op.LoadImage(DstTy.ImageTy info, nrrd), []))]
421 :     | S.E_InsideImage(pos, img, s) => let
422 :     val Ty.T_Image info = SV.typeOf img
423 :     in
424 :     [IR.ASSGN(lhs, IR.OP(Op.Inside(info, s), [lookup env pos, lookup env img]))]
425 :     end
426 : jhr 5564 | S.E_FieldFn f => let
427 : jhr 5570 (* Variable convention used
428 :     * - "alphas" tensor size
429 :     * - "stmt" statements
430 :     * - "lhs" high-ir variable
431 :     * - "_comp" function body
432 :     * - "_PF" parameters treated like a fields
433 :     * - "_PT " parameters treated like a tensor
434 :     *)
435 : jhr 5572 (* function body _comp *)
436 : jhr 5570 val IR.Func{params, body,...} = getFieldFnDef f
437 : jhr 5572 (* Decompose function body *)
438 : jhr 5570 val IR.CFG{entry, exit} = body
439 :     (* get computation inside field definition _comp*)
440 :     val IR.ND{kind as IR.EXIT {pred, ...}, ...} = exit
441 :     val IR.ND{kind, ...} = !pred
442 :     (* get last variable name used*)
443 :     val lhs_comp = (case (kind, params)
444 :     of (IR.ASSIGN{stm as (lhs_comp, _), ...}, _) => lhs_comp
445 :     | (IR.ENTRY _, [lhs_comp]) => lhs_comp
446 :     (* end case *))
447 :     (* get all the statements used in the function body *)
448 :     val stmt_comp = List.rev(gather (!pred))
449 :     (* analyze parameters*)
450 :     val lhs_PF = params
451 :     val lhs_PT = [] (*Fixed for now *)
452 :     val lhs_allP = lhs_PF@lhs_PT
453 :     (* for each argument set equal to a dummy var *)
454 :     val stmt_allP = List.map (fn v as IR.V{name,...} => IR.ASSGN(v, IR.LIT(Literal.String (name)))) (lhs_allP)
455 :     (* get tensor size of arguments*)
456 :     val alphas_PF = List.map tensorSize lhs_PF
457 :     val alphas_PT = List.map tensorSize lhs_PT
458 :     val alphas_allP = alphas_PF@alphas_PT
459 :     val alpha_comp = tensorSize lhs_comp
460 :     (* create ein operator*)
461 :     val rator = MkOperators.cfexpMix (alpha_comp, alphas_PF, alphas_PT)
462 :     val args = lhs_comp::lhs_allP
463 :     val ein = IR.EINAPP(rator, args)
464 : jhr 5564 in
465 : jhr 5570 stmt_allP @ stmt_comp @ [IR.ASSGN(lhs, ein)]
466 : jhr 5564 end
467 : jhr 3471 (* end case *))
468 :    
469 : jhr 4894 (* add nodes to save the varying strand state, followed by an exit node *)
470 : jhr 3471 fun saveStrandState (env, (srcState, dstState), exit) = let
471 : jhr 4894 val lookup = lookup env
472 :     fun save (x, x', cfg) = if AnalyzeSimple.varyingStateVar x'
473 :     then IR.CFG.appendNode (cfg, IR.Node.mkSAVE(x, lookup x'))
474 :     else cfg (* no need to save invariant variables! *)
475 : jhr 3471 in
476 : jhr 3476 IR.CFG.appendNode (
477 : jhr 4894 ListPair.foldlEq save IR.CFG.empty (dstState, srcState),
478 : jhr 3471 exit)
479 :     end
480 :     (*DEBUG*)handle ex => raise ex
481 :    
482 : jhr 4163 (* convert a block to a CFG. The parameters are:
483 : jhr 4317 * state -- a pair of the src/dst state variables for saving the state of a strand.
484 :     * These are empty if the block is not in a strand.
485 : jhr 4163 * env -- environment for mapping SimpleIR variables to HighIR locals
486 :     * joinStk -- a stack of pending joins
487 : jhr 4317 * blk -- the block to translate
488 : jhr 4163 *)
489 : jhr 3501 fun cvtBlock (state, env : env, joinStk, blk as S.Block{code, ...}) = let
490 : jhr 3471 fun cvt (env : env, cfg, []) = (cfg, env)
491 :     | cvt (env, cfg, stm::stms) = (case stm
492 : jhr 3485 of S.S_Var(x, NONE) => let
493 : jhr 3471 val x' = newVar x
494 :     in
495 : jhr 5286 cvt (insert (env, x, x'), cfg, stms)
496 : jhr 3471 end
497 : jhr 4317 | S.S_Var(x, SOME e) => let
498 : jhr 3485 val x' = newVar x
499 :     val assigns = cvtExp (env, x', e)
500 : jhr 4317 in
501 : jhr 3485 recordAssign (joinStk, x, x');
502 :     cvt (
503 : jhr 5286 insert(env, x, x'),
504 : jhr 3485 IR.CFG.concat(cfg, IR.CFG.mkBlock assigns),
505 :     stms)
506 : jhr 4317 end
507 : jhr 3471 | S.S_Assign(lhs, rhs) => let
508 : jhr 4317 val lhs' = newVar lhs
509 :     val assigns = cvtExp (env, lhs', rhs)
510 :     in
511 :     (* check for assignment to global (i.e., constant, input, or other global) *)
512 : jhr 3550 (* FIXME: for the global initialization block, we should batch up the saving of globals until
513 :     * the end so that we can properly set the bindings (i.e., so that we avoid conflicts between
514 :     * branches of an if statement).
515 :     *)
516 : jhr 4317 if isGlobalVar lhs
517 :     then cvt (
518 : jhr 5286 insert(env, lhs, lhs'),
519 : jhr 4317 IR.CFG.concat(
520 :     cfg,
521 :     IR.CFG.mkBlock(assigns @ [IR.GASSGN(cvtGlobalVar lhs, lhs')])),
522 :     stms)
523 :     else (
524 :     recordAssign (joinStk, lhs, lhs');
525 :     cvt (
526 : jhr 5286 insert(env, lhs, lhs'),
527 : jhr 4317 IR.CFG.concat(cfg, IR.CFG.mkBlock assigns),
528 :     stms))
529 :     end
530 : jhr 3471 | S.S_IfThenElse(x, b0, b1) => let
531 :     val x' = lookup env x
532 :     val join as JOIN{nd=joinNd, ...} = newJoin (env, 2)
533 :     val (cfg0, _) = cvtBlock (state, env, (0, join)::joinStk, b0)
534 :     val (cfg1, _) = cvtBlock (state, env, (1, join)::joinStk, b1)
535 : jhr 3476 val cond = IR.Node.mkCOND x'
536 :     fun addEdgeToJoin nd = (case IR.Node.kind nd
537 :     of IR.EXIT{succ, ...} => (
538 : jhr 3471 succ := SOME joinNd;
539 : jhr 3476 IR.Node.setPred (joinNd, nd)) (* will be converted to fake later *)
540 :     | _ => IR.Node.addEdge(nd, joinNd)
541 : jhr 3471 (* end case *))
542 :     (* package the CFG the represents the conditional (cond, two blocks, and join) *)
543 :     val condCFG = (
544 : jhr 3476 if IR.CFG.isEmpty cfg0
545 : jhr 3471 then (
546 : jhr 3476 IR.Node.setTrueBranch (cond, joinNd);
547 :     IR.Node.setPred (joinNd, cond))
548 : jhr 3471 else (
549 : jhr 3476 IR.Node.setTrueBranch (cond, IR.CFG.entry cfg0);
550 :     IR.Node.setPred (IR.CFG.entry cfg0, cond);
551 :     addEdgeToJoin (IR.CFG.exit cfg0));
552 :     if IR.CFG.isEmpty cfg1
553 : jhr 3471 then (
554 : jhr 3476 IR.Node.setFalseBranch (cond, joinNd);
555 :     IR.Node.setPred (joinNd, cond))
556 : jhr 3471 else (
557 : jhr 3476 IR.Node.setFalseBranch (cond, IR.CFG.entry cfg1);
558 :     IR.Node.setPred (IR.CFG.entry cfg1, cond);
559 :     addEdgeToJoin (IR.CFG.exit cfg1));
560 :     IR.CFG{entry = cond, exit = joinNd})
561 : jhr 3471 val env = commitJoin (joinStk, join)
562 : jhr 4317 val cfg = IR.CFG.concat (cfg, condCFG)
563 : jhr 3471 in
564 : jhr 4317 (* add an UNREACHABLE exit node when the join is the final node in the
565 :     * graph and it is unreachable.
566 :     *)
567 :     if List.null joinStk andalso not(IR.Node.isReachable joinNd)
568 :     then (* NOTE: this case implies that stms is empty! *)
569 :     (IR.CFG.appendNode(cfg, IR.Node.mkUNREACHABLE()), env)
570 :     else cvt (env, cfg, stms)
571 : jhr 3471 end
572 : jhr 4317 | S.S_Foreach(x, xs, b) => let
573 :     val x' = newVar x
574 : jhr 3500 val xs' = lookup env xs
575 : jhr 4317 (* For any local variable y that is both live on exit of the block b and
576 :     * assigned to in b, we will need a phi node for y.
577 :     *)
578 :     val phiVars = VSet.listItems(
579 :     VSet.intersection(AnalyzeSimple.assignedVars b, AnalyzeSimple.liveOut b))
580 : jhr 3502 val join as JOIN{env, nd=foreachNd, ...} = newForeach (env, x', xs', phiVars)
581 : jhr 5286 val (body, _) = cvtBlock (state, insert(env, x, x'), (1, join)::joinStk, b)
582 : jhr 4317 val body = IR.CFG.appendNode (body, IR.Node.mkNEXT())
583 : jhr 3502 val env = commitJoin (joinStk, join)
584 : jhr 4317 in
585 :     (* link in CFG edges *)
586 :     IR.Node.setBodyEntry (foreachNd, IR.CFG.entry body); (* loop header to body *)
587 :     IR.Node.setPred (IR.CFG.entry body, foreachNd); (* back edge *)
588 :     IR.Node.setSucc (IR.CFG.exit body, foreachNd);
589 :     IR.Node.setBodyExit (foreachNd, IR.CFG.exit body);
590 :     (* process the rest of the block *)
591 :     cvt (env, IR.CFG.concat (cfg, IR.CFG{entry=foreachNd, exit=foreachNd}), stms)
592 :     end
593 : jhr 3471 | S.S_New(strandId, args) => let
594 : jhr 4339 val nd = IR.Node.mkNEW(strandId, List.map (lookup env) args)
595 : jhr 3471 in
596 : jhr 3476 cvt (env, IR.CFG.appendNode (cfg, nd), stms)
597 : jhr 3471 end
598 : jhr 4628 | S.S_KillAll => let
599 :     val nd = IR.Node.mkMASSIGN([], IR.OP(Op.KillAll, []))
600 :     in
601 :     cvt (env, IR.CFG.appendNode (cfg, nd), stms)
602 :     end
603 :     | S.S_StabilizeAll => let
604 : jhr 4480 val nd = IR.Node.mkMASSIGN([], IR.OP(Op.StabilizeAll, []))
605 :     in
606 :     cvt (env, IR.CFG.appendNode (cfg, nd), stms)
607 :     end
608 : jhr 3471 | S.S_Continue => (
609 : jhr 4628 killPath joinStk;
610 : jhr 5286 if inMethod env
611 : jhr 4628 then (
612 :     IR.CFG.concat (cfg, saveStrandState (env, state, IR.Node.mkACTIVE())),
613 :     env
614 :     )
615 : jhr 5286 else if inGlobalUpdate env
616 :     then (IR.CFG.appendNode (cfg, IR.Node.mkNEXTSTEP()), env)
617 : jhr 4628 else (IR.CFG.appendNode (cfg, IR.Node.mkRETURN NONE), env))
618 : jhr 3471 | S.S_Die => (
619 :     killPath joinStk;
620 : jhr 3476 (IR.CFG.appendNode (cfg, IR.Node.mkDIE ()), env))
621 : jhr 3471 | S.S_Stabilize => (
622 :     killPath joinStk;
623 : jhr 3476 (IR.CFG.concat (cfg, saveStrandState (env, state, IR.Node.mkSTABILIZE())), env))
624 : jhr 4164 | S.S_Return x => (
625 : jhr 4317 killPath joinStk;
626 :     (IR.CFG.appendNode (cfg, IR.Node.mkRETURN(SOME(lookup env x))), env))
627 : jhr 3471 | S.S_Print args => let
628 :     val args = List.map (lookup env) args
629 : jhr 4362 val nd = IR.Node.mkMASSIGN([], IR.OP(Op.Print(List.map IR.Var.ty args), args))
630 : jhr 3471 in
631 : jhr 3476 cvt (env, IR.CFG.appendNode (cfg, nd), stms)
632 : jhr 3471 end
633 : jhr 4378 | S.S_MapReduce mrs => let
634 :     fun cvtMR (mr, (env, assigns, lhs, mrs')) = let
635 :     val (S.MapReduce{
636 :     result, reduction, mapf=S.Func{f, ...}, args, source, domain
637 :     }) = mr
638 :     (* note that we are making the source of strands explicit and changing the
639 :     * type of the first argument to the map function.
640 :     *)
641 :     val strandTy = cvtTy(SV.typeOf source)
642 :     val src = IR.Var.new (SV.nameOf source, DstTy.SeqTy(strandTy, NONE))
643 : jhr 5286 val env' = insert (env, source, src)
644 : jhr 4378 val srcAssign = IR.ASSGN(src, IR.OP(Op.Strands(strandTy, domain), []))
645 :     val result' = newVar result
646 : jhr 5286 val env = insert(env, result, result')
647 : jhr 4378 val mr' = (reduction, cvtFuncVar f, List.map (lookup env') args)
648 : jhr 4368 in
649 : jhr 4378 (env, srcAssign :: assigns, result'::lhs, mr'::mrs')
650 : jhr 4368 end
651 : jhr 4378 val (env, assigns, lhs, mrs) = List.foldl cvtMR (env, [], [], []) mrs
652 :     val assigns = IR.MASSGN(List.rev lhs, IR.MAPREDUCE(List.rev mrs)) :: assigns
653 : jhr 4368 in
654 : jhr 4378 cvt (env, IR.CFG.appendBlock (cfg, List.rev assigns), stms)
655 : jhr 4368 end
656 : jhr 3471 (* end case *))
657 :     in
658 : jhr 3505 cvt (env, IR.CFG.empty, code)
659 : jhr 3471 end
660 :     (*DEBUG*)handle ex => raise ex
661 :    
662 : jhr 4164 (* a function for generating a block of assignments to load the globals that
663 :     * are referenced in a SimpleIR block.
664 :     *)
665 :     fun loadGlobals (env, blk) = let
666 : jhr 4317 fun load (x, (env, stms)) = let
667 : jhr 3506 val x' = newVar x
668 :     val stm = IR.ASSGN(x', IR.GLOBAL(cvtGlobalVar x))
669 : jhr 5286 val env = insert (env, x, x')
670 : jhr 3506 in
671 : jhr 4164 (env, stm::stms)
672 : jhr 3506 end
673 : jhr 4628 val globs = AnalyzeSimple.globalsOfBlock blk
674 : jhr 4570 val (env, stms) = VSet.foldr load (env, []) globs
675 : jhr 3506 in
676 : jhr 4164 (IR.CFG.mkBlock stms, env)
677 : jhr 3506 end
678 :    
679 : jhr 4164 fun cvtMethod (env, isStabilize, state, svars, blk) = let
680 : jhr 3471 (* load the globals into fresh variables *)
681 : jhr 4164 val (loadGlobsCFG, env) = loadGlobals (env, blk)
682 : jhr 3471 (* load the state into fresh variables *)
683 :     val (env, loadCFG) = let
684 :     (* allocate shadow variables for the state variables *)
685 :     val (env, stateIn) = freshVars (env, state)
686 : jhr 3846 fun load (x, x') = IR.ASSGN(x, IR.STATE(NONE, x'))
687 : jhr 3476 val cfg = IR.CFG.mkBlock (ListPair.map load (stateIn, svars))
688 : jhr 3471 in
689 : jhr 3476 (env, IR.CFG.concat(loadGlobsCFG, cfg))
690 : jhr 3471 end
691 :     (* convert the body of the method *)
692 :     val (cfg, env) = cvtBlock ((state, svars), env, [], blk)
693 :     (* add the entry/exit nodes *)
694 : jhr 3476 val entry = IR.Node.mkENTRY ()
695 :     val loadCFG = IR.CFG.prependNode (entry, loadCFG)
696 : jhr 3505 val exit = if isStabilize
697 : jhr 4317 then IR.Node.mkRETURN NONE
698 :     else IR.Node.mkACTIVE()
699 : jhr 3476 val body = IR.CFG.concat (loadCFG, cfg)
700 :     val body = if IR.Node.hasSucc(IR.CFG.exit body)
701 :     then IR.CFG.concat (body, saveStrandState (env, (state, svars), exit))
702 :     else IR.CFG{entry = IR.CFG.entry body, exit = exit}
703 : jhr 3471 in
704 : jhr 3505 body
705 : jhr 3471 end
706 : jhr 3505 (*DEBUG*)handle ex => (print "error in cvtMethod\n"; raise ex)
707 : jhr 3471
708 : jhr 3506 (* convert global code *)
709 : jhr 5286 fun cvtGlobalBlock cxt block = let
710 : jhr 3506 (* load the globals into fresh variables *)
711 : jhr 5286 val (loadCFG, env) = loadGlobals (emptyEnv cxt, block)
712 : jhr 4317 (* convert the code *)
713 :     val (cfg, _) = cvtBlock (([], []), env, [], block)
714 :     val cfg = IR.CFG.concat (loadCFG, cfg)
715 :     val cfg = IR.CFG.prependNode (IR.Node.mkENTRY(), cfg)
716 : jhr 5286 val cfg = if inGlobalUpdate env
717 :     then IR.CFG.appendNode (cfg, IR.Node.mkNEXTSTEP())
718 :     else IR.CFG.appendNode (cfg, IR.Node.mkRETURN NONE)
719 : jhr 4317 in
720 :     cfg
721 :     end
722 : jhr 4570 (*DEBUG*)handle ex => raise ex
723 : jhr 3471
724 : jhr 4163 (* extend the global environment with the strand's parameters *)
725 :     fun initEnvFromParams params = let
726 : jhr 4317 fun cvtParam (x, (env, xs)) = let
727 :     val x' = newVar x
728 :     in
729 : jhr 5286 (insert(env, x, x'), x'::xs)
730 : jhr 4317 end
731 : jhr 5286 val (env, params) = List.foldl cvtParam (emptyEnv Other, []) params
732 : jhr 4317 in
733 :     (env, List.rev params)
734 :     end
735 : jhr 4163
736 : jhr 4359 (* convert a function definition to a HighIR function *)
737 :     fun cvtFunc (S.Func{f, params, body}) = let
738 : jhr 4368 (* initialize the environment with the function's parameters *)
739 :     val (env, params) = initEnvFromParams params
740 :     val (loadBlk, env) = loadGlobals (env, body)
741 :     val (bodyCFG, _) = cvtBlock (([], []), env, [], body)
742 :     val cfg = IR.CFG.prependNode (IR.Node.mkENTRY(), loadBlk)
743 :     val cfg = IR.CFG.concat(cfg, bodyCFG)
744 : jhr 5564 val fdef = IR.Func{name = cvtFuncVar f, params = params, body = cfg}
745 : jhr 4368 in
746 : jhr 5564 if (SimpleFunc.isDifferentiable f) then setFieldFnDef(f, fdef) else ();
747 :     fdef
748 : jhr 4368 end
749 : jhr 4359
750 :     (* lift functions used in map-reduce expressions *)
751 :     fun liftFuncs NONE = []
752 :     | liftFuncs (SOME blk) = let
753 : jhr 4368 fun liftBlk (S.Block{code, ...}, fns) = List.foldl liftStm fns code
754 :     and liftStm (S.S_IfThenElse(_, b1, b2), fns) = liftBlk(b2, liftBlk(b1, fns))
755 :     | liftStm (S.S_Foreach(_, _, b), fns) = liftBlk(b, fns)
756 : jhr 4378 | liftStm (S.S_MapReduce mrs, fns) =
757 :     List.foldl (fn (S.MapReduce{mapf, ...}, fns) => cvtFunc mapf :: fns) fns mrs
758 : jhr 4368 | liftStm (_, fns) = fns
759 :     in
760 :     liftBlk (blk, [])
761 :     end
762 : jhr 4359
763 : jhr 3485 fun translate prog = let
764 : jhr 4317 val S.Program{
765 :     props, consts, inputs, constInit, globals, funcs,
766 : jhr 4491 globInit, strand, create, start, update
767 : jhr 4317 } = prog
768 :     val _ = AnalyzeSimple.analyze prog
769 :     val consts' = List.map cvtGlobalVar consts
770 :     val inputs' = List.map (Inputs.map cvtGlobalVar) inputs
771 :     val inputs = List.map Inputs.varOf inputs
772 :     val constInit = let
773 : jhr 5286 val (cfg, _) = cvtBlock (([], []), emptyEnv Other, [], constInit)
774 : jhr 4317 val cfg = IR.CFG.prependNode (IR.Node.mkENTRY(), cfg)
775 :     val cfg = IR.CFG.appendNode (cfg, IR.Node.mkRETURN NONE)
776 :     in
777 :     cfg
778 :     end
779 : jhr 3506 val globals' = List.map cvtGlobalVar globals
780 : jhr 4359 val funcs' = List.map cvtFunc funcs
781 : jhr 4368 (* if the program has global reductions, then lift those functions *)
782 :     val funcs' = if Properties.hasProp Properties.GlobalReduce props
783 : jhr 4491 then liftFuncs start @ liftFuncs update @ funcs'
784 : jhr 4368 else funcs'
785 : jhr 3471 (* create the global initialization code *)
786 : jhr 3995 val globInit = let
787 : jhr 3471 (* we start by loading the input globals, since they may be needed to compute the
788 :     * other globals
789 :     *)
790 : jhr 5286 val (loadBlk, env) = loadGlobals (emptyEnv Other, globInit)
791 : jhr 3995 val (globBlk, env) = cvtBlock (([], []), env, [], globInit)
792 : jhr 3506 val cfg = IR.CFG.prependNode (IR.Node.mkENTRY(), loadBlk)
793 : jhr 3476 val cfg = IR.CFG.concat(cfg, globBlk)
794 : jhr 4164 val cfg = IR.CFG.appendNode (cfg, IR.Node.mkRETURN NONE)
795 : jhr 3471 in
796 :     cfg
797 :     end
798 : jhr 4570 (*DEBUG*)handle ex => raise ex
799 : jhr 4368 fun cvtStrand strand = let
800 :     val S.Strand{
801 : jhr 4491 name, params, spatialDim, state, stateInit, startM, updateM, stabilizeM
802 : jhr 4368 } = strand
803 : jhr 4163 (* initialize the environment with the strand's parameters *)
804 :     val (env, params) = initEnvFromParams params
805 : jhr 3471 (* create the state variables *)
806 : jhr 3846 val svars = List.map getStateVar state
807 : jhr 3471 (* convert the state initialization code *)
808 :     val (stateInit, env) = let
809 :     (* load globals into local variables *)
810 : jhr 4164 val (loadGlobsCFG, env) = loadGlobals (env, stateInit)
811 : jhr 5286 val env = List.foldl (fn (x, env) => insert(env, x, newVar x)) env state
812 : jhr 3471 val (cfg, env) = cvtBlock (([], []), env, [], stateInit)
813 : jhr 3476 val cfg = IR.CFG.concat(loadGlobsCFG, cfg)
814 :     val cfg = IR.CFG.prependNode (IR.Node.mkENTRY(), cfg)
815 : jhr 4894 (* add nodes to initialize the strand state *)
816 :     val cfg = let
817 :     val lookup = lookup env
818 :     fun save (x, x', cfg) =
819 :     IR.CFG.appendNode (cfg, IR.Node.mkSAVE(x, lookup x'))
820 :     in
821 :     IR.CFG.appendNode (
822 :     ListPair.foldlEq save cfg (svars, state),
823 :     IR.Node.mkRETURN NONE)
824 :     end
825 : jhr 3471 in
826 :     (cfg, env)
827 :     end
828 : jhr 3505 fun cvtMeth isStabilize blk =
829 : jhr 5286 cvtMethod (context(env, Method), isStabilize, state, svars, blk)
830 : jhr 3471 in
831 : jhr 3476 IR.Strand{
832 : jhr 3471 name = name,
833 :     params = params,
834 : jhr 4369 spatialDim = spatialDim,
835 : jhr 3471 state = svars,
836 :     stateInit = stateInit,
837 : jhr 4491 startM = Option.map (cvtMeth false) startM,
838 : jhr 4317 updateM = cvtMeth false updateM,
839 :     stabilizeM = Option.map (cvtMeth true) stabilizeM
840 : jhr 3471 }
841 :     end
842 : jhr 4570 (*DEBUG*)handle ex => raise ex
843 : jhr 5286 val create = Create.map (cvtGlobalBlock Other) create
844 : jhr 3476 val prog = IR.Program{
845 : jhr 3471 props = props,
846 : jhr 4317 consts = consts',
847 :     inputs = inputs',
848 :     globals = globals',
849 :     funcs = funcs',
850 :     constInit = constInit,
851 : jhr 3995 globInit = globInit,
852 : jhr 3505 strand = cvtStrand strand,
853 :     create = create,
854 : jhr 5286 start = Option.map (cvtGlobalBlock Other) start,
855 :     update = Option.map (cvtGlobalBlock GlobalUpdate) update
856 : jhr 4317 }
857 : jhr 3471 in
858 :     Census.init prog;
859 :     prog
860 :     end
861 :    
862 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0