Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2845, Fri Dec 12 06:46:23 2014 UTC revision 3017, Mon Mar 9 15:52:22 2015 UTC
# Line 31  Line 31 
31    
32      in      in
33    
34      val testing=1      val testing=0
35      fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}      fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
36      fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))      fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
37      val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])      val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
# Line 40  Line 40 
40      fun cleanIndex e =cleanI.cleanIndex e      fun cleanIndex e =cleanI.cleanIndex e
41      fun printEINAPP e=MidToString.printEINAPP e      fun printEINAPP e=MidToString.printEINAPP e
42      fun isZero e=handleE.isZero e      fun isZero e=handleE.isZero e
43        fun sweep e=handleE.sweep e
44      fun itos i =Int.toString i      fun itos i =Int.toString i
45        fun filterSca e=Filter.filterSca e
46      fun err str=raise Fail str      fun err str=raise Fail str
47      val cnt = ref 0      val cnt = ref 0
48      fun genName prefix = let      fun genName prefix = let
# Line 61  Line 63 
63      *creates new param and replacement tensor for the original ein_exp      *creates new param and replacement tensor for the original ein_exp
64      *)      *)
65      fun lift(name,e,params,index,sx,args)=let      fun lift(name,e,params,index,sx,args)=let
66    
67          val (tshape,sizes,body)=cleanIndex(e,index,sx)          val (tshape,sizes,body)=cleanIndex(e,index,sx)
68          val id=length(params)          val id=length(params)
69          val Rparams=params@[E.TEN(1,sizes)]          val Rparams=params@[E.TEN(1,sizes)]
# Line 83  Line 86 
86          | E.Apply _   => 0          | E.Apply _   => 0
87          | E.Lift _    => 0          | E.Lift _    => 0
88          | E.Neg _     => 1          | E.Neg _     => 1
89            | E.Sqrt _    => 1
90            | E.PowInt _    => 1
91            | E.PowReal _    => 1
92          | E.Add _     => 1          | E.Add _     => 1
93          | E.Sub _     => 1          | E.Sub _     => 1
94          | E.Prod _    => 1          | E.Prod _    => 1
# Line 113  Line 119 
119      fun rewriteOps(name,list1,params,index,sx,args)=let      fun rewriteOps(name,list1,params,index,sx,args)=let
120          fun m([],rest,params,args,code)=(rest,params,args,code)          fun m([],rest,params,args,code)=(rest,params,args,code)
121          | m(e1::es,rest,params,args,code)=let          | m(e1::es,rest,params,args,code)=let
122    
123              val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)              val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)
124              in              in
125                  m(es,rest@[e1'],params',args',code@code')                  m(es,rest@[e1'],params',args',code@code')
# Line 134  Line 141 
141      * calls rewriteOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
142      *)      *)
143      fun handleNeg(y,e1,params,index,args)=let      fun handleNeg(y,e1,params,index,args)=let
144          val (e1',params',args',code)=  rewriteOp(DstV.name y, e1,params,index,[],args)          val (e1',params',args',code)=  rewriteOp("neg", e1,params,index,[],args)
145          val body =E.Neg e1'          val body =E.Neg e1'
146          val einapp= rewriteOrig(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
147          in          in
148              (einapp,code)              (einapp,code)
149          end          end
150    
151        (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code
152        * calls rewriteOp() lift  on ein_exp
153        *)
154        fun handleSqrt(y,e1,params,index,args)=let
155            val (e1',params',args',code)=  rewriteOp("sqrt", e1,params,index,[],args)
156            val body =E.Sqrt e1'
157            val einapp= rewriteOrig(y,body,params',index,[],args')
158        in
159            (einapp,code)
160        end
161    
162    
163    (* handlePowInt:var*ein_exp *params*index*args-> (var*einap)*code
164    * calls rewriteOp() lift  on ein_exp
165    *)
166    fun handlePowInt(y,(e1,n1),params,index,args)=let
167    val (e1',params',args',code)=  rewriteOp("powint", e1,params,index,[],args)
168    val body =E.PowInt(e1',n1)
169    val einapp= rewriteOrig(y,body,params',index,[],args')
170    in
171    (einapp,code)
172    end
173    
174    
175        (* handlePowReal:var*ein_exp *params*index*args-> (var*einap)*code
176        * calls rewriteOp() lift  on ein_exp
177        *)
178        fun handlePowReal(y,(e1,n1),params,index,args)=let
179        val (e1',params',args',code)=  rewriteOp("powreal", e1,params,index,[],args)
180        val body =E.PowReal(e1',n1)
181        val einapp= rewriteOrig(y,body,params',index,[],args')
182        in
183        (einapp,code)
184        end
185    
186    
187     (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code     (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
188      * calls rewriteOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
189      *)      *)
190      fun handleSub(y,e1,e2,params,index,args)=let      fun handleSub(y,e1,e2,params,index,args)=let
191          val ([e1',e2'],params',args',code)=  rewriteOps(DstV.name y,[e1,e2],params,index,[],args)          val ([e1',e2'],params',args',code)=  rewriteOps("subt",[e1,e2],params,index,[],args)
192          val body =E.Sub(e1',e2')          val body =E.Sub(e1',e2')
193          val einapp= rewriteOrig(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
194          in          in
# Line 156  Line 199 
199      * calls rewriteOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
200      *)      *)
201      fun handleDiv(y,e1,e2,params,index,args)=let      fun handleDiv(y,e1,e2,params,index,args)=let
202          val (e1',params1',args1',code1')=rewriteOp(DstV.name y,e1,params,index,[],args)          val (e1',params1',args1',code1')=rewriteOp("div-num",e1,params,index,[],args)
203          val (e2',params2',args2',code2')=rewriteOp(DstV.name y,e2,params1',[],[],args1')          val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',index,[],args1')
204            (*val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',[],[],args1')*)
205          val body =E.Div(e1',e2')          val body =E.Div(e1',e2')
206          val einapp= rewriteOrig(y,body,params2',index,[],args2')          val einapp= rewriteOrig(y,body,params2',index,[],args2')
207          in          in
# Line 168  Line 212 
212      * calls rewriteOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
213      *)      *)
214      fun handleAdd(y,e1,params,index,args)=let      fun handleAdd(y,e1,params,index,args)=let
215          val (e1',params',args',code)=  rewriteOps(DstV.name y,e1,params,index,[],args)          val (e1',params',args',code)=  rewriteOps("add",e1,params,index,[],args)
216          val body =E.Add e1'          val body =E.Add e1'
217          val einapp= rewriteOrig(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
218          in          in
# Line 179  Line 223 
223       * calls rewriteOps() lift  on ein_exp       * calls rewriteOps() lift  on ein_exp
224       *)       *)
225      fun handleProd(y,e1,params,index,args)=let      fun handleProd(y,e1,params,index,args)=let
226          val (e1',params',args',code)=  rewriteOps(DstV.name y,e1,params,index,[],args)          val (e1',params',args',code)=  rewriteOps("prod",e1,params,index,[],args)
227          val body =E.Prod e1'          val body =E.Prod e1'
228          val einapp= rewriteOrig(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
229          in          in
# Line 190  Line 234 
234      * calls rewriteOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
235      *)      *)
236      fun handleSumProd(y,e1,params,index,sx,args)=let      fun handleSumProd(y,e1,params,index,sx,args)=let
237          val (e1',params',args',code)=  rewriteOps(DstV.name y,e1,params,index,sx,args)          val (e1',params',args',code)=  rewriteOps("sumprod",e1,params,index,sx,args)
238          val body= E.Sum(sx,E.Prod e1')          val body= E.Sum(sx,E.Prod e1')
239          val einapp= rewriteOrig(y,body,params',index,sx,args')          val einapp= rewriteOrig(y,body,params',index,sx,args')
240          in          in
# Line 205  Line 249 
249          val zero=   (setEinZero y,[])          val zero=   (setEinZero y,[])
250          val default=((y,einapp),[])          val default=((y,einapp),[])
251          val sumIndex=ref []          val sumIndex=ref []
252            val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
253            val _=testp["\n\nStarting split",P.printbody body]
254          fun rewrite b=(case b          fun rewrite b=(case b
255              of E.Probe _              => default              of E.Probe (E.Conv _,_)   => default
256                | E.Probe(E.Field _,_)    => raise Fail str
257                | E.Probe _               => raise Fail str
258              | E.Conv _                => zero              | E.Conv _                => zero
259              | E.Field _               => zero              | E.Field _               => zero
260              | E.Apply _               => zero              | E.Apply _               => zero
# Line 216  Line 264 
264              | E.Eps2 _                => default              | E.Eps2 _                => default
265              | E.Tensor _              => default              | E.Tensor _              => default
266              | E.Const _               => default              | E.Const _               => default
267                | E.ConstR _              => default
268              | E.Neg e1                => handleNeg(y,e1,params,index,args)              | E.Neg e1                => handleNeg(y,e1,params,index,args)
269                | E.Sqrt e1               => handleSqrt(y,e1,params,index,args)
270                | E.PowInt e1             => handlePowInt(y,e1,params,index,args)
271                | E.PowReal e1            => handlePowReal(y,e1,params,index,args)
272              | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args)              | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args)
273              | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args)              | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args)
274              | E.Sum(_,E.Prod[E.Eps2 _, E.Probe _ ])      => default              | E.Sum(sx,E.Tensor(id,[]))=> rewrite (E.Tensor(id,[]))
275              | E.Sum(_,E.Prod[E.Epsilon _, E.Probe _ ])      => default              | E.Sum(sx,E.Const c)      =>rewrite ( E.Const c )
276              | E.Sum(_,E.Probe _)      => default              | E.Sum(sx,E.ConstR r)    => rewrite (E.ConstR r)
             | E.Sum(_,E.Conv _)       => zero  
             | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args)  
277              | E.Sum(sx,E.Neg n)       => rewrite (E.Neg(E.Sum(sx,n)))              | E.Sum(sx,E.Neg n)       => rewrite (E.Neg(E.Sum(sx,n)))
278              | E.Sum(sx,E.Add a)       => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))              | E.Sum(sx,E.Add a)       => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))
279              | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))              | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
280              | E.Sum(sx,E.Div(e1,e2))  => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))              | E.Sum(sx,E.Div(e1,e2))  => rewrite (E.Sum(sx,E.Prod[e1,E.Div(E.Const 1,e2)]))
281                | E.Sum(sx,E.Lift e )     => rewrite (E.Lift(E.Sum(sx,e)))
282                | E.Sum(sx,E.PowReal(e,n1)) => rewrite (E.PowReal(E.Sum(sx,e),n1))
283                | E.Sum(sx,E.Sqrt e)        => rewrite (E.Sqrt(E.Sum(sx,e)))
284              | E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e))              | E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e))
285                | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_)  ])      => default
286                | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_)  ])      => default
287                | E.Sum(_,E.Probe(E.Conv _,_))    => default
288                | E.Sum(_,E.Conv _)       => zero
289                | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args)
290              | E.Sum(sx,_)             => default              | E.Sum(sx,_)             => default
291              | E.Add e1                => handleAdd(y,e1,params,index,args)              | E.Add e1                => handleAdd(y,e1,params,index,args)
292              | E.Prod e1               => handleProd(y,e1,params,index,args)              | E.Prod e1               => handleProd(y,e1,params,index,args)
# Line 243  Line 301 
301          end          end
302          |split(y,app) =((y,app),[])          |split(y,app) =((y,app),[])
303    
304    
305        (*Distribute summation if needed*)
306        fun distributeSummation(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
307            fun rewrite b=(case b
308            of E.Sum(sx,E.Tensor(id,[]))    => E.Tensor(id,[])
309            | E.Sum(sx,E.Const c)           => E.Const c
310            | E.Sum(sx,E.ConstR r)          => E.ConstR r
311            | E.Sum(sx,E.Neg n)             => rewrite(E.Neg(E.Sum(sx,n)))
312            | E.Sum(sx,E.Add a)             => rewrite(E.Add(List.map (fn e=> E.Sum(sx,e)) a))
313            | E.Sum(sx,E.Sub (e1,e2))       => rewrite(E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
314           (* | E.Sum(sx,E.Div(e1,e2))        => rewrite( E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))*)
315            | E.Sum(sx,E.Lift e )           => rewrite (E.Lift(E.Sum(sx,e)))
316            | E.Sum(sx,E.PowReal(e,n1))     => rewrite(E.PowReal(E.Sum(sx,e),n1))
317            | E.Sum(sx,E.Sqrt e)            => rewrite(E.Sqrt(E.Sum(sx,e)))
318            | E.Sum(sx,E.Sum (c2,e))        => rewrite (E.Sum (sx@c2,e))
319            | E.Sum(sx,E.Prod p)            => let
320                    val (c,e)=filterSca(sx,p)
321                    in e end
322            | E.Div(e1,e2)                  => E.Div(rewrite e1, rewrite e2)
323            | E.Sub(e1,e2)                  => E.Sub(rewrite e1, rewrite e2)
324            | E.Add es                      => E.Add(List.map rewrite es)
325            | E.Prod es                     => E.Prod(List.map rewrite es)
326            | E.Neg e                       => E.Neg(rewrite e)
327            | E.Sqrt e                      => E.Sqrt(rewrite e)
328            | _                             => b
329        (*end case*))
330        val body =rewrite body
331        val _ =testp["\nAfter distributeSummation \n",P.printbody body]
332        val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=body})
333        val einapp2= (y,DstIL.EINAPP(ein,args))
334        in
335            split(einapp2)
336        end
337        |distributeSummation(y,app) =((y,app),[])
338    
339    
340    
341      (* iterMultiple:code*code=> (code*code)      (* iterMultiple:code*code=> (code*code)
342       * recursively split ein expression into smaller pieces       * recursively split ein expression into smaller pieces
343      *)      *)
344      fun iterMultiple(einapp2,newbies2)=let      fun iterMultiple(einapp2,newbies2)=let
345          fun itercode([],rest,code)=(rest,code)          fun itercode([],rest,code,_)=(rest,code)
346          | itercode(e1::newbies,rest,code)=let          | itercode(e1::newbies,rest,code,cnt)=let
347              val (einapp3,code3) =split(e1)              val _ =testp["\n\n******* split term **",Int.toString cnt," *****","\n \n",printEINAPP(e1),"\n=>\n"]
348              val (rest4,code4)=itercode(code3,[],[])              val (einapp3,code3) = distributeSummation e1
349              in itercode(newbies,rest@[einapp3],code4@rest4@code)              val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP code3))]
350                val (rest4,code4)=itercode(code3,[],[],cnt+1)
351                in itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2)
352              end              end
353          val(rest,code)= itercode(newbies2,[],[])          val(rest,code)= itercode(newbies2,[],[],1)
354          in          in
355              (einapp2,code@rest)              (einapp2,code@rest)
356          end          end
357    
358      fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let      fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
359          (*val (_,_,body')=cleanIndex(body,index,[])          val bodysweep=handleE.sweep body
360          val einapp1= assignEinApp(y,params,index,body',args)          val _=testp["\nPresweep\n",P.printbody body,"\n\n Sweep\n",P.printbody bodysweep,"\n"]
361          *)          val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=bodysweep})
362          val (_,sizes,body')=cleanIndex(body,index,[])          val _=testp["\n\n Clean Summation\n",P.printbody(Ein.body ein),"\n"]
363          val einapp1= assignEinApp(y,params,index,body',args)          val einapp2=(y,DstIL.EINAPP(ein, args))
364          val a=testp["\n rewriten einapp\n \t",printEINAPP einapp1]          val _ =testp["\n\n******* split term **",Int.toString (0)," ***** \n \t==>\n",printEINAPP(einapp2)]
365          val (einapp2,newbies2)=split einapp1          val (einapp3,newbies2)=distributeSummation einapp2
366            val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP newbies2))]
367    
368      in      in
369          iterMultiple(einapp2,newbies2)          iterMultiple(einapp3,newbies2)
370      end      end
371    
372      (* gettest:code*code=> (code*code)      (* gettest:code*code=> (code*code)
# Line 276  Line 375 
375      fun gettest einapp=(case testing      fun gettest einapp=(case testing
376          of 0=>iterSplit(einapp)          of 0=>iterSplit(einapp)
377          | _=>let          | _=>let
378              val star="\n************* SPLIT********\n"              val star="\n************* SPLIT INITIAL********\n"
379              val _ =print(String.concat[star,"\n","start get test",printEINAPP einapp])              val _ =testp[star,"\n","start get test",printEINAPP einapp]
380              val (einapp2,newbies)=iterSplit(einapp)              val (einapp2,newbies)=iterSplit(einapp)
381              val a=printEINAPP einapp2              val _ =testp["\n\n Returning \n\n =>",printEINAPP einapp2,
382              val b=String.concatWith",\n\t"(List.map printEINAPP newbies)                      " newbies\n\t",String.concatWith",\n\t"(List.map printEINAPP newbies), "\n",star]
             val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star])  
383              in              in
384                  (einapp2,newbies)                  (einapp2,newbies)
385              end              end

Legend:
Removed from v.2845  
changed lines
  Added in v.3017

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0