Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/src/compiler/low-il/ein-to-low.sml
ViewVC logotype

View of /branches/ein16/src/compiler/low-il/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4257 - (download) (annotate)
Mon Jul 25 15:23:32 2016 UTC (2 years, 11 months ago) by cchiw
File size: 12741 byte(s)
fixed image shape issue
(*
* genfn-Does preliminary scan of the body of EIN.EIN for vectorization potential
* If there is a field then passes to FieldToLow
* If there is a tensor then passes to handle*() functions to check if indices match
* i.e. <A_ij+B_ij>_ij vs.<A_ji+B_ij>_ij
*
*     (1) If indices match then passes to Iter->VecToLow functions.
*            Creates LowIL vector operators.
*     (2) Iter->ScaToLow
*           Creates Low-IL scalar operators
* Note. The Iter function creates LowIL.CONS and therefore binds the indices in the EIN.body 
*)
structure EinToLowSet = struct
    local

    structure Var = LowIL.Var
    structure E = Ein
    structure P=Printer
    structure Iter=IterSet
    structure EtoFld= FieldToLowSet
    structure EtoSca= ScaToLowSet
    structure EtoVec= VecToLowSet
    structure H=HelperSet

    in

    fun iter e=Iter.prodIter e
    fun evalField e= EtoFld.evalField e
    fun intToReal n=H.intToReal n
    fun testp p=(String.concat p)
    val scaFlag= ref false
    
    val controls = [("scaFlag",scaFlag,"scaFlag")]


    (*dropIndex: a list-> int*a*alist
    * alpha::i->returns  length of list-1,i,alpha
    *)
    fun dropIndex alpha=let
        val (e1::es)=List.rev(alpha)
        in (length alpha-1,e1,List.rev es)
        end

    (*matchLast:E.alpha*int -> (E.alpha) Option
    * Is the last index of alpha E.V n.
    * If so, return the rest of the list 
    *)
    fun matchLast(alpha, n)=let
        val (e1::es)=List.rev(alpha)
        in (case e1
            of E.V v =>(case (n=v)
                of true => SOME(List.rev es)
                |_ =>   NONE
                (*end case*))
            | _ => NONE
            (*end case*))
        end

    (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
    * Is the last index of alpha =n.
    * is n anywhere else?
    *)
    fun matchFindLast(alpha, n)=let
        val es=List.tl(List.rev(alpha))
        val f=List.find(fn E.V e=>e=n|_=>false) es
        in
            (matchLast(alpha,n),f)
        end 

    (*runGeneralCase:Var*E.EIN*Var-> Var*LowIL.ASSN list
    * does not do vector projections
    * instead approach like a general EIN
    *)
    
    fun runGeneralCase(lhs:string,e:Ein.ein,args:LowIL.var list)=let
        val info=(lhs,e,args)
        val index=Ein.index e
        val opset= lowSet.LowSet.empty

        val rtn= iter(opset,index,index,EtoSca.generalfn,info)
        val (_,_,code)=rtn
        val n= length(code)
        
        (*
        val _ =if (n>10) then print(String.concat["\n Gen(",Int.toString(n),")",P.printerE(e)]) else print""*)
        in rtn end


    (*handleNeg:.body* int list*info ->Var*LowIL.ASSN list
    * info:(string*E.EIN*Var list)
    * low-IL code for scaling a vector with negative 1.
    *)
    fun handleNeg(E.Op1(E.Neg,E.Tensor(id ,alpha)),index,info)=let
        val (n,vecIndex,index')=dropIndex index
        in (case (matchLast(alpha,n))
            of SOME ix1 => let
                  val setT= lowSet.LowSet.empty
                val (setA,vA,A)= intToReal( setT, ~1)
                val (lhs,e,args)=info
                val nextfnargs=(lhs,Ein.params e,args,vecIndex, vA, id,ix1)
                val (setB,vB,B)=iter(setA,index,index',EtoVec.negV,nextfnargs)
                in
                    (setB,vB,A@B)
                end
            | NONE => runGeneralCase info
            (*end case*))
        end

    (*handleSub:E.body*int list*info ->Var*LowIL.ASSN list
    * info:(string*E.EIN*Var list)
    * low-IL code for subtracting two vectors
    *)
    fun handleSub(E.Op2(E.Sub,E.Tensor(id1,alpha),E.Tensor(id2,beta)),index,info)=let
        val (n,vecIndex,index')=dropIndex index
        in (case(matchLast(alpha,n) , matchLast(beta,n)) of
            (SOME ix1,SOME ix2)=>let
                val (lhs,e,args)=info
                val setT= lowSet.LowSet.empty
                val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,ix1,id2,ix2)
                (*val _ =print(String.concat["\nsubtraction:",P.printerE(e),String.concatWith","(List.map LowIL.Var.toString args)])*)
                in
                    iter(setT,index,index',EtoVec.subV,nextfnargs)
                end 
            | _   => runGeneralCase info
            (*end case*))
        end


    (*handleAdd:E.body*int list*info ->Var*LowIL.ASSN list
    * info:(string*E.EIN*Var list)
    * low-IL code for adding two vectors
    *)
    fun handleAdd(E.Opn(E.Add, es),index,info)=let
        val (n,vecIndex,index')=dropIndex index
        (*check that each tensor in addition list has matching indices*)
        fun sample([],rest)=let
            val (lhs,e,args)=info
            val setT= lowSet.LowSet.empty
            val nextfnargs=(lhs,Ein.params e, args,vecIndex,rest)
            in
                iter(setT,index,index',EtoVec.addV,nextfnargs)
            end 
        | sample(E.Tensor(id,alpha)::ts,rest) =(case (matchLast(alpha,n))
            of SOME ix1    => sample(ts,rest@[(id,ix1)])
            | _            => runGeneralCase info
            (*end case*))
        | sample _ = runGeneralCase info
        in
            sample(es,[])
        end
        


    (*handleScale:E.tensor_id*E.tensor_id*E.alpha*int list*info ->Var*LowIL.ASSN list
    * info:(string*E.EIN*Var list)
    * low-IL code for adding scaling a vector
    *)
    fun handleScale(id1,id2,alpha2,index,info)=let
        val (n,vecIndex,index')=dropIndex index
        in (case matchLast(alpha2,n)
            of SOME ix2=>  let
                val (lhs,e,args)=info
                 val setT= lowSet.LowSet.empty
                val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,[],id2,ix2)
                in
                    iter(setT,index,index',EtoVec.scaleV,nextfnargs)
                end
            | _=>runGeneralCase info
            (*end case*))
        end

    (*handleProd:E.body*int list*info ->Var*LowIL.ASSN list
    * info:(string*E.EIN*Var list)
    * low-IL code for vector product
    *)
    fun handleProd(E.Opn(E.Prod,[E.Tensor(id1 , alpha), E.Tensor(id2, beta)]),index,info)=let
        val (lhs,e,args)=info
        val (n,vecIndex,index')=dropIndex index
        val setT= lowSet.LowSet.empty
        (*val _ =print(String.concat["\nproduct:",P.printerE(e),String.concatWith","(List.map LowIL.Var.toString args)])*)

        in (case(matchFindLast(alpha,n),matchFindLast(beta,n))
            of ((SOME ix1,NONE),(SOME ix2,NONE)) => let
                (*n is the last index of alpha, beta and nowhere else,possible modulate*)
                
                val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,ix1,id2,ix2)
                in
                    iter(setT,index,index',EtoVec.prodV,nextfnargs)
                end
            | ((NONE,NONE),(SOME ix2,NONE)) =>let
                (*n is the last index of beta and nowhere else,possible scaleVector*)
                val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,alpha,id2,ix2)
                in
                    iter(setT,index,index',EtoVec.scaleV,nextfnargs)
                end
            | ((SOME ix1,NONE),(NONE,NONE)) =>let
                (*n is the last index of alpha and nowhere else,ossile scaleVector*)
                val nextfnargs=(lhs,Ein.params e, args,vecIndex,id2,beta,id1,ix1)
                in
                    iter(setT,index,index',EtoVec.scaleV,nextfnargs)
                end
            | _ =>runGeneralCase info
            (*end case*))
        end
    
    (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
    * info:(string*E.EIN*Var list)
    * low-IL code for dot product
    *)
    fun (*handleSumProd1(E.Sum([(E.V 1,_,ub)],E.Opn(E.Prod,[E.Tensor(id1 , [E.V 1]), E.Tensor(id2, [E.V 1,E.V 0])])    ),[i],info)=let
        val (lhs,e,args)=info
         val setT= lowSet.LowSet.empty
        val nextfnargs=(lhs,Ein.params e, args,ub+1,id1,[],id2,[E.V 0])
            in
            iter(setT,[i],[i],EtoVec.VM,nextfnargs)
         end
    
        | handleSumProd1(E.Sum([(E.V 2,_,ub)],E.Prod[E.Tensor(id1 , [E.V 0,E.V 2]), E.Tensor(id2, [E.V 2,E.V 1])]),index as [_,_],info)=let
        val (lhs,e,args)=info
         val setT= lowSet.LowSet.empty
        val nextfnargs=(lhs,Ein.params e, args,ub+1,id1,[E.V 0],id2,[E.V 1],index)
        val _ ="\nuses projFirst"
        in
             iter(setT,index,index,EtoVec.MM3,nextfnargs)
        end
    
        |*) handleSumProd1(E.Sum([(E.V v,_,ub)],E.Opn(E.Prod,[E.Tensor(id1 , alpha), E.Tensor(id2, beta)])),index,info)=
            (case(matchFindLast(alpha,v),matchFindLast(beta,v))
                of ((SOME ix1,NONE),(SOME ix2,NONE)) => let
                    (*v is the last index of alpha, beta and nowhere else,possible sumProd*)
                    val (lhs,e,args)=info
                    val setT= lowSet.LowSet.empty
                    val nextfnargs=(lhs,Ein.params e, args,ub+1,id1,ix1,id2,ix2)
                    in
                        iter(setT,index,index,EtoVec.dotV,nextfnargs)
                    end
            | _ =>runGeneralCase info
            (*end case*))
    
    (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
    * info:(string*E.EIN*Var list)
    * low-IL code for double dot product
    * Sigma_{i,j} A_ij B_ij
    *)
    fun handleSumProd2(E.Sum([(E.V v1,lb1,ub1),(E.V v2,lb2,ub2)],E.Opn(E.Prod,[E.Tensor(id1 , alpha), E.Tensor(id2, beta)])),index,info)=let
            fun check(v,ub,sx)=(case(matchFindLast(alpha,v),matchFindLast(beta,v))
                of ((SOME ix1,NONE),(SOME ix2,NONE)) => let
                    (*v is the last index of alpha, beta and nowhere else,possible sumProd*)
                    val (lhs,e,args)=info
                    val setT= lowSet.LowSet.empty
                    val nextfnargs=(lhs,Ein.params e, args,sx,ub+1,id1,ix1,id2,ix2)
                    in
                        SOME(iter(setT,index,index,EtoVec.sumDotV,nextfnargs))
                    end
                | _=> NONE
                (*end case*))
            in (case check(v1,ub1,(E.V v2,lb2,ub2))
                of SOME e=>e
                | _=> (case check(v2,ub2,(E.V v1,lb1,ub1))
                    of SOME e=> e
                    |_ =>runGeneralCase info
                    (*end case*))
                (*end case*))
            end
    
    
    (*scan:var*E.Ein*Var list * Var list-> Var*LowIL.Assgn list
    *scans body  for vectorization potential
    *)
    fun q(y,e:Ein.ein,args:LowIL.var list)= let
        val lhs=LowIL.Var.name y
        val b=Ein.body e
        val index=Ein.index e
        val info=(lhs,e,args)
        val all=(b,index,info)
        (* DEBUG*) val _ =print(String.concat["\n\n*** ", lhs,"=", P.printerE(e), String.concatWith"," (List.map (fn e=> LowIL.Var.name(e))args )]) 
        fun gen body=(case ([3,4],body)
            of (_::es,E.Op2(E.Sub,E.Tensor(_,i::ix),E.Tensor(_,j::jx)))
                => handleSub all

            |  (_::es, E.Opn(E.Add,(E.Tensor(_,E.V _::ix)::_)))
                =>  handleAdd all

            | (_::es,E.Op1(E.Neg,E.Tensor(_ ,i::ix)))
                => handleNeg all
      
            | (_::es, E.Opn(E.Prod,[E.Tensor(s, []), E.Tensor(v, (j as E.V _)::jx)]))
                => handleScale(s,v,j::jx,index,info)
            
            |  (_::es,E.Opn(E.Prod,[E.Tensor(v,  (j as E.V _)::jx), E.Tensor(s , [])]))
                => handleScale(s,v,j::jx,index,info)
              
            |  (_::es,E.Opn(E.Prod,[E.Tensor(_ , (E.V _)::ix), E.Tensor(_, (E.V _)::jx)]))
                => handleProd all


            |  ( _,E.Sum([_], E.Opn(E.Prod,[E.Tensor(_ , ( E.V _)::_), E.Tensor(_,  ( E.V _)::_)])))
                =>  handleSumProd1 all
            |  ( _ ,E.Sum([_,_],E.Opn( E.Prod,[E.Tensor(_ , ( E.V _)::_), E.Tensor(_, ( E.V _)::_)])))
                => handleSumProd2 all

            | (_,_ )=> runGeneralCase info
            (*end case*))
            
            
        fun scanSize body=(case (List.rev index,body)
            of (3::_,_) => runGeneralCase info
            |  ( _,E.Sum([(_,0,2)], E.Opn(E.Prod,[E.Tensor(_ , (E.V _)::_), E.Tensor(_, (E.V _)::_)])))
            =>  runGeneralCase info
            |  ( _ ,E.Sum([(_,0,2),(_,0,2)],E.Opn( E.Prod,[E.Tensor(_ , (E.V _)::_), E.Tensor(_, (E.V _)::_)])))
            => runGeneralCase info

            | (_, E.Opn(E.Prod,E.Tensor(_,_::_::_)::_))=> runGeneralCase info
            | (_,E.Sum(_,E.Opn(E.Prod,(E.Tensor(_,_::_::_)::_))))=> runGeneralCase info
            | (_,E.Sum(_,E.Opn(E.Prod,(_::E.Tensor(_,_::_::_)::_))))=> runGeneralCase info
            |  (_,_ )=> gen b
            (*end case*))
        in
        if (!scaFlag) then ( "ein-to-low sca";runGeneralCase info) else ("ein-to-low vec:"; gen b)
        end

end (* local *)

end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0