Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/gen-ein.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/gen-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2611 - (download) (annotate)
Mon May 5 21:21:12 2014 UTC (5 years, 2 months ago) by cchiw
File size: 17580 byte(s)
InnerProduct, DoubleDot:
(*hashs Ein Function after substitution*)
structure genEin = struct
    local
    structure E = Ein
    structure gHelper=gHelper
    structure genKrn=genKrn

    

structure DstIL = LowIL
structure DstTy = LowILTypes
structure DstOp = LowOps
structure Var = LowIL.Var
    in


fun insert (key, value) d =fn s =>
if s = key then SOME value
else d s

fun lookup k d = d k
val empty =fn key =>NONE

(*Iterate over the outside index*)
(*nextfn is the next function*)
(*m are arguements*)
fun prodIter(origIndex,index,nextfn,args)=let

    val index'=List.map (fn (e)=>(e-1)) index
    fun M(mapp,[],rest,code,shape,n)=let
        val (vF,code')=nextfn(mapp,args)
        in (vF, code'@code)
        end
    | M(a,[0], rest, code,shape,n)=let
        (*      val mapp=a@[0]*)
        val mapp =insert(n, 0) a
        val (vF,code')=nextfn(mapp,args)
        val(vE,E)=gHelper.aaV(DstOp.cons(DstTy.TensorTy shape,0),[vF]@rest,"Cons",DstTy.TensorTy(shape))
        in (vE, code'@code@E)
        end
    | M(a,[c],rest,code,shape,n)=let
        (*val mapp=a@[c]*)
        val mapp =insert(n, c) a
        val (vE,E)=nextfn(mapp,args)
      in  M(a, [c-1], [vE]@rest,E@code,shape,n) end
    | M (a,b::c,rest,ccode,s::shape,n)=let
        val n'=n+1
        fun S(0, rest,code)=let
            val mapp =insert(n, 0) a
            val (v',code')=M(mapp,c,[],[],shape,n')
            val(vA,A)=gHelper.aaV(DstOp.cons(DstTy.TensorTy (s::shape),0),[v']@rest,"Cons",DstTy.TensorTy(s::shape))
            in (vA, code'@code@A) end
        | S(i, rest, code)= let
            val mapp =insert(n, i) a
            val (v',code')=M(mapp,c,[],[],shape,n')
            in S(i-1,[v']@rest,code'@code) end
        val (vA,code')=S(b, [],[])
        in (vA,code'@ccode) end
    val (rest',code')= M(empty,index',[],[],origIndex,0)
    in (rest',code')
    end

(*
fun createDic(n,[],dict)=dict
| createDic(n,a::ap, dict)=let
    val d'=insert(n, a) dict
    in createDic(n+1,ap, d')
    end
*)
fun find(v, mapp)=let
    val a=lookup v mapp
    in (case a of NONE=> raise Fail "Outside Bound"
    |SOME s => s)
    end
        
(* general expressions*)
fun generalfn(ap,(body,_,origargs, args))= let
    
    val  jx= ref ap
    fun getMapp _ = let
        val ref mappC=jx
        in mappC end
       fun gen body=(case body
        of  E.Field _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Partial _ =>raise Fail(concat["Invalid FieldPartial here "]   )
        | E.Apply _ =>raise Fail(concat["Invalid FieldApply here "]   )
        | E.Probe _ =>raise Fail(concat["Invalid FieldProbe here "]   )
        | E.Conv _ =>raise Fail(concat["Invalid FieldConv here "]   )
        | E.Krn _ =>raise Fail(concat["Invalid FieldKrn here "]   )
        | E.Img _=> raise Fail(concat["Invalid FieldImg here "]   )
        | E.Value v =>let
              val n=find(v ,getMapp(1))
              in gHelper.mkC(n) end

        | E.Const c=> gHelper.mkC(c)
        | E.Tensor(id,[])=>  gHelper.mkSca(empty,(id,[],args))
| E.Tensor(id,ix)=>  (gHelper.mkSca(getMapp(1),(id,ix,args)))

        | E.Delta(i,j)=> gHelper.evalDelta2(i,j,getMapp(1))
        | E.Epsilon(i,j,k)=> let
            val n=gHelper.evalEps(i,j,k,getMapp(1))
            in gHelper.aaV(DstOp.C(n),[],"Const",DstTy.TensorTy([])) end
        | E.Neg e => let
            val (vA,A)=gen e
            val s=gHelper.skeleton A
            in (case s
                of 0 => (vA,A)
                | ~1 => gHelper.mkC 1
                | 1 =>  gHelper.mkC ~1
                |  _=> let
                    val (vB,B)=gHelper.aaV(DstOp.C (~1),[],"Const",DstTy.TensorTy([]))
                    val (vD,D)=gHelper.aaV(DstOp.prodSca,[vB,vA],"prodSca",DstTy.TensorTy([]))
                    in (vD,A@B@D) end
                (*end case*))
            end

        | E.Add e=> let
            (*check0 function removes 0 from list *)
            fun checkO ([],[],[])=let val (vA,A)=gHelper.aaV(DstOp.C(1),[],"Const",DstTy.TensorTy([])) in ([vA],A) end
                | checkO(ids,code,[])=(ids,code)
                | checkO(ids,code, e1::es)=let
                    val (a,b)=gen e1
                    val s=gHelper.skeleton b
                    in (case s
                        of 0 => checkO(ids,code,es)
                        |  _ => checkO(ids@[a],code@b,es)
                        (*end case*))
                    end
            val (ids,code)=checkO([],[],e)
            in  (case ids
                of [id1]=> (id1,code)
                | _=>let  val (vB,B)=gHelper.mkMultiple(ids,DstOp.addSca,DstTy.TensorTy([]))
                    in (vB,code@B) end
                (*end case*))
            end

        | E.Sub (e1,e2)=>let
            val (vA,A)=gen e1
            val (vB,B)=gen e2
            val sA=gHelper.skeleton A
            val sB=gHelper.skeleton B

            (* checks if either expression evaluates to 0*)
            in (case (sA,sB)
                of (0,0)=> gHelper.mkC 0
                |(0,_)=> let
                    val (vD,D)= gHelper.mkC ~1
                    val (vE,E)= gHelper.aaV(DstOp.prodSca,[vD,vB],"prodSca",DstTy.TensorTy([]))
                    in (vE,B@D@E) end

                | (_,0)=> (vA,A)
                | _ => let
                    val (vD,D)= gHelper.aaV(DstOp.subSca,[vA,vB],"subSca",DstTy.TensorTy([]))
                    in (vD, A@B@D) end
                (*end case*))
            end

      | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(id,del,pos)::es))=>let
            val mapp= getMapp 1 
            val harg=List.nth(origargs,id)
            val h=gHelper.getKernel(harg)
            val imgarg=List.nth(origargs,Vid)
            val  v=gHelper.getImage(imgarg)
            in genKrn.evalField(mapp,(body,v,h,args))
            end

        | E.Prod e => let
            
            (*checkO removes 1 from list, and returns 0  if there is one*)
            fun checkO ([],[],[])=let
                    val (vA,A)=gHelper.aaV(DstOp.C 1,[],"Const",DstTy.TensorTy([]))
                    in ([vA],A) end
                | checkO(ids,code,[])=(ids,code)
                | checkO(ids,code, e1::es)=let
                    val (a,b)=gen e1
                    val sB=gHelper.skeleton b
                    in (case sB
                        of 0 => ([a],b)
                        | 1 => checkO(ids,code,es)
                        | _ => checkO(ids@[a],code@b,es)
                        (*end case*))
                    end
            val (ids,code)=checkO([],[],e)
            in  (case ids
                of [id1]=> (id1,code)
                | _=>let
                    val (vB,B)=gHelper.mkMultiple(ids,DstOp.prodSca,DstTy.TensorTy([]))
                    in (vB,code@B) end
                (*end case*))
            end 
        | E.Div(e1,e2)=>(let
            val (vA,A)=gen e1
            val sA=gHelper.skeleton A
            in (case sA
                of 0=> gHelper.mkC 0
                | _=> let
                    val (vB,B)=gen e2
                    val (vD,D)= gHelper.aaV(DstOp.divSca,[vA,vB],"divSca",DstTy.TensorTy([]))
                    in (vD, A@B@D) end
                (*end case*))
            
            end)

        | E.Sum(sumx, e)=> let
                        fun sumloop(mapsum)= (jx:=mapsum; let

                val(vA,A)=gen e
                val sA=gHelper.skeleton A
                in (case sA
                    (*of 0 =>([],[])*)
                    of  _=>([vA],A)
                    (*end case*))

                end )
            fun sumI1(left,(v,0,lb1),[],rest,code)=let
                val mapp=insert(v, lb1) left
                val (vD,pre)= sumloop(mapp)
                in (vD@rest,pre@code) end
            |  sumI1(left,(v,i,lb1),[],rest,code)=let
    
                val mapp=insert(v, (i+lb1)) left
                val (vD,pre)=sumloop(mapp)
                in sumI1(mapp,( v,i-1,lb1),[],vD@rest,pre@code) end
            | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
                val mapp=insert(v, lb1) left
                in sumI1(mapp,(a,ub-lb2,lb2),sx,rest,code) end
            | sumI1(left,(v,s,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
                val mapp=insert(v, (s+lb1)) left
                val (rest',code')=sumI1(mapp,(a,ub-lb2,lb2),sx,rest,code)
                in sumI1(mapp,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end


            val (E.V v,lb,ub)=hd(sumx)
            val(li, code)=sumI1(getMapp 1,(v,ub-lb,lb),tl(sumx),[],[])
            in (case li
                of [l1] => (l1,code)
                |_=>let val(vF,F)=gHelper.mkMultiple(li,DstOp.addSca,DstTy.TensorTy([]))
                    in (vF,code@F) end
                    (*end case*))
                end

            (*end case*))

in (gen body) end

(*Below functions are used to check for vectorization*)
(*Check Addition expression *)
fun handleSimpleAdd(E.Add body,index,origargs, args)=let

    val n=length(index)-1
    fun add (lft,[])=let
            val index'=List.take(index,n)
            val m=List.nth(index,n)
            in prodIter(index,index',gHelper.handleAddVec,(lft,index,m,args)) end
        | add(lft,E.Tensor(id,[])::es) =prodIter(index,index,generalfn,(E.Add body,[],origargs, args))
        | add(lft,E.Tensor(id,list1)::es) =let
               
            val n1=length(list1)-1
             
            val E.V i=List.nth(list1, n1)
            in if(i=n) then let
                    val list1'=List.take(list1,n1)
                    in add(lft@[(id,list1')],es) end
                else prodIter(index,index,generalfn,(E.Add body,[],origargs, args))
            end
        | add(lft,e::es)=prodIter(index,index,generalfn,(E.Add body,[],origargs, args))
    in  add([],body)
    end


(*Addition and sutraction, with just two tensors *)
fun handleSimpleOp (orig,index,f1,f2,args)=let
    val [(id1,list1),(id2,list2)]=orig
    val n1=length(list1)-1
    val n2=length(list2)-1
    val n=length(index)-1
    val vi=List.nth(list1, n1)
    val vj=List.nth(list2,n2)
    val E.V i=vi
    val E.V j=vj
    in
        if(j=n andalso i=j)
        then let
            val list1'=List.take(list1,n1)
            val list2'=List.take(list2,n2)
            val index'=List.take(index,n)
             val m=List.nth(index,n)
            in prodIter(index,index',f1,([(id1,list1'),(id2,list2')],[],m,args)) end
        else  prodIter(index,index,f2,(orig,[],args))
    end

(*Need to double check here*)
fun handleNeg(orig,index,id, ix,origargs, args)=let
    (*Create a Vector from Tensor*)
    val n=(length index)-1
    val i=List.nth(ix, n)
    in (case i
        of E.V v =>
            if(v=n) (*can use vectorization*)
            then let
                val (vA,A)= gHelper.mkC(~1)

                val index'=List.take(index,n)
                val ix'=List.take(ix,n)
                 val m=List.nth(index,n)

                val (vB,B)=prodIter(index,index',gHelper.mkNegV,((vA,id,ix'),[],m,args))
                in (vB,A@B)
                end
            else prodIter(index,index,generalfn,(orig,[],origargs, args))
        |_ => prodIter(index,index,generalfn,(orig,[],origargs, args))
        (*end case *))
    end


(*Prodduct of two tensors*)



fun handleProd(orig,index,sx,origargs, args)=let
    val [(id1,list1),(id2,list2)]=orig
    val n1=length(list1)-1
    val n2=length(list2)-1
    val vi=List.nth(list1, n1)
    val vj=List.nth(list2,n2)
    val list1'=List.take(list1,n1)
    val list2'=List.take(list2,n2)
    val ns=length(sx)
      val _ =print "In Handle Product"
    in
        if(ns=0)
        then let
            val m=gHelper.findDup(list1,list2)
            val n=length(index)-1
            val E.V i=vi
            val E.V j=vj
            val index'=List.take(index,n)
             val mm=List.nth(index,n)
            in (case m
                of NONE =>
                    (*{A_.. B_..j}_...j ? i.e, outproduct*)
                    (* s*v  otherwise s*s *)
                    if(j=n) then
                    prodIter(index,index',gHelper.mkprodScaV,([(id1,list1),(id2,list2')],[],mm,args))
                    else prodIter(index,index,gHelper.mkprodSca,(orig,[],args))
                | _ =>
                    (*{A_i B_i}_i? i.e. modoulate*)
                    (* v*v  otherwise s*s*)
                    if(i=j andalso i=n)
                    then prodIter(index,index',gHelper.mkprodVec,([(id1,list1'),(id2,list2')],[],mm,args))
                    else prodIter(index,index,gHelper.mkprodSca,(orig,[],args))
            (*end case*)) end
        else if (ns=1) then let
                val _ =print "In one "
                val [(sx1,lb,ub)]=sx

                in
                    if(vi=vj andalso vi=sx1)
                    then
(print "In Dot product";prodIter(index,index,gHelper.mkprodSumVec,([(id1,list1'),(id2,list2')],[],ub,args))) (*v,v*)
else (print "not in Dot Product";
                            prodIter(index,index,gHelper.sum,(orig,sx,args))) (*s,s *)
                end
        else if (ns=2)
            then let val [(sx1,lb1,ub1),(sx2,lb2,ub2)]=sx
          
                in  if(vi=vj andalso vi=sx1)
                    
                    then prodIter(index,index,gHelper.sumDot,([(id1,list1'),(id2,list2')],[(sx2,lb2,ub2)],ub1,args))
                    else if(vi=vj andalso vi=sx2)
                        then prodIter(index,index,gHelper.sumDot,([(id1,list1'),(id2,list2')],[(sx1,lb1,ub1)],ub2,args))
                        else prodIter(index,index,gHelper.sum,(orig,sx,args))
                end
        else prodIter(index,index,gHelper.sum,(orig,sx,args))
    end



fun handleScVProd(body,orig,index,sx,origargs, args)=let
    val (id1,id2,list2)=orig
    val n2=length(list2)-1
    val vj=List.nth(list2,n2)
    val E.V j=vj
    val nsx=length sx
     val n=length(index)-1
        val m=List.nth(index,n)
    in if(j=n andalso nsx=0)
        then  let
                val index'=List.take(index,n)     val list2'=List.take(list2,n2)
     
            in prodIter(index,index',gHelper.mkprodScaV,([(id1,[]),(id2,list2')],[],m, args)) end
        else prodIter(index,index,generalfn,(body,[],origargs, args))
    end 


(*Simple Operators on two tensors, examine to see if we could use vectors *)
(*Have to pass orig args to everyone in case we have a kernel or image in a later stage of iteration*)

fun genfn(y,Ein.EIN{params, index, body},origargs,a)= let
    val  sx= ref[]
    val n=length index
    val args=(a,params)
    val _=print "\n in genfn "

   fun gen b=(case b
        of  E.Field _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Partial _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Apply _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Probe _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Conv _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Neg(E.Tensor(id,ix))=> handleNeg(body,index,id, ix,origargs, args)
        | E.Add _ =>(case !sx
            of []   => handleSimpleAdd(body,index,origargs,  args)
            | _     => prodIter(index,index,generalfn,(body,[],origargs, args))
            (*end case*))

        | E.Sub(E.Tensor(id1, ix1), E.Tensor(id2, ix2)) =>(case !sx
            of []   => handleSimpleOp([(id1,ix1),(id2,ix2)],index,gHelper.mksubVec,gHelper.mksubSca,args)
            | _     => prodIter(index,index,generalfn,(body,[],origargs, args))
            (*end case*))
        | E.Prod[E.Tensor(id1,[]), E.Tensor(id2, [])]=>
                gHelper.mkprodSca(empty,([(id1,[]),(id2,[])],[],args))
        | E.Prod[E.Tensor(id1, []), E.Tensor(id2, ix2)] =>let
            val ref x=sx
            in
                handleScVProd(body,(id1,id2,ix2),index,x,origargs, args)
            end
        | E.Prod[E.Tensor(id1, ix1), E.Tensor(id2, [])] =>let
            val ref x=sx
            in
                handleScVProd(body,(id2,id1,ix1),index,x, origargs, args)
            end
       | E.Prod[E.Tensor(id1, ix1), E.Tensor(id2, ix2)] =>
            if(length(index) =0) then 
            handleProd([(id1,ix1),(id2,ix2)],index,!sx,origargs, args)
            else prodIter(index,index,generalfn,(body,[],origargs, args))


        | E.Sum(ss,E.Prod(E.Img(Vid,_,_)::E.Krn(id,del,pos)::es))=>let
            val ref x=sx
                       in
                if(length x=0) then let
                    val harg=List.nth(origargs,id)
                    val h=gHelper.getKernel(harg)
                    val imgarg=List.nth(origargs,Vid)
                    val  v=gHelper.getImage(imgarg)
                    in prodIter(index,index,genKrn.evalField,(b,v, h,args)) end
                else prodIter(index,index,generalfn,(body,[],origargs, args))
            end

        | E.Sum(sx', e)=> (let
            val _=print "\n in summation"
            val ref x=sx

            in   sx:=x@sx' end ;gen e)
| _ => (print "other summation";prodIter(index,index,generalfn,(body,[],origargs, args)))
        (*end case*))

    (*Scalars only, not vectorization potential*)
    fun single b=(case b
        of E.Tensor(id,[]) =>gHelper.mkSca(empty,(id,[], args))
        | E.Const c=>gHelper.mkC(c)
        | E.Neg _ => generalfn(empty,(b,[],origargs, args))
        | E.Add _ =>generalfn(empty,(b,[],origargs, args))
        | E.Sub _=> generalfn(empty,(b,[],origargs, args))
        | E.Div _=> generalfn(empty,(b,[],origargs, args))
        | E.Prod [E.Tensor(_,[]),E.Tensor(_,[])] => generalfn(empty,(b,[],origargs, args))
        | _=> gen b

        (*end case*))

in (case n of 0 =>single body  | _=> gen body) end

end (* local *)

end 

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0