Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-to-mid/clean-index.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-to-mid/clean-index.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3570 - (download) (annotate)
Mon Jan 11 14:31:58 2016 UTC (4 years, 1 month ago) by jhr
File size: 11310 byte(s)
working on merge
(* clean-index.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *)

structure CleanIndex : sig

    val clean : ? -> ?

  end = struct

    structure E = Ein
    structure IMap = IntRedBlackMap

    fun lkupId (e1, mapp, str) = (case IMap.find (mapp, e1)
	   of SOME l => l
            | _ => raise Fail(str^Int.toString e1)
	  (* end case *))

    fun lkupVx (E.V e1, mapp, str) = E.V (lkupId(e1, mapp, str))
      | lkupVx (E.C e1, mapp, _) = E.C e1

    fun lkupSx ([], mapp, str) = []
      | lkupSx ((E.V e1, ub, lb)::es, mapp, str) = (case IMap.find(mapp, e1)
	   of SOME l => (E.V l, ub, lb) :: lkupSx(es, mapp, str)
            | _ => lkupSx(es, mapp, str)
	  (* end case *))

  (* ashape: ein_exp -> mu list
   * returns list of all indices used in b
   *)
    fun aShape b = let
        fun shape (b, ixs) = (case b
	       of E.Const _ => ixs
		| E.ConstR _ => ixs
		| E.Delta(i, j) => i::j::ixs
		| E.Epsilon(i, j, k) => i::j::k::ixs
		| E.Eps2(i, j) => i::j::ixs
		| E.Field(_, alpha) => alpha @ ixs
		| E.Lift e => shape(e, ixs)
		| E.Conv(_, alpha, _, dx) => alpha @ dx @ ixs
		| E.Partial alpha => alpha @ ixs
		| E.Apply(E.Partial dx, e1) => shape (e1, dx @ ixs)
		| E.Probe(e, _) => shape (e, ixs)
		| E.Value e1 => raise Fail "Error in Ashape"
		| E.Img _ => raise Fail "Error in Ashape"
		| E.Krn _ => raise Fail "Error in Ashape"
		| E.Sum(sx, e) => List.foldr (fn ((v, _, _), ixs) => v::ixs) (shape (e, ixs)) sx
		| E.Op1 (_, e) => shape (e, ixs)
		| E.Op2(_, e1, e2) => shape (e1, shape(e2, ixs))
		| E.Opn(_, es) => List.foldr shape ixs es
	      (* end case *))
        in
	  shape (b, [])
        end

    (* eShape: list of index-ids with potential to be in tshape
    *   of T_α -> eshape = α
    *   |  e1 +.. -> eshape = eShape(e1)
    *   |  e1/..  -> eshape = eShape(e1)
    *   |  e1 *e2 ->
    *       eshape = eShape(e1) and b = eShape(e2).
    *       forall i in b. if i not in eshape then add i to eshape
    *)
    fun eShape b = let
        fun iterList list = let
            fun f ([], rest) = List.rev rest
            | f (E.C _::es, rest) = f(es, rest)
            | f (e1::es, rest) = (case (List.find (fn x => x = e1) rest)
                of NONE => f (es, e1::rest)
                | SOME _ => f (es, rest)
                (* end case *))
            val llist = (List.map (fn e1 => eShape  e1) list)
            in
                foldl f (List.hd(llist)) llist
            end
        in (case b
            of E.B _  => []
            | E.Tensor(_, alpha)       => alpha
            | E.G(E.Delta(i, j))       => [i, j]
            | E.G(E.Epsilon(i, j, k))  => [E.V i, E.V j, E.V k]
            | E.G(E.Eps2(i, j))        => [E.V i, E.V j]
            | E.Field(_, alpha)        => alpha
            | E.Lift e1                => eShape e1
            | E.Conv(_, alpha, _, dx)  => alpha@dx
            | E.Partial alpha          => alpha
            | E.Apply(E.Partial dx, e1)=> (eShape e1)@dx
            | E.Probe (e1, _)          => eShape e1
            | E.Value e1               => raise Fail "raise Failor in Eshape"
            | E.Img _                  => raise Fail "raise Failor in Eshape"
            | E.Krn _                  => raise Fail "raise Failor in Eshape"
            | E.Sum(_ , e1)            => eShape  e1
            | E.Op1 (_, e1)            => eShape e1
            | E.Op2 (_, e1, e2)        => iterList[e1, e2]
            | E.Opn (_, es)            => iterList es
            (* end case *))
        end

    (* tShape: get shape of tensor replacement
    * :int list, sumrange list, ein expression -> mu list
    *)
    fun tShape(index, sx, e) = let
        (*outerAlpha = List of indices supported by original EIN *)
        val outerAlpha = (case index
            of [] => List.map (fn ( v, _, _) =>v) sx
            | _ => (List.tabulate(length index, fn e =>E.V e))@(List.map (fn ( v, _, _) =>v) sx)
            (* end case *))
        (* getT: sorts eShape to create tShape
        * getT(eshape, accumulator)
        * for every i in eshape if it is in outerAlpha then i::tshape
        *)
        fun getT([], rest) = List.rev rest
          | getT((E.C _)::es, rest) = getT(es, rest)
          | getT( (e1 as E.V v)::es, rest) = (case (List.find (fn x =>  x = e1) outerAlpha)
                of SOME _ => getT(es, e1::rest)
                | NONE => getT(es, rest)
                (* end case *))

        (* eShape(mu list): possible shape of tensor replacement *)
        val eshape = eShape e
        in
            getT(eshape, [])
        end

    (* sizeMapp: creates a map for index_id to dimension*)
    fun mkSizeMapp (index, sx) = let
        fun idToMapp (mapp, [],_ ) = mapp
          | idToMapp (mapp, ix::es, cnt) = idToMapp (IMap.insert (mapp, cnt, ix), es,cnt+1)
        fun sxToMapp (mapp, []) = mapp
          | sxToMapp (mapp, (E.V v, _, ub) ::es) = sxToMapp (IMap.insert (mapp, v, ub+1), es)
        in
            sxToMapp (idToMapp (empty, index, 0), sx)
        end

    (* mkIndexMapp: maps the index variables in subexpression*)
    fun mkIndexMapp (index, sx, ashape, tshape) =let
        (* adds index e1 to the mapp E.V e1=> E.V cnt *)
        fun vxToMapp (mapp, [], cnt) = (mapp, cnt)
          | vxToMapp (mapp, (E.V e1)::es, cnt) = vxToMapp (IMap.insert (mapp, e1, cnt), es, cnt+1)
        (*iff index e1 is in ashape add e1 the mapp E.V e1=> E.V cnt *)
        fun intToMapp (mapp, [], _) = mapp
          | intToMapp (mapp, e1::es, cnt) = (case (lookup e1 mapp)
            of SOME _ => intToMapp (mapp, es, cnt)
            | _  => (case (List.find (fn x  =>  x =E.V e1)  ashape)
                of NONE => intToMapp (mapp, es, cnt)
                | SOME _ => intToMapp (IMap.insert (e1, cnt) mapp, es, cnt+1)
                (* end case *))
            (* end case *))
        (*Creates an map for indices in tshape first.*)
        val (mapp, tocounter) = vxToMapp (empty, tshape, 0)
        (*finds max element in ashape and creates list [0, 1, 2, ...., max]*)
        val pp = List.map (fn E.V v =>v | _  => 0)  ashape
        val max =List.foldl (fn (a, b)  => Int.max (a, b)) (length index-1)  pp
        val maxmu = List.tabulate (max+1, (fn e => e))
        (*creates a map for the rest of the indices that may be used in the ein expression *)
        in
            intToMapp (mapp, maxmu, tocounter)
        end

    (* rewriteIndices: rewrites indices in e using mapp *)
    fun rewriteIx (mapp, e)=let
	  fun getAlpha alpha = List.map (fn e=> lkupVx (e, mapp, str)) alpha
	  fun getIx ix = lkupId(ix, mapp, str)
	  fun getVx ix = lkupVx (ix, mapp, str)
	  fun getSx sx = lkupSx(sx, mapp, str)
	  fun rewrite b = (case b
		 of E.B _                            => b
		  | E.Tensor(id,alpha)                => E.Tensor(id, getAlpha alpha)
		  | E.G(E.Delta(i, j) )               => E.G(E.Delta(getVx i, getVx j))
		  | E.G(E.Epsilon(i, j, k))           => E.G(E.Epsilon(getIx  i, getIx  j, getIx k))
		  | E.G(E.Eps2(i, j))                 => E.G(E.Eps2(getIx i, getIx j))
		  | E.Field (id, alpha)               => E.Field(id, getAlpha alpha)
		  | E.Lift e1                         => E.Lift(rewrite e1)
		  | E.Conv(v, alpha, h, dx)           => E.Conv (v, getAlpha alpha, h, getAlpha dx)
		  | E.Partial dx                      => E.Partial (getAlpha dx)
		  | E.Apply (e1, e2)                  => E.Apply(rewrite e1, rewrite e2)
		  | E.Probe(E.Conv(v, alpha, h,dx), t) => E.Probe(E.Conv(v, getAlpha alpha, h, getAlpha dx), rewrite t)
		  | E.Probe (e1, e2)                  => E.Probe(rewrite e1, rewrite e2)
		  | E.Value e1                        => raise Fail "unexpected Value"
		  | E.Img _                           => raise Fail "unexpected Img"
		  | E.Krn _                           => raise Fail "unexpected Krn"
		  | E.Sum(sx, e1)                     => E.Sum(getSx sx, rewrite e1)
		  | E.Op1(E.PowEmb(sx1, n1), e1)      => E.Op1(E.PowEmb(getSx sx1, n1), rewrite e1)
		  | E.Op1(op1, e1)                    => E.Op1(op1, rewrite e1)
		  | E.Op2(op2, e1, e2)                => E.Op2(op2, rewrite e1, rewrite e2)
		  | E.Opn(opn, es)                    => E.Opn(opn, List.map rewrite es)
	      (* end case *))
	  in
            rewrite e
	  end


    (*cleanIndex ()  cleans the indices in an EIN expression*)
    (*input-  e:ein expression
     index: int list for original EIN operator
     sx:sumrange list for outer summation expression, if any exist
     output- tshape:indices for tensor replacment,
     sizes: Tensor type of new EIN operator,
     e': rewritten e
     Generic Example
     x =λT {Σ_sx (e...)  ...)  }_index (arg0)
     ===>
     arg1 =λT {e'}_sizes (arg0),
     x =λTT{Σ_sx (T1_{tshape}...) ...) }index (arg0, arg1)
     *)
    fun clean (e, index, sx) = let
         (* Get shape of e
         * ashape (mu list) : all the indices mentioned in body
         * tshape (mu list) : shape of tensor replacement
	 *)
	   val ashape = aShape e
	   val tshape = tShape(index, sx, e)
         (* Create sizeMapp: index_id to dimension index_id is bound to*)
	   val sizeMapp = mkSizeMapp (index, sx)
         (* Find size of e by looking up tshape in the sizeMapp*)
	   val sizes = List.map (fn E.V e1 => lkupId (e1, sizeMapp, "Could not find Size of"))  tshape
         (* size (int list) : TensorType of tensor replacement*)
         
         (* Create indexMapp: Mapps the index variables e  => e'*)
	   val indexMapp = mkIndexMapp (index, sx, ashape, tshape)
         (* Rewrite subexpression: e  =>e' *)
	   val e' = rewriteIx (indexMapp, e)
	   in
            (tshape, sizes, e')
	   end
 
        (* Example 1
Input to cleanIndex()
    e: Σ_14[0-2]Prod< T20_14*  T21_14,1>, index:[2,3], sx:[E.V(6)[0-2]]
Analyzing
    Get shape of e, getShapes()
        aShape : 14,14,14,1,tShape : 1
    Create sizeMapp: index_id to dimension, mkSizeMapp()
        0 => 2, 1 => 3, 6 => 3
    Find size of e by looking up tshape in the sizeMapp
        sizes=[3]
    Create indexMapp: Map the index variables e => e', mkIndexMap()
        Set map for tshape indices first, vxToMapp()
        E.V(1) => E.V(0)
        Checks indices from E.V 0 to E.V15, intToMapp()
        E.V(14) => E.V(1)
    Rewrite subexpression: e =>e', rewriteIx()
        e =>  Σ_1[0-2]Prod< T20_1*  T21_1,0>
Output: tshape:[E.V(1)],sizes:[3],e': Σ_1[0-2]Prod< T20_1*  T21_1,0>

        b=<Σ_[E.V(6)[0-2]]( Σ_14[0-2]Prod< T20_14*  T21_14,1>)... >2,3 (args)
        ===>
        a=< Σ_1[0-2]Prod< T20_1*  T21_1,0>>_3 (args)
        b'=<T_{E.V(1)}..>2,3 (args,a)

* Example 2
Input to cleanIndex()
    e:Add( T23_6,1+ T24_6,1), index:[2,3], sx:[E.V(6)[0-2]]
Analyzing
    Get shape of e
        aShape : 6,1,6,1,tShape : 6,1
    Create sizeMapp: index_id to dimension
        0 => 2,1 => 3,6 => 3
    Find size of e by looking up tshape in the sizeMapp
        sizes=[3,3]
    Create indexMapp: Map the index variables e => e'
        Set map for tshape indices first
        E.V(6) => E.V(0),E.V(1) => E.V(1)
        Checks indices from E.V 0 to E.V7
    Rewrite subexpression: e =>e'
        e => Add( T23_0,1+ T24_0,1)
Output:
    tshape:[E.V(6),E.V(1)], sizes:[3,3], e':Add( T23_0,1+ T24_0,1)
        *)
  end (* CleanIndex *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0