Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/simplify/simplify.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4428 - (view) (download)

1 : jhr 3437 (* simplify.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 : jhr 3468 * Simplify the AST representation. This phase involves the following transformations:
9 :     *
10 : jhr 4317 * - types are simplified by removing meta variables (which will have been resolved)
11 : jhr 3468 *
12 : jhr 4317 * - expressions are simplified to involve a single operation on variables
13 : jhr 3468 *
14 : jhr 4317 * - global reductions are converted to MapReduce statements
15 : jhr 3468 *
16 : jhr 4317 * - other comprehensions and reductions are converted to foreach loops
17 : jhr 3468 *
18 : jhr 4317 * - unreachable code is pruned
19 : jhr 3468 *
20 : jhr 4317 * - negation of literal integers and reals are constant folded
21 : jhr 3437 *)
22 :    
23 :     structure Simplify : sig
24 :    
25 : jhr 4371 val transform : Error.err_stream * AST.program * GlobalEnv.t -> Simple.program
26 : jhr 3437
27 :     end = struct
28 :    
29 :     structure TU = TypeUtil
30 :     structure S = Simple
31 : jhr 3445 structure STy = SimpleTypes
32 :     structure Ty = Types
33 : jhr 3437 structure VMap = Var.Map
34 : jhr 4193 structure II = ImageInfo
35 : jhr 4359 structure BV = BasisVars
36 : jhr 3437
37 : jhr 4371 (* context for simplification *)
38 :     type context = {errStrm : Error.err_stream, gEnv : GlobalEnv.t}
39 :    
40 :     fun error ({errStrm, gEnv}, msg) = Error.error (errStrm, msg)
41 :     fun warning ({errStrm, gEnv}, msg) = Error.warning (errStrm, msg)
42 :    
43 : jhr 4428 fun imageTyToString (dim, shp) = String.concat[
44 :     "image(", Int.toString dim, ")[", String.concatWithMap "," Int.toString shp, "]"
45 :     ]
46 :    
47 : jhr 3445 (* convert a Types.ty to a SimpleTypes.ty *)
48 :     fun cvtTy ty = (case ty
49 :     of Ty.T_Var(Ty.TV{bind, ...}) => (case !bind
50 :     of NONE => raise Fail "unresolved type variable"
51 :     | SOME ty => cvtTy ty
52 :     (* end case *))
53 :     | Ty.T_Bool => STy.T_Bool
54 :     | Ty.T_Int => STy.T_Int
55 :     | Ty.T_String => STy.T_String
56 :     | Ty.T_Sequence(ty, NONE) => STy.T_Sequence(cvtTy ty, NONE)
57 : jhr 3452 | Ty.T_Sequence(ty, SOME dim) => STy.T_Sequence(cvtTy ty, SOME(TU.monoDim dim))
58 : jhr 4317 | Ty.T_Strand id => STy.T_Strand id
59 : jhr 4207 | Ty.T_Kernel _ => STy.T_Kernel
60 : jhr 3445 | Ty.T_Tensor shape => STy.T_Tensor(TU.monoShape shape)
61 : jhr 4193 | Ty.T_Image{dim, shape} =>
62 : jhr 4317 STy.T_Image(II.mkInfo(TU.monoDim dim, TU.monoShape shape))
63 : jhr 3445 | Ty.T_Field{diff, dim, shape} => STy.T_Field{
64 :     diff = TU.monoDiff diff,
65 :     dim = TU.monoDim dim,
66 :     shape = TU.monoShape shape
67 :     }
68 : jhr 4163 | Ty.T_Fun(tys1, ty2) => raise Fail "unexpected T_Fun in Simplify"
69 : jhr 4317 | Ty.T_Error => raise Fail "unexpected T_Error in Simplify"
70 : jhr 3445 (* end case *))
71 : jhr 3437
72 : jhr 3811 fun apiTypeOf x = let
73 : jhr 4317 fun cvtTy STy.T_Bool = APITypes.BoolTy
74 :     | cvtTy STy.T_Int = APITypes.IntTy
75 :     | cvtTy STy.T_String = APITypes.StringTy
76 :     | cvtTy (STy.T_Sequence(ty, len)) = APITypes.SeqTy(cvtTy ty, len)
77 :     | cvtTy (STy.T_Tensor shape) = APITypes.TensorTy shape
78 :     | cvtTy (STy.T_Image info) =
79 :     APITypes.ImageTy(II.dim info, II.voxelShape info)
80 :     | cvtTy ty = raise Fail "bogus API type"
81 :     in
82 :     cvtTy (SimpleVar.typeOf x)
83 :     end
84 : jhr 3811
85 : jhr 4117 fun newTemp (ty as STy.T_Image _) = SimpleVar.new ("img", SimpleVar.LocalVar, ty)
86 :     | newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)
87 : jhr 3437
88 : jhr 4163 (* a property to map AST function variables to SimpleAST functions *)
89 :     local
90 :     fun cvt x = let
91 : jhr 4317 val Ty.T_Fun(paramTys, resTy) = Var.monoTypeOf x
92 :     in
93 :     SimpleFunc.new (Var.nameOf x, cvtTy resTy, List.map cvtTy paramTys)
94 :     end
95 : jhr 4163 in
96 :     val {getFn = cvtFunc, ...} = Var.newProp cvt
97 :     end
98 :    
99 : jhr 3456 (* a property to map AST variables to SimpleAST variables *)
100 :     local
101 :     fun cvt x = SimpleVar.new (Var.nameOf x, Var.kindOf x, cvtTy(Var.monoTypeOf x))
102 : jhr 4193 val {getFn, setFn, ...} = Var.newProp cvt
103 : jhr 3456 in
104 : jhr 4193 val cvtVar = getFn
105 :     fun newVarWithType (x, ty) = let
106 : jhr 4317 val x' = SimpleVar.new (Var.nameOf x, Var.kindOf x, ty)
107 :     in
108 :     setFn (x, x');
109 :     x'
110 :     end
111 : jhr 3456 end
112 : jhr 3452
113 : jhr 3456 fun cvtVars xs = List.map cvtVar xs
114 : jhr 3452
115 : jhr 3437 (* make a block out of a list of statements that are in reverse order *)
116 : jhr 3501 fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}
117 : jhr 3437
118 : jhr 4393 (* make a variable definition *)
119 :     fun mkDef (x, e) = S.S_Var(x, SOME e)
120 :    
121 : jhr 4394 fun mkRDiv (res, a, b) = mkDef (res, S.E_Prim(BV.div_rr, [], [a, b], STy.realTy))
122 :     fun mkToReal (res, a) =
123 :     mkDef (res, S.E_Coerce{srcTy = STy.T_Int, dstTy = STy.realTy, x = a})
124 :     fun mkLength (res, elemTy, xs) =
125 :     mkDef (res, S.E_Prim(BV.fn_length, [STy.TY elemTy], [xs], STy.T_Int))
126 :    
127 : jhr 3437 (* simplify a statement into a single statement (i.e., a block if it expands
128 :     * into more than one new statement).
129 :     *)
130 : jhr 4371 fun simplifyBlock (cxt, stm) = mkBlock (simplifyStmt (cxt, stm, []))
131 : jhr 3437
132 : jhr 4193 (* convert the lhs variable of a var decl or assignment; if the rhs is a LoadImage,
133 :     * then we use the info from the proxy image to determine the type of the lhs
134 :     * variable.
135 :     *)
136 :     and cvtLHS (lhs, S.E_LoadImage(_, _, info)) = newVarWithType(lhs, STy.T_Image info)
137 :     | cvtLHS (lhs, _) = cvtVar lhs
138 :    
139 : jhr 3437 (* simplify the statement stm where stms is a reverse-order list of preceeding simplified
140 :     * statements. This function returns a reverse-order list of simplified statements.
141 :     * Note that error reporting is done in the typechecker, but it does not prune unreachable
142 :     * code.
143 :     *)
144 : jhr 4371 and simplifyStmt (cxt, stm, stms) : S.stmt list = (case stm
145 : jhr 3437 of AST.S_Block body => let
146 : jhr 3456 fun simplify ([], stms) = stms
147 : jhr 4371 | simplify (stm::r, stms) = simplify (r, simplifyStmt (cxt, stm, stms))
148 : jhr 3437 in
149 : jhr 3456 simplify (body, stms)
150 : jhr 3437 end
151 : jhr 3452 | AST.S_Decl(x, NONE) => let
152 : jhr 3456 val x' = cvtVar x
153 : jhr 3452 in
154 : jhr 3465 S.S_Var(x', NONE) :: stms
155 : jhr 3452 end
156 :     | AST.S_Decl(x, SOME e) => let
157 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, stms)
158 : jhr 4193 val x' = cvtLHS (x, e')
159 : jhr 3437 in
160 : jhr 3465 S.S_Var(x', SOME e') :: stms
161 : jhr 3437 end
162 : jhr 4310 (* FIXME: we should also define a "boolean negate" operation on AST expressions so that we can
163 : jhr 4254 * handle both cases!
164 :     *)
165 :     | AST.S_IfThenElse(AST.E_Orelse(e1, e2), s1 as AST.S_Block[], s2) =>
166 : jhr 4371 simplifyStmt (cxt, AST.S_IfThenElse(e1, s1, AST.S_IfThenElse(e2, s1, s2)), stms)
167 : jhr 4254 | AST.S_IfThenElse(AST.E_Andalso(e1, e2), s1, s2 as AST.S_Block[]) =>
168 : jhr 4371 simplifyStmt (cxt, AST.S_IfThenElse(e1, AST.S_IfThenElse(e2, s1, s2), s2), stms)
169 : jhr 3437 | AST.S_IfThenElse(e, s1, s2) => let
170 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
171 :     val s1 = simplifyBlock (cxt, s1)
172 :     val s2 = simplifyBlock (cxt, s2)
173 : jhr 3437 in
174 : jhr 3456 S.S_IfThenElse(x, s1, s2) :: stms
175 : jhr 3437 end
176 : jhr 4317 | AST.S_Foreach((x, e), body) => let
177 : jhr 4371 val (stms, xs') = simplifyExpToVar (cxt, e, stms)
178 :     val body' = simplifyBlock (cxt, body)
179 : jhr 4317 in
180 :     S.S_Foreach(cvtVar x, xs', body') :: stms
181 :     end
182 : jhr 3452 | AST.S_Assign((x, _), e) => let
183 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, stms)
184 : jhr 4193 val x' = cvtLHS (x, e')
185 : jhr 3437 in
186 : jhr 4193 S.S_Assign(x', e') :: stms
187 : jhr 3437 end
188 :     | AST.S_New(name, args) => let
189 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
190 : jhr 3437 in
191 : jhr 3456 S.S_New(name, xs) :: stms
192 : jhr 3437 end
193 : jhr 3456 | AST.S_Continue => S.S_Continue :: stms
194 :     | AST.S_Die => S.S_Die :: stms
195 :     | AST.S_Stabilize => S.S_Stabilize :: stms
196 : jhr 3437 | AST.S_Return e => let
197 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
198 : jhr 3437 in
199 : jhr 3456 S.S_Return x :: stms
200 : jhr 3437 end
201 :     | AST.S_Print args => let
202 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
203 : jhr 3437 in
204 : jhr 3456 S.S_Print xs :: stms
205 : jhr 3437 end
206 :     (* end case *))
207 :    
208 : jhr 4371 and simplifyExp (cxt, exp, stms) = let
209 : jhr 4317 fun doBorderCtl (f, args) = let
210 : jhr 4359 val (ctl, arg) = if Var.same(BV.image_border, f)
211 : jhr 4317 then (BorderCtl.Default(hd args), hd(tl args))
212 : jhr 4359 else if Var.same(BV.image_clamp, f)
213 : jhr 4317 then (BorderCtl.Clamp, hd args)
214 : jhr 4359 else if Var.same(BV.image_mirror, f)
215 : jhr 4317 then (BorderCtl.Mirror, hd args)
216 : jhr 4359 else if Var.same(BV.image_wrap, f)
217 : jhr 4317 then (BorderCtl.Wrap, hd args)
218 :     else raise Fail "impossible"
219 :     in
220 :     S.E_BorderCtl(ctl, arg)
221 :     end
222 :     fun doPrimApply (f, tyArgs, args, ty) = let
223 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
224 : jhr 4378 fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))
225 :     | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))
226 :     | cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))
227 :     | cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))
228 : jhr 4317 in
229 :     if Basis.isBorderCtl f
230 :     then (stms, doBorderCtl (f, xs))
231 : jhr 4368 else if Var.same(f, BV.fn_sphere_im)
232 : jhr 4371 then let
233 : jhr 4378 (* get the strand type for the query *)
234 :     val tyArgs as [S.TY(STy.T_Strand strand)] = List.map cvtTyArg tyArgs
235 :     (* get the strand environment for the strand *)
236 :     val SOME sEnv = GlobalEnv.findStrand(#gEnv cxt, strand)
237 :     fun result (query, pos) =
238 :     (stms, S.E_Prim(query, tyArgs, cvtVar pos::xs, cvtTy ty))
239 :     in
240 :     (* extract the position variable and spatial dimension *)
241 :     case (StrandEnv.findPosVar sEnv, StrandEnv.getSpaceDim sEnv)
242 :     of (SOME pos, SOME 1) => result (BV.fn_sphere1_r, pos)
243 :     | (SOME pos, SOME 2) => result (BV.fn_sphere2_t, pos)
244 :     | (SOME pos, SOME 3) => result (BV.fn_sphere3_t, pos)
245 :     | _ => raise Fail "impossible"
246 :     (* end case *)
247 :     end
248 : jhr 4317 else (case Var.kindOf f
249 :     of Var.BasisVar => let
250 :     val tyArgs = List.map cvtTyArg tyArgs
251 :     in
252 :     (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))
253 :     end
254 :     | _ => raise Fail "bogus prim application"
255 :     (* end case *))
256 :     end
257 :     in
258 :     case exp
259 :     of AST.E_Var(x, _) => (case Var.kindOf x
260 :     of Var.BasisVar => let
261 :     val ty = cvtTy(Var.monoTypeOf x)
262 :     val x' = newTemp ty
263 :     val stm = S.S_Var(x', SOME(S.E_Prim(x, [], [], ty)))
264 :     in
265 :     (stm::stms, S.E_Var x')
266 :     end
267 :     | _ => (stms, S.E_Var(cvtVar x))
268 :     (* end case *))
269 :     | AST.E_Lit lit => (stms, S.E_Lit lit)
270 :     | AST.E_Kernel h => (stms, S.E_Kernel h)
271 :     | AST.E_Select(e, (fld, _)) => let
272 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
273 : jhr 4317 in
274 :     (stms, S.E_Select(x, cvtVar fld))
275 :     end
276 :     | AST.E_Prim(rator, tyArgs, args as [e], ty) => (case e
277 : jhr 4359 of AST.E_Lit(Literal.Int n) => if Var.same(BV.neg_i, rator)
278 : jhr 4317 then (stms, S.E_Lit(Literal.Int(~n))) (* constant-fold negation of integer literals *)
279 :     else doPrimApply (rator, tyArgs, args, ty)
280 :     | AST.E_Lit(Literal.Real f) =>
281 : jhr 4359 if Var.same(BV.neg_t, rator)
282 : jhr 4317 then (stms, S.E_Lit(Literal.Real(RealLit.negate f))) (* constant-fold negation of real literals *)
283 :     else doPrimApply (rator, tyArgs, args, ty)
284 : jhr 4364 (* QUESTION: is there common code in handling a reduction over a sequence of strands vs. over a strand set? *)
285 : jhr 4317 | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator
286 :     then let
287 : jhr 4371 val (stms, xs) = simplifyExpToVar (cxt, e'', stms)
288 : jhr 4393 val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e', [])
289 : jhr 4317 val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
290 : jhr 4393 fun mkReductionLoop (redOp, bodyStms, bodyResult, stms) = let
291 :     val {rator, init, mvs} = Util.reductionInfo redOp
292 :     val acc = SimpleVar.new ("accum", Var.LocalVar, cvtTy ty)
293 :     val initStm = S.S_Var(acc, SOME(S.E_Lit init))
294 :     val updateStm = S.S_Assign(acc,
295 :     S.E_Prim(rator, mvs, [acc, bodyResult], seqTy'))
296 :     val foreachStm = S.S_Foreach(cvtVar x, xs,
297 :     mkBlock(updateStm :: bodyStms))
298 :     in
299 :     (foreachStm :: initStm :: stms, S.E_Var acc)
300 :     end
301 :     in
302 :     case Util.identifyReduction rator
303 :     of Util.MEAN => let
304 :     val (stms, S.E_Var resultV) = mkReductionLoop (
305 :     Reductions.SUM, bodyStms, bodyResult, stms)
306 :     val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
307 :     val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
308 :     val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
309 :     val stms =
310 : jhr 4394 mkRDiv (mean, resultV, rNum) ::
311 :     mkToReal (rNum, num) ::
312 :     mkLength (num, elemTy, xs) ::
313 : jhr 4393 stms
314 :     in
315 :     (stms, S.E_Var mean)
316 :     end
317 :     | Util.VARIANCE => raise Fail "FIXME: VARIANCE"
318 :     | Util.RED red => mkReductionLoop (red, bodyStms, bodyResult, stms)
319 :     (* end case *)
320 :     end
321 : jhr 4317 else doPrimApply (rator, tyArgs, args, ty)
322 :     | AST.E_ParallelMap(e', x, xs, _) =>
323 :     if Basis.isReductionOp rator
324 :     then let
325 : jhr 4394 val (result, stms) = simplifyReduction (cxt, rator, e', x, xs, ty, stms)
326 : jhr 4317 in
327 : jhr 4394 (stms, S.E_Var result)
328 : jhr 4317 end
329 :     else raise Fail "unsupported operation on parallel map"
330 :     | _ => doPrimApply (rator, tyArgs, args, ty)
331 :     (* end case *))
332 :     | AST.E_Prim(f, tyArgs, args, ty) => doPrimApply (f, tyArgs, args, ty)
333 :     | AST.E_Apply((f, _), args, ty) => let
334 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, args, stms)
335 : jhr 4317 in
336 :     case Var.kindOf f
337 :     of Var.FunVar => (stms, S.E_Apply(SimpleFunc.use(cvtFunc f), xs))
338 :     | _ => raise Fail "bogus application"
339 :     (* end case *)
340 :     end
341 :     | AST.E_Comprehension(e, (x, e'), seqTy) => let
342 :     (* convert a comprehension to a foreach loop over the sequence defined by e' *)
343 : jhr 4371 val (stms, xs) = simplifyExpToVar (cxt, e', stms)
344 :     val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
345 : jhr 4317 val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
346 :     val acc = SimpleVar.new ("accum", Var.LocalVar, seqTy')
347 :     val initStm = S.S_Var(acc, SOME(S.E_Seq([], seqTy')))
348 :     val updateStm = S.S_Assign(acc,
349 : jhr 4359 S.E_Prim(BV.at_dT, [S.TY elemTy], [acc, bodyResult], seqTy'))
350 : jhr 4317 val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))
351 :     in
352 :     (foreachStm :: initStm :: stms, S.E_Var acc)
353 :     end
354 :     | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME: ParallelMap"
355 :     | AST.E_Tensor(es, ty) => let
356 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, es, stms)
357 : jhr 4317 in
358 :     (stms, S.E_Tensor(xs, cvtTy ty))
359 :     end
360 :     | AST.E_Seq(es, ty) => let
361 : jhr 4371 val (stms, xs) = simplifyExpsToVars (cxt, es, stms)
362 : jhr 4317 in
363 :     (stms, S.E_Seq(xs, cvtTy ty))
364 :     end
365 :     | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
366 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
367 : jhr 4317 fun f NONE = NONE
368 :     | f (SOME(AST.E_Lit(Literal.Int i))) = SOME(Int.fromLarge i)
369 :     | f _ = raise Fail "expected integer literal in slice"
370 :     val indices = List.map f indices
371 :     in
372 :     (stms, S.E_Slice(x, indices, cvtTy ty))
373 :     end
374 :     | AST.E_Cond(e1, e2, e3, ty) => let
375 :     (* a conditional expression gets turned into an if-then-else statememt *)
376 :     val result = newTemp(cvtTy ty)
377 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e1, S.S_Var(result, NONE) :: stms)
378 : jhr 4317 fun simplifyBranch e = let
379 : jhr 4371 val (stms, e) = simplifyExp (cxt, e, [])
380 : jhr 4317 in
381 :     mkBlock (S.S_Assign(result, e)::stms)
382 :     end
383 :     val s1 = simplifyBranch e2
384 :     val s2 = simplifyBranch e3
385 :     in
386 :     (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
387 :     end
388 :     | AST.E_Orelse(e1, e2) => simplifyExp (
389 : jhr 4371 cxt,
390 : jhr 4317 AST.E_Cond(e1, AST.E_Lit(Literal.Bool true), e2, Ty.T_Bool),
391 :     stms)
392 :     | AST.E_Andalso(e1, e2) => simplifyExp (
393 : jhr 4371 cxt,
394 : jhr 4317 AST.E_Cond(e1, e2, AST.E_Lit(Literal.Bool false), Ty.T_Bool),
395 :     stms)
396 :     | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty
397 :     of ty as STy.T_Sequence(_, NONE) => (stms, S.E_LoadSeq(ty, nrrd))
398 :     | ty as STy.T_Image info => let
399 :     val dim = II.dim info
400 :     val shape = II.voxelShape info
401 :     in
402 : jhr 4371 case NrrdInfo.getInfo (#errStrm cxt, nrrd)
403 : jhr 4317 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
404 :     of NONE => (
405 : jhr 4428 (* FIXME: produce more informative error message (use imageTyToString) *)
406 : jhr 4371 error (cxt, [
407 : jhr 4317 "nrrd file \"", nrrd, "\" does not have expected type"
408 :     ]);
409 :     (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
410 :     | SOME imgInfo =>
411 :     (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))
412 :     (* end case *))
413 :     | NONE => (
414 : jhr 4371 warning (cxt, [
415 : jhr 4317 "nrrd file \"", nrrd, "\" does not exist"
416 :     ]);
417 :     (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
418 :     (* end case *)
419 :     end
420 :     | _ => raise Fail "bogus type for E_LoadNrrd"
421 :     (* end case *))
422 :     | AST.E_Coerce{dstTy, e=AST.E_Lit(Literal.Int n), ...} => (case cvtTy dstTy
423 :     of SimpleTypes.T_Tensor[] => (stms, S.E_Lit(Literal.Real(RealLit.fromInt n)))
424 :     | _ => raise Fail "impossible: bad coercion"
425 :     (* end case *))
426 :     | AST.E_Coerce{srcTy, dstTy, e} => let
427 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
428 : jhr 4317 val dstTy = cvtTy dstTy
429 :     val result = newTemp dstTy
430 :     val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}
431 :     in
432 :     (S.S_Var(result, SOME rhs)::stms, S.E_Var result)
433 :     end
434 :     (* end case *)
435 :     end
436 : jhr 3437
437 : jhr 4371 and simplifyExpToVar (cxt, exp, stms) = let
438 :     val (stms, e) = simplifyExp (cxt, exp, stms)
439 : jhr 3437 in
440 :     case e
441 :     of S.E_Var x => (stms, x)
442 :     | _ => let
443 :     val x = newTemp (S.typeOf e)
444 :     in
445 : jhr 3465 (S.S_Var(x, SOME e)::stms, x)
446 : jhr 3437 end
447 :     (* end case *)
448 :     end
449 :    
450 : jhr 4371 and simplifyExpsToVars (cxt, exps, stms) = let
451 : jhr 3437 fun f ([], xs, stms) = (stms, List.rev xs)
452 :     | f (e::es, xs, stms) = let
453 : jhr 4371 val (stms, x) = simplifyExpToVar (cxt, e, stms)
454 : jhr 3437 in
455 :     f (es, x::xs, stms)
456 :     end
457 :     in
458 :     f (exps, [], stms)
459 :     end
460 :    
461 : jhr 4359 (* simplify a parallel map-reduce *)
462 : jhr 4394 and simplifyReduction (cxt, rator, e, x, xs, resTy, stms) = let
463 : jhr 4378 val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)
464 : jhr 4368 val x' = cvtVar x
465 : jhr 4371 val (bodyStms, bodyResult) = simplifyExpToVar (cxt, e, [])
466 : jhr 4378 (* convert the domain from a variable to a StrandSets.t value *)
467 :     val domain = if Var.same(BV.set_active, xs) then StrandSets.ACTIVE
468 :     else if Var.same(BV.set_all, xs) then StrandSets.ALL
469 :     else if Var.same(BV.set_stable, xs) then StrandSets.STABLE
470 :     else raise Fail "impossible: not a strand set"
471 : jhr 4368 val (func, args) = Util.makeFunction(
472 :     Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),
473 :     SimpleVar.typeOf bodyResult)
474 : jhr 4359
475 : jhr 4394 in
476 :     case Util.identifyReduction rator
477 :     of Util.MEAN => let
478 :     val mapReduceStm = S.S_MapReduce[
479 :     S.MapReduce{
480 :     result = result, reduction = Reductions.SUM, mapf = func,
481 :     args = args, source = x', domain = domain
482 :     }]
483 :     val num = SimpleVar.new ("num", Var.LocalVar, STy.T_Int)
484 :     val rNum = SimpleVar.new ("rNum", Var.LocalVar, STy.realTy)
485 :     val mean = SimpleVar.new ("mean", Var.LocalVar, STy.realTy)
486 :     val numStrandsOp = (case domain
487 :     of StrandSets.ACTIVE => BV.numActive
488 :     | StrandSets.ALL => BV.numStrands
489 :     | StrandSets.STABLE => BV.numStable
490 :     (* end case *))
491 :     val stms =
492 :     mkRDiv (mean, result, rNum) ::
493 :     mkToReal (rNum, num) ::
494 :     mkDef (num, S.E_Prim(numStrandsOp, [], [], STy.T_Int)) ::
495 :     mapReduceStm ::
496 :     stms
497 :     in
498 :     (mean, stms)
499 :     end
500 :     | Util.VARIANCE => raise Fail "FIXME: variance reduction"
501 :     | Util.RED rator' => let
502 :     val mapReduceStm = S.S_MapReduce[
503 :     S.MapReduce{
504 :     result = result, reduction = rator', mapf = func, args = args,
505 :     source = x', domain = domain
506 :     }]
507 :     in
508 :     (result, mapReduceStm :: stms)
509 :     end
510 :     (* end case *)
511 :     end
512 :    
513 : jhr 4113 (* simplify a block and then prune unreachable and dead code *)
514 : jhr 4371 fun simplifyAndPruneBlock cxt blk =
515 :     DeadCode.eliminate (simplifyBlock (cxt, blk))
516 : jhr 4113
517 : jhr 4371 fun simplifyStrand (cxt, strand) = let
518 : jhr 4368 val AST.Strand{
519 :     name, params, spatialDim, state, stateInit, initM, updateM, stabilizeM
520 :     } = strand
521 : jhr 3456 val params' = cvtVars params
522 :     fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)
523 :     | simplifyState ((x, optE) :: r, xs, stms) = let
524 :     val x' = cvtVar x
525 : jhr 4317 in
526 :     case optE
527 :     of NONE => simplifyState (r, x'::xs, stms)
528 :     | SOME e => let
529 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, stms)
530 : jhr 4317 in
531 :     simplifyState (r, x'::xs, S.S_Assign(x', e') :: stms)
532 :     end
533 :     (* end case *)
534 :     end
535 : jhr 3456 val (xs, stm) = simplifyState (state, [], [])
536 : jhr 3437 in
537 : jhr 3452 S.Strand{
538 :     name = name,
539 :     params = params',
540 : jhr 4368 spatialDim = spatialDim,
541 : jhr 3456 state = xs,
542 : jhr 4317 stateInit = stm,
543 : jhr 4371 initM = Option.map (simplifyAndPruneBlock cxt) initM,
544 :     updateM = simplifyAndPruneBlock cxt updateM,
545 :     stabilizeM = Option.map (simplifyAndPruneBlock cxt) stabilizeM
546 : jhr 3452 }
547 : jhr 3437 end
548 :    
549 : jhr 4371 fun transform (errStrm, prog, gEnv) = let
550 : jhr 4317 val AST.Program{
551 :     props, const_dcls, input_dcls, globals, globInit, strand, create, init, update
552 :     } = prog
553 : jhr 4378 val cxt = {errStrm = errStrm, gEnv = gEnv}
554 : jhr 4317 val consts' = ref[]
555 :     val constInit = ref[]
556 :     val inputs' = ref[]
557 :     val globals' = ref[]
558 :     val globalInit = ref[]
559 :     val funcs = ref[]
560 :     fun simplifyConstDcl (x, SOME e) = let
561 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, [])
562 : jhr 4317 val x' = cvtVar x
563 :     in
564 :     consts' := x' :: !consts';
565 :     constInit := S.S_Assign(x', e') :: (stms @ !constInit)
566 :     end
567 :     fun simplifyInputDcl ((x, NONE), desc) = let
568 :     val x' = cvtVar x
569 :     val init = (case SimpleVar.typeOf x'
570 :     of STy.T_Image info => S.Image info
571 :     | _ => S.NoDefault
572 :     (* end case *))
573 :     val inp = S.INP{
574 :     var = x',
575 :     name = Var.nameOf x,
576 :     ty = apiTypeOf x',
577 :     desc = desc,
578 :     init = init
579 :     }
580 :     in
581 :     inputs' := inp :: !inputs'
582 :     end
583 :     | simplifyInputDcl ((x, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))), desc) = let
584 :     val (x', init) = (case Var.monoTypeOf x
585 :     of Ty.T_Sequence(_, NONE) => (cvtVar x, S.LoadSeq nrrd)
586 :     | Ty.T_Image{dim, shape} => let
587 :     val dim = TU.monoDim dim
588 :     val shape = TU.monoShape shape
589 :     in
590 : jhr 4371 case NrrdInfo.getInfo (#errStrm cxt, nrrd)
591 : jhr 4317 of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
592 :     of NONE => (
593 : jhr 4428 (* FIXME: produce more informative error message *)
594 : jhr 4371 error (cxt, [
595 : jhr 4426 "proxy input file \"", nrrd,
596 : jhr 4317 "\" does not have expected type"
597 :     ]);
598 :     (cvtVar x, S.Image(II.mkInfo(dim, shape))))
599 :     | SOME info =>
600 :     (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))
601 :     (* end case *))
602 :     | NONE => (
603 : jhr 4371 warning (cxt, [
604 : jhr 4426 "proxy input file \"", nrrd, "\" does not exist"
605 : jhr 4317 ]);
606 :     (cvtVar x, S.Image(II.mkInfo(dim, shape))))
607 :     (* end case *)
608 :     end
609 :     | _ => raise Fail "impossible"
610 :     (* end case *))
611 :     val inp = S.INP{
612 :     var = x',
613 :     name = Var.nameOf x,
614 :     ty = apiTypeOf x',
615 :     desc = desc,
616 :     init = init
617 :     }
618 :     in
619 :     inputs' := inp :: !inputs'
620 :     end
621 :     | simplifyInputDcl ((x, SOME e), desc) = let
622 :     val x' = cvtVar x
623 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, [])
624 : jhr 4317 val inp = S.INP{
625 :     var = x',
626 :     name = Var.nameOf x,
627 :     ty = apiTypeOf x',
628 :     desc = desc,
629 :     init = S.ConstExpr
630 :     }
631 :     in
632 :     inputs' := inp :: !inputs';
633 :     constInit := S.S_Assign(x', e') :: (stms @ !constInit)
634 :     end
635 :     fun simplifyGlobalDcl (AST.D_Var(x, NONE)) = globals' := cvtVar x :: !globals'
636 :     | simplifyGlobalDcl (AST.D_Var(x, SOME e)) = let
637 : jhr 4371 val (stms, e') = simplifyExp (cxt, e, [])
638 : jhr 4317 val x' = cvtLHS (x, e')
639 :     in
640 :     globals' := x' :: !globals';
641 :     globalInit := S.S_Assign(x', e') :: (stms @ !globalInit)
642 :     end
643 :     | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let
644 :     val f' = cvtFunc f
645 :     val params' = cvtVars params
646 : jhr 4371 val body' = simplifyAndPruneBlock cxt body
647 : jhr 4317 in
648 :     funcs := S.Func{f=f', params=params', body=body'} :: !funcs
649 :     end
650 :     val () = (
651 :     List.app simplifyConstDcl const_dcls;
652 :     List.app simplifyInputDcl input_dcls;
653 :     List.app simplifyGlobalDcl globals)
654 :     (* make the global-initialization block *)
655 :     val globInit = (case globInit
656 : jhr 4371 of SOME stm => mkBlock (simplifyStmt (cxt, stm, !globalInit))
657 : jhr 4317 | NONE => mkBlock (!globalInit)
658 :     (* end case *))
659 :     (* if the globInit block is non-empty, record the fact in the property list *)
660 :     val props = (case globInit
661 :     of S.Block{code=[], ...} => props
662 :     | _ => Properties.GlobalInit :: props
663 :     (* end case *))
664 : jhr 3452 in
665 :     S.Program{
666 :     props = props,
667 : jhr 4317 consts = List.rev(!consts'),
668 : jhr 3452 inputs = List.rev(!inputs'),
669 : jhr 4317 constInit = mkBlock (!constInit),
670 : jhr 3452 globals = List.rev(!globals'),
671 : jhr 3995 globInit = globInit,
672 : jhr 3452 funcs = List.rev(!funcs),
673 : jhr 4371 strand = simplifyStrand (cxt, strand),
674 :     create = Create.map (simplifyAndPruneBlock cxt) create,
675 :     init = Option.map (simplifyAndPruneBlock cxt) init,
676 :     update = Option.map (simplifyAndPruneBlock cxt) update
677 : jhr 3452 }
678 :     end
679 :    
680 : jhr 3437 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0