(* Currently under construction * * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu) * All rights reserved. *) (* During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators. Each EIN expression then only has one operation working on constants, tensors, deltas, epsilons, images and kernels. (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten. (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size. (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement. (1c) Call cleanParams.sml to clean the params in the subexpression.\\ *) structure Split = struct local structure E = Ein structure DstIL = MidIL structure DstTy = MidILTypes structure DstV = DstIL.Var structure P=Printer structure cleanP=cleanParams structure cleanI=cleanIndex structure handleE=handleEin in val testing=0 fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body} fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args)) val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[]) fun setEinZero y= (y,einappzero) fun cleanParams e =cleanP.cleanParams e fun cleanIndex e =cleanI.cleanIndex e fun printEINAPP e=MidToString.printEINAPP e fun isZero e=handleE.isZero e fun sweep e=handleE.sweep e fun itos i =Int.toString i fun err str=raise Fail str val cnt = ref 0 fun genName prefix = let val n = !cnt in cnt := n+1; String.concat[prefix, "_", Int.toString n] end fun testp n=(case testing of 0=> 1 | _ =>(print(String.concat n);1) (*end case*)) (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) *lifts expression and returns replacement tensor * cleans the index and params of subexpression *creates new param and replacement tensor for the original ein_exp *) fun lift(name,e,params,index,sx,args)=let val (tshape,sizes,body)=cleanIndex(e,index,sx) val id=length(params) val Rparams=params@[E.TEN(1,sizes)] val Re=E.Tensor(id,tshape) val M = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes) val Rargs=args@[M] val einapp=cleanParams(M,body,Rparams,sizes,Rargs) in (Re,Rparams,Rargs,[einapp]) end (* isOp: ein->int * checks to see if this sub-expression is pulled out or split form original * 0-becomes zero,1-remains the same, 2-operator *) fun isOp e =(case e of E.Field _ => 0 | E.Conv _ => 0 | E.Apply _ => 0 | E.Lift _ => 0 | E.Neg _ => 1 | E.Sqrt _ => 1 | E.PowInt _ => 1 | E.PowReal _ => 1 | E.Add _ => 1 | E.Sub _ => 1 | E.Prod _ => 1 | E.Div _ => 1 | E.Sum _ => 1 | E.Probe _ => 1 | E.Partial _ => err(" Partial used after normalize") | E.Krn _ => err("Krn used before expand") | E.Value _ => err("Value used before expand") | E.Img _ => err("Probe used before expand") | _ => 2 (*end case*)) (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code * If e1 an op then call lift() to replace it * Otherwise rewrite to 0 or it remains the same *) fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1) of 0 => (E.Const 0,params,args,[]) | 2 => (e1,params,args,[]) | _ => lift(name,e1,params,index,sx,args) (*end*)) (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars -> ein_exp list*params*args*code * calls rewriteOp on ein_exp list *) fun rewriteOps(name,list1,params,index,sx,args)=let fun m([],rest,params,args,code)=(rest,params,args,code) | m(e1::es,rest,params,args,code)=let val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args) in m(es,rest@[e1'],params',args',code@code') end in m(list1,[],params,args,[]) end (*rewriteOrig: var* ein_exp* params*index list*mid-il vars When the operation is zero then we return a real. -Moved is Zero to before split. *) fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body) of 1=> setEinZero y | _ => cleanParams(y,body,params,index,args) (*end case*)) (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code * calls rewriteOp() lift on ein_exp *) fun handleNeg(y,e1,params,index,args)=let val (e1',params',args',code)= rewriteOp("neg", e1,params,index,[],args) val body =E.Neg e1' val einapp= rewriteOrig(y,body,params',index,[],args') in (einapp,code) end (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code * calls rewriteOp() lift on ein_exp *) fun handleSqrt(y,e1,params,index,args)=let val (e1',params',args',code)= rewriteOp("sqrt", e1,params,index,[],args) val body =E.Sqrt e1' val einapp= rewriteOrig(y,body,params',index,[],args') in (einapp,code) end (* handlePowInt:var*ein_exp *params*index*args-> (var*einap)*code * calls rewriteOp() lift on ein_exp *) fun handlePowInt(y,(e1,n1),params,index,args)=let val (e1',params',args',code)= rewriteOp("powint", e1,params,index,[],args) val body =E.PowInt(e1',n1) val einapp= rewriteOrig(y,body,params',index,[],args') in (einapp,code) end (* handlePowReal:var*ein_exp *params*index*args-> (var*einap)*code * calls rewriteOp() lift on ein_exp *) fun handlePowReal(y,(e1,n1),params,index,args)=let val (e1',params',args',code)= rewriteOp("powreal", e1,params,index,[],args) val body =E.PowReal(e1',n1) val einapp= rewriteOrig(y,body,params',index,[],args') in (einapp,code) end (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code * calls rewriteOps() lift on ein_exp *) fun handleSub(y,e1,e2,params,index,args)=let val ([e1',e2'],params',args',code)= rewriteOps("subt",[e1,e2],params,index,[],args) val body =E.Sub(e1',e2') val einapp= rewriteOrig(y,body,params',index,[],args') in (einapp,code) end (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code * calls rewriteOp() lift on ein_exp *) fun handleDiv(y,e1,e2,params,index,args)=let val (e1',params1',args1',code1')=rewriteOp(DstV.name y,e1,params,index,[],args) val (e2',params2',args2',code2')=rewriteOp("div",e2,params1',[],[],args1') val body =E.Div(e1',e2') val einapp= rewriteOrig(y,body,params2',index,[],args2') in (einapp,code1'@code2') end (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code * calls rewriteOps() lift on ein_exp *) fun handleAdd(y,e1,params,index,args)=let val (e1',params',args',code)= rewriteOps("add",e1,params,index,[],args) val body =E.Add e1' val einapp= rewriteOrig(y,body,params',index,[],args') in (einapp,code) end (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code * calls rewriteOps() lift on ein_exp *) fun handleProd(y,e1,params,index,args)=let val (e1',params',args',code)= rewriteOps("prod",e1,params,index,[],args) val body =E.Prod e1' val einapp= rewriteOrig(y,body,params',index,[],args') in (einapp,code) end (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code * calls rewriteOps() lift on ein_exp *) fun handleSumProd(y,e1,params,index,sx,args)=let val (e1',params',args',code)= rewriteOps("sumprod",e1,params,index,sx,args) val body= E.Sum(sx,E.Prod e1') val einapp= rewriteOrig(y,body,params',index,sx,args') in (einapp,code) end (* split:var*ein_app-> (var*einap)*code * split ein expression into smaller pieces note we leave summation around probe exp *) fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let val zero= (setEinZero y,[]) val default=((y,einapp),[]) val sumIndex=ref [] val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body) fun rewrite b=(case b of E.Probe (E.Conv _,_) => default | E.Probe(E.Field _,_) => raise Fail str | E.Probe _ => raise Fail str | E.Conv _ => zero | E.Field _ => zero | E.Apply _ => zero | E.Lift e => zero | E.Delta _ => default | E.Epsilon _ => default | E.Eps2 _ => default | E.Tensor _ => default | E.Const _ => default | E.ConstR _ => default | E.Neg e1 => handleNeg(y,e1,params,index,args) | E.Sqrt e1 => handleSqrt(y,e1,params,index,args) | E.PowInt e1 => handlePowInt(y,e1,params,index,args) | E.PowReal e1 => handlePowReal(y,e1,params,index,args) | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args) | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args) | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n))) | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a)) | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2))) | E.Sum(sx,E.Div(e1,e2)) => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2))) | E.Sum(sx,E.Lift e ) => rewrite (E.Lift(E.Sum(sx,e))) | E.Sum(sx,E.PowReal(e,n1)) => rewrite (E.PowReal(E.Sum(sx,e),n1)) | E.Sum(sx,E.Sqrt e) => rewrite (E.Sqrt(E.Sum(sx,e))) | E.Sum(c1,E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e)) | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_) ]) => default | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_) ]) => default | E.Sum(_,E.Probe(E.Conv _,_)) => default | E.Sum(_,E.Conv _) => zero | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args) | E.Sum(sx,_) => default | E.Add e1 => handleAdd(y,e1,params,index,args) | E.Prod e1 => handleProd(y,e1,params,index,args) | E.Partial _ => err(" Partial used after normalize") | E.Krn _ => err("Krn used before expand") | E.Value _ => err("Value used before expand") | E.Img _ => err("Probe used before expand") (*end case *)) val (einapp2,newbies) =rewrite body in (einapp2,newbies) end |split(y,app) =((y,app),[]) (* iterMultiple:code*code=> (code*code) * recursively split ein expression into smaller pieces *) fun iterMultiple(einapp2,newbies2)=let fun itercode([],rest,code)=(rest,code) | itercode(e1::newbies,rest,code)=let val (einapp3,code3) =split(e1) val (rest4,code4)=itercode(code3,[],[]) in itercode(newbies,rest@[einapp3],code4@rest4@code) end val(rest,code)= itercode(newbies2,[],[]) in (einapp2,code@rest) end fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let (*)val _ =print(String.concat["\n******\n",printEINAPP (y,einapp)])*) val b=handleE.sweep body val einapp2=assignEinApp(y,params,index,b,args) val (einapp3,newbies2)=split einapp2 in iterMultiple(einapp3,newbies2) end (* gettest:code*code=> (code*code) * print results for splitting einapp *) fun gettest einapp=(case testing of 0=>iterSplit(einapp) | _=>let val star="\n************* SPLIT********\n" val _ =print(String.concat[star,"\n","start get test",printEINAPP einapp]) val (einapp2,newbies)=iterSplit(einapp) (*val a=printEINAPP einapp2 val b=String.concatWith",\n\t"(List.map printEINAPP newbies) val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star])*) in (einapp2,newbies) end (*end case*)) end; (* local *) end (* local *)
Click to toggle
does not end with </html> tag
does not end with </body> tag
The output has ended thus: (einapp2,newbies) end (*end case*)) end; (* local *) end (* local *)