Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/split.sml
 [diderot] / branches / charisee / src / compiler / high-to-mid / split.sml

# Diff of /branches/charisee/src/compiler/high-to-mid/split.sml

revision 2838, Tue Nov 25 03:40:24 2014 UTC revision 2843, Mon Dec 8 01:27:25 2014 UTC
# Line 4  Line 4
5   *)   *)
6
7     (*
8      During the transition from high-IL to mid-IL, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators.  Each EIN expression then only has one operation working on  constants, tensors, deltas, epsilons, images and kernels.
9
10     (1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
11
12     (1a.) When a subexpression is a field expression $\circledast,\nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
13
14     (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
15
16     (1c) Call cleanParams.sml to clean the params in the subexpression.\\
17     (2) All the lifted subexpressions in the original EIN operator are replaced with tensors and non-probed fields with zeros. Call isZero() to determine if the body is zero. If so, needs to return 0. Otherwise clean the EIN operator.
18
19     *)
20
21  structure Split = struct  structure Split = struct
22
# Line 36  Line 47
47
48      fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}      fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
49      fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))      fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
50      fun setEinZero(y,params,index,args)=  (y,DstIL.EINAPP(setEin(params,index,E.Const 0),args))      val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
51        fun setEinZero y=  (y,einappzero)
52      fun cleanParams e =cleanP.cleanParams e      fun cleanParams e =cleanP.cleanParams e
53      fun cleanIndex e =cleanI.cleanIndex e      fun cleanIndex e =cleanI.cleanIndex e
54      fun itos i =Int.toString i      fun itos i =Int.toString i
# Line 67  Line 79
79         | printEINAPP(id,_)= String.concat([Var.toString id,"<",(DstTy.toString (Var.ty id)),"> non-einapp\n"])         | printEINAPP(id,_)= String.concat([Var.toString id,"<",(DstTy.toString (Var.ty id)),"> non-einapp\n"])
80
81
(* mkreplacement:params*index*index_list*int list* ein_exp-> ein_exp* params*args*code*
*creates new param and replacement tensor for the original ein_exp
*Then cleans params for suebxpression
*)
fun mkreplacement(params,args,tshape,sizes,body)=let
val id=length(params)
val params'=params@[E.TEN(1,sizes)]
val e'=E.Tensor(id,tshape)
val M  = DstV.new (genName ("TLifted_"^itos id), DstTy.TensorTy sizes)
val args'=args@[M]
val einapp=cleanParams(M,body,params',sizes,args')
in
(e',params',args',[einapp])
end

82      (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)      (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
83      *lifts expression and returns replacement tensor      *lifts expression and returns replacement tensor
84      * cleans the index and params of subexpression      * cleans the index and params of subexpression
85        *creates new param and replacement tensor for the original ein_exp
86      *)      *)
87      fun lift(e,params,index,sx,args)=let      fun lift(e,params,index,sx,args)=let
88          val (tshape,sizes,body)=cleanIndex(e,index,sx)          val (tshape,sizes,body)=cleanIndex(e,index,sx)
89          val (Re,Rparams,Rargs,code)=mkreplacement(params,args,tshape,sizes,body)          val id=length(params)
90      in          val Rparams=params@[E.TEN(1,sizes)]
91          (Re,Rparams,Rargs,code)          val Re=E.Tensor(id,tshape)
92      end          val M  = DstV.new (genName ("TLifted_"^itos id), DstTy.TensorTy sizes)
93            val Rargs=args@[M]
94            val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
95
(* simplelift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
*lifts expression and returns replacement tensor
* cleans params of subexpression
*)
fun simplelift(e,params,index,args)=(*let
val tshape=List.map (fn x => E.V x) index
val(Re,Rparams,Rargs,code)=mkreplacement(params,args,tshape,index,e)
96          in          in
97          (Re,Rparams,Rargs,code)          (Re,Rparams,Rargs,[einapp])
98          end          end
*)lift(e,params,index,[],args)

99
100
101      (* isOp: ein->int      (* isOp: ein->int
# Line 131  Line 122
122      (*end case*))      (*end case*))
123
124
125      (* simpleOp:ein_exp*params*index*args-> ein_exp*params*args*code      (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
126       * If e1 an op then call simplelift() to replace it       * If e1 an op then call lift() to replace it
127       * Otherwise rewrite to 0 or it remains the same       * Otherwise rewrite to 0 or it remains the same
128       *)       *)
129      fun simpleOp(e1,params,index,args)=(case (isOp e1)      fun rewriteOp(e1,params,index,sx,args)=(case (isOp e1)
130          of  0   => (E.Const 0,params,args,[])          of  0   => (E.Const 0,params,args,[])
131          | 2     => (e1,params,args,[])          | 2     => (e1,params,args,[])
132          | _     => simplelift(e1,params,index,args)          | _     => lift(e1,params,index,sx,args)
133          (*end*))          (*end*))
134
135
136      (* simpleOps:ein_exp list*params*index*args-> ein_exp list*params*args*code      (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars
137       * calls simpleOp on ein_exp list             -> ein_exp list*params*args*code
138         * calls rewriteOp on ein_exp list
139       *)       *)
140      fun simpleOps(list1,params,index,args)=let      fun rewriteOps(list1,params,index,sx,args)=let
141          fun m([],rest,params,args,code)=(rest,params,args,code)          fun m([],rest,params,args,code)=(rest,params,args,code)
142          | m(e1::es,rest,params,args,code)=let          | m(e1::es,rest,params,args,code)=let
143              val (e1',params',args',code')= simpleOp(e1,params,index,args)              val (e1',params',args',code')= rewriteOp(e1,params,index,sx,args)
144              in              in
145                  m(es,rest@[e1'],params',args',code@code')                  m(es,rest@[e1'],params',args',code@code')
146              end              end
# Line 156  Line 148
148              m(list1,[],params,args,[])              m(list1,[],params,args,[])
149          end          end
150
151      (* prodOps:ein_exp list*params*index*sum_id list*args-> ein_exp list*params*args*code      (*isZero: var* ein_exp* params*index list*mid-il vars
152       * calls lift  on ein_exp list             When the operation is zero then we return a real.
153       *)       *)
154      fun prodOps(list1,params,index,sx,args)=let      fun isZero(y,body,params,index,sx,args) =(case (cleanP.isZero body)
155          fun m([],rest,params,args,code)=(rest,params,args,code)          of 1=>  setEinZero y
| m(e1::es,rest,params,args,code)=(case (isOp e1)
of  0   => m(es,rest@[E.Const 0],params,args,code)
| 2     => m(es,rest@[e1],params,args,code)
| 1     => let
val (e1',params',args',code')= lift(e1,params,index,sx,args)
in
m(es,rest@[e1'],params',args',code@code')
end
(*end case*))
in
m(list1,[],params,args,[])
end

fun isZero(y,body,params,index,args) =(case (cleanI.isZero body)
of 1=>  setEinZero(y,params,[],args)
156          | _ =>  cleanParams(y,body,params,index,args)          | _ =>  cleanParams(y,body,params,index,args)
157      (*end case*))      (*end case*))
158
159      (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code      (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
160      * calls simpleOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
161      *)      *)
162      fun handleNeg(y,e1,params,index,args)=let      fun handleNeg(y,e1,params,index,args)=let
163          val (e1',params',args',code)=  simpleOp(e1,params,index,args)          val (e1',params',args',code)=  rewriteOp(e1,params,index,[],args)
164          val body =E.Neg e1'          val body =E.Neg e1'
165          val einapp= isZero(y,body,params',index,args')          val einapp= isZero(y,body,params',index,[],args')
166      in      in
167          (einapp,code)          (einapp,code)
168      end      end
169
170     (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code     (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
171      * calls simpleOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
172      *)      *)
173      fun handleSub(y,e1,e2,params,index,args)=let      fun handleSub(y,e1,e2,params,index,args)=let
174          val ([e1',e2'],params',args',code)=  simpleOps([e1,e2],params,index,args)          val ([e1',e2'],params',args',code)=  rewriteOps([e1,e2],params,index,[],args)
175          val body =E.Sub(e1',e2')          val body =E.Sub(e1',e2')
176          val einapp= isZero(y,body,params',index,args')          val einapp= isZero(y,body,params',index,[],args')
177      in      in
178          (einapp,code)          (einapp,code)
179      end      end
180
181      (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code      (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
182      * calls simpleOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
183      *)      *)
184      fun handleDiv(y,e1,e2,params,index,args)=let      fun handleDiv(y,e1,e2,params,index,args)=let
185          val (e1',params1',args1',code1')=simpleOp(e1,params,index,args)          val (e1',params1',args1',code1')=rewriteOp(e1,params,index,[],args)
186          val (e2',params2',args2',code2')=simpleOp(e2,params1',[],args1')          val (e2',params2',args2',code2')=rewriteOp(e2,params1',[],[],args1')
187          val body =E.Div(e1',e2')          val body =E.Div(e1',e2')
188          val einapp= isZero(y,body,params2',index,args2')          val einapp= isZero(y,body,params2',index,[],args2')
189      in      in
190              (einapp,code1'@code2')              (einapp,code1'@code2')
191      end      end
192
193
195      * calls simpleOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
196      *)      *)
198          val (e1',params',args',code)=  simpleOps(e1,params,index,args)          val (e1',params',args',code)=  rewriteOps(e1,params,index,[],args)
200          val einapp= isZero(y,body,params',index,args')          val einapp= isZero(y,body,params',index,[],args')
201      in      in
202          (einapp,code)          (einapp,code)
203      end      end
204
205      (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code      (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
206       * calls prodOps() lift  on ein_exp       * calls rewriteOps() lift  on ein_exp
207       *)       *)
208      fun handleProd(y,e1,params,index,args)=let      fun handleProd(y,e1,params,index,args)=let
209          val (e1',params',args',code)=  prodOps(e1,params,index,[],args)          val (e1',params',args',code)=  rewriteOps(e1,params,index,[],args)
210          val body =E.Prod e1'          val body =E.Prod e1'
211          val einapp= isZero(y,body,params',index,args')          val einapp= isZero(y,body,params',index,[],args')
212      in      in
213          (einapp,code)          (einapp,code)
214      end      end
215
216     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
217      * calls prodOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
218      *)      *)
219      fun handleSumProd(y,e1,params,index,sx,args)=let      fun handleSumProd(y,e1,params,index,sx,args)=let
220          val _ =List.map (fn (_,_,ub)=> Int.toString ub) sx          val (e1',params',args',code)=  rewriteOps(e1,params,index,sx,args)
val (e1',params',args',code)=  prodOps(e1,params,index,sx,args)
221          val body= E.Sum(sx,E.Prod e1')          val body= E.Sum(sx,E.Prod e1')
222          val einapp= isZero(y,body,params',index,args')          val einapp= isZero(y,body,params',index,sx,args')
223      in      in
224          (einapp,code)          (einapp,code)
225      end      end
226
227      (* split:var*ein_app-> (var*einap)*code      (* split:var*ein_app-> (var*einap)*code
228      * split ein expression into smaller pieces      * split ein expression into smaller pieces
229          note we leave summation around probe exp
230      *)      *)
231      fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let      fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
232          val zero=   (setEinZero(y,params,[],args),[])          val zero=   (setEinZero y,[])
233          val default=((y,einapp),[])          val default=((y,einapp),[])
234          val sumIndex=ref []          val sumIndex=ref []
235          fun rewrite b=(case b          fun rewrite b=(case b
# Line 264  Line 240
240              | E.Lift e                => zero              | E.Lift e                => zero
241              | E.Delta _               => default              | E.Delta _               => default
242              | E.Epsilon _             => default              | E.Epsilon _             => default
243                | E.Eps2 _                => default
244              | E.Tensor _              => default              | E.Tensor _              => default
245              | E.Const _               => default              | E.Const _               => default
246              | E.Neg e1                => handleNeg(y,e1,params,index,args)              | E.Neg e1                => handleNeg(y,e1,params,index,args)
247              | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args)              | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args)
248              | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args)              | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args)
249                | E.Sum(_,E.Prod[E.Eps2 _, E.Probe _ ])      => default
250                | E.Sum(_,E.Prod[E.Epsilon _, E.Probe _ ])      => default
251              | E.Sum(_,E.Probe _)      => default              | E.Sum(_,E.Probe _)      => default
252              | E.Sum(_,E.Conv _)       => zero              | E.Sum(_,E.Conv _)       => zero
253              | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args)              | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args)
# Line 286  Line 265
265              | E.Img _                 => err("Probe used before expand")              | E.Img _                 => err("Probe used before expand")
266              (*end case *))              (*end case *))
267
268
269          val (einapp2,newbies) =rewrite body          val (einapp2,newbies) =rewrite body
270          in          in
271              (einapp2,newbies)              (einapp2,newbies)
# Line 309  Line 289
289          end          end
290
291
292      fun iterSplit(y,einapp)=let      fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
293          val (einapp2,newbies2)=split(y,einapp)          val (_,_,body')=cleanIndex(body,index,[])
294            val einapp1= assignEinApp(y,params,index,body',args)
295            val (einapp2,newbies2)=split einapp1
296      in      in
297          iterMultiple(einapp2,newbies2)          iterMultiple(einapp2,newbies2)
298      end      end

Legend:
 Removed from v.2838 changed lines Added in v.2843