Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/step1.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/step1.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2665 - (download) (annotate)
Tue Jun 3 02:37:46 2014 UTC (6 years, 5 months ago) by cchiw
File size: 7018 byte(s)
made changes to ld and mkvec funcitons
(*Looks for vectorization potential*)
structure step1 = struct
    local
    structure DstTy = LowILTypes
    structure DstOp = LowOps
    structure Var = LowIL.Var
    structure E = Ein
    structure P=Printer
    structure genKrn=genKrn
    structure S2= step2
    structure S3= step3
    
    in



val Sca=DstTy.TensorTy([])
fun errS str=raise Fail(str)

(*Returns last index, and rest of list*)
fun getLast list=let 
    val a=List.rev(list)
    val A=List.hd(a)
    val b=List.tl(a)
    val B=List.rev(b)
    in (A,B) end

(*Returns last index, and rest of list*)
fun getLastAll list=let
    val a=List.rev(list)
    val A=List.hd(a)
    val b=List.tl(a)
    val B=List.rev(b)
    in (case A
        of E.V i =>(i,A,B)
        | _ => errS("Last Index is not Variable")
        (*end case*))
    end 

fun findDup(list1,list2)=let
    fun current []=NONE
    | current (v::vs)=(case (List.find (fn x => x =v) list2)
        of NONE =>current vs
        |_=> SOME 1
        (*end case*))
    in
        current list1
    end

(*-----------Handle Functions for vectorization-----------*)


(*Negative Tensor*)
fun handleNeg(id, ix,index, nextfnArgs)=let
    val n=(length index)-1
     fun default _=S2.prodIter(index,index,S2.generalfn, nextfnArgs)
    in (case (List.nth(ix, n))
        of E.V v =>
            if(v=n) then let
                val (vA,A)= S3.mkInt ~1
                val ix'=List.take(ix,n)
                val (LastIndex,index')=getLast index
                val (_,_,info)=nextfnArgs
                val (vB,B)=S2.prodIter(index,index',S3.mkNegV,((vA,id,ix'),LastIndex,info))
                in (vB,A@B) end
            else default 0
        |_ => default 0
    (*end case *))
end

(*Subtract Tensors*)
fun handleSimpleOp(id1,list1,id2,list2,index,f1,f2,nextfnArgs)=let
    fun default _=S2.prodIter(index,index,f2, nextfnArgs)
    in (case (list1, list2)
       of ([],_)=> default 0
        | (_,[])=> default 0
        | _     => let
            val (i,vi,list1')=getLastAll list1
            val (j,vj,list2')=getLastAll list2
            val n=length(index)-1
            in
                if(j=i andalso j=n)
                then let
                    val (LastIndex,index')=getLast index
                    val (_,_,info)=nextfnArgs
                    in S2.prodIter(index,index',f1,(id1,list1',id2,list2',LastIndex,info)) end
                else  default 0 
            end
        (*end case*))
    end 


(*Add Vector *)
fun handleAdd(terms,index,nextfnArgs)=let
    val n=length(index)-1
    fun default _=S2.prodIter(index,index,S2.generalfn, nextfnArgs)
    fun add (e,rest)=(case e
        of [] => let
            val (LastIndex,index')=getLast index
            val (_,_,info)=nextfnArgs
            in S2.prodIter(index,index',S3.handleAddVec,(rest,LastIndex,info)) end
        | E.Tensor(id,[])::es =>  default 0
        | E.Tensor(id,alpha)::es => let
            val (i,_,list1')=getLastAll alpha
            in if(i=n) then add(es,rest@[(id,list1')])
                else  default 0
            end

        | _=> default 0
        (*end case *))
    in
        add(terms,[])
    end


(*Scaling*)
fun handleScVProd(id1,id2,alpha,index,nextfnArgs)=let
    val (j,vj,list2')=getLastAll alpha
    val (LastIndex,index')= getLast index
    val n=length(index)-1
    val (_,_,info)=nextfnArgs
    fun default _=S2.prodIter(index,index,S2.generalfn, nextfnArgs)
    in if(j=n)
        then S2.prodIter(index,index',S3.mkprodScaV,(id1,[],id2,list2',LastIndex,info))
        else default 0
    end



(*None:{A_.. B_..j}_...j ? i.e, outproduct : s*v  otherwise s*s *)
(*Some: {A_i B_i}_i? i.e. modoulate: v*v  otherwise s*s*)
fun handleProd(id1,list1,id2,list2,index,sx,nextfnArgs)=let
    val (i,vi,list1')=getLastAll list1
    val (j,vj,list2')=getLastAll list2
    val (_,_,info)=nextfnArgs
    fun default _ = S2.prodIter(index,index,S2.generalfn, nextfnArgs)

    in (case sx
        of []=>let
            val (LastIndex,index')=getLast index
            val n=length(index)-1
            in (case (findDup(list1,list2))
                of NONE =>if(j=n)
                    then S2.prodIter(index,index',S3.mkprodScaV,(id1,list1,id2,list2',LastIndex,info))
                    else default 0
                | _ =>  if(i=j andalso i=n)
                    then S2.prodIter(index,index',S3.mkprodVec,(id1,list1',id2,list2',LastIndex,info))
                    else  default 0
                (*end case*))
            end
       | [(sx1,0,ub)]=> if(vi=vj andalso vi=sx1)
                then S2.prodIter(index,index,S3.mkprodSumVec,(id1,list1',id2,list2',ub+1,info))
            else default 0
      | [(sx1,lb1,ub1),(sx2,lb2,ub2)]=>
            if(vi=vj andalso vi=sx1)
                then S2.prodIter(index,index,S3.sumDot,((sx2,lb2,ub2),(id1,list1',id2,list2',ub1+1, info)))
            else if(vi=vj andalso vi=sx2)
                then S2.prodIter(index,index,S3.sumDot,((sx1,lb1,ub1),(id1,list1',id2,list2',ub2+1,info)))
                else default 0
        |  _ =>  default 0
        (*end case*))
    end




(*General Function looks for vectorization potential*)
(*A=(body,info)*)
(*NextFn/GenerAlFn:(mapp, A)*)
(*ProdIter: index,index, nextfn, A*)
(*handleFunction:[(id,ix)]'s,index,info *)
(*params,args,origargs*)
fun genfn(Ein.EIN{params, index, body},args,origargs)= let
    val info=(params,args)
    val nextfnArgs=(body,origargs,info)
    val iterArgs=(index,index,S2.generalfn,nextfnArgs)
    fun gen body=(case body
        of(**) E.Neg(E.Tensor(id1,ix1))                            =>
            handleNeg(id1, ix1,index, nextfnArgs)
        | E.Sub(E.Tensor(id1,ix1),E.Tensor(id2,ix2))           =>
            handleSimpleOp(id1,ix1,id2,ix2,index,S3.mksubVec,S2.generalfn,nextfnArgs)
        |  E.Add es                                             =>
            handleAdd(es,index,nextfnArgs)
        | E.Prod[e]                                             => gen e
        | E.Prod[E.Tensor(id1,[]), E.Tensor(id2, [])]           =>
            S2.prodIter iterArgs
        | E.Prod[E.Tensor(id1, []), E.Tensor(id2, ix2)]         =>
            handleScVProd(id1,id2,ix2,index,nextfnArgs)
        | E.Prod[E.Tensor(id2, ix2), E.Tensor(id1, [])]         =>
            handleScVProd(id1,id2,ix2,index,nextfnArgs)
        | E.Prod[E.Tensor(id1, ix1), E.Tensor(id2, ix2)]        =>
            handleProd(id1,ix1,id2,ix2,index,[],nextfnArgs)
        | E.Sum(sx,E.Prod[E.Tensor(id1,ix1),E.Tensor(id2,ix2)]) =>
            handleProd(id1,ix1,id2,ix2,index,sx,nextfnArgs)
        | E.Sum(x,E.Prod(E.Img(Vid,_,_)::E.Krn(Hid,_,_)::_))    =>
            let
                val harg=List.nth(origargs,Hid)
                val h=S3.getKernel(harg)
                val imgarg=List.nth(origargs,Vid)
                val v=S3.getImage(imgarg)
            in
                S2.prodIter(index,index,genKrn.evalField,(body,v,h,info))
            end
        | _ => S2.prodIter iterArgs
        (*end case*))
    in gen body
    end 

end (* local *)

end 

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0