Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/gen-ein.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/gen-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2533 - (download) (annotate)
Thu Jan 30 04:58:56 2014 UTC (5 years, 5 months ago) by cchiw
File size: 16752 byte(s)
type checker
(*hashs Ein Function after substitution*)
structure genEin = struct
    local
    structure E = Ein
    structure genHelper=genHelper
    structure genKrn=genKrn

    

structure DstIL = LowIL
structure DstTy = LowILTypes
structure DstOp = LowOps
structure Var = LowIL.Var
    in



(*Iterate over the outside index*)
(*nextfn is the next function*)
(*m are arguements*)
fun prodIter(origIndex,index,nextfn,args)=let

    val index'=List.map (fn (e)=>(e-1)) index
    fun M(mapp,[],rest,code,shape)=let
        val (vF,code')=nextfn(mapp,args)
        in (vF, code'@code)
        end
    | M(a,[0], rest, code,shape)=let
        val mapp=a@[0]
        val (vF,code')=nextfn(mapp,args)
        val(vE,E)=genHelper.aaV(DstOp.cons(DstTy.TensorTy shape,0),[vF]@rest,"Cons",DstTy.TensorTy(shape))
        in (vE, code'@code@E)
        end
    | M(a,[c],rest,code,shape)=let
        val mapp=a@[c]
        val (vE,E)=nextfn(mapp,args)
      in  M(a, [c-1], [vE]@rest,E@code,shape) end
    | M (a,b::c,rest,ccode,s::shape)=let
        fun S(0, rest,code)=let
            val (v',code')=M(a@[0],c,[],[],shape)
            val(vA,A)=genHelper.aaV(DstOp.cons(DstTy.TensorTy (s::shape),0),[v']@rest,"Cons",DstTy.TensorTy(s::shape))
            in (vA, code'@code@A) end
        | S(i, rest, code)= let
            val (v',code')=M(a@[i],c,[],[],shape)
            in S(i-1,[v']@rest,code'@code) end
        val (vA,code')=S(b, [],[])
        in (vA,code'@ccode) end
    val (rest',code')= M([],index',[],[],origIndex)
    in (rest',code')
    end



(* general expressions*)
fun generalfn(ap,(body,_,origargs, args))= let
    val a=print "in general fn"
    val  mappA= ref ap

    fun gen body=(case body
        of  E.Field _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Partial _ =>raise Fail(concat["Invalid FieldPartial here "]   )
        | E.Apply _ =>raise Fail(concat["Invalid FieldApply here "]   )
        | E.Probe _ =>raise Fail(concat["Invalid FieldProbe here "]   )
        | E.Conv _ =>raise Fail(concat["Invalid FieldConv here "]   )
        | E.Krn _ =>raise Fail(concat["Invalid FieldKrn here "]   )
        | E.Img _=> raise Fail(concat["Invalid FieldImg here "]   )
        | E.Value v =>let
              val ref mapp=mappA
              val n=List.nth( mapp,v)
              in genHelper.mkC n end

        (*| E.Const c=> []*)
        | E.Tensor(id,[])=>  genHelper.mkSca([],(id,[],args))
        | E.Tensor(id,ix)=> let val ref mapp=mappA
            in genHelper.mkSca(mapp,(id,ix,args)) end
        | E.Delta(i,j)=> let val ref mapp=mappA
            in  genHelper.evalDelta2(i,j,mapp)  end
        | E.Epsilon(i,j,k)=> let
            val ref mapp=mappA
            val n=genHelper.evalEps(i,j,k,mapp)
            in genHelper.aaV(DstOp.C(n),[],"Const",DstTy.TensorTy([])) end
        | E.Neg e => let
            val (vA,A)=gen e
            val s=genHelper.skeleton A
            in (case s
                of 0 => (vA,A)
                | ~1 => genHelper.mkC 1
                | 1 =>  genHelper.mkC ~1
                |  _=> let
                    val (vB,B)=genHelper.aaV(DstOp.C (~1),[],"Const",DstTy.TensorTy([]))
                    val (vD,D)=genHelper.aaV(DstOp.prodSca,[vB,vA],"prodSca",DstTy.TensorTy([]))
                    in (vD,A@B@D) end
                (*end case*))
            end

        | E.Add e=> let
            (*check0 function removes 0 from list *)
            fun checkO ([],[],[])=let val (vA,A)=genHelper.aaV(DstOp.C(1),[],"Const",DstTy.TensorTy([])) in ([vA],A) end
                | checkO(ids,code,[])=(ids,code)
                | checkO(ids,code, e1::es)=let
                    val (a,b)=gen e1
                    val s=genHelper.skeleton b
                    in (case s
                        of 0 => checkO(ids,code,es)
                        |  _ => checkO(ids@[a],code@b,es)
                        (*end case*))
                    end
            val (ids,code)=checkO([],[],e)
            in  (case ids
                of [id1]=> (id1,code)
                | _=>let  val (vB,B)=genHelper.mkMultiple(ids,DstOp.addSca,DstTy.TensorTy([]))
                    in (vB,code@B) end
                (*end case*))
            end

        | E.Sub (e1,e2)=>let
            val (vA,A)=gen e1
            val (vB,B)=gen e2
            val sA=genHelper.skeleton A
            val sB=genHelper.skeleton B

            (* checks if either expression evaluates to 0*)
            in (case (sA,sB)
                of (0,0)=> genHelper.mkC 0
                |(0,_)=> let
                    val (vD,D)= genHelper.mkC ~1
                    val (vE,E)= genHelper.aaV(DstOp.prodSca,[vD,vB],"prodSca",DstTy.TensorTy([]))
                    in (vE,B@D@E) end

                | (_,0)=> (vA,A)
                | _ => let
                    val (vD,D)= genHelper.aaV(DstOp.subSca,[vA,vB],"subSca",DstTy.TensorTy([]))
                    in (vD, A@B@D) end
                (*end case*))
            end
      | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(id,del,pos)::es))=>let
            val ref mapp=mappA
            val harg=List.nth(origargs,id)
            val h=genHelper.getKernel(harg)
            val imgarg=List.nth(origargs,Vid)
            val  v=genHelper.getImage(imgarg)
            in genKrn.evalField(mapp,(body,v,h,args))
            end

        | E.Prod e => let
            (*checkO removes 1 from list, and returns 0  if there is one*)
            fun checkO ([],[],[])=let val (vA,A)=genHelper.aaV(DstOp.C 1,[],"Const",DstTy.TensorTy([])) in ([vA],A) end
                | checkO(ids,code,[])=(ids,code)
                | checkO(ids,code, e1::es)=let
                    val (a,b)=gen e1
                    val sB=genHelper.skeleton b
                    in (case sB
                        of 0 => ([a],b)
                        | 1 => checkO(ids,code,es)
                        | _ => checkO(ids@[a],code@b,es)
                        (*end case*))
                    end
            val (ids,code)=checkO([],[],e)
            in  (case ids
                of [id1]=> (id1,code)
                | _=>let
                    val (vB,B)=genHelper.mkMultiple(ids,DstOp.prodSca,DstTy.TensorTy([]))
                    in (vB,code@B) end
                (*end case*))
            end 
        | E.Div(e1,e2)=>(let
            val (vA,A)=gen e1
            val sA=genHelper.skeleton A 
            in (case sA
                of 0=> genHelper.mkC 0 
                | _=> let
                    val (vB,B)=gen e2
                    val (vD,D)= genHelper.aaV(DstOp.divSca,[vA,vB],"divSca",DstTy.TensorTy([]))
                    in (vD, A@B@D) end
                (*end case*))
            
            end)
        | E.Sum(sumx, e)=> let
            val m=print "in general sum"
            val ref orig=mappA
            fun sumloop(mapsum)= (mappA:=(orig@mapsum); let
                val(vA,A)=gen e
                val sA=genHelper.skeleton A 
                in (case sA
                    of 0 =>([],[])
                    | _=>([vA],A)
                    (*end case*))

                end )
            fun sumI1(left,(0,lb1),[],rest,code)=let
                val mapp=left@[lb1]
                val (vD,pre)= sumloop(mapp)
                in (vD@rest,pre@code) end
            |  sumI1(left,(i,lb1),[],rest,code)=let
                val mapp=left@[i+lb1]
                val (vD,pre)=sumloop(mapp)
                in sumI1(left,(i-1,lb1),[],vD@rest,pre@code) end
            | sumI1(left,(0,lb1),(a,lb2,ub)::sx,rest,code)=
                sumI1(left@[lb1],(ub-lb2,lb2),sx,rest,code)
            | sumI1(left,(s,lb1),(a,lb2,ub)::sx,rest,code)=let
                val (rest',code')=sumI1(left@[s+lb1],(ub-lb2,lb2),sx,rest,code)
                in sumI1(left,(s-1,lb1),(E.V 0,lb2,ub)::sx,rest',code') end
            val (_,lb,ub)=hd(sumx)
            val(li, code)=sumI1([],(ub-lb,lb),tl(sumx),[],[])
            in (case li
                of [l1] => (l1,code)
                |_=>let val(vF,F)=genHelper.mkMultiple(li,DstOp.addSca,DstTy.TensorTy([]))
                in (vF,code@F) end
                (*end case*))
                end

            (*end case*))

        in gen body end

(*Below functions are used to check for vectorization*)
(*Check Addition expression *)
fun handleSimpleAdd(E.Add body,index,origargs, args)=let
    val n=length(index)-1
    fun add (lft,[])=let
            val index'=List.take(index,n)
            val m=List.nth(index,n)
            in prodIter(index,index',genHelper.handleAddVec,(lft,index,m,args)) end
        | add(lft,E.Tensor(id,[])::es) =prodIter(index,index,generalfn,(E.Add body,[],origargs, args))
        | add(lft,E.Tensor(id,list1)::es) =let
               
            val n1=length(list1)-1
             
            val E.V i=List.nth(list1, n1)
            in if(i=n) then let
                    val list1'=List.take(list1,n1)
                    in add(lft@[(id,list1')],es) end
                else prodIter(index,index,generalfn,(E.Add body,[],origargs, args))
            end
        | add(lft,e::es)=prodIter(index,index,generalfn,(E.Add body,[],origargs, args))
    in  add([],body)
    end


(*Addition and sutraction, with just two tensors *)
fun handleSimpleOp (orig,index,f1,f2,args)=let
    val [(id1,list1),(id2,list2)]=orig
    val n1=length(list1)-1
    val n2=length(list2)-1
    val n=length(index)-1
    val vi=List.nth(list1, n1)
    val vj=List.nth(list2,n2)
    val E.V i=vi
    val E.V j=vj
    in
        if(j=n andalso i=j)
        then let
            val list1'=List.take(list1,n1)
            val list2'=List.take(list2,n2)
            val index'=List.take(index,n)
             val m=List.nth(index,n)
            in prodIter(index,index',f1,([(id1,list1'),(id2,list2')],[],m,args)) end
        else  prodIter(index,index,f2,(orig,[],args))
    end

(*Need to double check here*)
fun handleNeg(orig,index,id, ix,origargs, args)=let
    (*Create a Vector from Tensor*)
    val n=(length index)-1
    val i=List.nth(ix, n)
    in (case i
        of E.V v =>
            if(v=n) (*can use vectorization*)
            then let
                val (vA,A)= genHelper.mkC(0)
                val uuu=print "post genHelper Call"
                val index'=List.take(index,n)
                val ix'=List.take(ix,n)
                 val m=List.nth(index,n)
                val g=print "IN HANDLE NEG-- PUPPY\n"
                val ggg=print(Int.toString(m))
                val (vB,B)=prodIter(index,index',genHelper.mkNegV,((vA,id,ix'),[],m,args))
                in (vB,A@B)
                end
            else prodIter(index,index,generalfn,(orig,[],origargs, args))
        |_ => prodIter(index,index,generalfn,(orig,[],origargs, args))
        (*end case *))
    end


(*Prodduct of two tensors*)



fun handleProd(orig,index,sx,origargs, args)=let
    val [(id1,list1),(id2,list2)]=orig
    val n1=length(list1)-1
    val n2=length(list2)-1
    val vi=List.nth(list1, n1)
    val vj=List.nth(list2,n2)
    val list1'=List.take(list1,n1)
    val list2'=List.take(list2,n2)
    val ns=length(sx)
    in
        if(ns=0)
        then let
            val m=genHelper.findDup(list1,list2)
            val n=length(index)-1
            val E.V i=vi
            val E.V j=vj
            val index'=List.take(index,n)
             val mm=List.nth(index,n)
            in (case m
                of NONE =>
                    (*{A_.. B_..j}_...j ? i.e, outproduct*)
                    (* s*v  otherwise s*s *)
                    if(j=n) then
                    prodIter(index,index',genHelper.mkprodScaV,([(id1,list1),(id2,list2')],[],mm,args))
                    else prodIter(index,index,genHelper.mkprodSca,(orig,[],args))
                | _ =>
                    (*{A_i B_i}_i? i.e. modoulate*)
                    (* v*v  otherwise s*s*)
                    if(i=j andalso i=n)
                    then prodIter(index,index',genHelper.mkprodVec,([(id1,list1'),(id2,list2')],[],mm,args))
                    else prodIter(index,index,genHelper.mkprodSca,(orig,[],args))
            (*end case*)) end
        else if (ns=1) then let
                val [(sx1,lb,ub)]=sx
               
                in
                    if(vi=vj andalso vi=sx1)
                    then prodIter(index,index,genHelper.mkprodSumVec,([(id1,list1'),(id2,list2')],[],ub,args)) (*v,v*)
                    else prodIter(index,index,genHelper.sum,(orig,sx,args)) (*s,s *)
                end
        else if (ns=2)
            then let val [(sx1,lb1,ub1),(sx2,lb2,ub2)]=sx
          
                in  if(vi=vj andalso vi=sx1)
                    
                    then prodIter(index,index,genHelper.sumDot,([(id1,list1'),(id2,list2')],[(sx2,lb2,ub2)],ub1,args))
                    else if(vi=vj andalso vi=sx2)
                        then prodIter(index,index,genHelper.sumDot,([(id1,list1'),(id2,list2')],[(sx1,lb1,ub1)],ub2,args))
                        else prodIter(index,index,genHelper.sum,(orig,sx,args))
                end
        else prodIter(index,index,genHelper.sum,(orig,sx,args))
    end



fun handleScVProd(body,orig,index,sx,origargs, args)=let
    val (id1,id2,list2)=orig
    val n2=length(list2)-1
    val vj=List.nth(list2,n2)
    val E.V j=vj
    val nsx=length sx
     val n=length(index)-1
        val m=List.nth(index,n)
    in if(j=n andalso nsx=0)
        then  let
                val index'=List.take(index,n)     val list2'=List.take(list2,n2)
                val q=print(String.concat["Puppy-Make Vector index",Int.toString(m)])
            in prodIter(index,index',genHelper.mkprodScaV,([(id1,[]),(id2,list2')],[],m, args)) end
        else prodIter(index,index,generalfn,(body,[],origargs, args))
    end 


(*Simple Operators on two tensors, examine to see if we could use vectors *)
(*Have to pass orig args to everyone in case we have a kernel or image in a later stage of iteration*)

fun genfn(y,Ein.EIN{params, index, body},origargs,a)= let
    val  sx= ref[]
    val n=length index
    val args=(a,params)
   (*Potential for Vectorization here *)
   fun gen b=(case b
        of  E.Field _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Partial _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Apply _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Probe _ =>raise Fail(concat["Invalid Field here "]   )
        | E.Conv _ =>raise Fail(concat["Invalid Field here "]   )

        (*| E.Const _=>[]*)



        | E.Neg(E.Tensor(id,ix))=> handleNeg(body,index,id, ix,origargs, args)
        | E.Add _ => handleSimpleAdd(body,index,origargs,  args)
        | E.Sub(E.Tensor(id1, ix1), E.Tensor(id2, ix2)) =>
            handleSimpleOp([(id1,ix1),(id2,ix2)],index,genHelper.mksubVec,genHelper.mksubSca,args)
        | E.Prod[E.Tensor(id1, []), E.Tensor(id2, ix2)] =>let
                val ref x=sx
                in handleScVProd(body,(id1,id2,ix2),index,x,origargs, args) end
        | E.Prod[E.Tensor(id1, ix1), E.Tensor(id2, [])] =>let
                val ref x=sx
                in handleScVProd(body,(id2,id1,ix1),index,x, origargs, args) end 
       (* | E.Prod[E.Tensor(id1, ix1), E.Tensor(id2, ix2)] =>
                let
            val ref x=sx
            in
                handleProd([(id1,ix1),(id2,ix2)],index,x,args)
            end*)
        (*| E.Div(E.Tensor _,E.Tensor _ )=>[]*)

        | E.Sum(ss,E.Prod(E.Img(Vid,_,_)::E.Krn(id,del,pos)::es))=>let
            val ref x=sx
            val m=print "\n match img"
            
            in
                if(length x=0) then let
                    val harg=List.nth(origargs,id)
                    val h=genHelper.getKernel(harg)
                    val imgarg=List.nth(origargs,Vid)
                    val  v=genHelper.getImage(imgarg)
                    in prodIter(index,index,genKrn.evalField,(b,v, h,args)) end
                else prodIter(index,index,generalfn,(body,[],origargs, args))
            end
        | E.Sum(sx', e)=> (let
            val ref x=sx
            in   sx:=x@sx' end ;gen e)
        | _ => prodIter(index,index,generalfn,(body,[],origargs, args))
        (*end case*))

    (*Scalars only, not vectorization potential*)
    fun single b=(case b
        of E.Tensor(id,[]) =>genHelper.mkSca([],(id,[], args))
        | E.Const _=>generalfn([],(b,[],origargs, args))
        | E.Neg _ => generalfn([],(b,[],origargs, args))
        | E.Add _ =>generalfn([],(b,[],origargs, args))
        | E.Sub _=> generalfn([],(b,[],origargs, args))
        | E.Div _=> generalfn([],(b,[],origargs, args))
        | E.Prod [E.Tensor(_,[]),E.Tensor(_,[])] => generalfn([],(b,[],origargs, args))
        | _=> gen b 

        (*end case*))

   in (case n of 0 =>single body  | _=> gen body) end

end (* local *)

end 

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0