Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/simplify/simplify.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4368 - (view) (download)

1 : jhr 3437 (* simplify.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *
8 : jhr 3468 * Simplify the AST representation. This phase involves the following transformations:
9 :     *
10 : jhr 4317 * - types are simplified by removing meta variables (which will have been resolved)
11 : jhr 3468 *
12 : jhr 4317 * - expressions are simplified to involve a single operation on variables
13 : jhr 3468 *
14 : jhr 4317 * - global reductions are converted to MapReduce statements
15 : jhr 3468 *
16 : jhr 4317 * - other comprehensions and reductions are converted to foreach loops
17 : jhr 3468 *
18 : jhr 4317 * - unreachable code is pruned
19 : jhr 3468 *
20 : jhr 4317 * - negation of literal integers and reals are constant folded
21 : jhr 3437 *)
22 :    
23 :     structure Simplify : sig
24 :    
25 :     val transform : Error.err_stream * AST.program -> Simple.program
26 :    
27 :     end = struct
28 :    
29 :     structure TU = TypeUtil
30 :     structure S = Simple
31 : jhr 3445 structure STy = SimpleTypes
32 :     structure Ty = Types
33 : jhr 3437 structure VMap = Var.Map
34 : jhr 4193 structure II = ImageInfo
35 : jhr 4359 structure BV = BasisVars
36 : jhr 3437
37 : jhr 3445 (* convert a Types.ty to a SimpleTypes.ty *)
38 :     fun cvtTy ty = (case ty
39 :     of Ty.T_Var(Ty.TV{bind, ...}) => (case !bind
40 :     of NONE => raise Fail "unresolved type variable"
41 :     | SOME ty => cvtTy ty
42 :     (* end case *))
43 :     | Ty.T_Bool => STy.T_Bool
44 :     | Ty.T_Int => STy.T_Int
45 :     | Ty.T_String => STy.T_String
46 :     | Ty.T_Sequence(ty, NONE) => STy.T_Sequence(cvtTy ty, NONE)
47 : jhr 3452 | Ty.T_Sequence(ty, SOME dim) => STy.T_Sequence(cvtTy ty, SOME(TU.monoDim dim))
48 : jhr 4317 | Ty.T_Strand id => STy.T_Strand id
49 : jhr 4207 | Ty.T_Kernel _ => STy.T_Kernel
50 : jhr 3445 | Ty.T_Tensor shape => STy.T_Tensor(TU.monoShape shape)
51 : jhr 4193 | Ty.T_Image{dim, shape} =>
52 : jhr 4317 STy.T_Image(II.mkInfo(TU.monoDim dim, TU.monoShape shape))
53 : jhr 3445 | Ty.T_Field{diff, dim, shape} => STy.T_Field{
54 :     diff = TU.monoDiff diff,
55 :     dim = TU.monoDim dim,
56 :     shape = TU.monoShape shape
57 :     }
58 : jhr 4163 | Ty.T_Fun(tys1, ty2) => raise Fail "unexpected T_Fun in Simplify"
59 : jhr 4317 | Ty.T_Error => raise Fail "unexpected T_Error in Simplify"
60 : jhr 3445 (* end case *))
61 : jhr 3437
62 : jhr 3811 fun apiTypeOf x = let
63 : jhr 4317 fun cvtTy STy.T_Bool = APITypes.BoolTy
64 :     | cvtTy STy.T_Int = APITypes.IntTy
65 :     | cvtTy STy.T_String = APITypes.StringTy
66 :     | cvtTy (STy.T_Sequence(ty, len)) = APITypes.SeqTy(cvtTy ty, len)
67 :     | cvtTy (STy.T_Tensor shape) = APITypes.TensorTy shape
68 :     | cvtTy (STy.T_Image info) =
69 :     APITypes.ImageTy(II.dim info, II.voxelShape info)
70 :     | cvtTy ty = raise Fail "bogus API type"
71 :     in
72 :     cvtTy (SimpleVar.typeOf x)
73 :     end
74 : jhr 3811
75 : jhr 4117 fun newTemp (ty as STy.T_Image _) = SimpleVar.new ("img", SimpleVar.LocalVar, ty)
76 :     | newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)
77 : jhr 3437
78 : jhr 4163 (* a property to map AST function variables to SimpleAST functions *)
79 :     local
80 :     fun cvt x = let
81 : jhr 4317 val Ty.T_Fun(paramTys, resTy) = Var.monoTypeOf x
82 :     in
83 :     SimpleFunc.new (Var.nameOf x, cvtTy resTy, List.map cvtTy paramTys)
84 :     end
85 : jhr 4163 in
86 :     val {getFn = cvtFunc, ...} = Var.newProp cvt
87 :     end
88 :    
89 : jhr 3456 (* a property to map AST variables to SimpleAST variables *)
90 :     local
91 :     fun cvt x = SimpleVar.new (Var.nameOf x, Var.kindOf x, cvtTy(Var.monoTypeOf x))
92 : jhr 4193 val {getFn, setFn, ...} = Var.newProp cvt
93 : jhr 3456 in
94 : jhr 4193 val cvtVar = getFn
95 :     fun newVarWithType (x, ty) = let
96 : jhr 4317 val x' = SimpleVar.new (Var.nameOf x, Var.kindOf x, ty)
97 :     in
98 :     setFn (x, x');
99 :     x'
100 :     end
101 : jhr 3456 end
102 : jhr 3452
103 : jhr 3456 fun cvtVars xs = List.map cvtVar xs
104 : jhr 3452
105 : jhr 3437 (* make a block out of a list of statements that are in reverse order *)
106 : jhr 3501 fun mkBlock stms = S.Block{props = PropList.newHolder(), code = List.rev stms}
107 : jhr 3437
108 :     (* simplify a statement into a single statement (i.e., a block if it expands
109 :     * into more than one new statement).
110 :     *)
111 : jhr 4113 fun simplifyBlock (errStrm, stm) = mkBlock (simplifyStmt (errStrm, stm, []))
112 : jhr 3437
113 : jhr 4193 (* convert the lhs variable of a var decl or assignment; if the rhs is a LoadImage,
114 :     * then we use the info from the proxy image to determine the type of the lhs
115 :     * variable.
116 :     *)
117 :     and cvtLHS (lhs, S.E_LoadImage(_, _, info)) = newVarWithType(lhs, STy.T_Image info)
118 :     | cvtLHS (lhs, _) = cvtVar lhs
119 :    
120 : jhr 3437 (* simplify the statement stm where stms is a reverse-order list of preceeding simplified
121 :     * statements. This function returns a reverse-order list of simplified statements.
122 :     * Note that error reporting is done in the typechecker, but it does not prune unreachable
123 :     * code.
124 :     *)
125 : jhr 3456 and simplifyStmt (errStrm, stm, stms) : S.stmt list = (case stm
126 : jhr 3437 of AST.S_Block body => let
127 : jhr 3456 fun simplify ([], stms) = stms
128 :     | simplify (stm::r, stms) = simplify (r, simplifyStmt (errStrm, stm, stms))
129 : jhr 3437 in
130 : jhr 3456 simplify (body, stms)
131 : jhr 3437 end
132 : jhr 3452 | AST.S_Decl(x, NONE) => let
133 : jhr 3456 val x' = cvtVar x
134 : jhr 3452 in
135 : jhr 3465 S.S_Var(x', NONE) :: stms
136 : jhr 3452 end
137 :     | AST.S_Decl(x, SOME e) => let
138 : jhr 3456 val (stms, e') = simplifyExp (errStrm, e, stms)
139 : jhr 4193 val x' = cvtLHS (x, e')
140 : jhr 3437 in
141 : jhr 3465 S.S_Var(x', SOME e') :: stms
142 : jhr 3437 end
143 : jhr 4310 (* FIXME: we should also define a "boolean negate" operation on AST expressions so that we can
144 : jhr 4254 * handle both cases!
145 :     *)
146 :     | AST.S_IfThenElse(AST.E_Orelse(e1, e2), s1 as AST.S_Block[], s2) =>
147 : jhr 4317 simplifyStmt (errStrm, AST.S_IfThenElse(e1, s1, AST.S_IfThenElse(e2, s1, s2)), stms)
148 : jhr 4254 | AST.S_IfThenElse(AST.E_Andalso(e1, e2), s1, s2 as AST.S_Block[]) =>
149 : jhr 4317 simplifyStmt (errStrm, AST.S_IfThenElse(e1, AST.S_IfThenElse(e2, s1, s2), s2), stms)
150 : jhr 3437 | AST.S_IfThenElse(e, s1, s2) => let
151 : jhr 3456 val (stms, x) = simplifyExpToVar (errStrm, e, stms)
152 : jhr 4113 val s1 = simplifyBlock (errStrm, s1)
153 :     val s2 = simplifyBlock (errStrm, s2)
154 : jhr 3437 in
155 : jhr 3456 S.S_IfThenElse(x, s1, s2) :: stms
156 : jhr 3437 end
157 : jhr 4317 | AST.S_Foreach((x, e), body) => let
158 :     val (stms, xs') = simplifyExpToVar (errStrm, e, stms)
159 :     val body' = simplifyBlock (errStrm, body)
160 :     in
161 :     S.S_Foreach(cvtVar x, xs', body') :: stms
162 :     end
163 : jhr 3452 | AST.S_Assign((x, _), e) => let
164 : jhr 3456 val (stms, e') = simplifyExp (errStrm, e, stms)
165 : jhr 4193 val x' = cvtLHS (x, e')
166 : jhr 3437 in
167 : jhr 4193 S.S_Assign(x', e') :: stms
168 : jhr 3437 end
169 :     | AST.S_New(name, args) => let
170 : jhr 3456 val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)
171 : jhr 3437 in
172 : jhr 3456 S.S_New(name, xs) :: stms
173 : jhr 3437 end
174 : jhr 3456 | AST.S_Continue => S.S_Continue :: stms
175 :     | AST.S_Die => S.S_Die :: stms
176 :     | AST.S_Stabilize => S.S_Stabilize :: stms
177 : jhr 3437 | AST.S_Return e => let
178 : jhr 3456 val (stms, x) = simplifyExpToVar (errStrm, e, stms)
179 : jhr 3437 in
180 : jhr 3456 S.S_Return x :: stms
181 : jhr 3437 end
182 :     | AST.S_Print args => let
183 : jhr 3456 val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)
184 : jhr 3437 in
185 : jhr 3456 S.S_Print xs :: stms
186 : jhr 3437 end
187 :     (* end case *))
188 :    
189 : jhr 3456 and simplifyExp (errStrm, exp, stms) = let
190 : jhr 4317 fun doBorderCtl (f, args) = let
191 : jhr 4359 val (ctl, arg) = if Var.same(BV.image_border, f)
192 : jhr 4317 then (BorderCtl.Default(hd args), hd(tl args))
193 : jhr 4359 else if Var.same(BV.image_clamp, f)
194 : jhr 4317 then (BorderCtl.Clamp, hd args)
195 : jhr 4359 else if Var.same(BV.image_mirror, f)
196 : jhr 4317 then (BorderCtl.Mirror, hd args)
197 : jhr 4359 else if Var.same(BV.image_wrap, f)
198 : jhr 4317 then (BorderCtl.Wrap, hd args)
199 :     else raise Fail "impossible"
200 :     in
201 :     S.E_BorderCtl(ctl, arg)
202 :     end
203 :     fun doPrimApply (f, tyArgs, args, ty) = let
204 :     val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)
205 :     in
206 :     if Basis.isBorderCtl f
207 :     then (stms, doBorderCtl (f, xs))
208 : jhr 4368 else if Var.same(f, BV.fn_sphere_im)
209 :     then raise Fail "FIXME: implicit sphere query"
210 : jhr 4317 else (case Var.kindOf f
211 :     of Var.BasisVar => let
212 :     fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))
213 :     | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))
214 :     | cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))
215 :     | cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))
216 :     val tyArgs = List.map cvtTyArg tyArgs
217 :     in
218 :     (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))
219 :     end
220 :     | _ => raise Fail "bogus prim application"
221 :     (* end case *))
222 :     end
223 :     in
224 :     case exp
225 :     of AST.E_Var(x, _) => (case Var.kindOf x
226 :     of Var.BasisVar => let
227 :     val ty = cvtTy(Var.monoTypeOf x)
228 :     val x' = newTemp ty
229 :     val stm = S.S_Var(x', SOME(S.E_Prim(x, [], [], ty)))
230 :     in
231 :     (stm::stms, S.E_Var x')
232 :     end
233 :     | _ => (stms, S.E_Var(cvtVar x))
234 :     (* end case *))
235 :     | AST.E_Lit lit => (stms, S.E_Lit lit)
236 :     | AST.E_Kernel h => (stms, S.E_Kernel h)
237 :     | AST.E_Select(e, (fld, _)) => let
238 :     val (stms, x) = simplifyExpToVar (errStrm, e, stms)
239 :     in
240 :     (stms, S.E_Select(x, cvtVar fld))
241 :     end
242 :     | AST.E_Prim(rator, tyArgs, args as [e], ty) => (case e
243 : jhr 4359 of AST.E_Lit(Literal.Int n) => if Var.same(BV.neg_i, rator)
244 : jhr 4317 then (stms, S.E_Lit(Literal.Int(~n))) (* constant-fold negation of integer literals *)
245 :     else doPrimApply (rator, tyArgs, args, ty)
246 :     | AST.E_Lit(Literal.Real f) =>
247 : jhr 4359 if Var.same(BV.neg_t, rator)
248 : jhr 4317 then (stms, S.E_Lit(Literal.Real(RealLit.negate f))) (* constant-fold negation of real literals *)
249 :     else doPrimApply (rator, tyArgs, args, ty)
250 : jhr 4364 (* QUESTION: is there common code in handling a reduction over a sequence of strands vs. over a strand set? *)
251 : jhr 4317 | AST.E_Comprehension(e', (x, e''), seqTy) => if Basis.isReductionOp rator
252 :     then let
253 :     val {rator, init, mvs} = Util.reductionInfo rator
254 :     val (stms, xs) = simplifyExpToVar (errStrm, e'', stms)
255 :     val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e, [])
256 :     val acc = SimpleVar.new ("accum", Var.LocalVar, cvtTy ty)
257 :     val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
258 :     val initStm = S.S_Var(acc, SOME(S.E_Lit init))
259 :     val updateStm = S.S_Assign(acc,
260 :     S.E_Prim(rator, mvs, [acc, bodyResult], seqTy'))
261 :     val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))
262 :     in
263 :     (foreachStm :: initStm :: stms, S.E_Var acc)
264 :     end
265 :     else doPrimApply (rator, tyArgs, args, ty)
266 :     | AST.E_ParallelMap(e', x, xs, _) =>
267 :     if Basis.isReductionOp rator
268 :     then let
269 : jhr 4368 val (result, stm) = simplifyReduction (errStrm, rator, e', x, xs, ty)
270 : jhr 4317 in
271 : jhr 4359 (stm :: stms, S.E_Var result)
272 : jhr 4317 end
273 :     else raise Fail "unsupported operation on parallel map"
274 :     | _ => doPrimApply (rator, tyArgs, args, ty)
275 :     (* end case *))
276 :     | AST.E_Prim(f, tyArgs, args, ty) => doPrimApply (f, tyArgs, args, ty)
277 :     | AST.E_Apply((f, _), args, ty) => let
278 :     val (stms, xs) = simplifyExpsToVars (errStrm, args, stms)
279 :     in
280 :     case Var.kindOf f
281 :     of Var.FunVar => (stms, S.E_Apply(SimpleFunc.use(cvtFunc f), xs))
282 :     | _ => raise Fail "bogus application"
283 :     (* end case *)
284 :     end
285 :     | AST.E_Comprehension(e, (x, e'), seqTy) => let
286 :     (* convert a comprehension to a foreach loop over the sequence defined by e' *)
287 :     val (stms, xs) = simplifyExpToVar (errStrm, e', stms)
288 :     val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e, [])
289 :     val seqTy' as STy.T_Sequence(elemTy, NONE) = cvtTy seqTy
290 :     val acc = SimpleVar.new ("accum", Var.LocalVar, seqTy')
291 :     val initStm = S.S_Var(acc, SOME(S.E_Seq([], seqTy')))
292 :     val updateStm = S.S_Assign(acc,
293 : jhr 4359 S.E_Prim(BV.at_dT, [S.TY elemTy], [acc, bodyResult], seqTy'))
294 : jhr 4317 val foreachStm = S.S_Foreach(cvtVar x, xs, mkBlock(updateStm :: bodyStms))
295 :     in
296 :     (foreachStm :: initStm :: stms, S.E_Var acc)
297 :     end
298 :     | AST.E_ParallelMap(e, x, xs, ty) => raise Fail "FIXME: ParallelMap"
299 :     | AST.E_Tensor(es, ty) => let
300 :     val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)
301 :     in
302 :     (stms, S.E_Tensor(xs, cvtTy ty))
303 :     end
304 :     | AST.E_Seq(es, ty) => let
305 :     val (stms, xs) = simplifyExpsToVars (errStrm, es, stms)
306 :     in
307 :     (stms, S.E_Seq(xs, cvtTy ty))
308 :     end
309 :     | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
310 :     val (stms, x) = simplifyExpToVar (errStrm, e, stms)
311 :     fun f NONE = NONE
312 :     | f (SOME(AST.E_Lit(Literal.Int i))) = SOME(Int.fromLarge i)
313 :     | f _ = raise Fail "expected integer literal in slice"
314 :     val indices = List.map f indices
315 :     in
316 :     (stms, S.E_Slice(x, indices, cvtTy ty))
317 :     end
318 :     | AST.E_Cond(e1, e2, e3, ty) => let
319 :     (* a conditional expression gets turned into an if-then-else statememt *)
320 :     val result = newTemp(cvtTy ty)
321 :     val (stms, x) = simplifyExpToVar (errStrm, e1, S.S_Var(result, NONE) :: stms)
322 :     fun simplifyBranch e = let
323 :     val (stms, e) = simplifyExp (errStrm, e, [])
324 :     in
325 :     mkBlock (S.S_Assign(result, e)::stms)
326 :     end
327 :     val s1 = simplifyBranch e2
328 :     val s2 = simplifyBranch e3
329 :     in
330 :     (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
331 :     end
332 :     | AST.E_Orelse(e1, e2) => simplifyExp (
333 :     errStrm,
334 :     AST.E_Cond(e1, AST.E_Lit(Literal.Bool true), e2, Ty.T_Bool),
335 :     stms)
336 :     | AST.E_Andalso(e1, e2) => simplifyExp (
337 :     errStrm,
338 :     AST.E_Cond(e1, e2, AST.E_Lit(Literal.Bool false), Ty.T_Bool),
339 :     stms)
340 :     | AST.E_LoadNrrd(_, nrrd, ty) => (case cvtTy ty
341 :     of ty as STy.T_Sequence(_, NONE) => (stms, S.E_LoadSeq(ty, nrrd))
342 :     | ty as STy.T_Image info => let
343 :     val dim = II.dim info
344 :     val shape = II.voxelShape info
345 :     in
346 :     case NrrdInfo.getInfo (errStrm, nrrd)
347 :     of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
348 :     of NONE => (
349 :     Error.error (errStrm, [
350 :     "nrrd file \"", nrrd, "\" does not have expected type"
351 :     ]);
352 :     (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
353 :     | SOME imgInfo =>
354 :     (stms, S.E_LoadImage(STy.T_Image imgInfo, nrrd, imgInfo))
355 :     (* end case *))
356 :     | NONE => (
357 :     Error.warning (errStrm, [
358 :     "nrrd file \"", nrrd, "\" does not exist"
359 :     ]);
360 :     (stms, S.E_LoadImage(ty, nrrd, II.mkInfo(dim, shape))))
361 :     (* end case *)
362 :     end
363 :     | _ => raise Fail "bogus type for E_LoadNrrd"
364 :     (* end case *))
365 :     | AST.E_Coerce{dstTy, e=AST.E_Lit(Literal.Int n), ...} => (case cvtTy dstTy
366 :     of SimpleTypes.T_Tensor[] => (stms, S.E_Lit(Literal.Real(RealLit.fromInt n)))
367 :     | _ => raise Fail "impossible: bad coercion"
368 :     (* end case *))
369 :     | AST.E_Coerce{srcTy, dstTy, e} => let
370 :     val (stms, x) = simplifyExpToVar (errStrm, e, stms)
371 :     val dstTy = cvtTy dstTy
372 :     val result = newTemp dstTy
373 :     val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}
374 :     in
375 :     (S.S_Var(result, SOME rhs)::stms, S.E_Var result)
376 :     end
377 :     (* end case *)
378 :     end
379 : jhr 3437
380 : jhr 3456 and simplifyExpToVar (errStrm, exp, stms) = let
381 :     val (stms, e) = simplifyExp (errStrm, exp, stms)
382 : jhr 3437 in
383 :     case e
384 :     of S.E_Var x => (stms, x)
385 :     | _ => let
386 :     val x = newTemp (S.typeOf e)
387 :     in
388 : jhr 3465 (S.S_Var(x, SOME e)::stms, x)
389 : jhr 3437 end
390 :     (* end case *)
391 :     end
392 :    
393 : jhr 3456 and simplifyExpsToVars (errStrm, exps, stms) = let
394 : jhr 3437 fun f ([], xs, stms) = (stms, List.rev xs)
395 :     | f (e::es, xs, stms) = let
396 : jhr 3456 val (stms, x) = simplifyExpToVar (errStrm, e, stms)
397 : jhr 3437 in
398 :     f (es, x::xs, stms)
399 :     end
400 :     in
401 :     f (exps, [], stms)
402 :     end
403 :    
404 : jhr 4359 (* simplify a parallel map-reduce *)
405 :     and simplifyReduction (errStrm, rator, e, x, xs, resTy) = let
406 : jhr 4368 val rator' = if Var.same(BV.red_all, rator) then Reductions.ALL
407 :     else if Var.same(BV.red_exists, rator) then Reductions.EXISTS
408 :     else if Var.same(BV.red_max, rator) then Reductions.MAX
409 :     else if Var.same(BV.red_mean, rator) then raise Fail "FIXME: mean reduction"
410 :     else if Var.same(BV.red_min, rator) then Reductions.MIN
411 :     else if Var.same(BV.red_product, rator) then Reductions.PRODUCT
412 :     else if Var.same(BV.red_sum, rator) then Reductions.SUM
413 :     else if Var.same(BV.red_variance, rator) then raise Fail "FIXME: variance reduction"
414 :     else raise Fail "impossible: not a reduction"
415 :     val x' = cvtVar x
416 :     val result = SimpleVar.new ("res", Var.LocalVar, cvtTy resTy)
417 :     val (bodyStms, bodyResult) = simplifyExpToVar (errStrm, e, [])
418 : jhr 4359 (* FIXME: need to handle reductions over active/stable subsets of strands *)
419 : jhr 4368 val (func, args) = Util.makeFunction(
420 :     Var.nameOf rator, mkBlock(S.S_Return bodyResult :: bodyStms),
421 :     SimpleVar.typeOf bodyResult)
422 :     val mapReduceStm = S.S_MapReduce{
423 :     results = [result],
424 :     reductions = [rator'],
425 :     body = func,
426 :     args = args,
427 :     source = x'
428 :     }
429 :     in
430 :     (result, mapReduceStm)
431 :     end
432 : jhr 4359
433 :    
434 : jhr 4113 (* simplify a block and then prune unreachable and dead code *)
435 :     fun simplifyAndPruneBlock errStrm blk =
436 : jhr 4317 DeadCode.eliminate (simplifyBlock (errStrm, blk))
437 : jhr 4113
438 : jhr 3995 fun simplifyStrand (errStrm, strand) = let
439 : jhr 4368 val AST.Strand{
440 :     name, params, spatialDim, state, stateInit, initM, updateM, stabilizeM
441 :     } = strand
442 : jhr 3456 val params' = cvtVars params
443 :     fun simplifyState ([], xs, stms) = (List.rev xs, mkBlock stms)
444 :     | simplifyState ((x, optE) :: r, xs, stms) = let
445 :     val x' = cvtVar x
446 : jhr 4317 in
447 :     case optE
448 :     of NONE => simplifyState (r, x'::xs, stms)
449 :     | SOME e => let
450 :     val (stms, e') = simplifyExp (errStrm, e, stms)
451 :     in
452 :     simplifyState (r, x'::xs, S.S_Assign(x', e') :: stms)
453 :     end
454 :     (* end case *)
455 :     end
456 : jhr 3456 val (xs, stm) = simplifyState (state, [], [])
457 : jhr 3437 in
458 : jhr 3452 S.Strand{
459 :     name = name,
460 :     params = params',
461 : jhr 4368 spatialDim = spatialDim,
462 : jhr 3456 state = xs,
463 : jhr 4317 stateInit = stm,
464 :     initM = Option.map (simplifyAndPruneBlock errStrm) initM,
465 :     updateM = simplifyAndPruneBlock errStrm updateM,
466 :     stabilizeM = Option.map (simplifyAndPruneBlock errStrm) stabilizeM
467 : jhr 3452 }
468 : jhr 3437 end
469 :    
470 : jhr 3452 fun transform (errStrm, prog) = let
471 : jhr 4317 val AST.Program{
472 :     props, const_dcls, input_dcls, globals, globInit, strand, create, init, update
473 :     } = prog
474 :     val consts' = ref[]
475 :     val constInit = ref[]
476 :     val inputs' = ref[]
477 :     val globals' = ref[]
478 :     val globalInit = ref[]
479 :     val funcs = ref[]
480 :     fun simplifyConstDcl (x, SOME e) = let
481 :     val (stms, e') = simplifyExp (errStrm, e, [])
482 :     val x' = cvtVar x
483 :     in
484 :     consts' := x' :: !consts';
485 :     constInit := S.S_Assign(x', e') :: (stms @ !constInit)
486 :     end
487 :     fun simplifyInputDcl ((x, NONE), desc) = let
488 :     val x' = cvtVar x
489 :     val init = (case SimpleVar.typeOf x'
490 :     of STy.T_Image info => S.Image info
491 :     | _ => S.NoDefault
492 :     (* end case *))
493 :     val inp = S.INP{
494 :     var = x',
495 :     name = Var.nameOf x,
496 :     ty = apiTypeOf x',
497 :     desc = desc,
498 :     init = init
499 :     }
500 :     in
501 :     inputs' := inp :: !inputs'
502 :     end
503 :     | simplifyInputDcl ((x, SOME(AST.E_LoadNrrd(tvs, nrrd, ty))), desc) = let
504 :     val (x', init) = (case Var.monoTypeOf x
505 :     of Ty.T_Sequence(_, NONE) => (cvtVar x, S.LoadSeq nrrd)
506 :     | Ty.T_Image{dim, shape} => let
507 :     val dim = TU.monoDim dim
508 :     val shape = TU.monoShape shape
509 :     in
510 :     case NrrdInfo.getInfo (errStrm, nrrd)
511 :     of SOME nrrdInfo => (case II.fromNrrd(nrrdInfo, dim, shape)
512 :     of NONE => (
513 :     Error.error (errStrm, [
514 :     "proxy nrrd file \"", nrrd,
515 :     "\" does not have expected type"
516 :     ]);
517 :     (cvtVar x, S.Image(II.mkInfo(dim, shape))))
518 :     | SOME info =>
519 :     (newVarWithType(x, STy.T_Image info), S.Proxy(nrrd, info))
520 :     (* end case *))
521 :     | NONE => (
522 :     Error.warning (errStrm, [
523 :     "proxy nrrd file \"", nrrd, "\" does not exist"
524 :     ]);
525 :     (cvtVar x, S.Image(II.mkInfo(dim, shape))))
526 :     (* end case *)
527 :     end
528 :     | _ => raise Fail "impossible"
529 :     (* end case *))
530 :     val inp = S.INP{
531 :     var = x',
532 :     name = Var.nameOf x,
533 :     ty = apiTypeOf x',
534 :     desc = desc,
535 :     init = init
536 :     }
537 :     in
538 :     inputs' := inp :: !inputs'
539 :     end
540 :     | simplifyInputDcl ((x, SOME e), desc) = let
541 :     val x' = cvtVar x
542 :     val (stms, e') = simplifyExp (errStrm, e, [])
543 :     val inp = S.INP{
544 :     var = x',
545 :     name = Var.nameOf x,
546 :     ty = apiTypeOf x',
547 :     desc = desc,
548 :     init = S.ConstExpr
549 :     }
550 :     in
551 :     inputs' := inp :: !inputs';
552 :     constInit := S.S_Assign(x', e') :: (stms @ !constInit)
553 :     end
554 :     fun simplifyGlobalDcl (AST.D_Var(x, NONE)) = globals' := cvtVar x :: !globals'
555 :     | simplifyGlobalDcl (AST.D_Var(x, SOME e)) = let
556 :     val (stms, e') = simplifyExp (errStrm, e, [])
557 :     val x' = cvtLHS (x, e')
558 :     in
559 :     globals' := x' :: !globals';
560 :     globalInit := S.S_Assign(x', e') :: (stms @ !globalInit)
561 :     end
562 :     | simplifyGlobalDcl (AST.D_Func(f, params, body)) = let
563 :     val f' = cvtFunc f
564 :     val params' = cvtVars params
565 :     val body' = simplifyAndPruneBlock errStrm body
566 :     in
567 :     funcs := S.Func{f=f', params=params', body=body'} :: !funcs
568 :     end
569 :     val () = (
570 :     List.app simplifyConstDcl const_dcls;
571 :     List.app simplifyInputDcl input_dcls;
572 :     List.app simplifyGlobalDcl globals)
573 :     (* make the global-initialization block *)
574 :     val globInit = (case globInit
575 :     of SOME stm => mkBlock (simplifyStmt (errStrm, stm, !globalInit))
576 :     | NONE => mkBlock (!globalInit)
577 :     (* end case *))
578 :     (* if the globInit block is non-empty, record the fact in the property list *)
579 :     val props = (case globInit
580 :     of S.Block{code=[], ...} => props
581 :     | _ => Properties.GlobalInit :: props
582 :     (* end case *))
583 : jhr 3452 in
584 :     S.Program{
585 :     props = props,
586 : jhr 4317 consts = List.rev(!consts'),
587 : jhr 3452 inputs = List.rev(!inputs'),
588 : jhr 4317 constInit = mkBlock (!constInit),
589 : jhr 3452 globals = List.rev(!globals'),
590 : jhr 3995 globInit = globInit,
591 : jhr 3452 funcs = List.rev(!funcs),
592 : jhr 3456 strand = simplifyStrand (errStrm, strand),
593 : jhr 4113 create = Create.map (simplifyAndPruneBlock errStrm) create,
594 : jhr 4317 init = Option.map (simplifyAndPruneBlock errStrm) init,
595 :     update = Option.map (simplifyAndPruneBlock errStrm) update
596 : jhr 3452 }
597 :     end
598 :    
599 : jhr 3437 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0