Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/codegen/collect-info.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/codegen/collect-info.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5286 - (view) (download)

1 : jhr 3866 (* collect-info.sml
2 :     *
3 :     * Collect information about the types and operations used in a program. We need this
4 :     * information to figure out what utility code to generate.
5 :     *
6 : jhr 3996 * The types are ordered so that base types are first, followed by TensorRefTy, followed
7 :     * by TensorTy, followed by sequences and tuples. Furthermore, the argument type of
8 : jhr 5139 * a sequence appears before the sequence type.
9 : jhr 3996 *
10 : jhr 3866 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
11 :     *
12 :     * COPYRIGHT (c) 2016 The University of Chicago
13 :     * All rights reserved.
14 :     *)
15 :    
16 :     structure CollectInfo : sig
17 :    
18 :     type t
19 :    
20 : jhr 3920 datatype operation
21 : jhr 4014 = Print of TreeTypes.t
22 :     | RClamp | RLerp
23 : jhr 5139 | VScale of int * int (* `VScale(w, pw)`: scalar times vector; `w` is width of
24 :     * vector and `pw` is the padded-width supported by
25 :     * the hardware.
26 :     *)
27 :     | VSum of int * int (* `VSum(w, pw)`: vector addition *)
28 : jhr 4056 | VDot of int * int
29 : jhr 3949 | VCeiling of int * int
30 :     | VFloor of int * int
31 :     | VRound of int * int
32 :     | VTrunc of int * int
33 : jhr 5139 | VToInt of TreeTypes.vec_layout
34 : jhr 3920 | VLoad of int * int
35 :     | VCons of int * int
36 : jhr 3931 | VPack of TreeTypes.vec_layout
37 : jhr 3955 | TensorCopy of int list
38 : jhr 4027 | Transform of int
39 :     | Translate of int
40 : jhr 4151 | Inside of VectorLayout.t * int
41 : jhr 4277 | EigenVals2x2
42 :     | EigenVals3x3
43 :     | EigenVecs2x2
44 :     | EigenVecs3x3
45 : jhr 4373 | SphereQuery of int * string
46 : jhr 3920
47 : jhr 3866 val collect : TreeIR.program -> t
48 :    
49 : jhr 4014 val listTypes : t -> TreeTypes.t list
50 : jhr 4013 val listOps : t -> operation list
51 : jhr 3866
52 :     end = struct
53 :    
54 :     structure IR = TreeIR
55 : jhr 3919 structure Ty = TreeTypes
56 : jhr 3920 structure Op = TreeOps
57 : jhr 3866
58 : jhr 3920 datatype operation
59 : jhr 4014 = Print of TreeTypes.t
60 :     | RClamp | RLerp
61 : jhr 3949 | VScale of int * int
62 : jhr 3920 | VSum of int * int
63 : jhr 4056 | VDot of int * int
64 : jhr 3949 | VCeiling of int * int
65 :     | VFloor of int * int
66 :     | VRound of int * int
67 :     | VTrunc of int * int
68 : jhr 5139 | VToInt of TreeTypes.vec_layout
69 : jhr 3920 | VLoad of int * int
70 :     | VCons of int * int
71 : jhr 3949 | VPack of TreeTypes.vec_layout
72 : jhr 3955 | TensorCopy of int list
73 : jhr 4027 | Transform of int
74 :     | Translate of int
75 : jhr 4151 | Inside of VectorLayout.t * int
76 : jhr 4277 | EigenVals2x2
77 :     | EigenVals3x3
78 :     | EigenVecs2x2
79 :     | EigenVecs3x3
80 : jhr 4373 | SphereQuery of int * string
81 : jhr 3920
82 : jhr 4096 (* operator to string (for debugging) *)
83 :     local
84 :     fun vop2s (rator, w, pw) = if (w = pw)
85 : jhr 4317 then rator ^ Int.toString w
86 :     else concat[rator, Int.toString w, "{", Int.toString pw, "}"]
87 : jhr 4096 in
88 :     fun toString rator = (case rator
89 : jhr 4317 of Print ty => concat["Print<", TreeTypes.toString ty, ">"]
90 :     | RClamp => "RClamp"
91 :     | RLerp => "RLerp"
92 :     | VScale(w, pw) => vop2s ("VScale", w, pw)
93 :     | VSum(w, pw) => vop2s ("VSum", w, pw)
94 :     | VDot(w, pw) => vop2s ("VDot", w, pw)
95 :     | VCeiling(w, pw) => vop2s ("VCeiling", w, pw)
96 :     | VFloor(w, pw) => vop2s ("VFloor", w, pw)
97 :     | VRound(w, pw) => vop2s ("VRound", w, pw)
98 :     | VTrunc(w, pw) => vop2s ("VTrunc", w, pw)
99 : jhr 5139 | VToInt layout => "VToInt" ^ VectorLayout.toString layout
100 : jhr 4317 | VLoad(w, pw) => vop2s ("VLoad", w, pw)
101 :     | VCons(w, pw) => vop2s ("VCons", w, pw)
102 :     | VPack layout => "VPack" ^ VectorLayout.toString layout
103 :     | TensorCopy shape =>
104 :     concat["TensorCopy[", String.concatWithMap "," Int.toString shape, "]"]
105 :     | Transform d => concat["Transform", Int.toString d, "D"]
106 :     | Translate d => concat["Translate", Int.toString d, "D"]
107 :     | Inside({wid, ...}, s) =>
108 :     concat["Inside", Int.toString wid, "D<", Int.toString s, ">"]
109 :     | EigenVals2x2 => "EigenVals2x2"
110 :     | EigenVals3x3 => "EigenVals3x3"
111 :     | EigenVecs2x2 => "EigenVecs2x2"
112 :     | EigenVecs3x3 => "EigenVecs3x3"
113 : jhr 4373 | SphereQuery(d, s) => concat["SphereQuery", Int.toString d, "<", s, ">"]
114 : jhr 4317 (* end case *))
115 : jhr 4096 end (* local *)
116 :    
117 : jhr 3920 structure OpTbl = HashTableFn (
118 :     struct
119 : jhr 4041 type hash_key = operation
120 :     fun hashVal rator = (case rator
121 :     of Print ty =>0w5 * TreeTypes.hash ty
122 :     | RClamp => 0w13
123 :     | RLerp => 0w17
124 :     | VScale(w, _) => 0w19 + 0w7 * Word.fromInt w
125 :     | VSum(w, _) => 0w23 + 0w7 * Word.fromInt w
126 : jhr 4056 | VDot(w, _) => 0w29 + 0w7 * Word.fromInt w
127 :     | VCeiling(w, _) => 0w43 + 0w7 * Word.fromInt w
128 :     | VFloor(w, _) => 0w47 + 0w7 * Word.fromInt w
129 :     | VRound(w, _) => 0w53 + 0w7 * Word.fromInt w
130 :     | VTrunc(w, _) => 0w59 + 0w7 * Word.fromInt w
131 : jhr 5139 | VToInt{wid, ...} => 0w61 + 0w7 * Word.fromInt wid
132 : jhr 4056 | VLoad(w, _) => 0w67 + 0w7 * Word.fromInt w
133 :     | VCons(w, _) => 0w71 + 0w7 * Word.fromInt w
134 :     | VPack{wid, ...} => 0w79 + 0w7 * Word.fromInt wid
135 :     | TensorCopy dd => 0w83 + List.foldl (fn (i, s) => (Word.fromInt i + 0w3*s)) 0w0 dd
136 :     | Transform d => 0w89 + 0w7 * Word.fromInt d
137 :     | Translate d => 0w97 + 0w7 * Word.fromInt d
138 : jhr 4317 | Inside(layout, s) =>
139 :     0w101 + 0w7 *VectorLayout.hash layout + 0w13 * Word.fromInt s
140 :     | EigenVals2x2 => 0w103
141 :     | EigenVals3x3 => 0w107
142 :     | EigenVecs2x2 => 0w109
143 :     | EigenVecs3x3 => 0w113
144 : jhr 4373 | SphereQuery(d, s) => 0w117 + 0w7 * Word.fromInt d + HashString.hashString s
145 : jhr 4041 (* end case *))
146 :     fun sameKey (op1, op2) = (case (op1, op2)
147 :     of (Print ty1, Print ty2) => TreeTypes.same(ty1, ty2)
148 :     | (RClamp, RClamp) => true
149 :     | (RLerp, RLerp) => true
150 :     | (VScale(w1, _), VScale(w2, _)) => (w1 = w2)
151 :     | (VSum(w1, _), VSum(w2, _)) => (w1 = w2)
152 : jhr 4056 | (VDot(w1, _), VDot(w2, _)) => (w1 = w2)
153 : jhr 4041 | (VCeiling(w1, _), VCeiling(w2, _)) => (w1 = w2)
154 :     | (VFloor(w1, _), VFloor(w2, _)) => (w1 = w2)
155 :     | (VRound(w1, _), VRound(w2, _)) => (w1 = w2)
156 :     | (VTrunc(w1, _), VTrunc(w2, _)) => (w1 = w2)
157 : jhr 5139 | (VToInt{wid=w1, ...}, VToInt{wid=w2, ...}) => (w1 = w2)
158 : jhr 4041 | (VLoad(w1, _), VLoad(w2, _)) => (w1 = w2)
159 :     | (VCons(w1, _), VCons(w2, _)) => (w1 = w2)
160 :     | (VPack{wid=w1, ...}, VPack{wid=w2, ...}) => (w1 = w2)
161 :     | (TensorCopy dd1, TensorCopy dd2) => ListPair.allEq (op =) (dd1, dd2)
162 :     | (Transform d1, Transform d2) => (d1 = d2)
163 :     | (Translate d1, Translate d2) => (d1 = d2)
164 : jhr 4317 | (Inside(l1, s1), Inside(l2, s2)) =>
165 :     VectorLayout.same(l1, l2) andalso (s1 = s2)
166 :     | (EigenVals2x2, EigenVals2x2) => true
167 :     | (EigenVals3x3, EigenVals3x3) => true
168 :     | (EigenVecs2x2, EigenVecs2x2) => true
169 :     | (EigenVecs3x3, EigenVecs3x3) => true
170 : jhr 4373 | (SphereQuery(d1, s1), SphereQuery(d2, s2)) => (d1 = d2) andalso (s1 = s2)
171 : jhr 4041 | _ => false
172 :     (* end case *))
173 : jhr 3920 end)
174 :    
175 : jhr 3866 datatype t = Info of {
176 : jhr 4041 tys : unit Ty.Tbl.hash_table, (* mapping for types in program *)
177 :     ops : unit OpTbl.hash_table (* mapping for selected operations in the program *)
178 : jhr 3866 }
179 :    
180 :     fun addType (Info{tys, ...}) = let
181 : jhr 4041 val find = Ty.Tbl.find tys
182 :     val ins = Ty.Tbl.insert tys
183 :     fun addTy ty = let
184 :     fun insert ty = (case find ty
185 :     of NONE => ins (ty, ())
186 :     | SOME () => ()
187 :     (* end case *))
188 : jhr 4317 (* insert a TensorTy or TensorRefTy, which means inserting both types (plus
189 :     * the last-dimension vector type for 2nd-order and higher tensors)
190 :     *)
191 :     fun insertTensorTy (shp as [_]) = (
192 :     insert (Ty.TensorTy shp);
193 :     insert (Ty.TensorRefTy shp))
194 :     | insertTensorTy (shp as _::dd) = let
195 :     val d = List.last dd
196 :     in
197 :     insert (Ty.TensorTy shp);
198 :     insert (Ty.TensorRefTy shp);
199 :     (* we also need the vector types for the "last" member function *)
200 :     insert (Ty.TensorTy[d]);
201 :     insert (Ty.TensorRefTy[d])
202 :     end
203 : jhr 4041 fun add ty = (case ty
204 :     of Ty.BoolTy => ()
205 :     | Ty.IntTy => ()
206 :     | Ty.StringTy => ()
207 :     | Ty.VecTy(1, 1) => ()
208 :     | Ty.TensorTy shp => insertTensorTy shp
209 :     | Ty.TensorRefTy shp => insertTensorTy shp
210 : jhr 4386 | Ty.StrandIdTy _ => ()
211 : jhr 4041 | Ty.TupleTy tys => (insert ty; List.app add tys)
212 :     | Ty.SeqTy(ty', _) => (insert ty; add ty')
213 :     | _ => insert ty
214 :     (* end case *))
215 :     in
216 :     add ty
217 :     end
218 :     in
219 :     addTy
220 :     end
221 : jhr 3866
222 : jhr 3921 fun insertOp (Info{ops, ...}) = let
223 : jhr 4041 val find = OpTbl.find ops
224 :     val ins = OpTbl.insert ops
225 :     in
226 :     fn rator => (case find rator
227 :     of NONE => ins (rator, ())
228 :     | SOME() => ()
229 :     (* end case *))
230 :     end
231 : jhr 3921
232 :     fun addOp info = let
233 : jhr 4041 val insert = insertOp info
234 : jhr 4317 val addTy = addType info
235 : jhr 4041 fun add' rator = (case rator
236 :     of Op.RClamp => insert RClamp
237 :     | Op.RLerp => insert RLerp
238 :     | Op.VScale(w, pw) => insert (VScale(w, pw))
239 :     | Op.VSum(w, pw) => insert (VSum(w, pw))
240 : jhr 4056 | Op.VDot(w, pw) => insert (VDot(w, pw))
241 : jhr 4041 | Op.VCeiling(w, pw) => insert (VCeiling(w, pw))
242 :     | Op.VFloor(w, pw) => insert (VFloor(w, pw))
243 :     | Op.VRound(w, pw) => insert (VRound(w, pw))
244 :     | Op.VTrunc(w, pw) => insert (VTrunc(w, pw))
245 : jhr 5139 | Op.VToInt layout => insert (VToInt layout)
246 : jhr 4317 | Op.ProjectLast(Ty.TensorTy(_::(dd as _::_)), _) =>
247 :     addTy (Ty.TensorRefTy[List.last dd])
248 :     | Op.ProjectLast(Ty.TensorRefTy(_::(dd as _::_)), _) =>
249 :     addTy (Ty.TensorRefTy[List.last dd])
250 : jhr 4041 | Op.TensorCopy shp => insert (TensorCopy shp)
251 : jhr 4317 | Op.EigenVecs2x2 => insert EigenVecs2x2
252 :     | Op.EigenVecs3x3 => insert EigenVecs3x3
253 :     | Op.EigenVals2x2 => insert EigenVals2x2
254 :     | Op.EigenVals3x3 => insert EigenVals3x3
255 : jhr 4386 | Op.SphereQuery(d, Ty.StrandIdTy s) => insert (SphereQuery(d, Atom.toString s))
256 : jhr 4041 | Op.Transform info => insert (Transform(ImageInfo.dim info))
257 :     | Op.Translate info => insert (Translate(ImageInfo.dim info))
258 : jhr 4317 | Op.Inside(layout, _, s) => insert (Inside(layout, s))
259 : jhr 4041 | _ => ()
260 :     (* end case *))
261 :     in
262 :     add'
263 :     end
264 : jhr 3920
265 : jhr 3866 fun collect prog = let
266 : jhr 4041 val IR.Program{
267 : jhr 4175 consts, inputs, constInit, globals, funcs, globInit,
268 : jhr 4493 strand, create, start, update, ...
269 : jhr 4041 } = prog
270 : jhr 4493 val IR.Strand{params, state, stateInit, startM, updateM, stabilizeM, ...} = strand
271 : jhr 4041 val info = Info{
272 :     tys = TreeTypes.Tbl.mkTable (64, Fail "tys"),
273 :     ops = OpTbl.mkTable (64, Fail "ops")
274 :     }
275 :     val addType = addType info
276 :     val addOp = addOp info
277 :     val insertOp = insertOp info
278 : jhr 5139 fun insertPrint (Ty.TensorTy shp) = insertPrint (Ty.TensorRefTy shp)
279 :     | insertPrint ty = (
280 :     addType ty; insertOp (Print ty);
281 :     (* add a printer for elements (when necessary) *)
282 :     case ty
283 :     of Ty.TupleTy tys => List.app insertPrint tys
284 :     | Ty.SeqTy(ty', _) => insertPrint ty'
285 :     | _ => ()
286 :     (* end case *))
287 : jhr 4041 fun doGlobalV x = addType(TreeGlobalVar.ty x)
288 :     fun doStateV x = addType(TreeStateVar.ty x)
289 :     fun doV x = addType(TreeVar.ty x)
290 :     fun doExp e = (case e
291 :     of IR.E_State(SOME e, sv) => doExp e
292 :     | IR.E_Op(rator, args) => (
293 :     case rator
294 :     of Op.Transform info => (case ImageInfo.dim info
295 :     of 1 => ()
296 :     | d => addType (Ty.TensorRefTy[d, d])
297 :     (* end case *))
298 :     | Op.Translate info => (case ImageInfo.dim info
299 :     of 1 => ()
300 :     | d => addType (Ty.TensorRefTy[d])
301 :     (* end case *))
302 :     | _ => ()
303 :     (* end case *);
304 :     addOp rator;
305 :     List.app doExp args)
306 :     | IR.E_Vec(w, pw, es) => (
307 :     addType(Ty.VecTy(w, pw));
308 :     insertOp (VCons(w, pw));
309 :     List.app doExp es)
310 :     | IR.E_Cons(es, ty) => (addType ty; List.app doExp es)
311 :     | IR.E_Seq(es, ty) => (addType ty; List.app doExp es)
312 :     | IR.E_Pack(layout, es) => (
313 :     List.app (fn ty => addType ty) (Ty.piecesOf layout);
314 :     insertOp (VPack layout);
315 :     List.app doExp es)
316 :     | IR.E_VLoad(layout, e, i) => let
317 :     val ty as Ty.VecTy(w, pw) = Ty.nthVec(layout, i)
318 :     in
319 :     addType ty;
320 :     insertOp (VLoad(w, pw));
321 :     doExp e
322 :     end
323 :     | _ => ()
324 :     (* end case *))
325 :     fun doStm stm = (case stm
326 :     of IR.S_Assign(isDecl, x, e) => (
327 :     if isDecl then doV x else ();
328 :     doExp e)
329 :     | IR.S_MAssign(_, e) => doExp e
330 :     | IR.S_GAssign(_, e) => doExp e
331 :     | IR.S_IfThen(e, b) => (doExp e; doBlk b)
332 :     | IR.S_IfThenElse(e, b1, b2) => (doExp e; doBlk b1; doBlk b2)
333 :     | IR.S_For(x, lo, hi, b) => (doV x; doExp lo; doExp hi; doBlk b)
334 :     | IR.S_Foreach(x, e, b) => (doV x; doExp e; doBlk b)
335 :     | IR.S_Input(_, _, _, SOME e) => doExp e
336 :     | IR.S_New(_, es) => List.app doExp es
337 :     | IR.S_Save(_, e) => doExp e
338 :     | IR.S_Print(tys, es) => (
339 : jhr 4530 List.app insertPrint tys;
340 : jhr 4041 List.app doExp es)
341 : jhr 5286 | IR.S_Return(SOME e) => doExp e
342 : jhr 4041 | _ => ()
343 :     (* end case *))
344 :     and doBlk (IR.Block{locals, body}) = (
345 :     List.app doV (!locals);
346 :     List.app doStm body)
347 : jhr 4317 and doFunc (IR.Func{name, body, ...}) = let
348 :     val (resTy, tys) = TreeFunc.ty name
349 :     in
350 :     List.app addType (resTy::tys);
351 :     doBlk body
352 :     end
353 : jhr 4041 and doMethod (IR.Method{body, ...}) = doBlk body
354 :     in
355 :     List.app doGlobalV consts;
356 :     List.app (doGlobalV o Inputs.varOf) inputs;
357 :     List.app doGlobalV globals;
358 :     List.app doStateV state;
359 : jhr 4317 List.app doFunc funcs;
360 : jhr 4041 doBlk constInit;
361 :     doBlk globInit;
362 :     doMethod stateInit;
363 : jhr 4493 Option.app doMethod startM;
364 : jhr 4041 doMethod updateM;
365 :     Option.app doMethod stabilizeM;
366 : jhr 4045 Create.app doBlk create;
367 : jhr 4493 Option.app doBlk start;
368 : jhr 4041 Option.app doBlk update;
369 :     info
370 :     end
371 : jhr 3866
372 : jhr 3989 (* sort function for types; we need to sort the types to ensure that types are
373 : jhr 4005 * declared before being used in another type. Since the generated decls are
374 :     * accumulated from last to first, we reverse the sort order.
375 : jhr 3989 *)
376 :     val tySort = let
377 : jhr 4041 (* partial ordering on types is defined by a "depth" metric *)
378 :     fun depth (Ty.TupleTy tys) = 4 + List.foldl (fn (ty, d) => Int.max(depth ty, d)) 0 tys
379 :     | depth (Ty.SeqTy(ty, _)) = 4 + depth ty
380 :     | depth (Ty.TensorTy[]) = 0
381 :     | depth (Ty.TensorRefTy _) = 1
382 :     | depth (Ty.TensorTy _) = 2
383 :     | depth (Ty.ImageTy _) = 3
384 :     | depth _ = 0
385 :     fun gt (Ty.TensorTy dd1, Ty.TensorTy dd2) = List.length dd1 > List.length dd2
386 :     | gt (Ty.TensorRefTy dd1, Ty.TensorRefTy dd2) = List.length dd1 > List.length dd2
387 :     | gt (ty1, ty2) = (depth ty1 > depth ty2)
388 :     in
389 :     ListMergeSort.sort gt
390 :     end
391 : jhr 3989
392 : jhr 4014 fun listTypes (Info{tys, ...}) =
393 : jhr 4041 tySort (TreeTypes.Tbl.foldi (fn (ty, _, acc) => ty::acc) [] tys)
394 : jhr 3989
395 : jhr 4013 fun listOps (Info{ops, ...}) = OpTbl.foldi (fn (k, _, acc) => k::acc) [] ops
396 : jhr 3866
397 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0