(* Currently under construction * * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) * * COPYRIGHT (c) 2015 The University of Chicago * All rights reserved. *) (* During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. (1c) Call cleanParams.sml to clean the params in the subexpression.\\ *) structure Split = struct local structure E = Ein structure DstIL = MidIL structure DstTy = MidILTypes structure DstV = DstIL.Var structure P=Printer structure cleanP=cleanParams structure cleanI=cleanIndex in val numFlag=1 (*remove common subexpression*) val testing=0 fun mkEin e = E.mkEin e val einappzero= DstIL.EINAPP(mkEin([],[],E.B(E.Const 0)),[]) fun setEinZero y= (y,einappzero) fun cleanParams e = cleanP.cleanParams e fun cleanIndex e = cleanI.cleanIndex e fun toStringBind e= MidToString.toStringBind e fun itos i = Int.toString i fun err str = raise Fail str val cnt = ref 0 fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1) fun genName prefix = let val n = !cnt in cnt := n+1; String.concat[prefix, "_", Int.toString n] end fun testp n=(case testing of 0=> 1 | _ =>(print(String.concat n);1) (*end case*)) (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) *lifts expression and returns replacement tensor * cleans the index and params of subexpression *creates new param and replacement tensor for the original ein_exp *) fun lift(name,e,params,index,sx,args,fieldset,flag)=let val (tshape,sizes,body)=cleanIndex(e,index,sx) val id=length(params) val Rparams=params@[E.TEN(1,sizes)] val Re=E.Tensor(id,tshape) val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes) val Rargs=args@[M] val einapp=cleanParams(M,body,Rparams,sizes,Rargs) val (_,einapp0)=einapp val (Rargs,newbies,fieldset) =(case flag of 1=> let val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0) in (case var of NONE=> (args@[M],[einapp],fieldset) | SOME v=> (incUse v ;(args@[v],[],fieldset)) (*end case*)) end | _=>(args@[M],[einapp],fieldset) (*end case*)) in (Re,Rparams,Rargs,newbies,fieldset) end (* isOp: ein->int * checks to see if this sub-expression is pulled out or split form original * 0-becomes zero,1-remains the same, 2-operator *) fun isOp e =(case e of E.Field _ => 0 | E.Conv _ => 0 | E.Apply _ => 0 | E.Lift _ => 0 | E.Op1 _ => 1 | E.Op2 _ => 1 | E.Opn _ => 1 | E.Sum _ => 1 | E.Probe _ => 1 | E.Partial _ => err(" Partial used after normalize") | E.Krn _ => err("Krn used before expand") | E.Value _ => err("Value used before expand") | E.Img _ => err("Probe used before expand") | _ => 2 (*end case*)) (* *************************************** helpers ******************************** *) fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1) of 0 => (E.B(E.Const 0),params,args,[],fieldset) | 1 => lift(name,e1,params,index,sx,args,fieldset,flag) | _ => (e1,params,args,[],fieldset) (*not lifted*) (*end*)) fun unaryOp(name,sx,e1,x)=let val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x val params=Ein.params ein val index=Ein.index ein in rewriteOp(name,e1,params,index,sx,args,fieldset,flag) end fun multOp(name,sx,list1,x)=let val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x val params=Ein.params ein val index=Ein.index ein fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset) | m(e1::es,rest,params,args,code,fieldset)=let val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag) in m(es,rest@[e1'],params',args',code@code',fieldset) end in m(list1,[],params,args,[],fieldset) end (*clean params*) fun cleanOrig(body,params,args,x) =let val ((y,DstIL.EINAPP(ein,_)),_,_)=x val index=Ein.index ein in cleanParams(y,body,params,index,args) end (* *************************************** general handle Ops ******************************** *) fun handleUnaryOp(name,opp,x,e1)=let val (e1',params',args',code,fieldset)= unaryOp(name,[],e1,x) val body' =E.Op1(opp, e1') val einapp= cleanOrig(body',params',args',x) in (einapp,code,fieldset) end fun handleBinaryOp(name,opp,x,es)=let val ([e1',e2'],params',args',code,fieldset)= multOp(name,[],es,x) val body' =E.Op2(opp,e1',e2') val einapp= cleanOrig(body',params',args',x) in (einapp,code,fieldset) end fun handleMultOp(name,opp,x,es)= let val (e1',params',args',code,fieldset)= multOp(name,[],es,x) val body =E.Opn(opp ,e1') val einapp= cleanOrig(body,params',args',x) in (einapp,code,fieldset) end (* ***************************************specific handle Ops ******************************** *) fun handleDiv(e1,e2,x)=let val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x val params=Ein.params ein val index=Ein.index ein val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag) val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag) val body' =E.Op2(E.Div,e1',e2') val einapp= cleanOrig(body',params2',args2',x) in (einapp,code1'@code2',fieldset) end fun handleSumProd(e1,sx,x)=let val (e1',params',args',code,fieldset)= multOp("sumprod",sx,e1,x) val body'= E.Sum(sx,E.Opn(E.Prod, e1')) val einapp= cleanOrig(body',params',args',x) in (einapp,code,fieldset) end (* *************************************** Split ******************************** *) (* split:var*ein_app-> (var*einap)*code * split ein expression into smaller pieces note we leave summation around probe exp *) fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let val x= ((y,einapp),fieldset,flag) val zero= (setEinZero y,[],fieldset) val default=((y,einapp),[],fieldset) val sumIndex=ref [] val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body) val _=testp["\n\nStarting split",P.printbody body] fun rewrite b=(case b of E.B _ => default | E.Tensor _ => default | E.G _ => default | E.Field _ => raise Fail "should have been swept" | E.Lift e => raise Fail "should have been swept" | E.Conv _ => raise Fail "should have been swept" | E.Partial _ => err(" Partial used after normalize") | E.Apply _ => raise Fail "should have been swept" | E.Probe(E.Conv _,_) => default | E.Probe(E.Field _,_) => raise Fail str | E.Probe _ => raise Fail str | E.Value _ => err("Value used before expand") | E.Img _ => err("Probe used before expand") | E.Krn _ => err("Krn used before expand") | E.Sum(_,E.Probe(E.Conv _,_)) => default | E.Sum(sx,E.Tensor _) => default (* | E.Sum(_,E.Opn(E.Prod,[E.Eps2 _, E.Probe(E.Conv _,_)])) => default | E.Sum(_,E.Opn(E.Prod,[E.Epsilon _, E.Probe(E.Conv _,_) ])) => default*) | E.Sum(sx,E.Opn(E.Prod, e1)) => handleSumProd(e1,sx,x) | E.Sum(sx,E.G(E.Delta d)) => handleSumProd([E.G(E.Delta d)],sx,x) | E.Sum(sx,_) => err(" summation not distributed:"^str) | E.Op1(op1,e1) => handleUnaryOp("op1",op1,x,e1) | E.Op2(E.Sub,e1,e2) => handleBinaryOp("subtract",E.Sub,x,[e1,e2]) | E.Op2(E.Div,e1,e2) => handleDiv(e1,e2,x) | E.Opn(E.Add,es) => handleMultOp("add",E.Add,x,es) | E.Opn(Prod,[E.Tensor(id0,[]),E.Tensor(id1,[i]),E.Tensor(id2,[])])=> rewrite (E.Opn(E.Prod,[ E.Opn(E.Prod,[E.Tensor(id0,[]),E.Tensor(id2,[])]),E.Tensor(id1,[i])])) | E.Opn(E.Prod,es) => handleMultOp("prod",E.Prod,x,es) (*end case *)) val (einapp2,newbies,fieldset) =rewrite body in ((einapp2,newbies),fieldset) end |split((y,app),fieldset,_) =(((y,app),[]),fieldset) (* *************************************** main ******************************** *) fun limitSplit(einapp2,splitlimit)=let val fieldset= einSet.EinSet.empty fun itercode([],rest,code,cnt)=(("\n Empty-SplitCount: "^Int.toString(cnt));(rest,code)) | itercode(e1::newbies,rest,code,cnt)=let val ((einapp3,code3),_) = split(e1,fieldset,numFlag) val (rest4,code4)=itercode(code3,[],[],cnt+1) val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))] in if (length(rest@newbies@code) > splitlimit) then let val _ =print("\n SplitCount: "^Int.toString(splitlimit)) val code5=code4@rest4@code val rest5=rest@[einapp3] in (rest5,code5@newbies)(*tab4*) end else itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2) end val(rest,code)= itercode([einapp2],[],[],0) in code@rest end fun limitSplit2(einapp2,splitlimit,fieldset)=let fun itercode([],rest,code,cnt)=(("\n Empty-SplitCount: "^Int.toString(cnt));(rest,code)) | itercode(e1::newbies,rest,code,cnt)=let val ((einapp3,code3),_) = split(e1,fieldset,numFlag) val (rest4,code4)=itercode(code3,[],[],cnt+1) val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))] in if (length(rest@newbies@code) > splitlimit) then let val _ =print("\n SplitCount: "^Int.toString(splitlimit)) val code5=code4@rest4@code val rest5=rest@[einapp3] in (rest5,code5@newbies)(*tab4*) end else itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2) end val(rest,code)= itercode([einapp2],[],[],0) in code@rest end fun splitEinApp einapp0 =let val fieldset= einSet.EinSet.empty val einapp2=[einapp0] fun itercode([],rest,code,_)=(rest,code) | itercode(e1::newbies,rest,code,cnt)=let val ((einapp3,code3),_) = split(e1,fieldset,numFlag) val (rest4,code4)=itercode(code3,[],[],cnt+1) val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))] in itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2) end val(rest,code)= itercode(einapp2,[],[],0) in (code@rest) end end; (* local *) end (* local *)
Click to toggle
does not end with </html> tag
does not end with </body> tag
The output has ended thus: einapp2,[],[],0) in (code@rest) end end; (* local *) end (* local *)