Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-to-mid/clean-index.sml
 [diderot] / branches / vis15 / src / compiler / high-to-mid / clean-index.sml

# View of /branches/vis15/src/compiler/high-to-mid/clean-index.sml

Mon Jan 11 14:31:58 2016 UTC (4 years, 1 month ago) by jhr
File size: 11310 byte(s)
`working on merge`
```(* clean-index.sml
*
* This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
*
* COPYRIGHT (c) 2016 The University of Chicago
*)

structure CleanIndex : sig

val clean : ? -> ?

end = struct

structure E = Ein
structure IMap = IntRedBlackMap

fun lkupId (e1, mapp, str) = (case IMap.find (mapp, e1)
of SOME l => l
| _ => raise Fail(str^Int.toString e1)
(* end case *))

fun lkupVx (E.V e1, mapp, str) = E.V (lkupId(e1, mapp, str))
| lkupVx (E.C e1, mapp, _) = E.C e1

fun lkupSx ([], mapp, str) = []
| lkupSx ((E.V e1, ub, lb)::es, mapp, str) = (case IMap.find(mapp, e1)
of SOME l => (E.V l, ub, lb) :: lkupSx(es, mapp, str)
| _ => lkupSx(es, mapp, str)
(* end case *))

(* ashape: ein_exp -> mu list
* returns list of all indices used in b
*)
fun aShape b = let
fun shape (b, ixs) = (case b
of E.Const _ => ixs
| E.ConstR _ => ixs
| E.Delta(i, j) => i::j::ixs
| E.Epsilon(i, j, k) => i::j::k::ixs
| E.Eps2(i, j) => i::j::ixs
| E.Field(_, alpha) => alpha @ ixs
| E.Lift e => shape(e, ixs)
| E.Conv(_, alpha, _, dx) => alpha @ dx @ ixs
| E.Partial alpha => alpha @ ixs
| E.Apply(E.Partial dx, e1) => shape (e1, dx @ ixs)
| E.Probe(e, _) => shape (e, ixs)
| E.Value e1 => raise Fail "Error in Ashape"
| E.Img _ => raise Fail "Error in Ashape"
| E.Krn _ => raise Fail "Error in Ashape"
| E.Sum(sx, e) => List.foldr (fn ((v, _, _), ixs) => v::ixs) (shape (e, ixs)) sx
| E.Op1 (_, e) => shape (e, ixs)
| E.Op2(_, e1, e2) => shape (e1, shape(e2, ixs))
| E.Opn(_, es) => List.foldr shape ixs es
(* end case *))
in
shape (b, [])
end

(* eShape: list of index-ids with potential to be in tshape
*   of T_α -> eshape = α
*   |  e1 +.. -> eshape = eShape(e1)
*   |  e1/..  -> eshape = eShape(e1)
*   |  e1 *e2 ->
*       eshape = eShape(e1) and b = eShape(e2).
*       forall i in b. if i not in eshape then add i to eshape
*)
fun eShape b = let
fun iterList list = let
fun f ([], rest) = List.rev rest
| f (E.C _::es, rest) = f(es, rest)
| f (e1::es, rest) = (case (List.find (fn x => x = e1) rest)
of NONE => f (es, e1::rest)
| SOME _ => f (es, rest)
(* end case *))
val llist = (List.map (fn e1 => eShape  e1) list)
in
foldl f (List.hd(llist)) llist
end
in (case b
of E.B _  => []
| E.Tensor(_, alpha)       => alpha
| E.G(E.Delta(i, j))       => [i, j]
| E.G(E.Epsilon(i, j, k))  => [E.V i, E.V j, E.V k]
| E.G(E.Eps2(i, j))        => [E.V i, E.V j]
| E.Field(_, alpha)        => alpha
| E.Lift e1                => eShape e1
| E.Conv(_, alpha, _, dx)  => alpha@dx
| E.Partial alpha          => alpha
| E.Apply(E.Partial dx, e1)=> (eShape e1)@dx
| E.Probe (e1, _)          => eShape e1
| E.Value e1               => raise Fail "raise Failor in Eshape"
| E.Img _                  => raise Fail "raise Failor in Eshape"
| E.Krn _                  => raise Fail "raise Failor in Eshape"
| E.Sum(_ , e1)            => eShape  e1
| E.Op1 (_, e1)            => eShape e1
| E.Op2 (_, e1, e2)        => iterList[e1, e2]
| E.Opn (_, es)            => iterList es
(* end case *))
end

(* tShape: get shape of tensor replacement
* :int list, sumrange list, ein expression -> mu list
*)
fun tShape(index, sx, e) = let
(*outerAlpha = List of indices supported by original EIN *)
val outerAlpha = (case index
of [] => List.map (fn ( v, _, _) =>v) sx
| _ => (List.tabulate(length index, fn e =>E.V e))@(List.map (fn ( v, _, _) =>v) sx)
(* end case *))
(* getT: sorts eShape to create tShape
* getT(eshape, accumulator)
* for every i in eshape if it is in outerAlpha then i::tshape
*)
fun getT([], rest) = List.rev rest
| getT((E.C _)::es, rest) = getT(es, rest)
| getT( (e1 as E.V v)::es, rest) = (case (List.find (fn x =>  x = e1) outerAlpha)
of SOME _ => getT(es, e1::rest)
| NONE => getT(es, rest)
(* end case *))

(* eShape(mu list): possible shape of tensor replacement *)
val eshape = eShape e
in
getT(eshape, [])
end

(* sizeMapp: creates a map for index_id to dimension*)
fun mkSizeMapp (index, sx) = let
fun idToMapp (mapp, [],_ ) = mapp
| idToMapp (mapp, ix::es, cnt) = idToMapp (IMap.insert (mapp, cnt, ix), es,cnt+1)
fun sxToMapp (mapp, []) = mapp
| sxToMapp (mapp, (E.V v, _, ub) ::es) = sxToMapp (IMap.insert (mapp, v, ub+1), es)
in
sxToMapp (idToMapp (empty, index, 0), sx)
end

(* mkIndexMapp: maps the index variables in subexpression*)
fun mkIndexMapp (index, sx, ashape, tshape) =let
(* adds index e1 to the mapp E.V e1=> E.V cnt *)
fun vxToMapp (mapp, [], cnt) = (mapp, cnt)
| vxToMapp (mapp, (E.V e1)::es, cnt) = vxToMapp (IMap.insert (mapp, e1, cnt), es, cnt+1)
(*iff index e1 is in ashape add e1 the mapp E.V e1=> E.V cnt *)
fun intToMapp (mapp, [], _) = mapp
| intToMapp (mapp, e1::es, cnt) = (case (lookup e1 mapp)
of SOME _ => intToMapp (mapp, es, cnt)
| _  => (case (List.find (fn x  =>  x =E.V e1)  ashape)
of NONE => intToMapp (mapp, es, cnt)
| SOME _ => intToMapp (IMap.insert (e1, cnt) mapp, es, cnt+1)
(* end case *))
(* end case *))
(*Creates an map for indices in tshape first.*)
val (mapp, tocounter) = vxToMapp (empty, tshape, 0)
(*finds max element in ashape and creates list [0, 1, 2, ...., max]*)
val pp = List.map (fn E.V v =>v | _  => 0)  ashape
val max =List.foldl (fn (a, b)  => Int.max (a, b)) (length index-1)  pp
val maxmu = List.tabulate (max+1, (fn e => e))
(*creates a map for the rest of the indices that may be used in the ein expression *)
in
intToMapp (mapp, maxmu, tocounter)
end

(* rewriteIndices: rewrites indices in e using mapp *)
fun rewriteIx (mapp, e)=let
fun getAlpha alpha = List.map (fn e=> lkupVx (e, mapp, str)) alpha
fun getIx ix = lkupId(ix, mapp, str)
fun getVx ix = lkupVx (ix, mapp, str)
fun getSx sx = lkupSx(sx, mapp, str)
fun rewrite b = (case b
of E.B _                            => b
| E.Tensor(id,alpha)                => E.Tensor(id, getAlpha alpha)
| E.G(E.Delta(i, j) )               => E.G(E.Delta(getVx i, getVx j))
| E.G(E.Epsilon(i, j, k))           => E.G(E.Epsilon(getIx  i, getIx  j, getIx k))
| E.G(E.Eps2(i, j))                 => E.G(E.Eps2(getIx i, getIx j))
| E.Field (id, alpha)               => E.Field(id, getAlpha alpha)
| E.Lift e1                         => E.Lift(rewrite e1)
| E.Conv(v, alpha, h, dx)           => E.Conv (v, getAlpha alpha, h, getAlpha dx)
| E.Partial dx                      => E.Partial (getAlpha dx)
| E.Apply (e1, e2)                  => E.Apply(rewrite e1, rewrite e2)
| E.Probe(E.Conv(v, alpha, h,dx), t) => E.Probe(E.Conv(v, getAlpha alpha, h, getAlpha dx), rewrite t)
| E.Probe (e1, e2)                  => E.Probe(rewrite e1, rewrite e2)
| E.Value e1                        => raise Fail "unexpected Value"
| E.Img _                           => raise Fail "unexpected Img"
| E.Krn _                           => raise Fail "unexpected Krn"
| E.Sum(sx, e1)                     => E.Sum(getSx sx, rewrite e1)
| E.Op1(E.PowEmb(sx1, n1), e1)      => E.Op1(E.PowEmb(getSx sx1, n1), rewrite e1)
| E.Op1(op1, e1)                    => E.Op1(op1, rewrite e1)
| E.Op2(op2, e1, e2)                => E.Op2(op2, rewrite e1, rewrite e2)
| E.Opn(opn, es)                    => E.Opn(opn, List.map rewrite es)
(* end case *))
in
rewrite e
end

(*cleanIndex ()  cleans the indices in an EIN expression*)
(*input-  e:ein expression
index: int list for original EIN operator
sx:sumrange list for outer summation expression, if any exist
output- tshape:indices for tensor replacment,
sizes: Tensor type of new EIN operator,
e': rewritten e
Generic Example
x =λT {Σ_sx (e...)  ...)  }_index (arg0)
===>
arg1 =λT {e'}_sizes (arg0),
x =λTT{Σ_sx (T1_{tshape}...) ...) }index (arg0, arg1)
*)
fun clean (e, index, sx) = let
(* Get shape of e
* ashape (mu list) : all the indices mentioned in body
* tshape (mu list) : shape of tensor replacement
*)
val ashape = aShape e
val tshape = tShape(index, sx, e)
(* Create sizeMapp: index_id to dimension index_id is bound to*)
val sizeMapp = mkSizeMapp (index, sx)
(* Find size of e by looking up tshape in the sizeMapp*)
val sizes = List.map (fn E.V e1 => lkupId (e1, sizeMapp, "Could not find Size of"))  tshape
(* size (int list) : TensorType of tensor replacement*)

(* Create indexMapp: Mapps the index variables e  => e'*)
val indexMapp = mkIndexMapp (index, sx, ashape, tshape)
(* Rewrite subexpression: e  =>e' *)
val e' = rewriteIx (indexMapp, e)
in
(tshape, sizes, e')
end

(* Example 1
Input to cleanIndex()
e: Σ_14[0-2]Prod< T20_14*  T21_14,1>, index:[2,3], sx:[E.V(6)[0-2]]
Analyzing
Get shape of e, getShapes()
aShape : 14,14,14,1,tShape : 1
Create sizeMapp: index_id to dimension, mkSizeMapp()
0 => 2, 1 => 3, 6 => 3
Find size of e by looking up tshape in the sizeMapp
sizes=[3]
Create indexMapp: Map the index variables e => e', mkIndexMap()
Set map for tshape indices first, vxToMapp()
E.V(1) => E.V(0)
Checks indices from E.V 0 to E.V15, intToMapp()
E.V(14) => E.V(1)
Rewrite subexpression: e =>e', rewriteIx()
e =>  Σ_1[0-2]Prod< T20_1*  T21_1,0>
Output: tshape:[E.V(1)],sizes:[3],e': Σ_1[0-2]Prod< T20_1*  T21_1,0>

b=<Σ_[E.V(6)[0-2]]( Σ_14[0-2]Prod< T20_14*  T21_14,1>)... >2,3 (args)
===>
a=< Σ_1[0-2]Prod< T20_1*  T21_1,0>>_3 (args)
b'=<T_{E.V(1)}..>2,3 (args,a)

* Example 2
Input to cleanIndex()
Analyzing
Get shape of e
aShape : 6,1,6,1,tShape : 6,1
Create sizeMapp: index_id to dimension
0 => 2,1 => 3,6 => 3
Find size of e by looking up tshape in the sizeMapp
sizes=[3,3]
Create indexMapp: Map the index variables e => e'
Set map for tshape indices first
E.V(6) => E.V(0),E.V(1) => E.V(1)
Checks indices from E.V 0 to E.V7
Rewrite subexpression: e =>e'