Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /trunk/src/compiler/simplify/simplify.sml
ViewVC logotype

Annotation of /trunk/src/compiler/simplify/simplify.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2476 - (view) (download)

1 : jhr 171 (* simplify.sml
2 :     *
3 : jhr 435 * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 : jhr 171 * All rights reserved.
5 :     *
6 :     * Simplify the AST representation.
7 :     *)
8 :    
9 :     structure Simplify : sig
10 :    
11 : jhr 1140 val transform : Error.err_stream * AST.program -> Simple.program
12 : jhr 171
13 :     end = struct
14 :    
15 : jhr 2476 structure TU = TypeUtil
16 : jhr 171 structure S = Simple
17 : jhr 2476 structure VMap = Var.Map
18 : jhr 171
19 : jhr 2476 val cvtTy = SimpleTypes.simplify
20 : jhr 2356
21 : jhr 2476 fun newTemp ty = SimpleVar.new ("_t", SimpleVar.LocalVar, ty)
22 : jhr 171
23 : jhr 2476 (* convert an AST variable to a Simple variable *)
24 :     fun cvtVar (env, x as Var.V{name, kind, ty=([], ty), ...}) = let
25 :     val x' = SimpleVar.new (name, kind, cvtTy ty)
26 :     in
27 :     (x', VMap.insert(env, x, x'))
28 :     end
29 :    
30 :     fun cvtVars (env, xs) = List.foldr
31 :     (fn (x, (xs, env)) => let
32 :     val (x', env) = cvtVar(env, x)
33 :     in
34 :     (x'::xs, env)
35 :     end) ([], env) xs
36 :    
37 :     fun lookupVar (env, x) = (case VMap.find (env, x)
38 :     of SOME x' => x'
39 :     | NONE => raise Fail(concat["lookupVar(", Var.uniqueNameOf x, ")"])
40 :     (* end case *))
41 :    
42 : jhr 171 (* make a block out of a list of statements that are in reverse order *)
43 : jhr 197 fun mkBlock stms = S.Block(List.rev stms)
44 : jhr 171
45 : jhr 2356 (* is the given statement's continuation the syntactically following statement? *)
46 :     fun contIsNext (AST.S_Block stms) = List.all contIsNext stms
47 :     | contIsNext (AST.S_IfThenElse(_, s1, s2)) = contIsNext s1 orelse contIsNext s2
48 :     | contIsNext AST.S_Die = false
49 :     | contIsNext AST.S_Stabilize = false
50 :     | contIsNext (AST.S_Return _) = false
51 :     | contIsNext _ = true
52 :    
53 : jhr 227 fun simplifyProgram (AST.Program dcls) = let
54 : jhr 2356 val globals = ref []
55 :     val globalInit = ref []
56 :     val funcs = ref []
57 :     val initially = ref NONE
58 :     val strands = ref []
59 :     fun setInitially init = (case !initially
60 :     of NONE => initially := SOME init
61 : jhr 1116 (* FIXME: the check for multiple initially decls should happen in type checking *)
62 : jhr 2356 | SOME _ => raise Fail "multiple initially declarations"
63 :     (* end case *))
64 : jhr 2476 fun simplifyDecl (dcl, env) = (case dcl
65 : jhr 2356 of AST.D_Input(x, desc, NONE) => let
66 : jhr 2476 val (x', env) = cvtVar(env, x)
67 :     val e' = S.E_Input(SimpleVar.typeOf x', SimpleVar.nameOf x', desc, NONE)
68 : jhr 2356 in
69 : jhr 2476 globals := x' :: !globals;
70 :     globalInit := S.S_Assign(x', e') :: !globalInit;
71 :     env
72 : jhr 2356 end
73 :     | AST.D_Input(x, desc, SOME e) => let
74 : jhr 2476 val (x', env) = cvtVar(env, x)
75 :     val (stms, x'') = simplifyExpToVar (env, e, [])
76 :     val e' = S.E_Input(SimpleVar.typeOf x', SimpleVar.nameOf x', desc, SOME x'')
77 : jhr 2356 in
78 : jhr 2476 globals := x' :: !globals;
79 :     globalInit := S.S_Assign(x', e') :: (stms @ !globalInit);
80 :     env
81 : jhr 2356 end
82 :     | AST.D_Var(AST.VD_Decl(x, e)) => let
83 : jhr 2476 val (x', env) = cvtVar(env, x)
84 :     val (stms, e') = simplifyExp (env, e, [])
85 : jhr 2356 in
86 : jhr 2476 globals := x' :: !globals;
87 :     globalInit := S.S_Assign(x', e') :: (stms @ !globalInit);
88 :     env
89 : jhr 2356 end
90 : jhr 2476 | AST.D_Func(f, params, body) => let
91 :     val (f', env) = cvtVar(env, f)
92 :     val (params', env) = cvtVars (env, params)
93 :     val body' = simplifyBlock(env, body)
94 :     in
95 :     funcs := S.Func{f=f', params=params', body=body'} :: !funcs;
96 :     env
97 :     end
98 :     | AST.D_Strand info => (
99 :     strands := simplifyStrand(env, info) :: !strands;
100 :     env)
101 :     | AST.D_InitialArray(creat, iters) => (
102 :     setInitially (simplifyInit(env, true, creat, iters));
103 :     env)
104 :     | AST.D_InitialCollection(creat, iters) => (
105 :     setInitially (simplifyInit(env, false, creat, iters));
106 :     env)
107 : jhr 2356 (* end case *))
108 : jhr 2476 val env = List.foldl simplifyDecl VMap.empty dcls
109 : jhr 2356 in
110 :     S.Program{
111 :     globals = List.rev(!globals),
112 :     globalInit = mkBlock (!globalInit),
113 :     funcs = List.rev(!funcs),
114 :     init = (case !initially
115 : jhr 1116 (* FIXME: the check for the initially block should really happen in typechecking *)
116 : jhr 2356 of NONE => raise Fail "missing initially declaration"
117 :     | SOME blk => blk
118 :     (* end case *)),
119 :     strands = List.rev(!strands)
120 :     }
121 :     end
122 : jhr 171
123 : jhr 2476 and simplifyInit (env, isArray, AST.C_Create(strand, exps), iters) = let
124 :     fun simplifyIter (AST.I_Range(x, e1, e2), (env, iters, stms)) = let
125 :     val (stms, lo) = simplifyExpToVar (env, e1, stms)
126 :     val (stms, hi) = simplifyExpToVar (env, e2, stms)
127 :     val (x', env) = cvtVar (env, x)
128 :     in
129 :     (env, {param=x', lo=lo, hi=hi}::iters, stms)
130 :     end
131 :     val (env, iters, iterStms) = List.foldl simplifyIter (env, [], []) iters
132 :     val (stms, xs) = simplifyExpsToVars (env, exps, [])
133 : jhr 2356 val creat = S.C_Create{
134 :     argInit = mkBlock stms,
135 :     name = strand,
136 :     args = xs
137 :     }
138 :     in
139 :     S.Initially{
140 :     isArray = isArray,
141 : jhr 2476 rangeInit = mkBlock iterStms,
142 : jhr 2356 iters = List.rev iters,
143 :     create = creat
144 :     }
145 :     end
146 : jhr 1116
147 : jhr 2476 and simplifyStrand (env, AST.Strand{name, params, state, methods}) = let
148 :     val (params', env) = cvtVars (env, params)
149 :     fun simplifyState (env, [], xs, stms) = (List.rev xs, mkBlock stms, env)
150 :     | simplifyState (env, AST.VD_Decl(x, e) :: r, xs, stms) = let
151 :     val (stms, e') = simplifyExp (env, e, stms)
152 :     val (x', env) = cvtVar(env, x)
153 : jhr 2356 in
154 : jhr 2476 simplifyState (env, r, x'::xs, S.S_Assign(x', e') :: stms)
155 : jhr 2356 end
156 : jhr 2476 val (xs, stm, env) = simplifyState (env, state, [], [])
157 : jhr 2356 in
158 :     S.Strand{
159 :     name = name,
160 : jhr 2476 params = params',
161 : jhr 2356 state = xs, stateInit = stm,
162 : jhr 2476 methods = List.map (simplifyMethod env) methods
163 : jhr 2356 }
164 :     end
165 : jhr 171
166 : jhr 2476 and simplifyMethod env (AST.M_Method(name, body)) =
167 :     S.Method(name, simplifyBlock(env, body))
168 : jhr 171
169 : jhr 1116 (* simplify a statement into a single statement (i.e., a block if it expands
170 :     * into more than one new statement).
171 : jhr 171 *)
172 : jhr 2476 and simplifyBlock (env, stm) = mkBlock (#1 (simplifyStmt (env, stm, [])))
173 : jhr 171
174 : jhr 2356 (* simplify the statement stm where stms is a reverse-order list of preceeding simplified
175 :     * statements. This function returns a reverse-order list of simplified statements.
176 :     * Note that error reporting is done in the typechecker, but it does not prune unreachable
177 :     * code.
178 :     *)
179 : jhr 2476 and simplifyStmt (env, stm, stms) = (case stm
180 : jhr 2356 of AST.S_Block body => let
181 : jhr 2476 fun simplify (_, [], stms) = stms
182 :     | simplify (env', stm::r, stms) = let
183 :     val (stms, env') = simplifyStmt (env', stm, stms)
184 :     in
185 :     if contIsNext stm
186 :     then simplify (env', r, stms)
187 :     else stms (* prune the unreachable statements "r" *)
188 :     end
189 : jhr 2356 in
190 : jhr 2476 (simplify (env, body, stms), env)
191 : jhr 2356 end
192 :     | AST.S_Decl(AST.VD_Decl(x, e)) => let
193 : jhr 2476 val (stms, e') = simplifyExp (env, e, stms)
194 :     val (x', env) = cvtVar(env, x)
195 : jhr 2356 in
196 : jhr 2476 (S.S_Assign(x', e') :: stms, env)
197 : jhr 2356 end
198 :     | AST.S_IfThenElse(e, s1, s2) => let
199 : jhr 2476 val (stms, x) = simplifyExpToVar (env, e, stms)
200 :     val s1 = simplifyBlock (env, s1)
201 :     val s2 = simplifyBlock (env, s2)
202 : jhr 2356 in
203 : jhr 2476 (S.S_IfThenElse(x, s1, s2) :: stms, env)
204 : jhr 2356 end
205 :     | AST.S_Assign(x, e) => let
206 : jhr 2476 val (stms, e') = simplifyExp (env, e, stms)
207 : jhr 2356 in
208 : jhr 2476 (S.S_Assign(lookupVar(env, x), e') :: stms, env)
209 : jhr 2356 end
210 :     | AST.S_New(name, args) => let
211 : jhr 2476 val (stms, xs) = simplifyExpsToVars (env, args, stms)
212 : jhr 2356 in
213 : jhr 2476 (S.S_New(name, xs) :: stms, env)
214 : jhr 2356 end
215 : jhr 2476 | AST.S_Die => (S.S_Die :: stms, env)
216 :     | AST.S_Stabilize => (S.S_Stabilize :: stms, env)
217 : jhr 2356 | AST.S_Return e => let
218 : jhr 2476 val (stms, x) = simplifyExpToVar (env, e, stms)
219 : jhr 2356 in
220 : jhr 2476 (S.S_Return x :: stms, env)
221 : jhr 2356 end
222 : jhr 1640 | AST.S_Print args => let
223 : jhr 2476 val (stms, xs) = simplifyExpsToVars (env, args, stms)
224 : jhr 1640 in
225 : jhr 2476 (S.S_Print xs :: stms, env)
226 : jhr 1640 end
227 : jhr 2356 (* end case *))
228 : jhr 171
229 : jhr 2476 and simplifyExp (env, exp, stms) = (
230 : jhr 2356 case exp
231 :     of AST.E_Var x => (case Var.kindOf x
232 :     of Var.BasisVar => let
233 : jhr 2476 val ty = cvtTy(Var.monoTypeOf x)
234 : jhr 2356 val x' = newTemp ty
235 : jhr 2476 val stm = S.S_Assign(x', S.E_Prim(x, [], [], ty))
236 : jhr 2356 in
237 :     (stm::stms, S.E_Var x')
238 :     end
239 : jhr 2476 | _ => (stms, S.E_Var(lookupVar(env, x)))
240 : jhr 2356 (* end case *))
241 :     | AST.E_Lit lit => (stms, S.E_Lit lit)
242 :     | AST.E_Tuple es => raise Fail "E_Tuple not yet implemented"
243 :     | AST.E_Apply(f, tyArgs, args, ty) => let
244 : jhr 2476 val (stms, xs) = simplifyExpsToVars (env, args, stms)
245 : jhr 2356 in
246 : jhr 2476 case Var.kindOf f
247 :     of S.FunVar => (stms, S.E_Apply(lookupVar(env, f), xs, cvtTy ty))
248 :     | S.BasisVar => let
249 :     fun cvtTyArg (Types.TYPE tv) = S.TY(cvtTy(TU.resolve tv))
250 :     | cvtTyArg (Types.DIFF dv) = S.DIFF(TU.monoDiff(TU.resolveDiff dv))
251 :     | cvtTyArg (Types.SHAPE sv) = S.SHAPE(TU.monoShape(TU.resolveShape sv))
252 :     | cvtTyArg (Types.DIM dv) = S.DIM(TU.monoDim(TU.resolveDim dv))
253 :     val tyArgs = List.map cvtTyArg tyArgs
254 :     in
255 :     (stms, S.E_Prim(f, tyArgs, xs, cvtTy ty))
256 :     end
257 :     | _ => raise Fail "bogus application"
258 :     (* end case *)
259 : jhr 2356 end
260 :     | AST.E_Cons es => let
261 : jhr 2476 val (stms, xs) = simplifyExpsToVars (env, es, stms)
262 : jhr 2356 in
263 :     (stms, S.E_Cons xs)
264 :     end
265 :     | AST.E_Slice(e, indices, ty) => let (* tensor slicing *)
266 : jhr 2476 val (stms, x) = simplifyExpToVar (env, e, stms)
267 : jhr 2356 fun f ([], ys, stms) = (stms, List.rev ys)
268 :     | f (NONE::es, ys, stms) = f (es, NONE::ys, stms)
269 :     | f (SOME e::es, ys, stms) = let
270 : jhr 2476 val (stms, y) = simplifyExpToVar (env, e, stms)
271 : jhr 2356 in
272 :     f (es, SOME y::ys, stms)
273 :     end
274 :     val (stms, indices) = f (indices, [], stms)
275 :     in
276 : jhr 2476 (stms, S.E_Slice(x, indices, cvtTy ty))
277 : jhr 2356 end
278 :     | AST.E_Cond(e1, e2, e3, ty) => let
279 :     (* a conditional expression gets turned into an if-then-else statememt *)
280 : jhr 2476 val result = newTemp(cvtTy ty)
281 :     val (stms, x) = simplifyExpToVar (env, e1, S.S_Var result :: stms)
282 : jhr 2356 fun simplifyBranch e = let
283 : jhr 2476 val (stms, e) = simplifyExp (env, e, [])
284 : jhr 2356 in
285 :     mkBlock (S.S_Assign(result, e)::stms)
286 :     end
287 :     val s1 = simplifyBranch e2
288 :     val s2 = simplifyBranch e3
289 :     in
290 :     (S.S_IfThenElse(x, s1, s2) :: stms, S.E_Var result)
291 :     end
292 :     | AST.E_Coerce{srcTy, dstTy, e} => let
293 : jhr 2476 val (stms, x) = simplifyExpToVar (env, e, stms)
294 :     val dstTy = cvtTy dstTy
295 : jhr 2356 val result = newTemp dstTy
296 : jhr 2476 val rhs = S.E_Coerce{srcTy = cvtTy srcTy, dstTy = dstTy, x = x}
297 : jhr 2356 in
298 :     (S.S_Assign(result, rhs)::stms, S.E_Var result)
299 :     end
300 :     (* end case *))
301 : jhr 171
302 : jhr 2476 and simplifyExpToVar (env, exp, stms) = let
303 :     val (stms, e) = simplifyExp (env, exp, stms)
304 : jhr 2356 in
305 :     case e
306 :     of S.E_Var x => (stms, x)
307 :     | _ => let
308 :     val x = newTemp (S.typeOf e)
309 :     in
310 :     (S.S_Assign(x, e)::stms, x)
311 :     end
312 :     (* end case *)
313 :     end
314 : jhr 171
315 : jhr 2476 and simplifyExpsToVars (env, exps, stms) = let
316 : jhr 2356 fun f ([], xs, stms) = (stms, List.rev xs)
317 :     | f (e::es, xs, stms) = let
318 : jhr 2476 val (stms, x) = simplifyExpToVar (env, e, stms)
319 : jhr 2356 in
320 :     f (es, x::xs, stms)
321 :     end
322 :     in
323 :     f (exps, [], stms)
324 :     end
325 : jhr 171
326 : jhr 1140 fun transform (errStrm, ast) = let
327 : jhr 2356 val simple = simplifyProgram ast
328 :     val _ = SimplePP.output (Log.logFile(), "simplify", simple) (* DEBUG *)
329 :     val simple = Inliner.transform simple
330 :     val _ = SimplePP.output (Log.logFile(), "inlining", simple) (* DEBUG *)
331 :     val simple = Lift.transform simple
332 :     handle Eval.Error msg => (Error.error(errStrm, msg); simple)
333 :     val _ = SimplePP.output (Log.logFile(), "lifting", simple) (* DEBUG *)
334 :     in
335 :     simple
336 :     end
337 : jhr 227
338 : jhr 171 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0