Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/step1.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/step1.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2845 - (download) (annotate)
Fri Dec 12 06:46:23 2014 UTC (4 years, 7 months ago) by cchiw
File size: 8223 byte(s)
added norm
(*Looks for vectorization potential*)
structure step1 = struct
    local
    structure DstTy = LowILTypes
    structure DstOp = LowOps
    structure Var = LowIL.Var
    structure E = Ein
    structure P=Printer
    structure genKrn=genKrn
    structure S2= step2
    structure S3= step3
    structure DstIL = LowIL

    in

    val Sca=DstTy.TensorTy([])
    fun errS str=raise Fail(str)

    (*Returns last index, and rest of list*)
    fun getLast list=let 
        val a=List.rev(list)
        val A=List.hd(a)
        val b=List.tl(a)
        val B=List.rev(b)
        in (A,B) end

    (*Returns last index, and rest of list*)
    fun getLastAll list=let
        val a=List.rev(list)
        val A=List.hd(a)
        val b=List.tl(a)
        val B=List.rev(b)
        in (case A
            of E.V i =>(i,A,B)
            | _ => errS("Last Index is not Variable")
            (*end case*))
        end 

    fun findDup(list1,list2)=let
        fun current []=NONE
        | current (v::vs)=(case (List.find (fn x => x =v) list2)
            of NONE =>current vs
            |_=> SOME 1
            (*end case*))
        in
            current list1
        end
    (*-----------Handle Functions for vectorization-----------*)

    (*Negative Tensor*)
    fun handleNeg(id, ix,index, nextfnArgs)=let
        val n=(length index)-1
         fun default _=S2.prodIter(index,index,S2.generalfn, nextfnArgs)
        in (case (List.nth(ix, n))
            of E.V v =>
                if(v=n) then let
                    val (vA,A)= S3.mkReal ~1
                    val ix'=List.take(ix,n)
                    val (LastIndex,index')=getLast index
                    val (_,_,info)=nextfnArgs
                    val (vB,B)=S2.prodIter(index,index',S3.mkNegV,((vA,id,ix'),LastIndex,info))
                    in (vB,A@B) end
                else default 0
            |_ => default 0
        (*end case *))
    end

    (*Subtract Tensors*)
    fun handleSimpleOp(id1,list1,id2,list2,index,f1,f2,nextfnArgs)=let
        fun default _=S2.prodIter(index,index,f2, nextfnArgs)
        in (case (list1, list2)
           of ([],_)=> default 0
            | (_,[])=> default 0
            | _     => let
                val (i,vi,list1')=getLastAll list1
                val (j,vj,list2')=getLastAll list2
                val n=length(index)-1
                in
                    if(j=i andalso j=n)
                    then let
                        val (LastIndex,index')=getLast index
                        val (_,_,info)=nextfnArgs
                        in S2.prodIter(index,index',f1,(id1,list1',id2,list2',LastIndex,info)) end
                    else  default 0 
                end
            (*end case*))
        end 


    (*Add Vector *)
    fun handleAdd(terms,index,nextfnArgs)=let
        val n=length(index)-1
        fun default _=S2.prodIter(index,index,S2.generalfn, nextfnArgs)
        fun add (e,rest)=(case e
            of [] => let
                val (LastIndex,index')=getLast index
                val (_,_,info)=nextfnArgs
                in S2.prodIter(index,index',S3.handleAddVec,(rest,LastIndex,info)) end
            | E.Tensor(id,[])::es =>  default 0
            | E.Tensor(id,alpha)::es => let
                val (i,_,list1')=getLastAll alpha
                in if(i=n) then add(es,rest@[(id,list1')])
                    else  default 0
                end

            | _=> default 0
            (*end case *))
        in
            add(terms,[])
        end


    (*Scaling*)
    fun handleScVProd(id1,id2,alpha,index,nextfnArgs)=let
        val (j,vj,list2')=getLastAll alpha
        val (LastIndex,index')= getLast index
        val n=length(index)-1
        val (_,_,info)=nextfnArgs
        fun default _=S2.prodIter(index,index,S2.generalfn, nextfnArgs)
        in if(j=n)
            then S2.prodIter(index,index',S3.mkprodScaV,(id1,[],id2,list2',LastIndex,info))
            else default 0
        end



    (*None:{A_.. B_..j}_...j ? i.e, outproduct : s*v  otherwise s*s *)
    (*Some: {A_i B_i}_i? i.e. modoulate: v*v  otherwise s*s*)
    fun handleProd(id1,list1,id2,list2,index,sx,nextfnArgs)=let
        val (i,vi,list1')=getLastAll list1
        val (j,vj,list2')=getLastAll list2
        val (_,_,info)=nextfnArgs
        fun default () = S2.prodIter(index,index,S2.generalfn, nextfnArgs)
        in (case sx
            of []=>let
                val (LastIndex,index')=getLast index
                val n=length(index)-1
                in (case (findDup(list1,list2))
                    of NONE =>if(j=n)
                        then S2.prodIter(index,index',S3.mkprodScaV,(id1,list1,id2,list2',LastIndex,info))
                        else default()
                    | _ =>  if(i=j andalso i=n)
                        then S2.prodIter(index,index',S3.mkprodVec,(id1,list1',id2,list2',LastIndex,info))
                        else  default ()
                    (*end case*))
                end
           | [(sx1,0,ub)]=> if(vi=vj andalso vi=sx1)
                    then S2.prodIter(index,index,S3.mkprodSumVec,(id1,list1',id2,list2',ub+1,info))
                else default ()
          | [(sx1,lb1,ub1),(sx2,lb2,ub2)]=>
                if(vi=vj andalso vi=sx1)
                    then S2.prodIter(index,index,S3.sumDot,((sx2,lb2,ub2),(id1,list1',id2,list2',ub1+1, info)))
                else if(vi=vj andalso vi=sx2)
                    then S2.prodIter(index,index,S3.sumDot,((sx1,lb1,ub1),(id1,list1',id2,list2',ub2+1,info)))
                    else default ()
            |  _ =>  default ()
            (*end case*))
        end




    (*General Function looks for vectorization potential*)
    (*A=(body,info)*)
    (*NextFn/GenerAlFn:(mapp, A)*)
    (*ProdIter: index,index, nextfn, A*)
    (*handleFunction:[(id,ix)]'s,index,info *)
    (*params,args,origargs*)
    fun genfn(lhs,Ein.EIN{params, index, body},args,origargs)= let
        val info=(lhs,params,index,args)
        val nextfnArgs=(body,origargs,info)
        val iterArgs=(index,index,S2.generalfn,nextfnArgs)
        fun gen body=(case body
            of E.Neg(E.Tensor(_ ,[]))       => S2.prodIter iterArgs
            | E.Neg(E.Tensor(id1,ix1))                            =>
                handleNeg(id1, ix1,index, nextfnArgs)
            | E.Sub(E.Tensor(_,[]),E.Tensor(_,[])) => S2.prodIter iterArgs
            | E.Sub(E.Tensor(id1,ix1),E.Tensor(id2,ix2))           =>            handleSimpleOp(id1,ix1,id2,ix2,index,S3.mksubVec,S2.generalfn,nextfnArgs)
            |  E.Add es                                             =>
                handleAdd(es,index,nextfnArgs)
           (* | E.Div(E.Const 1,_)                                    =>
                S2.prodIter iterArgs
            | E.Div(a,b)                                            =>
                gen (E.Prod[E.Div(E.Const 1, b),a])*)
            | E.Prod[e]                                             =>
                gen e
            | E.Prod[E.Tensor(id1,[]), E.Tensor(id2, [])]           =>
                S2.prodIter iterArgs
            | E.Prod[E.Tensor(id1, []), E.Tensor(id2, ix2)]         =>
                handleScVProd(id1,id2,ix2,index,nextfnArgs)
            | E.Prod[E.Tensor(id2, ix2), E.Tensor(id1, [])]         =>
                handleScVProd(id1,id2,ix2,index,nextfnArgs)
            | E.Prod[E.Tensor(id1, ix1), E.Tensor(id2, ix2)]        =>
                handleProd(id1,ix1,id2,ix2,index,[],nextfnArgs)
            | E.Sum(sx,E.Prod[E.Tensor(id1,ix1),E.Tensor(id2,ix2)]) =>
                handleProd(id1,ix1,id2,ix2,index,sx,nextfnArgs)
            | E.Sum(x,E.Prod(E.Img(Vid,_,_)::E.Krn(Hid,_,_)::_))    =>
                let
                    val harg=List.nth(origargs,Hid)
                    val h=S3.getKernel harg
                    val imgarg=List.nth(origargs,Vid)
                    val imgargNew=List.nth(args,Vid)
                    val v=S3.getImageSrc imgarg
           
                in
                    S2.prodIter(index,index,genKrn.evalField,(body,(v,imgargNew),h,info))
                end
            | _ => S2.prodIter iterArgs
            (*end case*))
       
        in
            gen body
        end 

end (* local *)

end 

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0