Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3174 - (download) (annotate)
Mon Mar 30 11:46:58 2015 UTC (4 years, 2 months ago) by cchiw
File size: 16241 byte(s)
hack
(* Currently under construction
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)
 
 (*
  During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators.  Each EIN expression then only has one operation working on  constants, tensors, deltas, epsilons, images and kernels.
 
 (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
 
 (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
 
 (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
 
 (1c) Call cleanParams.sml to clean the params in the subexpression.\\
 *)

structure Split = struct

    local
   
    structure E = Ein
    structure DstIL = MidIL
    structure DstTy = MidILTypes
    structure DstV = DstIL.Var

    structure P=Printer
    structure cleanP=cleanParams
    structure cleanI=cleanIndex


    in

    val numFlag=1   (*remove common subexpression*)
    val testing=0
    fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
    fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
    val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
    fun setEinZero y=  (y,einappzero)
    fun cleanParams e =cleanP.cleanParams e
    fun cleanIndex e =cleanI.cleanIndex e
    fun printEINAPP e=MidToString.printEINAPP e
    fun itos i =Int.toString i
    fun filterSca e=Filter.filterSca e
    fun err str=raise Fail str
    val cnt = ref 0
    fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
    fun genName prefix = let
        val n = !cnt
    in
            cnt := n+1;
        String.concat[prefix, "_", Int.toString n]
    end
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))

    (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
    *lifts expression and returns replacement tensor
    * cleans the index and params of subexpression
    *creates new param and replacement tensor for the original ein_exp
    *)
    fun lift(name,e,params,index,sx,args,fieldset,flag)=let
        val (tshape,sizes,body)=cleanIndex(e,index,sx)
        val id=length(params)
        val Rparams=params@[E.TEN(1,sizes)]
        val Re=E.Tensor(id,tshape)
        val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
        val Rargs=args@[M]
        val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
        val (_,einapp0)=einapp
        val (Rargs,newbies,fieldset) =(case flag
            of 1=> let
                val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0)
                in (case var
                    of NONE=> (args@[M],[einapp],fieldset)
                    | SOME v=> (incUse v ;(args@[v],[],fieldset))
                    (*end case*))
                end 
            | _=>(args@[M],[einapp],fieldset)
              (*end case*))
        in
            (Re,Rparams,Rargs,newbies,fieldset)
        end


    (* isOp: ein->int
     * checks to see if this sub-expression is pulled out or split form original
     * 0-becomes zero,1-remains the same, 2-operator
     *)
    fun isOp e =(case e
        of E.Field _  => 0
        | E.Conv _    => 0
        | E.Apply _   => 0
        | E.Lift _    => 0
        | E.Neg _     => 1
        | E.Sqrt _    => 1
        | E.Cosine _      => 1
        | E.ArcCosine _   => 1
        | E.Sine _        => 1
        | E.ArcSine _     => 1
        | E.PowInt _      => 1
        | E.PowReal _     => 1
        | E.Add _     => 1
        | E.Sub _     => 1
        | E.Prod _    => 1
        | E.Div _     => 1
        | E.Sum _     => 1
        | E.Probe _   => 1
        | E.Partial _ => err(" Partial used after normalize")
        | E.Krn _     => err("Krn used before expand")
        | E.Value _   => err("Value used before expand")
        | E.Img _     => err("Probe used before expand")
        | _           => 2
    (*end case*))



    fun rewriteOp3(name,sx,e1,x)=let
        val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
        val params=Ein.params ein
        val index=Ein.index ein 
        in (case (isOp e1)
            of  0   => (E.Const 0,params,args,[],fieldset)
            | 1     => lift(name,e1,params,index,sx,args,fieldset,flag)
            | 2     => (e1,params,args,[],fieldset)            
            (*end*))
        end 

    (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
    * If e1 an op then call lift() to replace it
    *)
    fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1)
        of  0   => (E.Const 0,params,args,[],fieldset)
        | 1     => lift(name,e1,params,index,sx,args,fieldset,flag)
        | 2     => (e1,params,args,[],fieldset)             (*not lifted*)
        (*end*))

      


    fun rewriteOps(name,list1,params,index,sx,args,fieldset0,flag)=let
        fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset)
        | m(e1::es,rest,params,args,code,fieldset)=let

            val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag)
            in
                m(es,rest@[e1'],params',args',code@code',fieldset)
            end
        in
                m(list1,[],params,args,[],fieldset0)
        end


    (*rewriteOrig: var* ein_exp* params*index list*mid-il vars
           When the operation is zero then we return a real.
        -Moved is Zero to before split.
    *)
    fun rewriteOrig(y,body,params,index,sx,args) =cleanParams(y,body,params,index,args)

    fun rewriteOrig3(sx,body,params,args,x) =let
        val ((y,DstIL.EINAPP(ein,_)),_,_)=x
        val index=Ein.index ein 
        in  cleanParams(y,body,params,index,args)
        end 

    (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
    * calls rewriteOp() lift  on ein_exp
    *)
    fun handleNeg(e1,x)=let
        val (e1',params',args',code,fieldset)=  rewriteOp3("neg",[],e1,x)
        val body' =E.Neg e1'
        val einapp= rewriteOrig3([],body',params',args',x)
        in
            (einapp,code,fieldset)
        end

    (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code
    * calls rewriteOp() lift  on ein_exp
    *)
    fun handleSqrt(y,e1,params,index,args,fieldset,flag)=let
        val (e1',params',args',code,fieldset)=  rewriteOp("sqrt", e1,params,index,[],args,fieldset,flag)
        val body =E.Sqrt e1'
        val einapp= rewriteOrig(y,body,params',index,[],args')
    in
        (einapp,code,fieldset)
    end


    (* handleCosine:var*ein_exp *params*index*args-> (var*einap)*code
    * calls rewriteOp() lift  on ein_exp
    *)
    fun handleCosine(y,e1,params,index,args,fieldset,flag)=let
        val (e1',params',args',code,fieldset)=  rewriteOp("cosine", e1,params,index,[],args,fieldset,flag)
        val body =E.Cosine e1'
        val einapp= rewriteOrig(y,body,params',index,[],args')
        in
            (einapp,code,fieldset)
    end

    (* handleArcCosine:var*ein_exp *params*index*args-> (var*einap)*code
    * calls rewriteOp() lift  on ein_exp
    *)
    fun handleArcCosine(y,e1,params,index,args,fieldset,flag)=let
        val (e1',params',args',code,fieldset)=  rewriteOp("ArcCosine", e1,params,index,[],args,fieldset,flag)
        val body =E.ArcCosine e1'
        val einapp= rewriteOrig(y,body,params',index,[],args')
        in
            (einapp,code,fieldset)
    end

    (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
    * calls rewriteOp() lift  on ein_exp
    *)
    fun handleSine(y,e1,params,index,args,fieldset,flag)=let
        val (e1',params',args',code,fieldset)=  rewriteOp("sine", e1,params,index,[],args,fieldset,flag)
        val body =E.Sine e1'
        val einapp= rewriteOrig(y,body,params',index,[],args')
        in
            (einapp,code,fieldset)
    end

    (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
    * calls rewriteOp() lift  on ein_exp
    *)
    fun handleArcSine(y,e1,params,index,args,fieldset,flag)=let
        val (e1',params',args',code,fieldset)=  rewriteOp("ArcSine", e1,params,index,[],args,fieldset,flag)
        val body =E.ArcSine e1'
        val einapp= rewriteOrig(y,body,params',index,[],args')
        in
            (einapp,code,fieldset)
        end


   (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
    * calls rewriteOps() lift  on ein_exp
    *)
    fun handleSub(y,e1,e2,params,index,args,fieldset,flag)=let
        val ([e1',e2'],params',args',code,fieldset)=  rewriteOps("subt",[e1,e2],params,index,[],args,fieldset,flag)
        val body =E.Sub(e1',e2')
        val einapp= rewriteOrig(y,body,params',index,[],args')
        in
            (einapp,code,fieldset)
        end

    (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
    * calls rewriteOp() lift  on ein_exp
    *)
    fun handleDiv(y,e1,e2,params,index,args,fieldset,flag)=let
        val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag)
        val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag)
        val body =E.Div(e1',e2')
        val einapp= rewriteOrig(y,body,params2',index,[],args2')
        in
                (einapp,code1'@code2',fieldset)
        end

    (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
    * calls rewriteOps() lift  on ein_exp
    *)
    fun handleAdd(y,e1,params,index,args,fieldset,flag)=let

        val (e1',params',args',code,fieldset)=  rewriteOps("add",e1,params,index,[],args,fieldset,flag)
        val body =E.Add e1'
        val einapp= rewriteOrig(y,body,params',index,[],args')
        in
            (einapp,code,fieldset)
        end

    (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
     * calls rewriteOps() lift  on ein_exp
     *)
    fun handleProd(y,e1,params,index,args,fieldset,flag)=let
        val (e1',params',args',code,fieldset)=  rewriteOps("prod",e1,params,index,[],args,fieldset,flag)
        val body =E.Prod e1'
        val einapp= rewriteOrig(y,body,params',index,[],args')
        in
            (einapp,code,fieldset)
        end

   (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
    * calls rewriteOps() lift  on ein_exp
    *)
    fun handleSumProd(y,e1,params,index,sx,args,fieldset,flag)=let
        val (e1',params',args',code,fieldset)=  rewriteOps("sumprod",e1,params,index,sx,args,fieldset,flag)
        val body= E.Sum(sx,E.Prod e1')
        val einapp= rewriteOrig(y,body,params',index,sx,args')
        in
            (einapp,code,fieldset)
        end

    (* split:var*ein_app-> (var*einap)*code
    * split ein expression into smaller pieces
      note we leave summation around probe exp
    *)
    fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let
        val x= ((y,einapp),fieldset,flag)
        val zero=   (setEinZero y,[],fieldset)
        val default=((y,einapp),[],fieldset)
        val sumIndex=ref []
        val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
        val _=testp["\n\nStarting split",P.printbody body]
            fun rewrite b=(case b
            of E.Probe (E.Conv _,_)   => default
            | E.Probe(E.Field _,_)    => raise Fail str
            | E.Probe _               => raise Fail str
            | E.Conv _                => zero
            | E.Field _               => zero
            | E.Apply _               => zero
            | E.Lift e                => zero
            | E.Delta _               => default
            | E.Epsilon _             => default
            | E.Eps2 _                => default
            | E.Tensor _              => default
            | E.Const _               => default
            | E.ConstR _              => default
            | E.Neg e1                => handleNeg(e1,x)
            | E.Sqrt e1               => handleSqrt(y,e1,params,index,args,fieldset,flag)
            | E.Cosine e1             => handleCosine(y,e1,params,index,args,fieldset,flag)
            | E.ArcCosine e1          => handleArcCosine(y,e1,params,index,args,fieldset,flag)
            | E.Sine e1               => handleSine(y,e1,params,index,args,fieldset,flag)
            | E.ArcSine e1            => handleArcSine(y,e1,params,index,args,fieldset,flag)
            | E.PowInt e1             => err(" PowInt unsupported")
            | E.PowReal e1            => err(" PowReal unsupported")
            | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args,fieldset,flag)
            | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args,fieldset,flag)
            | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_)  ])      => default
            | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_)  ])      => default
            | E.Sum(_,E.Probe(E.Conv _,_))    => default
            | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args,fieldset,flag)
            | E.Sum(sx,E.Delta d)     => handleSumProd(y,[E.Delta d],params,index,sx,args,fieldset,flag)
            | E.Sum(sx,_)             => err(" summation not distributed:"^str)
            | E.Add e1                => handleAdd(y,e1,params,index,args,fieldset,flag)
            | E.Prod e1               => handleProd(y,e1,params,index,args,fieldset,flag)
            | E.Partial _             => err(" Partial used after normalize")
            | E.Krn _                 => err("Krn used before expand")
            | E.Value _               => err("Value used before expand")
            | E.Img _                 => err("Probe used before expand")
            (*end case *))
        val (einapp2,newbies,fieldset) =rewrite body
        in
            ((einapp2,newbies),fieldset)
        end
        |split((y,app),fieldset,_) =(((y,app),[]),fieldset)


    fun iterMultiple(einapp2,newbies2,fieldset)=let
        fun itercode([],rest,code,_)=(rest,code)
          | itercode(e1::newbies,rest,code,cnt)=let
                val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
                    val (rest4,code4)=itercode(code3,[],[],cnt+1)
        in
            itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
        end
            val(rest,code)= itercode(newbies2,[],[],1)
        in
            ((code)@rest@[einapp2])
        end


    fun iterAll(einapp2,fieldset)=let
        fun itercode([],rest,code,_)=(rest,code)
        | itercode(e1::newbies,rest,code,cnt)=let
            val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
            val (rest4,code4)=itercode(code3,[],[],cnt+1)
            in
                itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
            end
        val(rest,code)= itercode(einapp2,[],[],0)
        in
            (code@rest)
        end

    fun splitEinApp einapp3= let
        val fieldset= einSet.EinSet.empty

        (* **** split in parts **** *)
        (*
        val ((einapp4,newbies4),fieldset)=split(einapp3,fieldset,0)
        val _ =testp["\n\t===>\n",printEINAPP(einapp4),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP newbies4))]
        val (newbies5)= iterMultiple(einapp4,newbies4,fieldset)
        *)

        (* **** split all at once **** *)
        val (newbies5)= iterAll([einapp3],fieldset)

        in
            newbies5
        end


  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0