Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/clean-index.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/clean-index.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3978 - (view) (download)

1 : jhr 3561 (* clean-index.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 : jhr 3570
9 : jhr 3575 (* Example 1
10 :     Input to clean ()
11 :     e: Σ_14[0-2]Prod< T20_14* T21_14,1>, index:[2,3], sx:[E.V(6)[0-2]]
12 :     Analyzing
13 :     Get shape of e, getShapes()
14 :     aShape : 14,14,14,1,tShape : 1
15 :     Create sizeMapp: index_id to dimension, mkSizeMapp()
16 :     0 => 2, 1 => 3, 6 => 3
17 :     Find size of e by looking up tshape in the sizeMapp
18 :     sizes=[3]
19 :     Create indexMapp: Map the index variables e => e', mkIndexMap()
20 :     Set map for tshape indices first, vxToMapp()
21 :     E.V(1) => E.V(0)
22 :     Checks indices from E.V 0 to E.V15, intToMapp()
23 :     E.V(14) => E.V(1)
24 :     Rewrite subexpression: e =>e', rewriteIx()
25 :     e => Σ_1[0-2]Prod< T20_1* T21_1,0>
26 :     Output: tshape:[E.V(1)],sizes:[3],e': Σ_1[0-2]Prod< T20_1* T21_1,0>
27 :    
28 :     b=<Σ_[E.V(6)[0-2]]( Σ_14[0-2]Prod< T20_14* T21_14,1>)... >2,3 (args)
29 :     ===>
30 :     a=< Σ_1[0-2]Prod< T20_1* T21_1,0>>_3 (args)
31 :     b'=<T_{E.V(1)}..>2,3 (args,a)
32 :    
33 :     * Example 2
34 :     Input to clean ()
35 :     e:Add( T23_6,1+ T24_6,1), index:[2,3], sx:[E.V(6)[0-2]]
36 :     Analyzing
37 :     Get shape of e
38 :     aShape : 6,1,6,1,tShape : 6,1
39 :     Create sizeMapp: index_id to dimension
40 :     0 => 2,1 => 3,6 => 3
41 :     Find size of e by looking up tshape in the sizeMapp
42 :     sizes=[3,3]
43 :     Create indexMapp: Map the index variables e => e'
44 :     Set map for tshape indices first
45 :     E.V(6) => E.V(0),E.V(1) => E.V(1)
46 :     Checks indices from E.V 0 to E.V7
47 :     Rewrite subexpression: e =>e'
48 :     e => Add( T23_0,1+ T24_0,1)
49 :     Output:
50 :     tshape:[E.V(6),E.V(1)], sizes:[3,3], e':Add( T23_0,1+ T24_0,1)
51 :     *)
52 :    
53 : jhr 3561 structure CleanIndex : sig
54 :    
55 : jhr 3575 val clean : Ein.ein_exp * int list * Ein.sumrange list -> Ein.mu list * int list * Ein.ein_exp
56 : jhr 3561
57 :     end = struct
58 :    
59 :     structure E = Ein
60 : jhr 3574 structure ISet = IntRedBlackSet
61 : jhr 3565 structure IMap = IntRedBlackMap
62 : jhr 3561
63 : jhr 3574 fun lookupId (e1, mapp) = (case IMap.find (mapp, e1)
64 : jhr 3570 of SOME l => l
65 : jhr 3574 | _ => raise Fail (concat["lookupId: ", Int.toString e1, " not found"])
66 : jhr 3570 (* end case *))
67 : cchiw 3568
68 : jhr 3574 fun lkupVx (E.V e1, mapp) = E.V (lookupId(e1, mapp))
69 :     | lkupVx (E.C e1, mapp) = E.C e1
70 : jhr 3561
71 : jhr 3574 fun lkupSx ([], mapp) = []
72 : cchiw 3978 | lkupSx ((e1, ub, lb)::es, mapp) = (case IMap.find(mapp, e1)
73 :     of SOME l => (l, ub, lb) :: lkupSx(es, mapp)
74 : jhr 3574 | _ => lkupSx(es, mapp)
75 : jhr 3570 (* end case *))
76 :    
77 : jhr 3574 (* compute the set of indices (both parameter and summation-index) that are used in an
78 :     * ein expression.
79 : jhr 3570 *)
80 : jhr 3566 fun aShape b = let
81 : jhr 3575 fun addMus (s, []) = s
82 :     | addMus (s, E.V i :: mus) = addMus (ISet.add(s, i), mus)
83 : jhr 3827 | addMus (s, E.C _ :: mus) = addMus (s, mus)
84 : cchiw 3978 fun addSingle(s, []) = s
85 :     | addSingle(s, i :: ixs) = addSingle (ISet.add(s, i), ixs)
86 : jhr 3574 fun shape (b, ixs) = (case b
87 :     of E.Const _ => ixs
88 :     | E.ConstR _ => ixs
89 : jhr 3575 | E.Tensor(_, alpha) => addMus(ixs, alpha)
90 : cchiw 3978 | E.Delta(i, j) => ISet.add(ISet.add(ixs, i), j)
91 : jhr 3574 | E.Epsilon(i, j, k) => ISet.add(ISet.add(ISet.add(ixs, i), j), k)
92 :     | E.Eps2(i, j) => ISet.add(ISet.add(ixs, i), j)
93 : jhr 3575 | E.Field(_, alpha) => addMus(ixs, alpha)
94 : jhr 3574 | E.Lift e => shape(e, ixs)
95 : jhr 3575 | E.Conv(_, alpha, _, dx) => addMus(addMus(ixs, alpha), dx)
96 :     | E.Partial alpha => addMus(ixs, alpha)
97 :     | E.Apply(E.Partial alpha, e1) => shape (e1, addMus(ixs, alpha))
98 : jhr 3574 | E.Probe(e, _) => shape (e, ixs)
99 :     | E.Value e1 => raise Fail "Error in Ashape"
100 :     | E.Img _ => raise Fail "Error in Ashape"
101 :     | E.Krn _ => raise Fail "Error in Ashape"
102 : cchiw 3978 | E.Sum(sx, e) => shape (e, addSingle (ixs, List.map #1 sx))
103 : jhr 3574 | E.Op1 (_, e) => shape (e, ixs)
104 :     | E.Op2(_, e1, e2) => shape (e1, shape(e2, ixs))
105 :     | E.Opn(_, es) => List.foldl shape ixs es
106 : jhr 3576 | _ => raise Fail "impossible"
107 : jhr 3574 (* end case *))
108 :     in
109 : jhr 3575 shape (b, ISet.empty)
110 : jhr 3574 end
111 : jhr 3566
112 : cchiw 3568 (* eShape: list of index-ids with potential to be in tshape
113 :     * of T_α -> eshape = α
114 :     * | e1 +.. -> eshape = eShape(e1)
115 : jhr 3574 * | e1 / e2 ->
116 : cchiw 3568 * eshape = eShape(e1) and b = eShape(e2).
117 :     * forall i in b. if i not in eshape then add i to eshape
118 : jhr 3574 * | e1 * e2 ->
119 :     * eshape = eShape(e1) and b = eShape(e2).
120 :     * forall i in b. if i not in eshape then add i to eshape
121 : jhr 3566 *)
122 : cchiw 3568 fun eShape b = let
123 : jhr 3574 fun shape (b, ixs) = (case b
124 :     of E.Const _ => ixs
125 :     | E.ConstR _ => ixs
126 :     | E.Tensor(_, alpha) => alpha @ ixs
127 : cchiw 3978 | E.Delta(i, j) => E.V i :: E.V j :: ixs
128 : jhr 3574 | E.Epsilon(i, j, k) => E.V i :: E.V j :: E.V k :: ixs
129 :     | E.Eps2(i, j) => E.V i :: E.V j :: ixs
130 :     | E.Field(_, alpha) => alpha @ ixs
131 :     | E.Lift e => shape (e, ixs)
132 :     | E.Conv(_, alpha, _, dx) => alpha @ dx @ ixs
133 :     | E.Partial alpha => alpha @ ixs
134 :     | E.Apply(E.Partial dx, e) => shape (e, dx@ixs)
135 :     | E.Probe(e, _) => shape (e, ixs)
136 :     | E.Value _ => raise Fail "unexpected Value"
137 :     | E.Img _ => raise Fail "unexpected Img"
138 :     | E.Krn _ => raise Fail "unexpected Krn"
139 :     | E.Sum(_ , e) => shape (e, ixs)
140 :     | E.Op1(_, e) => shape (e, ixs)
141 :     | E.Op2(_, e1, e2) => shape' ([e1, e2], ixs)
142 :     | E.Opn(E.Add, e::_) => shape(e, ixs)
143 :     | E.Opn(E.Prod, es) => shape' (es, ixs)
144 :     | _ => raise Fail "impossible"
145 :     (* end case *))
146 :     (* processing a list of subexpressions that is under a division or product operator.
147 :     * es -- list of sub expressions
148 :     * ixs -- indices to the right of the parent operator
149 :     *)
150 :     and shape' (es, ixs) = let
151 :     fun f ([], _, jxs) = List.revAppend(jxs, ixs)
152 :     | f (e::es, seen, jxs) = let
153 :     (* QUESTION: perhaps we don't need the set and could just use jxs instead *)
154 :     fun add ([], seen, jxs) = f (es, seen, jxs)
155 :     | add (E.V i::ixs, seen, jxs) = if ISet.member(seen, i)
156 :     then add (ixs, seen, jxs)
157 :     else add (ixs, ISet.add(seen, i), E.V i::jxs)
158 : cchiw 3969 | add (E.C i::ixs, seen, jxs) = add (ixs, seen, jxs)
159 : jhr 3574 in
160 :     add (shape (e, []), seen, jxs)
161 :     end
162 :     in
163 :     f (es, ISet.empty, [])
164 :     end
165 :     in
166 :     shape (b, [])
167 :     end
168 : jhr 3566
169 : cchiw 3568 (* tShape: get shape of tensor replacement
170 :     * :int list, sumrange list, ein expression -> mu list
171 : jhr 3566 *)
172 : jhr 3574 fun tShape (index, sx, e) = let
173 :     (* outerAlpha = set of indices supported by original EIN *)
174 :     val outerAlpha = let
175 : cchiw 3978 fun add ([], _, s) = ISet.addList(s, List.map (fn (v, _, _) => v) sx)
176 : jhr 3574 | add (_::r, i, s) = add (r, i+1, ISet.add(s, i))
177 :     in
178 :     add (index, 0, ISet.empty)
179 :     end
180 :     (* getT: filters eShape to create tShape
181 : cchiw 3568 * getT(eshape, accumulator)
182 :     * for every i in eshape if it is in outerAlpha then i::tshape
183 :     *)
184 : jhr 3574 fun getT ([], rest) = List.rev rest
185 :     | getT ((E.C _)::es, rest) = getT(es, rest)
186 :     | getT ((e as E.V v)::rest, es) =
187 :     if ISet.member(outerAlpha, v)
188 :     then getT (rest, e::es)
189 :     else getT (rest, es)
190 :     in
191 :     getT (eShape e, [])
192 :     end
193 : jhr 3566
194 : jhr 3576 (* sizeMapp: creates a map for index_id to dimension*)
195 : cchiw 3568 fun mkSizeMapp (index, sx) = let
196 : jhr 3575 fun idToMapp (mapp, [],_ ) = mapp
197 :     | idToMapp (mapp, ix::es, cnt) = idToMapp (IMap.insert (mapp, cnt, ix), es,cnt+1)
198 :     fun sxToMapp (mapp, []) = mapp
199 : cchiw 3978 | sxToMapp (mapp, (v, _, ub)::es) = sxToMapp (IMap.insert (mapp, v, ub+1), es)
200 : jhr 3575 in
201 :     sxToMapp (idToMapp (IMap.empty, index, 0), sx)
202 :     end
203 : cchiw 3568
204 : jhr 3576 (* mkIndexMapp: maps the index variables in subexpression*)
205 : cchiw 3568 fun mkIndexMapp (index, sx, ashape, tshape) =let
206 :     (* adds index e1 to the mapp E.V e1=> E.V cnt *)
207 : jhr 3575 fun vxToMapp (mapp, [], cnt) = (mapp, cnt)
208 :     | vxToMapp (mapp, (E.V e1)::es, cnt) = vxToMapp (IMap.insert (mapp, e1, cnt), es, cnt+1)
209 :     (* Creates an map for indices in tshape first. *)
210 :     val (mapp, tocounter) = vxToMapp (IMap.empty, tshape, 0)
211 :     (* finds max element in ashape and creates list [0, 1, 2, ...., max] *)
212 :     (* FIXME: with SML/NJ 110.80, we'll be able to use "maxItem" for this *)
213 : jhr 3585 val maxmu = (case ISet.listItems ashape
214 :     of [] => ~1
215 :     | l => List.last l
216 :     (* end case *))
217 : jhr 3576 (* iff index e1 is in ashape add e1 the mapp E.V e1=> E.V cnt *)
218 :     fun intToMapp (mapp, i, cnt) = if (i > maxmu)
219 :     then mapp
220 :     else if IMap.inDomain(mapp, i)
221 :     then intToMapp (mapp, i+1, cnt)
222 :     else if ISet.member(ashape, i)
223 :     then intToMapp (IMap.insert (mapp, i, cnt), i+1, cnt+1)
224 :     else intToMapp (mapp, i+1, cnt)
225 : jhr 3575 (* creates a map for the rest of the indices that may be used in the ein expression *)
226 :     in
227 : jhr 3576 intToMapp (mapp, 0, tocounter)
228 : jhr 3575 end
229 : jhr 3566
230 : cchiw 3568 (* rewriteIndices: rewrites indices in e using mapp *)
231 : jhr 3574 fun rewriteIx (mapp, e) = let
232 :     fun getAlpha alpha = List.map (fn e=> lkupVx (e, mapp)) alpha
233 :     fun getIx ix = lookupId (ix, mapp)
234 :     fun getVx ix = lkupVx (ix, mapp)
235 :     fun getSx sx = lkupSx (sx, mapp)
236 : jhr 3570 fun rewrite b = (case b
237 : jhr 3575 of E.Const _ => b
238 :     | E.ConstR _ => b
239 :     | E.Tensor(id, alpha) => E.Tensor(id, getAlpha alpha)
240 : cchiw 3978 | E.Delta(i, j) => E.Delta(getIx i, getIx j)
241 : jhr 3575 | E.Epsilon(i, j, k) => E.Epsilon(getIx i, getIx j, getIx k)
242 :     | E.Eps2(i, j) => E.Eps2(getIx i, getIx j)
243 :     | E.Field(id, alpha) => E.Field(id, getAlpha alpha)
244 :     | E.Lift e1 => E.Lift(rewrite e1)
245 :     | E.Conv(v, alpha, h, dx) => E.Conv (v, getAlpha alpha, h, getAlpha dx)
246 :     | E.Partial dx => E.Partial (getAlpha dx)
247 :     | E.Apply (e1, e2) => E.Apply(rewrite e1, rewrite e2)
248 :     | E.Probe(E.Conv(v, alpha, h,dx), t) =>
249 :     E.Probe(E.Conv(v, getAlpha alpha, h, getAlpha dx), rewrite t)
250 :     | E.Probe (e1, e2) => E.Probe(rewrite e1, rewrite e2)
251 :     | E.Value e1 => raise Fail "unexpected Value"
252 :     | E.Img _ => raise Fail "unexpected Img"
253 :     | E.Krn _ => raise Fail "unexpected Krn"
254 :     | E.Sum(sx, e1) => E.Sum(getSx sx, rewrite e1)
255 :     | E.Op1(op1, e1) => E.Op1(op1, rewrite e1)
256 :     | E.Op2(op2, e1, e2) => E.Op2(op2, rewrite e1, rewrite e2)
257 :     | E.Opn(opn, es) => E.Opn(opn, List.map rewrite es)
258 : jhr 3570 (* end case *))
259 :     in
260 :     rewrite e
261 :     end
262 : jhr 3566
263 : jhr 3574 (*clean () cleans the indices in an EIN expression*)
264 : cchiw 3568 (*input- e:ein expression
265 :     index: int list for original EIN operator
266 :     sx:sumrange list for outer summation expression, if any exist
267 :     output- tshape:indices for tensor replacment,
268 :     sizes: Tensor type of new EIN operator,
269 :     e': rewritten e
270 :     Generic Example
271 : jhr 3574 x = λT < Σ_sx (e...) ...) >_{index} (arg0)
272 : cchiw 3568 ===>
273 : jhr 3574 arg1 = λT <e'>_{sizes} (arg0),
274 :     x =λ T T' < Σ_sx (T1_{tshape}...) ...) >_{index} (arg0, arg1)
275 : cchiw 3568 *)
276 : jhr 3570 fun clean (e, index, sx) = let
277 : cchiw 3568 (* Get shape of e
278 : jhr 3575 * ashape ISet.set : all the indices mentioned in body
279 : jhr 3570 * tshape (mu list) : shape of tensor replacement
280 :     *)
281 :     val ashape = aShape e
282 :     val tshape = tShape(index, sx, e)
283 : cchiw 3568 (* Create sizeMapp: index_id to dimension index_id is bound to*)
284 : jhr 3570 val sizeMapp = mkSizeMapp (index, sx)
285 : jhr 3574 (* Find size of e by looking up tshape in the sizeMapp.
286 :     * sizes (int list) : TensorType of tensor replacement
287 :     *)
288 : jhr 3575 val sizes = List.map (fn E.V e1 => lookupId (e1, sizeMapp)) tshape
289 : cchiw 3568 (* Create indexMapp: Mapps the index variables e => e'*)
290 : jhr 3570 val indexMapp = mkIndexMapp (index, sx, ashape, tshape)
291 : cchiw 3568 (* Rewrite subexpression: e =>e' *)
292 : jhr 3570 val e' = rewriteIx (indexMapp, e)
293 :     in
294 : jhr 3574 (tshape, sizes, e')
295 : jhr 3570 end
296 : cchiw 3568
297 : jhr 3561 end (* CleanIndex *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0