Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/clean-index.sml
 [diderot] / branches / vis15 / src / compiler / high-to-mid / clean-index.sml

# Diff of /branches/vis15/src/compiler/high-to-mid/clean-index.sml

revision 3567, Mon Jan 11 04:45:02 2016 UTC revision 3568, Mon Jan 11 05:10:30 2016 UTC
# Line 5  Line 5
5   * COPYRIGHT (c) 2016 The University of Chicago   * COPYRIGHT (c) 2016 The University of Chicago
7   *)   *)

(*
cleanIndex.sml cleans the indices in the EIN expression. This process is a bit more complicated because the vast number of possibilities. We need to both know all indices used in the subexpression in order to create a mapp, and the shape of the subexpression.

For each subexpression we look for the\\
ashape(mu list): all the indices mentioned in body
tshape(mu list):shape of tensor replacement
size(int list): TensorType of tensor replacement\\

*)
8  structure CleanIndex : sig  structure CleanIndex : sig
9
10      val clean : ? -> ?      val clean : ? -> ?
# Line 25  Line 14
14      structure E = Ein      structure E = Ein
15      structure IMap = IntRedBlackMap      structure IMap = IntRedBlackMap
16
17
18    (* dictionary to lookup mapp *)    (* dictionary to lookup mapp *)
19      fun lookupIx (e1, mapp, str) = (case IMap.find(mapp, e1)      fun lookupIx (e1, mapp, str) = (case IMap.find(mapp, e1)
20             of SOME l => l             of SOME l => l
21              | _ => raise Fail(str^ Integer.toString e1)              | _ => raise Fail(str^ Integer.toString e1)
22            (* endcase *))            (* endcase *))
23        val empty = IMap.empty
24      (*sizeMapp:->int      fun lookup k d = IMap.find(d, k)
25      * create sizeMapp for bound on mu      fun insert (k, v) d = IMap.insert(d, k, v)
26      *)      fun lkupId(e1,mapp,str) = (case (lookup e1 mapp)
27      fun mkSizeMapp (index, sx) = let          of SOME l => l
28            fun m ([], mapp) = mapp          | _ => raise Fail(str^Int.toString(e1))
| m ((E.V v, _, ub)::es, mapp) = m (es, IMap.insert(mapp, v, ub+1))
| m ( _, _) = err "Non-V-index in sx"
fun f (_, [], mapp) = mapp
| f (counter, ix::es, mapp) = f (counter+1, es, IMap.insert(mapp, counter, ix))
val mapp = f (0, index, IMap.empty)
in
m (sx, mapp)
end

(* mkIndexMapp:int list, sum_id list*mu list*mu list=>dict*int list *mu list
* A map is created to match index-ids in e1 to their new ids.
* First we iterate over indices in $tshape\beta$ and set a counter to 0.
* i.e. [$\beta_0 \longrightarrow 0,\beta_1 \longrightarrow 1,..,\beta_n \longrightarrow n$]$\in$indexMapp
Then from i=0 to i=maxium possible index-id.
if (i $\not \in$indexMapp and i $\in \gamma$) then $i \longrightarrow counter \in indexMapp$
*)
fun mkIndexMapp (index, sx, ashape, tshape) = let
fun f (mapp, [],tocounter) = (mapp, tocounter)
| f (mapp, (E.V e1)::es, tocounter) = let
val dict = IMap.insert(mapp, e1, tocounter)
in
f(dict,es,tocounter+1)
end
fun m (mapp, [], tocounter) = mapp
| m (mapp, e1::es, tocounter) = (case IMap.lookup(mapp, e1)
of SOME _=> m (mapp, es, tocounter)
| _ => (case (List.find (fn x => x = E.V e1) ashape)
of NONE => m (mapp, es, tocounter)
| SOME _ => let
val dict = IMap.insert (mapp, e1, tocounter)
in
m(dict,es,tocounter+1)
end
(* end case *))
29                    (* end case *))                    (* end case *))
30            val (mapp, tocounter) = f (IMap.empty, tshape, 0)      fun lkupVx(E.V e1, mapp, str) = E.V (lkupId(e1, mapp, str))
31            val pp = List.map (fn E.V v => v | _ => 0) ashape        | lkupVx(E.C e1, mapp, _ ) = E.C e1
32            val max = List.foldl (fn (a, b) => Int.max(a, b)) (length index-1) pp      fun lkupSx([],mapp,str)=[]
33            (*finds max element in ashape and creates [0,1,2,....,max]*)        | lkupSx((E.V e1, ub, lb)::es, mapp, str) = (case (lookup e1 mapp)
34            val maxmu = List.tabulate (max+1, fn e => e)              of SOME l => [(E.V l, ub, lb)]@lkupSx(es, mapp, str)
35            val indexMapp = m (mapp, maxmu, tocounter)              |_ => []@lkupSx(es, mapp,str)
in
indexMapp
end

(* rewriteIx:dict*ein_exp ->ein_exp
*  rewrites indices in e using mapp
*)
fun rewriteIx (mapp, e) = let
fun getAlpha alpha = List.map (fn e => lkupIndexV(e, mapp, str)) alpha
fun getIx ix = lkupIndexSingle (ix, mapp, str)
fun getVx ix = lkupIndexV (ix, mapp, str)
fun getSx sx = lkupIndexSx (sx, mapp, str)
fun rewrite b = (case b
of E.Const _ => b
| E.ConstR _ => b
| E.Tensor(id, alpha) => E.Tensor(id, getAlpha alpha)
| E.Delta(i, j) => E.Delta(getVx i, getVx j)
| E.Epsilon(i, j, k) => E.Epsilon(getIx i, getIx j, getIx k)
| E.Eps2(i, j) => E.Eps2(getIx i, getIx j)
| E.Field(id, alpha) => E.Field(id, getAlpha alpha)
| E.Lift e1 => E.Lift(rewrite e1)
| E.Conv(v, alpha, h, dx) => E.Conv(v, getAlpha alpha, h, getAlpha dx)
| E.Partial dx => E.Partial(getAlpha dx)
| E.Apply(e1, e2) => E.Apply(rewrite e1, rewrite e2)
| E.Probe(E.Conv(v, alpha, h, dx), t) =>
E.Probe(E.Conv(v, getAlpha alpha,h, getAlpha dx), rewrite t)
| E.Probe(e1, e2) => E.Probe(rewrite e1, rewrite e2)
| E.Value e1 => raise Fail "unexpected Value"
| E.Img _ => raise Fail "unexpected Img"
| E.Krn _ => raise Fail "unexpected Krn"
| E.Sum(sx ,e1) => E.Sum(getSx sx ,rewrite e1)
| E.Op1(E.PowEmb(sx1, n1), e1) => E.Op1(E.PowEmb(getSx sx1, n1), rewrite e1)
| E.Op1(op1, e1) => E.Op1(op1, rewrite e1)
| E.Op2(op2, e1, e2) => E.Op2(op2, rewrite e1, rewrite e2)
| E.Opn(opn, es) => E.Opn(opn, List.map rewrite es)
36                  (* end case *))                  (* end case *))
in
rewrite e
end
37
38      (* ashape:ein_exp -> mu list      (* ashape:ein_exp -> mu list
39      * get all indices used in expression      * returns list of all indices used in b
*  Σ_3 e. 3 is in aShape  even when 3 doesn't appear in e.
40      *)      *)
41      fun aShape b = let      fun aShape b = let
42            fun shape (b, ixs) = (case b            fun shape (b, ixs) = (case b
# Line 146  Line 63
63              shape (b, [])              shape (b, [])
64            end            end
65
66      (*      (* eShape: list of index-ids with potential to be in tshape
67      * potential tshape for e      *   of T_α -> eshape = α
68      * list of index-ids with potential to be in $\beta$. $\rho \in \gamma$.      *   |  e1 +.. -> eshape = eShape(e1)
69      * Right now the eshape/$\rho$ of einexpression e. case e      *   |  e1/..  -> eshape = eShape(e1)
70      *   of $A_\alpha \longrightarrow \rho=\alpha$      *   |  e1 *e2 ->
71      *   |  $e1 +.. \longrightarrow \rho(e1)$.      *       eshape = eShape(e1) and b = eShape(e2).
72      *   |  $e1 /.. \longrightarrow \rho(e1)$.      *       forall i in b. if i not in eshape then add i to eshape
*   |  $e1 *e2 \longrightarrow$.
*       $\rho=\rho(e1)$ and $\beta=\rho(e2). * \forall i \in \beta$  if $(i \not \in \rho)$ then $i \in \rho$
73      *)      *)
74      fun eShape b=let      fun eShape b=let
val _ =testp["\n eshape",P.printbody b]
75          fun iterList list=let          fun iterList list=let
76              val llist=(List.map (fn e1=>eShape  e1) list)              fun f ([], rest) = List.rev rest
fun f ([],rest)=rest
77                | f(E.C _::es,rest)=f(es,rest)                | f(E.C _::es,rest)=f(es,rest)
78                | f(e1::es,rest)=(case (List.find (fn x =>  x = e1) rest)                | f(e1::es,rest)=(case (List.find (fn x =>  x = e1) rest)
79                      of NONE=> f(es,rest@[e1])                  of NONE => f (es, e1::rest)
80                      | SOME _ =>f(es,rest)                      | SOME _ =>f(es,rest)
81                      (*end case*))                      (*end case*))
82                val llist = (List.map (fn e1 => eShape  e1) list)
83              in              in
84                  foldl f (List.hd(llist)) llist                  foldl f (List.hd(llist)) llist
85              end              end
# Line 182  Line 95
95              | E.Partial alpha           => alpha              | E.Partial alpha           => alpha
96              | E.Apply(E.Partial dx,e1)  => (eShape e1)@dx              | E.Apply(E.Partial dx,e1)  => (eShape e1)@dx
97              | E.Probe (e1,_)            => eShape e1              | E.Probe (e1,_)            => eShape e1
98              | E.Value e1                => err "Error in Eshape"              | E.Value e1               => raise Fail "raise Failor in Eshape"
99              | E.Img _                   => err "Error in Eshape"              | E.Img _                  => raise Fail "raise Failor in Eshape"
100              | E.Krn _                   => err "Error in Eshape"              | E.Krn _                  => raise Fail "raise Failor in Eshape"
101              | E.Sum(_ ,e1)              => eShape  e1              | E.Sum(_ ,e1)              => eShape  e1
102              | E.Op1 (_,e1)              => eShape e1              | E.Op1 (_,e1)              => eShape e1
103              | E.Op2 (_,e1,e2)           => iterList[e1,e2]              | E.Op2 (_,e1,e2)           => iterList[e1,e2]
# Line 192  Line 105
105              (*end case*))              (*end case*))
106          end          end
107
108      (* tShape: ->mu list      (* tShape: get shape of tensor replacement
109      * get shape of tensor replacement      * :int list, sumrange list, ein expression -> mu list
110      * outerAlpha= List of indices supported by original EIN. Created with index and sx.      *)
111      * Simply,      fun tShape(index, sx, e) = let
112      * For every index i in eShape:          (*outerAlpha = List of indices supported by original EIN *)
if i in outerAlpha then it is in tShape
otherwise it must be an index supported by the subexpression alone
and not in tShape.
*)
fun tShape(index,sx,e,eshape)=let
val outerAlpha=List.map (fn ( v,_,_)=>v) sx
val n'=(length index)
113          val outerAlpha=(case index          val outerAlpha=(case index
114              of []=>outerAlpha              of [] => List.map (fn ( v, _, _) =>v) sx
115              | _=>(List.tabulate(n',fn e=>E.V e))@outerAlpha              | _ => (List.tabulate(length index, fn e =>E.V e))@(List.map (fn ( v, _, _) =>v) sx)
116              (*end case*))              (*end case*))
117            (* getT: sorts eShape to create tShape
118          val removedup=false          * getT(eshape, accumulator)
119              fun getT([],rest) =rest          * for every i in eshape if it is in outerAlpha then i::tshape
120            *)
121            fun getT([], rest) = List.rev rest
122              | getT((E.C _)::es,rest)= getT(es,rest)              | getT((E.C _)::es,rest)= getT(es,rest)
123              | getT( (e1 as E.V v)::es,rest)= (case (List.find (fn x =>  x = e1) outerAlpha)              | getT( (e1 as E.V v)::es,rest)= (case (List.find (fn x =>  x = e1) outerAlpha)
124                  of SOME _ =>                  of SOME _ => getT(es, e1::rest)
if (removedup)
then (case (List.find (fn x =>  x = e1) rest)
of NONE=> getT(es,rest@[e1]) (*remove duplicates*)
| SOME _ => getT(es,rest)
(*end case*))
else getT(es,rest@[e1])
125                  | NONE =>getT(es,rest)                  | NONE =>getT(es,rest)
126                  (*end case*))                  (*end case*))
127
128            (* eShape(mu list): possible shape of tensor replacement *)
129            val eshape = eShape e
130          in          in
131              getT(eshape,[])              getT(eshape,[])
132          end          end
133
134      fun getShapes (e, index, sx)= let      (* sizeMapp: creates a map for index_id to dimension*)
135            val ashape = aShape e      fun mkSizeMapp (index, sx) = let
136            val eshape = eShape e          fun idToMapp (mapp, [],_ ) = mapp
137            val tshape = tShape(index, sx, e, eshape)            | idToMapp (mapp, ix::es, cnt) = idToMapp (insert (cnt, ix)  mapp, es,cnt+1)
138            fun sxToMapp (mapp, []) = mapp
139              | sxToMapp (mapp, (E.V v, _, ub) ::es) = sxToMapp (insert (v, ub+1)  mapp, es)
140            in            in
141              (ashape, tshape)              sxToMapp (idToMapp (empty, index, 0), sx)
142            end            end
143
144      (* cleanIndex:ein_exp*int list *sum_id ->mu list *int list*ein_exp      (* mkIndexMapp: maps the index variables in subexpression*)
145      * cleans index in body      fun mkIndexMapp (index, sx, ashape, tshape) =let
146      * returns shape of replacement in index variable list, size, and rewritten body          (* adds index e1 to the mapp E.V e1=> E.V cnt *)
147            fun vxToMapp (mapp, [], cnt) = (mapp, cnt)
148              | vxToMapp (mapp, (E.V e1)::es, cnt) = vxToMapp ((insert (e1, cnt)  mapp), es, cnt+1)
149            (*iff index e1 is in ashape add e1 the mapp E.V e1=> E.V cnt *)
150            fun intToMapp (mapp, [], _) = mapp
151              | intToMapp (mapp, e1::es, cnt) = (case (lookup e1 mapp)
152                of SOME _ => intToMapp (mapp, es, cnt)
153                | _  => (case (List.find (fn x  =>  x =E.V e1)  ashape)
154                    of NONE => intToMapp (mapp, es, cnt)
155                    | SOME _ => intToMapp ((insert (e1, cnt) mapp), es, cnt+1)
156                    (*end case*))
157                (*end case*))
158            (*Creates an map for indices in tshape first.*)
159            val (mapp, tocounter) = vxToMapp (empty, tshape, 0)
160            (*finds max element in ashape and creates list [0, 1, 2, ...., max]*)
161            val pp = List.map (fn E.V v =>v | _  => 0)  ashape
162            val max =List.foldl (fn (a, b)  => Int.max (a, b)) (length index-1)  pp
163            val maxmu = List.tabulate (max+1, (fn e => e))
164            (*creates a map for the rest of the indices that may be used in the ein expression *)
165            in
166                intToMapp (mapp, maxmu, tocounter)
167            end
168
169        (* rewriteIndices: rewrites indices in e using mapp *)
170        fun rewriteIx(mapp, e)=let
171            val str="Error indexMapp from expression:"^P.printbody e^"Index"
172            fun getAlpha alpha = List.map (fn e=> lkupVx (e, mapp, str)) alpha
173            fun getIx ix = lkupId(ix, mapp, str)
174            fun getVx ix = lkupVx (ix, mapp, str)
175            fun getSx sx = lkupSx(sx, mapp, str)
176            fun rewriteExp b=(case b
177                of E.B _                            => b
178                | E.Tensor(id,alpha)                => E.Tensor(id, getAlpha alpha)
179                | E.G(E.Delta(i, j) )               => E.G(E.Delta(getVx i, getVx j))
180                | E.G(E.Epsilon(i, j, k))           => E.G(E.Epsilon(getIx  i, getIx  j, getIx k))
181                | E.G(E.Eps2(i, j))                 => E.G(E.Eps2(getIx i, getIx j))
182                | E.Field (id, alpha)               => E.Field(id, getAlpha alpha)
183                | E.Lift e1                         => E.Lift(rewriteExp e1)
184                | E.Conv(v, alpha, h, dx)           => E.Conv (v, getAlpha alpha, h, getAlpha dx)
185                | E.Partial dx                      => E.Partial (getAlpha dx)
186                | E.Apply (e1, e2)                  => E.Apply(rewriteExp e1, rewriteExp e2)
187                | E.Probe(E.Conv(v, alpha, h,dx), t)  => E.Probe(E.Conv (v, getAlpha alpha,h, getAlpha dx), rewriteExp t)
188                | E.Probe (e1, e2)                  => E.Probe(rewriteExp e1, rewriteExp e2)
189                | E.Value e1                        => raise Fail"Should not be here"
190                | E.Img _                           => raise Fail "should not be here"
191                | E.Krn _                           => raise Fail"Should not be here"
192                | E.Sum(sx, e1)                     => E.Sum(getSx sx, rewriteExp e1)
193                | E.Op1(E.PowEmb(sx1, n1), e1)      => E.Op1(E.PowEmb(getSx sx1, n1), rewriteExp e1)
194                | E.Op1(op1, e1)                    => E.Op1(op1, rewriteExp e1)
195                | E.Op2(op2, e1, e2)                => E.Op2(op2, rewriteExp e1, rewriteExp e2)
196                | E.Opn(opn, es)                    => E.Opn(opn, List.map rewriteExp es)
197                (*end case*))
198            in
199                rewriteExp e
200            end
201
202
203        (*cleanIndex ()  cleans the indices in an EIN expression*)
204        (*input-  e:ein expression
205         index: int list for original EIN operator
206         sx:sumrange list for outer summation expression, if any exist
207         output- tshape:indices for tensor replacment,
208         sizes: Tensor type of new EIN operator,
209         e': rewritten e
210         Generic Example
211         x =λT {Σ_sx (e...)  ...)  }_index (arg0)
212         ===>
213         arg1 =λT {e'}_sizes (arg0),
214         x =λTT{Σ_sx (T1_{tshape}...) ...) }index (arg0, arg1)
215      *)      *)
216      fun clean (e, index, sx) = let      fun cleanIndex (e, index, sx) = let
217            val (ashape, tshape) = getShapes (e, index, sx)           (* Get shape of e
218             * ashape (mu list) : all the indices mentioned in body
219             * tshape (mu list) : shape of tensor replacement*)
220             val ashape = aShape e
221             val tshape = tShape(index, sx, e)
222
223             (* Create sizeMapp: index_id to dimension index_id is bound to*)
224            val sizeMapp = mkSizeMapp (index, sx)            val sizeMapp = mkSizeMapp (index, sx)
225            val sizes = List.map (fn E.V e1 => lookupIx(e1, sizeMapp, "Could not find Size of")) tshape           (*     Find size of e by looking up tshape in the sizeMapp*)
226             val sizes =List.map (fn E.V e1 => lkupId (e1, sizeMapp, "Could not find Size of"))  tshape
227             (* size (int list) : TensorType of tensor replacement*)
228
229             (* Create indexMapp: Mapps the index variables e  => e'*)
230            val indexMapp = mkIndexMapp (index, sx, ashape, tshape)            val indexMapp = mkIndexMapp (index, sx, ashape, tshape)
231            val body = rewriteIx (indexMapp, e)           (* Rewrite subexpression: e  =>e' *)
232             val e' = rewriteIx (indexMapp, e)
233            in            in
234              (tshape,sizes,body)              (tshape, sizes, e')
235            end            end
236
237            (* Example 1
238    Input to cleanIndex()
239        e: Σ_14[0-2]Prod< T20_14*  T21_14,1>, index:[2,3], sx:[E.V(6)[0-2]]
240    Analyzing
241        Get shape of e, getShapes()
242            aShape : 14,14,14,1,tShape : 1
243        Create sizeMapp: index_id to dimension, mkSizeMapp()
244            0 => 2, 1 => 3, 6 => 3
245        Find size of e by looking up tshape in the sizeMapp
246            sizes=[3]
247        Create indexMapp: Map the index variables e => e', mkIndexMap()
248            Set map for tshape indices first, vxToMapp()
249            E.V(1) => E.V(0)
250            Checks indices from E.V 0 to E.V15, intToMapp()
251            E.V(14) => E.V(1)
252        Rewrite subexpression: e =>e', rewriteIx()
253            e =>  Σ_1[0-2]Prod< T20_1*  T21_1,0>
254    Output: tshape:[E.V(1)],sizes:[3],e': Σ_1[0-2]Prod< T20_1*  T21_1,0>
255
256            b=<Σ_[E.V(6)[0-2]]( Σ_14[0-2]Prod< T20_14*  T21_14,1>)... >2,3 (args)
257            ===>
258            a=< Σ_1[0-2]Prod< T20_1*  T21_1,0>>_3 (args)
259            b'=<T_{E.V(1)}..>2,3 (args,a)
260
261    * Example 2
262    Input to cleanIndex()
263        e:Add( T23_6,1+ T24_6,1), index:[2,3], sx:[E.V(6)[0-2]]
264    Analyzing
265        Get shape of e
266            aShape : 6,1,6,1,tShape : 6,1
267        Create sizeMapp: index_id to dimension
268            0 => 2,1 => 3,6 => 3
269        Find size of e by looking up tshape in the sizeMapp
270            sizes=[3,3]
271        Create indexMapp: Map the index variables e => e'
272            Set map for tshape indices first
273            E.V(6) => E.V(0),E.V(1) => E.V(1)
274            Checks indices from E.V 0 to E.V7
275        Rewrite subexpression: e =>e'
276            e => Add( T23_0,1+ T24_0,1)
277    Output:
278        tshape:[E.V(6),E.V(1)], sizes:[3,3], e':Add( T23_0,1+ T24_0,1)
279            *)
280    end (* CleanIndex *)    end (* CleanIndex *)

Legend:
 Removed from v.3567 changed lines Added in v.3568