Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/simplify/simplify.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5574 - (view) (download)

1 : jhr 3437 (* simplify.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 : jhr 3468 * Simplify the AST representation. This phase involves the following transformations:
9 :     *
10 : jhr 4317 * - types are simplified by removing meta variables (which will have been resolved)
11 : jhr 3468 *
12 : jhr 4317 * - expressions are simplified to involve a single operation on variables
13 : jhr 3468 *
14 : jhr 4317 * - global reductions are converted to MapReduce statements
15 : jhr 3468 *
16 : jhr 4317 * - other comprehensions and reductions are converted to foreach loops
17 : jhr 3468 *
18 : jhr 4317 * - unreachable code is pruned
19 : jhr 3468 *
20 : jhr 4317 * - negation of literal integers and reals are constant folded
21 : jhr 3437 *)
22 :    
23 :     structure Simplify : sig
24 :    
25 : jhr 4371 val transform : Error.err_stream * AST.program * GlobalEnv.t -> Simple.program
26 : jhr 3437
27 :     end = struct
28 :    
29 :     structure TU = TypeUtil
30 :     structure S = Simple
31 : jhr 3445 structure STy = SimpleTypes
32 :     structure Ty = Types
33 : jhr 3437 structure VMap = Var.Map
34 : jhr 4193 structure II = ImageInfo
35 : jhr 4359 structure BV = BasisVars
36 : jhr 3437
37 : jhr 4441 (* environment for mapping small global constants to AST expressions *)
38 :     type const_env = AST.expr VMap.map
39 :    
40 : jhr 4371 (* context for simplification *)
41 : jhr 4441 datatype context = Cxt of {
42 :     errStrm : Error.err_stream,
43 :     gEnv : GlobalEnv.t,
44 :     cEnv : const_env
45 :     }
46 : jhr 4371
47 : jhr 4441 fun getNrrdInfo (Cxt{errStrm, ...}, nrrd) = NrrdInfo.getInfo (errStrm, nrrd)
48 : jhr 4371
49 : jhr 4441 fun findStrand (Cxt{gEnv, ...}, s) = GlobalEnv.findStrand(gEnv, s)
50 :    
51 :     fun insertConst (Cxt{errStrm, gEnv, cEnv}, x, e) = Cxt{
52 :     errStrm = errStrm, gEnv = gEnv, cEnv = VMap.insert(cEnv, x, e)
53 :     }
54 :    
55 :     fun findConst (Cxt{cEnv, ...}, x) = VMap.find(cEnv, x)
56 :    
57 :     fun error (Cxt{errStrm, ...}, msg) = Error.error (errStrm, msg)
58 :     fun warning (Cxt{errStrm, ...}, msg) = Error.warning (errStrm, msg)
59 :    
60 : jhr 4432 (* error message for when a nrrd image file is incompatible with the declared image type *)
61 :     fun badImageNrrd (cxt, nrrdFile, nrrdInfo, expectedDim, expectedShp) = let
62 : jhr 4441 val NrrdInfo.NrrdInfo{dim, nElems, ...} = nrrdInfo
63 :     val expectedNumElems = List.foldl (op * ) 1 expectedShp
64 :     val prefix = String.concat[
65 :     "image file \"", nrrdFile, "\" is incompatible with expected type image(",
66 :     Int.toString expectedDim, ")[",
67 :     String.concatWithMap "," Int.toString expectedShp, "]"
68 :     ]
69 :     in
70 :     case (dim = expectedDim, nElems = expectedNumElems)
71 :     of (false, true) => error (cxt, [
72 :     prefix, "; its dimension is ", Int.toString dim
73 :     ])
74 :     | (true, false) => error (cxt, [
75 : jhr 4578 prefix, "; it has ", Int.toString nElems, " values per voxel"
76 : jhr 4441 ])
77 :     | _ => error (cxt, [
78 :     prefix, "; its dimension is ", Int.toString dim, " and it has ",
79 : jhr 4578 Int.toString nElems, " values per voxel"
80 : jhr 4441 ])
81 :     (* end case *)
82 :     end
83 : jhr 4428
84 : jhr 3445 (* convert a Types.ty to a SimpleTypes.ty *)
85 :     fun cvtTy ty = (case ty
86 :     of Ty.T_Var(Ty.TV{bind, ...}) => (case !bind
87 :     of NONE => raise Fail "unresolved type variable"
88 :     | SOME ty => cvtTy ty
89 :     (* end case *))
90 :     | Ty.T_Bool => STy.T_Bool
91 :     | Ty.T_Int => STy.T_Int
92 :     | Ty.T_String => STy.T_String
93 :     | Ty.T_Sequence(ty, NONE) => STy.T_Sequence(cvtTy ty, NONE)
94 : jhr 3452 | Ty.T_Sequence(ty, SOME dim) => STy.T_Sequence(cvtTy ty, SOME(TU.monoDim dim))
95 : jhr 4317 | Ty.T_Strand id => STy.T_Strand id
96 : jhr 4207 | Ty.T_Kernel _ => STy.T_Kernel
97 : jhr 3445 | Ty.T_Tensor shape => STy.T_Tensor(TU.monoShape shape)
98 : jhr 4193 | Ty.T_Image{dim, shape} =>
99 : jhr 4317 STy.T_Image(II.mkInfo(TU.monoDim dim, TU.monoShape shape))
100 : jhr 3445 | Ty.T_Field{diff, dim, shape} => STy.T_Field{
101 :     diff = TU.monoDiff diff,
102 :     dim = TU.monoDim dim,
103 :     shape = TU.monoShape shape
104 :     }
105 : jhr 4163 | Ty.T_Fun(tys1, ty2) => raise Fail "unexpected T_Fun in Simplify"
106 : jhr 4317 | Ty.T_Error => raise Fail "unexpected T_Error in Simplify"
107 : jhr 3445 (* end case *))
108 : jhr 3437
109 : jhr 3811 fun apiTypeOf x = let
110 : jhr 4317 fun cvtTy STy.T_Bool = APITypes.BoolTy
111 :     | cvtTy STy.T_Int = APITypes.IntTy
112 :     | cvtTy STy.T_String = APITypes.StringTy
113 :     | cvtTy (STy.T_Sequence(ty, len)) = APITypes.SeqTy(cvtTy ty, len)
114 :     | cvtTy (STy.T_Tensor shape) = APITypes.TensorTy shape
115 :     | cvtTy (STy.T_Image info) =
116 :     APITypes.ImageTy(II.dim info, II.voxelShape info)
117 :     | cvtTy ty = raise Fail "bogus API type"
118 :     in
119 :     cvtTy (SimpleVar.typeOf x)
120 :     end
121 : jhr 3811
122 : jhr 4117 fun newTemp (ty as STy.T_Image _) = SimpleVar.new ("img", SimpleVar.LocalVar, ty)
123 :     | newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)
124 : jhr 3437
125 : jhr 4163 (* a property to map AST function variables to SimpleAST functions *)
126 :     local
127 : jhr 5561 fun cvt f = (case Var.monoTypeOf f
128 :     of Ty.T_Fun(paramTys, resTy) =>
129 :     SimpleFunc.new (Var.nameOf f, cvtTy resTy, List.map cvtTy paramTys)
130 :     | ty as Ty.T_Field _ => let
131 :     val ty' as STy.T_Field{dim, shape, ...} = cvtTy ty
132 :     in
133 :     SimpleFunc.newDiff (Var.nameOf f, STy.T_Tensor shape, [STy.T_Tensor[dim]])
134 :     end
135 :     | ty => raise Fail "expected function or field type"
136 :     (* end case *))
137 : jhr 4163 in
138 :     val {getFn = cvtFunc, ...} = Var.newProp cvt
139 :     end
140 :    
141 : jhr 3456 (* a property to map AST variables to SimpleAST variables *)
142 :     local
143 :     fun cvt x = SimpleVar.new (Var.nameOf x, Var.kindOf x, cvtTy(Var.monoTypeOf x))
144 : jhr 4193 val {getFn, setFn, ...} = Var.newProp cvt
145 : jhr 3456 in
146 : jhr 4193 val cvtVar = getFn
147 :     fun newVarWithType (x, ty) = let
148 : jhr 4317 val x' = SimpleVar.new (Var.nameOf x, Var.kindOf x, ty)
149 :     in
150 :     setFn (x, x');
151 :     x'
152 :     end
153 : jhr 3456 end
154 : jhr 3452
155 : jhr 3456 fun cvtVars xs = List.map cvtVar xs
156 : jhr 3452
157 : jhr 3437 (* make a block out of a list of statements that are in reverse order *)
158 : jhr 3501 fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}
159 : jhr 3437
160 : jhr 4393 (* make a variable definition *)
161 :     fun mkDef (x, e) = S.S_Var(x, SOME e)
162 :    
163 : jhr 4394 fun mkRDiv (res, a, b) = mkDef (res, S.E_Prim(BV.div_rr, [], [a, b], STy.realTy))
164 :     fun mkToReal (res, a) =
165 : jhr 4441 mkDef (res, S.E_Coerce{srcTy = STy.T_Int, dstTy = STy.realTy, x = a})
166 : jhr 4394 fun mkLength (res, elemTy, xs) =
167 : jhr 4441 mkDef (res, S.E_Prim(BV.fn_length, [STy.TY elemTy], [xs], STy.T_Int))
168 : jhr 4394
169 : jhr 3437 (* simplify a statement into a single statement (i.e., a block if it expands
170 :     * into more than one new statement).
171 :     *)
172 : jhr 4371 fun simplifyBlock (cxt, stm) = mkBlock (simplifyStmt (cxt, stm, []))
173 : jhr 3437
174 : jhr 4193 (* convert the lhs variable of a var decl or assignment; if the rhs is a LoadImage,
175 :     * then we use the info from the proxy image to determine the type of the lhs
176 :     * variable.
177 :     *)
178 :     and cvtLHS (lhs, S.E_LoadImage(_, _, info)) = newVarWithType(lhs, STy.T_Image info)
179 :     | cvtLHS (lhs, _) = cvtVar lhs
180 :    
181 : jhr 3437 (* simplify the statement stm where stms is a reverse-order list of preceeding simplified
182 :     * statements. This function returns a reverse-order list of simplified statements.
183 :     * Note that error reporting is done in the typechecker, but it does not prune unreachable
184 :     * code.
185 :     *)
186 : jhr 4441 and simplifyStmt (cxt : context, stm, stms) : S.stmt list = (case stm
187 : jhr 3437 of AST.S_Block body => let
188 : jhr 3456 fun simplify ([], stms) = stms
189 : jhr 4371 | simplify (stm::r, stms) = simplify (r, simplifyStmt (cxt, stm, stms))
190 : jhr 3437 in
191 : jhr 3456 simplify (body, stms)
192 : jhr 3437 end
193 : jhr 3452 | AST.S_Decl(x, NONE) => let
194 : jhr 3456 val x' = cvtVar x
195 : jhr 3452 in
196 : jhr 3465 S.S_Var(x', NONE) :: stms
197 : jhr 3452 end
198 :     | AST.S_Decl(x, SOME e) => let
199 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, stms)
200 : jhr 4193 val x' = cvtLHS (x, e')
201 : jhr 3437 in
202 : jhr 3465 S.S_Var(x', SOME e') :: stms
203 : jhr 3437 end
204 : jhr 4310 (* FIXME: we should also define a "boolean negate" operation on AST expressions so that we can
205 : jhr 4254 * handle both cases!
206 :     *)
207 :     | AST.S_IfThenElse(AST.E_Orelse(e1, e2), s1 as AST.S_Block[], s2) =>
208 : jhr 4371 simplifyStmt (cxt, AST.S_IfThenElse(e1, s1, AST.S_IfThenElse(e2, s1, s2)), stms)
209 : jhr 4254 | AST.S_IfThenElse(AST.E_Andalso(e1, e2), s1, s2 as AST.S_Block[]) =>
210 : jhr 4371 simplifyStmt (cxt, AST.S_IfThenElse(e1, AST.S_IfThenElse(e2, s1, s2), s2), stms)
211 : jhr 3437 | AST.S_IfThenElse(e, s1, s2) => let
212 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
213 :     val s1 = simplifyBlock (cxt, s1)
214 :     val s2 = simplifyBlock (cxt, s2)
215 : jhr 3437 in
216 : jhr 3456 S.S_IfThenElse(x, s1, s2) :: stms
217 : jhr 3437 end
218 : jhr 4317 | AST.S_Foreach((x, e), body) => let
219 : jhr 4371 val (stms, xs') = simplifyExpToVar (cxt, e, stms)
220 :     val body' = simplifyBlock (cxt, body)
221 : jhr 4317 in
222 :     S.S_Foreach(cvtVar x, xs', body') :: stms
223 :     end
224 : jhr 3452 | AST.S_Assign((x, _), e) => let
225 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, stms)
226 : jhr 4193 val x' = cvtLHS (x, e')
227 : jhr 3437 in
228 : jhr 4193 S.S_Assign(x', e') :: stms
229 : jhr 3437 end
230 :     | AST.S_New(name, args) => let
231 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
232 : jhr 3437 in
233 : jhr 3456 S.S_New(name, xs) :: stms
234 : jhr 3437 end
235 : jhr 4628 | AST.S_KillAll => S.S_KillAll :: stms
236 : jhr 4480 | AST.S_StabilizeAll => S.S_StabilizeAll :: stms
237 : jhr 3456 | AST.S_Continue => S.S_Continue :: stms
238 :     | AST.S_Die => S.S_Die :: stms
239 :     | AST.S_Stabilize => S.S_Stabilize :: stms
240 : jhr 3437 | AST.S_Return e => let
241 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
242 : jhr 3437 in
243 : jhr 3456 S.S_Return x :: stms
244 : jhr 3437 end
245 :     | AST.S_Print args => let
246 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
247 : jhr 3437 in
248 : jhr 3456 S.S_Print xs :: stms
249 : jhr 3437 end
250 :     (* end case *))
251 :    
252 : jhr 4371 and simplifyExp (cxt, exp, stms) = let
253 : jhr 4317 fun doBorderCtl (f, args) = let
254 : jhr 4359 val (ctl, arg) = if Var.same(BV.image_border, f)
255 : jhr 4317 then (BorderCtl.Default(hd args), hd(tl args))
256 : jhr 4359 else if Var.same(BV.image_clamp, f)
257 : jhr 4317 then (BorderCtl.Clamp, hd args)
258 : jhr 4359 else if Var.same(BV.image_mirror, f)
259 : jhr 4317 then (BorderCtl.Mirror, hd args)
260 : jhr 4359 else if Var.same(BV.image_wrap, f)
261 : jhr 4317 then (BorderCtl.Wrap, hd args)
262 :     else raise Fail "impossible"
263 :     in
264 :     S.E_BorderCtl(ctl, arg)
265 :     end
266 :     fun doPrimApply (f, tyArgs, args, ty) = let
267 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
268 : jhr 4378 fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))
269 :     | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))
270 :     | cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))
271 :     | cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))
272 : jhr 4317 in
273 :     if Basis.isBorderCtl f
274 :     then (stms, doBorderCtl (f, xs))
275 : jhr 4368 else if Var.same(f, BV.fn_sphere_im)
276 : jhr 4371 then let
277 : jhr 4378 (* get the strand type for the query *)
278 :     val tyArgs as [S.TY(STy.T_Strand strand)] = List.map cvtTyArg tyArgs
279 :     (* get the strand environment for the strand *)
280 : jhr 4441 val SOME sEnv = findStrand(cxt, strand)
281 : jhr 4378 fun result (query, pos) =
282 :     (stms, S.E_Prim(query, tyArgs, cvtVar pos::xs, cvtTy ty))
283 :     in
284 :     (* extract the position variable and spatial dimension *)
285 :     case (StrandEnv.findPosVar sEnv, StrandEnv.getSpaceDim sEnv)
286 :     of (SOME pos, SOME 1) => result (BV.fn_sphere1_r, pos)
287 :     | (SOME pos, SOME 2) => result (BV.fn_sphere2_t, pos)
288 :     | (SOME pos, SOME 3) => result (BV.fn_sphere3_t, pos)
289 :     | _ => raise Fail "impossible"
290 :     (* end case *)
291 :     end
292 : jhr 4317 else (case Var.kindOf f
293 :     of Var.BasisVar => let
294 :     val tyArgs = List.map cvtTyArg tyArgs
295 :     in
296 :     (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))
297 :     end
298 :     | _ => raise Fail "bogus prim application"
299 :     (* end case *))
300 :     end
301 : jhr 4589 fun doCoerce (srcTy, dstTy, e, stms) = let
302 :     val (stms, x) = simplifyExpToVar (cxt, e, stms)
303 :     val dstTy = cvtTy dstTy
304 :     val result = newTemp dstTy
305 :     val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}
306 :     in
307 :     (S.S_Var(result, SOME rhs)::stms, S.E_Var result)
308 :     end
309 : jhr 4317 in
310 :     case exp
311 :     of AST.E_Var(x, _) => (case Var.kindOf x
312 :     of Var.BasisVar => let
313 :     val ty = cvtTy(Var.monoTypeOf x)
314 :     val x' = newTemp ty
315 :     val stm = S.S_Var(x', SOME(S.E_Prim(x, [], [], ty)))
316 :     in
317 :     (stm::stms, S.E_Var x')
318 :     end
319 : jhr 4441 | Var.ConstVar => (case findConst(cxt, x)
320 :     of SOME e => let
321 :     val (stms, x') = simplifyExpToVar (cxt, e, stms)
322 :     in
323 :     (stms, S.E_Var x')
324 :     end
325 :     | NONE => (stms, S.E_Var(cvtVar x))
326 :     (* end case *))
327 : jhr 4317 | _ => (stms, S.E_Var(cvtVar x))
328 :     (* end case *))
329 :     | AST.E_Lit lit => (stms, S.E_Lit lit)
330 :     | AST.E_Kernel h => (stms, S.E_Kernel h)
331 :     | AST.E_Select(e, (fld, _)) => let
332 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
333 : jhr 4317 in
334 :     (stms, S.E_Select(x, cvtVar fld))
335 :     end
336 :     | AST.E_Prim(rator, tyArgs, args as [e], ty) => (case e
337 : jhr 4359 of AST.E_Lit(Literal.Int n) => if Var.same(BV.neg_i, rator)
338 : jhr 4317 then (stms, S.E_Lit(Literal.Int(~n))) (* constant-fold negation of integer literals *)
339 :     else doPrimApply (rator, tyArgs, args, ty)
340 :     | AST.E_Lit(Literal.Real f) =>
341 : jhr 4359 if Var.same(BV.neg_t, rator)
342 : jhr 4317 then (stms, S.E_Lit(Literal.Real(RealLit.negate f))) (* constant-fold negation of real literals *)
343 :     else doPrimApply (rator, tyArgs, args, ty)
344 : jhr 4364 (* QUESTION: is there common code in handling a reduction over a sequence of strands vs. over a strand set? *)
345 : jhr 4317 | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator
346 :     then let
347 : jhr 4371 val (stms, xs) = simplifyExpToVar (cxt, e'', stms)
348 : jhr 4393 val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e', [])
349 : jhr 4317 val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
350 : jhr 4441 fun mkReductionLoop (redOp, bodyStms, bodyResult, stms) = let
351 :     val {rator, init, mvs} = Util.reductionInfo redOp
352 :     val acc = SimpleVar.new ("accum", Var.LocalVar, cvtTy ty)
353 :     val initStm = S.S_Var(acc, SOME(S.E_Lit init))
354 :     val updateStm = S.S_Assign(acc,
355 :     S.E_Prim(rator, mvs, [acc, bodyResult], seqTy'))
356 :     val foreachStm = S.S_Foreach(cvtVar x, xs,
357 :     mkBlock(updateStm :: bodyStms))
358 :     in
359 :     (foreachStm :: initStm :: stms, S.E_Var acc)
360 :     end
361 :     in
362 :     case Util.identifyReduction rator
363 :     of Util.MEAN => let
364 :     val (stms, S.E_Var resultV) = mkReductionLoop (
365 : jhr 4588 Reductions.RSUM, bodyStms, bodyResult, stms)
366 : jhr 4441 val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
367 :     val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
368 :     val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
369 :     val stms =
370 :     mkRDiv (mean, resultV, rNum) ::
371 :     mkToReal (rNum, num) ::
372 :     mkLength (num, elemTy, xs) ::
373 :     stms
374 :     in
375 :     (stms, S.E_Var mean)
376 :     end
377 :     | Util.VARIANCE => raise Fail "FIXME: VARIANCE"
378 :     | Util.RED red => mkReductionLoop (red, bodyStms, bodyResult, stms)
379 :     (* end case *)
380 :     end
381 : jhr 4317 else doPrimApply (rator, tyArgs, args, ty)
382 :     | AST.E_ParallelMap(e', x, xs, _) =>
383 :     if Basis.isReductionOp rator
384 :     then let
385 : jhr 4394 val (result, stms) = simplifyReduction (cxt, rator, e', x, xs, ty, stms)
386 : jhr 4317 in
387 : jhr 4394 (stms, S.E_Var result)
388 : jhr 4317 end
389 :     else raise Fail "unsupported operation on parallel map"
390 :     | _ => doPrimApply (rator, tyArgs, args, ty)
391 :     (* end case *))
392 :     | AST.E_Prim(f, tyArgs, args, ty) => doPrimApply (f, tyArgs, args, ty)
393 :     | AST.E_Apply((f, _), args, ty) => let
394 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
395 : jhr 4317 in
396 :     case Var.kindOf f
397 :     of Var.FunVar => (stms, S.E_Apply(SimpleFunc.use(cvtFunc f), xs))
398 :     | _ => raise Fail "bogus application"
399 :     (* end case *)
400 :     end
401 :     | AST.E_Comprehension(e, (x, e'), seqTy) => let
402 :     (* convert a comprehension to a foreach loop over the sequence defined by e' *)
403 : jhr 4371 val (stms, xs) = simplifyExpToVar (cxt, e', stms)
404 :     val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
405 : jhr 4317 val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
406 :     val acc = SimpleVar.new ("accum", Var.LocalVar, seqTy')
407 :     val initStm = S.S_Var(acc, SOME(S.E_Seq([], seqTy')))
408 :     val updateStm = S.S_Assign(acc,
409 : jhr 4359 S.E_Prim(BV.at_dT, [S.TY elemTy], [acc, bodyResult], seqTy'))
410 : jhr 4317 val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))
411 :     in
412 :     (foreachStm :: initStm :: stms, S.E_Var acc)
413 :     end
414 : jhr 5200 | AST.E_ParallelMap(e, x, xs, ty) =>
415 :     (* a map over a strand set without a reduction should be converted to a
416 :     * foreach loop.
417 :     *)
418 : jhr 5199 raise Fail "FIXME: unexpected ParallelMap without reduction"
419 : jhr 4317 | AST.E_Tensor(es, ty) => let
420 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, es, stms)
421 : jhr 4317 in
422 :     (stms, S.E_Tensor(xs, cvtTy ty))
423 :     end
424 : jhr 5574 | AST.E_Field(es, ty) => let
425 :     val (stms, xs) = simplifyExpsToVars (cxt, es, stms)
426 :     in
427 :     (stms, S.E_Field(xs, cvtTy ty))
428 :     end
429 : jhr 4317 | AST.E_Seq(es, ty) => let
430 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, es, stms)
431 : jhr 4317 in
432 :     (stms, S.E_Seq(xs, cvtTy ty))
433 :     end
434 :     | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
435 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
436 : jhr 4317 fun f NONE = NONE
437 :     | f (SOME(AST.E_Lit(Literal.Int i))) = SOME(Int.fromLarge i)
438 :     | f _ = raise Fail "expected integer literal in slice"
439 :     val indices = List.map f indices
440 :     in
441 :     (stms, S.E_Slice(x, indices, cvtTy ty))
442 :     end
443 :     | AST.E_Cond(e1, e2, e3, ty) => let
444 :     (* a conditional expression gets turned into an if-then-else statememt *)
445 :     val result = newTemp(cvtTy ty)
446 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e1, S.S_Var(result, NONE) :: stms)
447 : jhr 4317 fun simplifyBranch e = let
448 : jhr 4371 val (stms, e) = simplifyExp (cxt, e, [])
449 : jhr 4317 in
450 :     mkBlock (S.S_Assign(result, e)::stms)
451 :     end
452 :     val s1 = simplifyBranch e2
453 :     val s2 = simplifyBranch e3
454 :     in
455 :     (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
456 :     end
457 :     | AST.E_Orelse(e1, e2) => simplifyExp (
458 : jhr 4371 cxt,
459 : jhr 4317 AST.E_Cond(e1, AST.E_Lit(Literal.Bool true), e2, Ty.T_Bool),
460 :     stms)
461 :     | AST.E_Andalso(e1, e2) => simplifyExp (
462 : jhr 4371 cxt,
463 : jhr 4317 AST.E_Cond(e1, e2, AST.E_Lit(Literal.Bool false), Ty.T_Bool),
464 :     stms)
465 :     | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty
466 :     of ty as STy.T_Sequence(_, NONE) => (stms, S.E_LoadSeq(ty, nrrd))
467 :     | ty as STy.T_Image info => let
468 :     val dim = II.dim info
469 :     val shape = II.voxelShape info
470 :     in
471 : jhr 4441 case getNrrdInfo (cxt, nrrd)
472 : jhr 4317 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
473 :     of NONE => (
474 : jhr 4441 badImageNrrd (cxt, nrrd, nrrdInfo, dim, shape);
475 : jhr 4317 (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
476 :     | SOME imgInfo =>
477 :     (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))
478 :     (* end case *))
479 :     | NONE => (
480 : jhr 4433 error (cxt, [
481 :     "proxy-image file \"", nrrd, "\" does not exist"
482 : jhr 4317 ]);
483 :     (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
484 :     (* end case *)
485 :     end
486 :     | _ => raise Fail "bogus type for E_LoadNrrd"
487 :     (* end case *))
488 :     | AST.E_Coerce{dstTy, e=AST.E_Lit(Literal.Int n), ...} => (case cvtTy dstTy
489 :     of SimpleTypes.T_Tensor[] => (stms, S.E_Lit(Literal.Real(RealLit.fromInt n)))
490 :     | _ => raise Fail "impossible: bad coercion"
491 :     (* end case *))
492 : jhr 4589 | AST.E_Coerce{dstTy, srcTy, e as AST.E_Seq(es, ty)} => let
493 :     val Ty.T_Sequence(dstTy', dstBnd) = TU.prune dstTy
494 :     val Ty.T_Sequence(srcTy', srcBnd) = TU.prune srcTy
495 :     in
496 :     if STy.same(cvtTy dstTy', cvtTy srcTy')
497 :     then (* static-size to dynamic coercion *)
498 :     doCoerce (srcTy, dstTy, e, stms)
499 : jhr 5561 else let
500 : jhr 4589 (* distribute the coercion over the sequence elements *)
501 :     val es = List.map
502 :     (fn e => AST.E_Coerce{dstTy=dstTy', srcTy=srcTy', e=e})
503 :     es
504 :     val eTy = Ty.T_Sequence(dstTy', srcBnd)
505 :     val e = AST.E_Seq(es, eTy)
506 :     in
507 :     if Option.isSome dstBnd
508 :     then simplifyExp (cxt, e, stms)
509 :     else simplifyExp (cxt, AST.E_Coerce{dstTy=dstTy, srcTy=eTy, e=e}, stms)
510 :     end
511 :     end
512 : jhr 4529 | AST.E_Coerce{srcTy, dstTy, e} => doCoerce (srcTy, dstTy, e, stms)
513 : jhr 4317 (* end case *)
514 :     end
515 : jhr 3437
516 : jhr 4371 and simplifyExpToVar (cxt, exp, stms) = let
517 :     val (stms, e) = simplifyExp (cxt, exp, stms)
518 : jhr 3437 in
519 :     case e
520 :     of S.E_Var x => (stms, x)
521 :     | _ => let
522 :     val x = newTemp (S.typeOf e)
523 :     in
524 : jhr 3465 (S.S_Var(x, SOME e)::stms, x)
525 : jhr 3437 end
526 :     (* end case *)
527 :     end
528 :    
529 : jhr 4371 and simplifyExpsToVars (cxt, exps, stms) = let
530 : jhr 3437 fun f ([], xs, stms) = (stms, List.rev xs)
531 :     | f (e::es, xs, stms) = let
532 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
533 : jhr 3437 in
534 :     f (es, x::xs, stms)
535 :     end
536 :     in
537 :     f (exps, [], stms)
538 :     end
539 :    
540 : jhr 4591 (* `simplifyReduction (cxt, rator, e, x, xs, resTy, stms)`
541 :     * simplify a parallel map-reduce, where `e` is the body of the map, `x` is the
542 :     * strand variable, and `xs` is the source of strands.
543 :     *)
544 : jhr 4394 and simplifyReduction (cxt, rator, e, x, xs, resTy, stms) = let
545 : jhr 4378 val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)
546 : jhr 4368 val x' = cvtVar x
547 : jhr 4371 val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
548 : jhr 4378 (* convert the domain from a variable to a StrandSets.t value *)
549 :     val domain = if Var.same(BV.set_active, xs) then StrandSets.ACTIVE
550 :     else if Var.same(BV.set_all, xs) then StrandSets.ALL
551 :     else if Var.same(BV.set_stable, xs) then StrandSets.STABLE
552 :     else raise Fail "impossible: not a strand set"
553 : jhr 4368 val (func, args) = Util.makeFunction(
554 : jhr 4591 Var.nameOf rator, x', mkBlock(S.S_Return bodyResult :: bodyStms),
555 : jhr 4368 SimpleVar.typeOf bodyResult)
556 : jhr 4441 in
557 :     case Util.identifyReduction rator
558 :     of Util.MEAN => let
559 :     val mapReduceStm = S.S_MapReduce[
560 :     S.MapReduce{
561 : jhr 4588 result = result, reduction = Reductions.RSUM, mapf = func,
562 : jhr 4441 args = args, source = x', domain = domain
563 :     }]
564 :     val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
565 :     val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
566 :     val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
567 :     val numStrandsOp = (case domain
568 : jhr 4588 of StrandSets.ACTIVE => BV.fn_numActive
569 :     | StrandSets.ALL => BV.fn_numStrands
570 :     | StrandSets.STABLE => BV.fn_numStable
571 : jhr 4441 (* end case *))
572 :     val stms =
573 :     mkRDiv (mean, result, rNum) ::
574 :     mkToReal (rNum, num) ::
575 :     mkDef (num, S.E_Prim(numStrandsOp, [], [], STy.T_Int)) ::
576 :     mapReduceStm ::
577 :     stms
578 :     in
579 :     (mean, stms)
580 :     end
581 :     | Util.VARIANCE => raise Fail "FIXME: variance reduction"
582 :     | Util.RED rator' => let
583 :     val mapReduceStm = S.S_MapReduce[
584 :     S.MapReduce{
585 :     result = result, reduction = rator', mapf = func, args = args,
586 :     source = x', domain = domain
587 :     }]
588 :     in
589 :     (result, mapReduceStm :: stms)
590 :     end
591 :     (* end case *)
592 :     end
593 : jhr 4394
594 : jhr 4113 (* simplify a block and then prune unreachable and dead code *)
595 : jhr 4371 fun simplifyAndPruneBlock cxt blk =
596 :     DeadCode.eliminate (simplifyBlock (cxt, blk))
597 : jhr 4113
598 : jhr 4371 fun simplifyStrand (cxt, strand) = let
599 : jhr 4368 val AST.Strand{
600 : jhr 4494 name, params, spatialDim, state, stateInit, startM, updateM, stabilizeM
601 : jhr 4368 } = strand
602 : jhr 3456 val params' = cvtVars params
603 : jhr 4505 fun simplifyState ([], xs, stms) = (List.rev xs, stms)
604 : jhr 3456 | simplifyState ((x, optE) :: r, xs, stms) = let
605 :     val x' = cvtVar x
606 : jhr 4317 in
607 :     case optE
608 :     of NONE => simplifyState (r, x'::xs, stms)
609 :     | SOME e => let
610 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, stms)
611 : jhr 4317 in
612 :     simplifyState (r, x'::xs, S.S_Assign(x', e') :: stms)
613 :     end
614 :     (* end case *)
615 :     end
616 : jhr 4505 val (xs, stateInit) = let
617 : jhr 4589 (* simplify the state-variable initializations *)
618 :     val (xs, stms) = simplifyState (state, [], [])
619 :     (* simplify optional "initialize" block *)
620 :     val blk = (case stateInit
621 :     of SOME stm => mkBlock (simplifyStmt (cxt, stm, stms))
622 :     | NONE => mkBlock stms
623 :     (* end case *))
624 :     in
625 :     (xs, blk)
626 :     end
627 : jhr 3437 in
628 : jhr 3452 S.Strand{
629 :     name = name,
630 :     params = params',
631 : jhr 4368 spatialDim = spatialDim,
632 : jhr 3456 state = xs,
633 : jhr 4505 stateInit = stateInit,
634 : jhr 4494 startM = Option.map (simplifyAndPruneBlock cxt) startM,
635 : jhr 4371 updateM = simplifyAndPruneBlock cxt updateM,
636 :     stabilizeM = Option.map (simplifyAndPruneBlock cxt) stabilizeM
637 : jhr 3452 }
638 : jhr 3437 end
639 :    
640 : jhr 4371 fun transform (errStrm, prog, gEnv) = let
641 : jhr 4317 val AST.Program{
642 : jhr 4494 props, const_dcls, input_dcls, globals, globInit, strand, create, start, update
643 : jhr 4317 } = prog
644 :     val consts' = ref[]
645 :     val constInit = ref[]
646 :     val inputs' = ref[]
647 :     val globals' = ref[]
648 :     val globalInit = ref[]
649 :     val funcs = ref[]
650 : jhr 4441 (* simplify the constant dcls: the small constants will be added to the context
651 :     * while the large constants will be added to the const' list.
652 :     *)
653 :     val cxt = let
654 :     val cxt = Cxt{errStrm = errStrm, gEnv = gEnv, cEnv = VMap.empty}
655 :     fun simplifyConstDcl ((x, SOME e), cxt) = if Util.isSmallExp e
656 :     then insertConst (cxt, x, e)
657 :     else let
658 :     val (stms, e') = simplifyExp (cxt, e, [])
659 :     val x' = cvtVar x
660 :     in
661 :     consts' := x' :: !consts';
662 :     constInit := S.S_Assign(x', e') :: (stms @ !constInit);
663 :     cxt
664 :     end
665 :     | simplifyConstDcl _ = raise Fail "impossble"
666 : jhr 4317 in
667 : jhr 4441 List.foldl simplifyConstDcl cxt const_dcls
668 : jhr 4317 end
669 :     fun simplifyInputDcl ((x, NONE), desc) = let
670 :     val x' = cvtVar x
671 :     val init = (case SimpleVar.typeOf x'
672 : jhr 4432 of STy.T_Image info => (
673 : jhr 4441 warning(cxt, [
674 :     "assuming a sample type of ", RawTypes.toString(II.sampleTy info),
675 :     " for '", SimpleVar.nameOf x',
676 :     "'; specify a proxy-image file to override the default sample type"
677 :     ]);
678 :     S.Image info)
679 : jhr 4317 | _ => S.NoDefault
680 :     (* end case *))
681 :     val inp = S.INP{
682 :     var = x',
683 :     name = Var.nameOf x,
684 :     ty = apiTypeOf x',
685 :     desc = desc,
686 :     init = init
687 :     }
688 :     in
689 :     inputs' := inp :: !inputs'
690 :     end
691 :     | simplifyInputDcl ((x, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))), desc) = let
692 :     val (x', init) = (case Var.monoTypeOf x
693 :     of Ty.T_Sequence(_, NONE) => (cvtVar x, S.LoadSeq nrrd)
694 :     | Ty.T_Image{dim, shape} => let
695 :     val dim = TU.monoDim dim
696 :     val shape = TU.monoShape shape
697 :     in
698 : jhr 4441 case getNrrdInfo (cxt, nrrd)
699 : jhr 4317 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
700 :     of NONE => (
701 : jhr 4441 badImageNrrd (cxt, nrrd, nrrdInfo, dim, shape);
702 : jhr 4317 (cvtVar x, S.Image(II.mkInfo(dim, shape))))
703 :     | SOME info =>
704 :     (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))
705 :     (* end case *))
706 :     | NONE => (
707 : jhr 4433 error (cxt, [
708 :     "proxy-image file \"", nrrd, "\" does not exist"
709 : jhr 4317 ]);
710 :     (cvtVar x, S.Image(II.mkInfo(dim, shape))))
711 :     (* end case *)
712 :     end
713 :     | _ => raise Fail "impossible"
714 :     (* end case *))
715 :     val inp = S.INP{
716 :     var = x',
717 :     name = Var.nameOf x,
718 :     ty = apiTypeOf x',
719 :     desc = desc,
720 :     init = init
721 :     }
722 :     in
723 :     inputs' := inp :: !inputs'
724 :     end
725 :     | simplifyInputDcl ((x, SOME e), desc) = let
726 :     val x' = cvtVar x
727 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, [])
728 : jhr 4317 val inp = S.INP{
729 :     var = x',
730 :     name = Var.nameOf x,
731 :     ty = apiTypeOf x',
732 :     desc = desc,
733 :     init = S.ConstExpr
734 :     }
735 :     in
736 :     inputs' := inp :: !inputs';
737 :     constInit := S.S_Assign(x', e') :: (stms @ !constInit)
738 :     end
739 : jhr 5561 (* simplify a global declaration *)
740 : jhr 4317 fun simplifyGlobalDcl (AST.D_Var(x, NONE)) = globals' := cvtVar x :: !globals'
741 :     | simplifyGlobalDcl (AST.D_Var(x, SOME e)) = let
742 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, [])
743 : jhr 4317 val x' = cvtLHS (x, e')
744 :     in
745 :     globals' := x' :: !globals';
746 :     globalInit := S.S_Assign(x', e') :: (stms @ !globalInit)
747 :     end
748 :     | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let
749 :     val f' = cvtFunc f
750 :     val params' = cvtVars params
751 : jhr 4371 val body' = simplifyAndPruneBlock cxt body
752 : jhr 4317 in
753 :     funcs := S.Func{f=f', params=params', body=body'} :: !funcs
754 :     end
755 : jhr 5561 | simplifyGlobalDcl (AST.D_DiffFunc(f, params, body)) = let
756 : jhr 5565 (* differentiable field function: we map it to both a function definition and
757 :     * a field variable.
758 :     *)
759 : jhr 5564 val vf = cvtVar f
760 :     val f' = SimpleFunc.use(cvtFunc f)
761 : jhr 5561 val params' = cvtVars params
762 :     val body' = simplifyAndPruneBlock cxt (AST.S_Return body)
763 :     in
764 : jhr 5564 funcs := S.Func{f=f', params=params', body=body'} :: !funcs;
765 :     globals' := vf :: !globals';
766 :     globalInit := S.S_Assign(vf, S.E_FieldFn f') :: !globalInit
767 : jhr 5561 end
768 : jhr 4317 val () = (
769 :     List.app simplifyInputDcl input_dcls;
770 :     List.app simplifyGlobalDcl globals)
771 : jhr 5104 (* check if there no remaining constants *)
772 :     val props = if List.null(!consts')
773 :     then Properties.clearProp Properties.HasConsts props
774 :     else props
775 : jhr 4317 (* make the global-initialization block *)
776 :     val globInit = (case globInit
777 : jhr 4371 of SOME stm => mkBlock (simplifyStmt (cxt, stm, !globalInit))
778 : jhr 4317 | NONE => mkBlock (!globalInit)
779 :     (* end case *))
780 :     (* if the globInit block is non-empty, record the fact in the property list *)
781 :     val props = (case globInit
782 :     of S.Block{code=[], ...} => props
783 :     | _ => Properties.GlobalInit :: props
784 :     (* end case *))
785 : jhr 3452 in
786 :     S.Program{
787 :     props = props,
788 : jhr 4317 consts = List.rev(!consts'),
789 : jhr 3452 inputs = List.rev(!inputs'),
790 : jhr 4317 constInit = mkBlock (!constInit),
791 : jhr 3452 globals = List.rev(!globals'),
792 : jhr 3995 globInit = globInit,
793 : jhr 3452 funcs = List.rev(!funcs),
794 : jhr 4371 strand = simplifyStrand (cxt, strand),
795 :     create = Create.map (simplifyAndPruneBlock cxt) create,
796 : jhr 4494 start = Option.map (simplifyAndPruneBlock cxt) start,
797 : jhr 4371 update = Option.map (simplifyAndPruneBlock cxt) update
798 : jhr 3452 }
799 :     end
800 :    
801 : jhr 3437 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0