Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3448 - (download) (annotate)
Fri Nov 20 20:33:38 2015 UTC (3 years, 9 months ago) by cchiw
File size: 12828 byte(s)
merge over dev branch
(* Currently under construction
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)
 
 (*
  During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators.  Each EIN expression then only has one operation working on  constants, tensors, deltas, epsilons, images and kernels.
 
 (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
 
 (1a.) When a subexpression is a field expression $\circledast,\nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
 
 (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
 
 (1c) Call cleanParams.sml to clean the params in the subexpression.\\
 *)

structure Split = struct

    local
   
    structure E = Ein
    structure DstIL = MidIL
    structure DstTy = MidILTypes
    structure DstV = DstIL.Var

    structure P=Printer
    structure cleanP=cleanParams
    structure cleanI=cleanIndex


    in

    val numFlag=1   (*remove common subexpression*)
    val testing=0
    fun mkEin e = E.mkEin e
    val einappzero= DstIL.EINAPP(mkEin([],[],E.B(E.Const 0)),[])
    fun setEinZero y=  (y,einappzero)
    fun cleanParams e = cleanP.cleanParams e
    fun cleanIndex e = cleanI.cleanIndex e
    fun toStringBind e= MidToString.toStringBind e
    fun itos i = Int.toString i
    fun err str = raise Fail str
    val cnt = ref 0
    fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
    fun genName prefix = let
        val n = !cnt
    in
            cnt := n+1;
        String.concat[prefix, "_", Int.toString n]
    end
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))

    (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
    *lifts expression and returns replacement tensor
    * cleans the index and params of subexpression
    *creates new param and replacement tensor for the original ein_exp
    *)
    fun lift(name,e,params,index,sx,args,fieldset,flag)=let
        val (tshape,sizes,body)=cleanIndex(e,index,sx)
        val id=length(params)
        val Rparams=params@[E.TEN(1,sizes)]
        val Re=E.Tensor(id,tshape)
        val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
        val Rargs=args@[M]
        val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
        val (_,einapp0)=einapp
        val (Rargs,newbies,fieldset) =(case flag
            of 1=> let
                val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0)
                in (case var
                    of NONE=> (args@[M],[einapp],fieldset)
                    | SOME v=> (incUse v ;(args@[v],[],fieldset))
                    (*end case*))
                end 
            | _=>(args@[M],[einapp],fieldset)
              (*end case*))
        in
            (Re,Rparams,Rargs,newbies,fieldset)
        end

    (* isOp: ein->int
     * checks to see if this sub-expression is pulled out or split form original
     * 0-becomes zero,1-remains the same, 2-operator
     *)
    fun isOp e =(case e
        of E.Field _  => 0
        | E.Conv _    => 0
        | E.Apply _   => 0
        | E.Lift _    => 0
        | E.Op1 _     => 1
        | E.Op2 _     => 1
        | E.Opn _     => 1
        | E.Sum _     => 1
        | E.Probe _   => 1
        | E.Partial _ => err(" Partial used after normalize")
        | E.Krn _     => err("Krn used before expand")
        | E.Value _   => err("Value used before expand")
        | E.Img _     => err("Probe used before expand")
        | _           => 2
    (*end case*))

    (* *************************************** helpers ******************************** *)
    fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1)
        of  0   => (E.B(E.Const 0),params,args,[],fieldset)
        | 1     => lift(name,e1,params,index,sx,args,fieldset,flag)
        | _     => (e1,params,args,[],fieldset)             (*not lifted*)
        (*end*))

    fun unaryOp(name,sx,e1,x)=let
        val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
        val params=Ein.params ein
        val index=Ein.index ein 
        in
            rewriteOp(name,e1,params,index,sx,args,fieldset,flag)
        end

    fun multOp(name,sx,list1,x)=let
        val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
        val params=Ein.params ein
        val index=Ein.index ein
        fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset)
          | m(e1::es,rest,params,args,code,fieldset)=let
            val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag)
            in
                m(es,rest@[e1'],params',args',code@code',fieldset)
            end
        in
            m(list1,[],params,args,[],fieldset)
        end

    (*clean params*)
    fun cleanOrig(body,params,args,x) =let
        val ((y,DstIL.EINAPP(ein,_)),_,_)=x
        val index=Ein.index ein 
        in  cleanParams(y,body,params,index,args)
        end

    (* *************************************** general handle Ops ******************************** *)
    fun handleUnaryOp(name,opp,x,e1)=let
        val (e1',params',args',code,fieldset)=  unaryOp(name,[],e1,x)
        val body' =E.Op1(opp, e1')
        val einapp= cleanOrig(body',params',args',x)
        in
            (einapp,code,fieldset)
        end
    fun handleBinaryOp(name,opp,x,es)=let
        val ([e1',e2'],params',args',code,fieldset)= multOp(name,[],es,x)
        val body' =E.Op2(opp,e1',e2')
        val einapp= cleanOrig(body',params',args',x)
        in
            (einapp,code,fieldset)
        end
    fun handleMultOp(name,opp,x,es)= let
        val (e1',params',args',code,fieldset)= multOp(name,[],es,x)
        val body =E.Opn(opp ,e1')
        val einapp= cleanOrig(body,params',args',x)
        in
            (einapp,code,fieldset)
        end
    (* ***************************************specific handle Ops ******************************** *)
    fun handleDiv(e1,e2,x)=let
        val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
        val params=Ein.params ein
        val index=Ein.index ein
        val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag)
        val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag)
        val body' =E.Op2(E.Div,e1',e2')
        val einapp= cleanOrig(body',params2',args2',x)
        in
                (einapp,code1'@code2',fieldset)
        end
    fun handleSumProd(e1,sx,x)=let
        val (e1',params',args',code,fieldset)=  multOp("sumprod",sx,e1,x)
        val body'= E.Sum(sx,E.Opn(E.Prod, e1'))
        val einapp= cleanOrig(body',params',args',x)
        in
            (einapp,code,fieldset)
        end

    (* *************************************** Split ******************************** *)

    (* split:var*ein_app-> (var*einap)*code
    * split ein expression into smaller pieces
      note we leave summation around probe exp
    *)
    fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let
        val x= ((y,einapp),fieldset,flag)
        val zero=   (setEinZero y,[],fieldset)
        val default=((y,einapp),[],fieldset)
        val sumIndex=ref []
        val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
        val _=testp["\n\nStarting split",P.printbody body]
            fun rewrite b=(case b
            of E.B  _               => default
            | E.Tensor _              => default
            | E.G _                   => default
            | E.Field _               => raise Fail "should have been swept"
            | E.Lift e                => raise Fail "should have been swept"
            | E.Conv _                => raise Fail "should have been swept"
            | E.Partial _             => err(" Partial used after normalize")
            | E.Apply _               => raise Fail "should have been swept"
            | E.Probe(E.Conv _,_)     => default
            | E.Probe(E.Field _,_)    => raise Fail str
            | E.Probe _               => raise Fail str
            | E.Value _               => err("Value used before expand")
            | E.Img _                 => err("Probe used before expand")
            | E.Krn _                 => err("Krn used before expand")
            | E.Sum(_,E.Probe(E.Conv _,_)) => default
            | E.Sum(sx,E.Tensor _)    => default

            (* | E.Sum(_,E.Opn(E.Prod,[E.Eps2 _, E.Probe(E.Conv _,_)]))  => default
            | E.Sum(_,E.Opn(E.Prod,[E.Epsilon _, E.Probe(E.Conv _,_) ])) => default*)
            | E.Sum(sx,E.Opn(E.Prod, e1))     => handleSumProd(e1,sx,x)
            | E.Sum(sx,E.G(E.Delta d))  => handleSumProd([E.G(E.Delta d)],sx,x)
            | E.Sum(sx,_)             => err(" summation not distributed:"^str)
            | E.Op1(op1,e1)           =>
                (case op1
                    of E.Neg          => handleUnaryOp("neg",op1,x,e1)
                    | E.Sqrt          => handleUnaryOp("sqrt",op1,x,e1)
                    | E.Exp           => handleUnaryOp("exp",op1,x,e1)
                    | E.PowInt n1     => handleUnaryOp("PowInt",op1,x,e1)
                    | _               => handleUnaryOp("Trig",op1,x,e1)
                (*end case *))
            | E.Op2(E.Sub,e1,e2)      => handleBinaryOp("subtract",E.Sub,x,[e1,e2])
            | E.Op2(E.Div,e1,e2)      => handleDiv(e1,e2,x)
            | E.Opn(E.Add,es)         => handleMultOp("add",E.Add,x,es)
            | E.Opn(Prod,[E.Tensor(id0,[]),E.Tensor(id1,[i]),E.Tensor(id2,[])])=>
                rewrite (E.Opn(E.Prod,[
                    E.Opn(E.Prod,[E.Tensor(id0,[]),E.Tensor(id2,[])]),E.Tensor(id1,[i])]))
            | E.Opn(E.Prod,es)        => handleMultOp("prod",E.Prod,x,es)

            (*end case *))
        val (einapp2,newbies,fieldset) =rewrite body
        in
            ((einapp2,newbies),fieldset)
        end
        |split((y,app),fieldset,_) =(((y,app),[]),fieldset)


    (* *************************************** main  ******************************** *)
    fun limitSplit(einapp2,fields2,splitlimit)=let
        val fieldset= einSet.EinSet.empty
        val _ =print ("\nSPLit with limit"^(Int.toString(splitlimit)))
        fun itercode([],rest,code,cnt)=(("\n Empty-SplitCount: "^Int.toString(cnt));(rest,code))
        | itercode(e1::newbies,rest,code,cnt)=let
            val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
            val (rest4,code4)=itercode(code3,[],[],cnt+1)
            val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))]
            in
                if (length(rest@newbies@code) > splitlimit) then let
                        val _ =("\n SplitCount: "^Int.toString(cnt))
                        val code5=code4@rest4@code
                        val rest5=rest@[einapp3]
                        in
                            (rest5,code5@newbies)(*tab4*)
                        end
                else  itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2)
            end
        val(rest,code)= itercode([einapp2],[],[],0)
        in
           
              fields2@code@rest (*B*)
        end

    fun splitEinApp einapp0 =let
        val fieldset= einSet.EinSet.empty
        val einapp2=[einapp0]
        fun itercode([],rest,code,_)=(rest,code)
        | itercode(e1::newbies,rest,code,cnt)=let
            val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
            val (rest4,code4)=itercode(code3,[],[],cnt+1)
                val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))]
            in
                itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
            end
        val(rest,code)= itercode(einapp2,[],[],0)
        in
            (code@rest)
        end

  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0