Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2576 - (download) (annotate)
Wed Apr 2 04:36:30 2014 UTC (5 years, 6 months ago) by cchiw
File size: 10287 byte(s)
Added field product
(* Shift Functions cleans up Params, and shifts down indices*)
structure shiftHtM = struct
    local
    structure E = Ein
    structure P=Printer

    in



fun insert (key, value) d =fn s =>
    if s = key then SOME value
    else d s

fun lookup k d = d k
val empty =fn key =>NONE



fun flat xs = List.foldr op@ [] xs

(*remap the tensor ids*)
fun cleanParams(body, params,args)=let
    (*First step build a list of occurances*)
    fun build(body,occur)=
        (case body
            of E.Tensor(id,ix)=> insert(id, 1) occur
            | E.Sum(sx, e)=> build(e, occur)
            | E.Neg e=> build(e,occur)
            | E.Add e=> let
             
                fun add([],dict)=dict
                | add(e1::es,dict)=let
                    val dict'=build(e1, dict)
                    in add(es, dict') end
                in add (e, occur) end
            | E.Sub(e1,e2) =>let
                val dict'=build(e1, occur)
                in build(e2, dict') end 
            | E.Div(e1,e2) =>let
                val dict'=build(e1, occur)
                in build(e2, dict') end
            | E.Prod e => let
                fun add([],dict)=dict
                | add(e1::es,dict)=let
                    val dict'=build(e1, dict)
                    in add(es, dict') end
                in add (e, occur) end
            |  E.Img(id,_,pos) => let
                fun add([],dict)=dict
                | add(e1::es,dict)=let
                    val dict'=build(e1, dict)
                    in add(es, dict') end
                val d=insert(id, 1) occur
                in add (pos, d) end
            | E.Krn(id,_,pos) => let
                    val d=insert(id, 1) occur
                    in build(pos,d) end 
            | E.Conv(id,_,h,_) => let
                val d= insert(id, 1) occur
                in insert(h, 1) d end
            | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
            | _ => occur
        (*end case*))
    
    val occur=build(body, empty)

    (*remove params, args that are not used *)
    fun removeP(_,_,newbie,lftp, [], lfta,_)=(newbie, lftp, lfta)
        | removeP(_,_,newbie,lftp, _, lfta,[])=(newbie, lftp, lfta)
        | removeP(pos,sumcount,newbie,lftp, p::pp, lfta,a::aa)=let
        val c=lookup pos occur 
        in case c
            of NONE  => removeP(pos+1, sumcount,newbie@[0], lftp, pp,lfta,aa)
            | SOME _=> removeP(pos+1,sumcount+1, newbie@[sumcount], lftp@[p], pp,lfta@[a],aa)
        end 
  
    val (newbie, params',args')=removeP(0,0,[],[],params,[],args)

    (*remap the tensor ids*)
    fun remap body=(case body
        of E.Tensor(id, ix)=> let val g=List.nth( newbie,id) in E.Tensor(g, ix) end
        | E.Neg e=> E.Neg(remap e)
        | E.Sum(sx, e)=> E.Sum(sx, remap e)
        | E.Add e=> E.Add (List.map remap e)
        | E.Prod e=> E.Prod (List.map remap e)
        | E.Sub (e1,e2)=> E.Sub(remap e1, remap e2)
        | E.Div(e1,e2)=> E.Div(remap e1, remap e2)
        | E.Img(id,alpha,pos)=>let val g=List.nth( newbie,id)
            in E.Img(g,alpha,(List.map remap pos)) end
        | E.Krn(id,delta,pos)=>let val g=List.nth( newbie,id)
            in E.Krn(g,delta,(remap pos)) end
        | E.Conv(id,alpha,h,pos)=>let
            val id'=List.nth( newbie,id)
            val h'=List.nth( newbie,h)
            in E.Conv(id',alpha,h',pos) end
        | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
        | _=>body
        (*end case*))
   
    val body'=remap body

    in (params',body',args')
    end



(*Remaps all the indices*)

fun cleanIndex(e, intialn,index)=let
    (*Each element in the list is unique*)
   val h=print "IN SHIFT"

    fun uniq list1 =let
        fun m([],l)=l
            | m(e1::es,l)= (case e1
                of E.V v=> let val a=List.find (fn x => x = e1) l
                    in (case a of NONE=> m(es,l@[e1]) | _=> m(es,l)) end
                |_ => m(es,l))

        in m(list1,[])
        end

    fun filterIndex(E.V v)= if(v>intialn) then [] else [E.V v]


    (*find all indices *)
    fun findOuterIndex body=(case body
        of  E.Tensor(id,ix)=> ix
            | E.Const _=> []
            | E.Add e=> (*findOuterIndex(e1,n)*) let 
                val e'=List.map (fn e1=>findOuterIndex e1) e
                in  uniq(flat e')
                end
            | E.Sub(e1,e2)=> (*findOuterIndex(e1,n)*)let
                val e'=List.map (fn e1=>findOuterIndex e1) [e1,e2]
                in  uniq(flat e')
                end
            | E.Div(e1,e2)=> (*findOuterIndex(e1,n)*)let
                val e'=List.map (fn e1=>findOuterIndex e1) [e1,e2]
                in  uniq(flat e')
                end
            | E.Value e1=> [E.V e1]
            | E.Sum(sx,e)=> (print "in summ";findOuterIndex e)
            | E.Prod e=> let
                val e'=List.map (fn e1=>findOuterIndex e1) e
                in  uniq(flat e')
                end
            | E.Delta(i,j)=>[i,j]
            | E.Epsilon(i,j,k)=>[E.V i, E.V j, E.V k]
            | E.Neg e=> findOuterIndex e
            | E.Img(v,alpha,pos)=>let
                val e'=List.map (fn e1=>findOuterIndex e1) pos
                    in  uniq(flat ([alpha]@e'))
                    end

            | E.Krn(v,dels,pos)=> (print "krn";(List.map (fn(e1,e2)=> e2) dels) @ (findOuterIndex pos))
            | E.Conv(_,alpha,_,dx)=> alpha@dx
            | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
            | _=> []
            (*end case*))

    

    val ix=findOuterIndex e


(*
    val g=print(String.concat["\n\n --",P.printbody(e),"length of binding-", Int.toString(length(index))," Outer", Int.toString(length(ix)),"\n"])

    fun q(E.V p,E.V n)=print(String.concat["\n", Int.toString(p),"===>>",Int.toString(n),"\n"])
*)
  
    (*Mapps just outer indices *)

    (*Various ways to mapp summation indices
        -one way is to subtract adjustments, but that assumes the summaiton indices are in order*)

(*
    fun g([],index',_,c,mapp,outer)=(index',c,mapp,outer)
        | g(e::es,index',n,c, mapp,outer)= let
            val b=List.find (fn(E.V v)=>v=n) ix
                in case b
                    of NONE=>(g(es,index',n+1, c, mapp,outer))
                    |_=> let val mapp'=insert(E.V n, E.V c) mapp
                    (*val gg=print "Found, and inserting"*)
                    in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)]) end
                end


    val (index',c,mapp,outer)=g(index,[],0,0,empty,[])
    val adjustment= intialn-c
    fun addIndextoMapp([],mapp)=mapp
        | addIndextoMapp((s,_,_)::es,mapp)= let
            val E.V p=s
            val n'=(p-adjustment)
            val m=insert(s, E.V n') mapp
            in addIndextoMapp(es, m) end
        *)



        (*Get Max*)
        fun getMax max [] = max
        | getMax max (E.V n::ns)= getMax  (if n>max then n else max) ns
        val max=getMax 0 ix


        fun g(_,index',_,c,mapp,outer,0)=(index',c,mapp,outer)
        | g([],index',n,c, mapp,outer,maxx)= let
            val b=List.find (fn(E.V v)=>v=n) ix
            in case b
                of NONE=>g([],index',n+1, c, mapp,outer,maxx-1)
                |_=> let val mapp'=insert(E.V n, E.V c) mapp
                    in g([], index', n+1, c+1, mapp',outer,maxx-1) end
                end
        | g(e::es,index',n,c, mapp,outer,maxx)= let
            val b=List.find (fn(E.V v)=>v=n) ix
            in case b
                of NONE=>g(es,index',n+1, c, mapp,outer,maxx-1)
                |_=> let val mapp'=insert(E.V n, E.V c) mapp
                   
                    in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)],maxx-1) end
            end
    

        val mapp'=insert(E.V 0, E.V 0) empty
        
        val (index',c,mapp,outer)=g(index,[],0,0,mapp',[],max+1)
         

    fun rewriteIndex(e, smapp) =(case e
        of E.V v =>let val l=lookup e smapp
                in case l of NONE=> raise Fail("error Could not find :"^Int.toString(v))
            | SOME s=> s end
        | E.C _=> e
        (*end case*))

    fun singleIndex(e,smapp)=let
        val l=lookup (E.V e) smapp
       (* val g=print(String.concat["\n SingleIndex:", Int.toString(e)])*)
        in case l
            of NONE=> (raise Fail" error could not find index" )
            | SOME(E.V s)=> s
        end 

    fun rewrite (body,smapp)=(case body
        of  E.Tensor(id,ix)=> E.Tensor(id, (List.map (fn e=>rewriteIndex(e, smapp)) ix))
        | E.Epsilon(i,j,k)=>E.Epsilon(singleIndex(i, smapp),singleIndex(j, smapp),singleIndex(k, smapp))
        | E.Value i=> (E.Value(singleIndex (i,smapp)))
        | E.Delta(i,j)=> E.Delta(rewriteIndex(i,smapp), rewriteIndex(j,smapp))
        | E.Add e=> E.Add(List.map (fn(e1)=>rewrite(e1,smapp)) e)
        | E.Sub(e1,e2)=>  E.Sub(rewrite(e1,smapp),rewrite(e2,smapp))
        | E.Div(e1,e2)=>  E.Div(rewrite(e1,smapp),rewrite(e2,smapp))
        | E.Sum(sx,e)=> (*let
            val mm=addIndextoMapp(sx,smapp)
            val m=E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,mm),lb,ub)) sx),rewrite(e,mm))
            in m end

        *)(print "pre sum";E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,smapp),lb,ub)) sx),rewrite(e,smapp)))
        | E.Prod e=> E.Prod(List.map (fn(e1)=>rewrite(e1,smapp)) e)
        | E.Neg e=> E.Neg(rewrite(e, smapp))
        |  E.Krn (h,dx, pos)=>
            E.Krn(h,(List.map (fn (e1,e2)=>(e1,rewriteIndex(e2, smapp))) dx), rewrite(pos,smapp))
        |  E.Img (v,alpha, pos)=>
            E.Img(v,(List.map (fn e=>rewriteIndex(e, smapp)) alpha),
                (List.map (fn e=>rewrite(e, smapp)) pos))
        | E.Conv(v,alpha,h, dx)=> (print "conv";E.Conv(v, (List.map (fn e=>rewriteIndex(e, smapp)) alpha),h,(List.map (fn e=>rewriteIndex(e, smapp)) dx)))
        | E.Probe(e1,e2)=>raise Fail "Probe- Should have been expanded"
        | _=> body
        (*end case*))
 
    val e'=rewrite(e, mapp)

    in
        (outer,index',e')
    end


fun clean(params, index, body, args)=let
    val (p',body',args')= cleanParams(body, params, args)
    val (_,i',b')=cleanIndex(body', length index, index)
    
    (*val hh=print(String.concat["\n ~~~\n",P.printbody(body),"===>\n",P.printbody(b'),"\n ~~~\n"])*)
    in  (p',i',b',args')
    end

end (* local *)

end 

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0