Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee / src / compiler / high-to-mid / split.sml

# View of /branches/charisee/src/compiler/high-to-mid/split.sml

Mon Dec 8 01:27:25 2014 UTC (4 years, 9 months ago) by cchiw
File size: 12592 byte(s)
added 2-d cross product, new rep. of 2-d curl
(* Currently under construction
*
* COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
*)

(*
During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators.  Each EIN expression then only has one operation working on  constants, tensors, deltas, epsilons, images and kernels.

(1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.

(1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.

(1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.

(1c) Call cleanParams.sml to clean the params in the subexpression.\\
(2) All the lifted subexpressions in the original EIN operator are replaced with tensors and non-probed fields with zeros. Call isZero() to determine if the body is zero. If so, needs to return 0. Otherwise clean the EIN operator.

*)

structure Split = struct

local

structure E = Ein
structure mk= mkOperators
structure SrcIL = HighIL
structure SrcTy = HighILTypes
structure SrcOp = HighOps
structure SrcSV = SrcIL.StateVar
structure VTbl = SrcIL.Var.Tbl
structure DstIL = MidIL
structure DstTy = MidILTypes
structure DstOp = MidOps
structure DstV = DstIL.Var
structure SrcV = SrcIL.Var
structure P=Printer
structure F=Filter
structure T=TransformEin
structure Var = MidIL.Var
structure cleanP=cleanParams
structure cleanI=cleanIndex

val testing=1
in

fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
fun setEinZero y=  (y,einappzero)
fun cleanParams e =cleanP.cleanParams e
fun cleanIndex e =cleanI.cleanIndex e
fun itos i =Int.toString i
fun err str=raise Fail str
val cnt = ref 0
fun genName prefix = let
val n = !cnt
in
cnt := n+1;
String.concat[prefix, "_", Int.toString n]
end
fun testp n=(case testing
of 0=> 1
| _ =>(print(String.concat n);1)
(*end case*))

fun printEINAPP(id, DstIL.EINAPP(rator, args))=let
val a=String.concatWith " , " (List.map Var.toString args)
in
String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> ==",P.printerE rator, a,"\n"])
end
| printEINAPP(id, DstIL.OP(rator, args))=let
val a=String.concatWith " , " (List.map Var.toString args)
in
String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> =",DstOp.toString rator,a,"\n"])
end

| printEINAPP(id,_)= String.concat([Var.toString id,"<",(DstTy.toString (Var.ty id)),"> non-einapp\n"])

(* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
*lifts expression and returns replacement tensor
* cleans the index and params of subexpression
*creates new param and replacement tensor for the original ein_exp
*)
fun lift(e,params,index,sx,args)=let
val (tshape,sizes,body)=cleanIndex(e,index,sx)
val id=length(params)
val Rparams=params@[E.TEN(1,sizes)]
val Re=E.Tensor(id,tshape)
val M  = DstV.new (genName ("TLifted_"^itos id), DstTy.TensorTy sizes)
val Rargs=args@[M]
val einapp=cleanParams(M,body,Rparams,sizes,Rargs)

in
(Re,Rparams,Rargs,[einapp])
end

(* isOp: ein->int
* checks to see if this sub-expression is pulled out or split form original
* 0-becomes zero,1-remains the same, 2-operator
*)
fun isOp e =(case e
of E.Field _  => 0
| E.Conv _    => 0
| E.Apply _   => 0
| E.Lift _    => 0
| E.Neg _     => 1
| E.Sub _     => 1
| E.Prod _    => 1
| E.Div _     => 1
| E.Sum _     => 1
| E.Probe _   => 1
| E.Partial _ => err(" Partial used after normalize")
| E.Krn _     => err("Krn used before expand")
| E.Value _   => err("Value used before expand")
| E.Img _     => err("Probe used before expand")
| _           => 2
(*end case*))

(* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
* If e1 an op then call lift() to replace it
* Otherwise rewrite to 0 or it remains the same
*)
fun rewriteOp(e1,params,index,sx,args)=(case (isOp e1)
of  0   => (E.Const 0,params,args,[])
| 2     => (e1,params,args,[])
| _     => lift(e1,params,index,sx,args)
(*end*))

(* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars
-> ein_exp list*params*args*code
* calls rewriteOp on ein_exp list
*)
fun rewriteOps(list1,params,index,sx,args)=let
fun m([],rest,params,args,code)=(rest,params,args,code)
| m(e1::es,rest,params,args,code)=let
val (e1',params',args',code')= rewriteOp(e1,params,index,sx,args)
in
m(es,rest@[e1'],params',args',code@code')
end
in
m(list1,[],params,args,[])
end

(*isZero: var* ein_exp* params*index list*mid-il vars
When the operation is zero then we return a real.
*)
fun isZero(y,body,params,index,sx,args) =(case (cleanP.isZero body)
of 1=>  setEinZero y
| _ => cleanParams(y,body,params,index,args)
(*end case*))

(* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
* calls rewriteOp() lift  on ein_exp
*)
fun handleNeg(y,e1,params,index,args)=let
val (e1',params',args',code)=  rewriteOp(e1,params,index,[],args)
val body =E.Neg e1'
val einapp= isZero(y,body,params',index,[],args')
in
(einapp,code)
end

(* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
* calls rewriteOps() lift  on ein_exp
*)
fun handleSub(y,e1,e2,params,index,args)=let
val ([e1',e2'],params',args',code)=  rewriteOps([e1,e2],params,index,[],args)
val body =E.Sub(e1',e2')
val einapp= isZero(y,body,params',index,[],args')
in
(einapp,code)
end

(* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
* calls rewriteOp() lift  on ein_exp
*)
fun handleDiv(y,e1,e2,params,index,args)=let
val (e1',params1',args1',code1')=rewriteOp(e1,params,index,[],args)
val (e2',params2',args2',code2')=rewriteOp(e2,params1',[],[],args1')
val body =E.Div(e1',e2')
val einapp= isZero(y,body,params2',index,[],args2')
in
(einapp,code1'@code2')
end

* calls rewriteOps() lift  on ein_exp
*)
val (e1',params',args',code)=  rewriteOps(e1,params,index,[],args)
val einapp= isZero(y,body,params',index,[],args')
in
(einapp,code)
end

(* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
* calls rewriteOps() lift  on ein_exp
*)
fun handleProd(y,e1,params,index,args)=let
val (e1',params',args',code)=  rewriteOps(e1,params,index,[],args)
val body =E.Prod e1'
val einapp= isZero(y,body,params',index,[],args')
in
(einapp,code)
end

(* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
* calls rewriteOps() lift  on ein_exp
*)
fun handleSumProd(y,e1,params,index,sx,args)=let
val (e1',params',args',code)=  rewriteOps(e1,params,index,sx,args)
val body= E.Sum(sx,E.Prod e1')
val einapp= isZero(y,body,params',index,sx,args')
in
(einapp,code)
end

(* split:var*ein_app-> (var*einap)*code
* split ein expression into smaller pieces
note we leave summation around probe exp
*)
fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
val zero=   (setEinZero y,[])
val default=((y,einapp),[])
val sumIndex=ref []
fun rewrite b=(case b
of E.Probe _              => default
| E.Conv _                => zero
| E.Field _               => zero
| E.Apply _               => zero
| E.Lift e                => zero
| E.Delta _               => default
| E.Epsilon _             => default
| E.Eps2 _                => default
| E.Tensor _              => default
| E.Const _               => default
| E.Neg e1                => handleNeg(y,e1,params,index,args)
| E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args)
| E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args)
| E.Sum(_,E.Prod[E.Eps2 _, E.Probe _ ])      => default
| E.Sum(_,E.Prod[E.Epsilon _, E.Probe _ ])      => default
| E.Sum(_,E.Probe _)      => default
| E.Sum(_,E.Conv _)       => zero
| E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args)
| E.Sum(sx,E.Neg n)       => rewrite (E.Neg(E.Sum(sx,n)))
| E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
| E.Sum(sx,E.Div(e1,e2))  => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))
| E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e))
| E.Sum(sx,_)             => default
| E.Prod e1               => handleProd(y,e1,params,index,args)
| E.Partial _             => err(" Partial used after normalize")
| E.Krn _                 => err("Krn used before expand")
| E.Value _               => err("Value used before expand")
| E.Img _                 => err("Probe used before expand")
(*end case *))

val (einapp2,newbies) =rewrite body
in
(einapp2,newbies)
end
|split(y,app) =((y,app),[])

(* iterMultiple:code*code=> (code*code)
* recursively split ein expression into smaller pieces
*)
fun iterMultiple(einapp2,newbies2)=let
fun itercode([],rest,code)=(rest,code)
| itercode(e1::newbies,rest,code)=let
val (einapp3,code3) =split(e1)
val (rest4,code4)=itercode(code3,[],[])
in itercode(newbies,rest@[einapp3],code4@rest4@code)
end
val(rest,code)= itercode(newbies2,[],[])
in
(einapp2,code@rest)
end

fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
val (_,_,body')=cleanIndex(body,index,[])
val einapp1= assignEinApp(y,params,index,body',args)
val (einapp2,newbies2)=split einapp1
in
iterMultiple(einapp2,newbies2)
end

(* gettest:code*code=> (code*code)
* print results for splitting einapp
*)
fun gettest(einapp)=(case testing
of 0=>iterSplit(einapp)
| _=>let
val star="\n*************\n"
val _ =print(String.concat[star])
val (einapp2,newbies)=iterSplit(einapp)
val a=printEINAPP einapp2
val b=String.concatWith",\n\t"(List.map printEINAPP newbies)
val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star])
in
(einapp2,newbies)
end
(*end case*))

end; (* local *)

end (* local *)