Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/split.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/split.sml revision 3030, Tue Mar 10 01:24:41 2015 UTC branches/charisee_dev/src/compiler/high-to-mid/split.sml revision 3316, Sat Oct 17 01:40:08 2015 UTC
# Line 24  Line 24 
24      structure DstIL = MidIL      structure DstIL = MidIL
25      structure DstTy = MidILTypes      structure DstTy = MidILTypes
26      structure DstV = DstIL.Var      structure DstV = DstIL.Var
27    
28      structure P=Printer      structure P=Printer
29      structure cleanP=cleanParams      structure cleanP=cleanParams
30      structure cleanI=cleanIndex      structure cleanI=cleanIndex
31      structure handleE=handleEin  
32    
33      in      in
34    
35      val testing=1      val numFlag=1   (*remove common subexpression*)
36      fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}      val testing=0
37      fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))      fun mkEin e = E.mkEin e
38      val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])      val einappzero= DstIL.EINAPP(mkEin([],[],E.Const 0),[])
39      fun setEinZero y=  (y,einappzero)      fun setEinZero y=  (y,einappzero)
40      fun cleanParams e =cleanP.cleanParams e      fun cleanParams e =cleanP.cleanParams e
41      fun cleanIndex e =cleanI.cleanIndex e      fun cleanIndex e =cleanI.cleanIndex e
42      fun printEINAPP e=MidToString.printEINAPP e      fun toStringBind e= MidToString.toStringBind e
     fun isZero e=handleE.isZero e  
     fun sweep e=handleE.sweep e  
43      fun itos i =Int.toString i      fun itos i =Int.toString i
     fun filterSca e=Filter.filterSca e  
44      fun err str=raise Fail str      fun err str=raise Fail str
45      val cnt = ref 0      val cnt = ref 0
46        fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
47      fun genName prefix = let      fun genName prefix = let
48          val n = !cnt          val n = !cnt
49      in      in
# Line 56  Line 55 
55          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
56          (*end case*))          (*end case*))
57    
   
58      (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)      (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
59      *lifts expression and returns replacement tensor      *lifts expression and returns replacement tensor
60      * cleans the index and params of subexpression      * cleans the index and params of subexpression
61      *creates new param and replacement tensor for the original ein_exp      *creates new param and replacement tensor for the original ein_exp
62      *)      *)
63      fun lift(name,e,params,index,sx,args)=let      fun lift(name,e,params,index,sx,args,fieldset,flag)=let
   
64          val (tshape,sizes,body)=cleanIndex(e,index,sx)          val (tshape,sizes,body)=cleanIndex(e,index,sx)
65          val id=length(params)          val id=length(params)
66          val Rparams=params@[E.TEN(1,sizes)]          val Rparams=params@[E.TEN(1,sizes)]
# Line 71  Line 68 
68          val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)          val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
69          val Rargs=args@[M]          val Rargs=args@[M]
70          val einapp=cleanParams(M,body,Rparams,sizes,Rargs)          val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
71            val (_,einapp0)=einapp
72            val (Rargs,newbies,fieldset) =(case flag
73                of 1=> let
74                    val (fieldset,var) = einSet.rtnVar(fieldset,M,einapp0)
75                    in (case var
76                        of NONE=> (args@[M],[einapp],fieldset)
77                        | SOME v=> (incUse v ;(args@[v],[],fieldset))
78                        (*end case*))
79                    end
80                | _=>(args@[M],[einapp],fieldset)
81                  (*end case*))
82      in      in
83          (Re,Rparams,Rargs,[einapp])              (Re,Rparams,Rargs,newbies,fieldset)
84      end      end
85    
86    
87      (* isOp: ein->int      (* isOp: ein->int
88       * checks to see if this sub-expression is pulled out or split form original       * checks to see if this sub-expression is pulled out or split form original
89       * 0-becomes zero,1-remains the same, 2-operator       * 0-becomes zero,1-remains the same, 2-operator
# Line 87  Line 95 
95          | E.Lift _    => 0          | E.Lift _    => 0
96          | E.Neg _     => 1          | E.Neg _     => 1
97          | E.Sqrt _    => 1          | E.Sqrt _    => 1
98            | E.Cosine _      => 1
99            | E.ArcCosine _   => 1
100            | E.Sine _        => 1
101            | E.ArcSine _     => 1
102          | E.PowInt _    => 1          | E.PowInt _    => 1
103          | E.PowReal _    => 1          | E.PowReal _    => 1
104          | E.Add _     => 1          | E.Add _     => 1
# Line 102  Line 114 
114          | _           => 2          | _           => 2
115      (*end case*))      (*end case*))
116    
117        fun rewriteOp3(name,sx,e1,x)=let
118            val ((y, DstIL.EINAPP(ein,args)),fieldset,flag)=x
119            val params=Ein.params ein
120            val index=Ein.index ein
121            in (case (isOp e1)
122                of  0   => (E.Const 0,params,args,[],fieldset)
123                | 1     => lift(name,e1,params,index,sx,args,fieldset,flag)
124                | 2     => (e1,params,args,[],fieldset)
125                (*end*))
126            end
127    
128      (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code      (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
129       * If e1 an op then call lift() to replace it       * If e1 an op then call lift() to replace it
      * Otherwise rewrite to 0 or it remains the same  
130       *)       *)
131      fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1)      fun rewriteOp(name,e1,params,index,sx,args,fieldset,flag)=(case (isOp e1)
132          of  0   => (E.Const 0,params,args,[])          of  0   => (E.Const 0,params,args,[],fieldset)
133          | 2     => (e1,params,args,[])          | 1     => lift(name,e1,params,index,sx,args,fieldset,flag)
134          | _     =>   lift(name,e1,params,index,sx,args)          | 2     => (e1,params,args,[],fieldset)             (*not lifted*)
135          (*end*))          (*end*))
136    
137      (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars      fun rewriteOps(name,list1,params,index,sx,args,fieldset0,flag)=let
138             -> ein_exp list*params*args*code          fun m([],rest,params,args,code,fieldset)=(rest,params,args,code,fieldset)
139       * calls rewriteOp on ein_exp list          | m(e1::es,rest,params,args,code,fieldset)=let
      *)  
     fun rewriteOps(name,list1,params,index,sx,args)=let  
         fun m([],rest,params,args,code)=(rest,params,args,code)  
         | m(e1::es,rest,params,args,code)=let  
140    
141              val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)              val (e1',params',args',code',fieldset)= rewriteOp(name,e1,params,index,sx,args,fieldset,flag)
 val _ =print("rewriteOP:\n"^P.printbody e1^"\n\t=>"^ P.printbody e1')  
142              in              in
143                  m(es,rest@[e1'],params',args',code@code')                  m(es,rest@[e1'],params',args',code@code',fieldset)
144              end              end
145          in          in
146              m(list1,[],params,args,[])                  m(list1,[],params,args,[],fieldset0)
147          end          end
148    
149    
150      (*rewriteOrig: var* ein_exp* params*index list*mid-il vars      (*rewriteOrig: var* ein_exp* params*index list*mid-il vars
151             When the operation is zero then we return a real.             When the operation is zero then we return a real.
152          -Moved is Zero to before split.          -Moved is Zero to before split.
153      *)      *)
154      fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body)      fun rewriteOrig(y,body,params,index,sx,args) =cleanParams(y,body,params,index,args)
155          of 1=>  setEinZero y  
156          | _ => cleanParams(y,body,params,index,args)      fun rewriteOrig3(sx,body,params,args,x) =let
157          (*end case*))          val ((y,DstIL.EINAPP(ein,_)),_,_)=x
158            val index=Ein.index ein
159            in  cleanParams(y,body,params,index,args)
160            end
161    
162      (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code      (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
163      * calls rewriteOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
164      *)      *)
165      fun handleNeg(y,e1,params,index,args)=let      fun handleNeg(e1,x)=let
166          val (e1',params',args',code)=  rewriteOp("neg", e1,params,index,[],args)          val (e1',params',args',code,fieldset)=  rewriteOp3("neg",[],e1,x)
167          val body =E.Neg e1'          val body' =E.Neg e1'
168          val einapp= rewriteOrig(y,body,params',index,[],args')          val einapp= rewriteOrig3([],body',params',args',x)
169          in          in
170              (einapp,code)              (einapp,code,fieldset)
171          end          end
172    
173      (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code      (* handleSqrt:var*ein_exp *params*index*args-> (var*einap)*code
174      * calls rewriteOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
175      *)      *)
176      fun handleSqrt(y,e1,params,index,args)=let      fun handleSqrt(y,e1,params,index,args,fieldset,flag)=let
177          val (e1',params',args',code)=  rewriteOp("sqrt", e1,params,index,[],args)          val (e1',params',args',code,fieldset)=  rewriteOp("sqrt", e1,params,index,[],args,fieldset,flag)
178          val body =E.Sqrt e1'          val body =E.Sqrt e1'
179          val einapp= rewriteOrig(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
180      in      in
181          (einapp,code)          (einapp,code,fieldset)
182      end      end
183    
184    
185  (* handlePowInt:var*ein_exp *params*index*args-> (var*einap)*code      (* handleCosine:var*ein_exp *params*index*args-> (var*einap)*code
186        * calls rewriteOp() lift  on ein_exp
187        *)
188        fun handleCosine(y,e1,params,index,args,fieldset,flag)=let
189            val (e1',params',args',code,fieldset)=  rewriteOp("cosine", e1,params,index,[],args,fieldset,flag)
190            val body =E.Cosine e1'
191            val einapp= rewriteOrig(y,body,params',index,[],args')
192            in
193                (einapp,code,fieldset)
194        end
195    
196        (* handleArcCosine:var*ein_exp *params*index*args-> (var*einap)*code
197  * calls rewriteOp() lift  on ein_exp  * calls rewriteOp() lift  on ein_exp
198  *)  *)
199  fun handlePowInt(y,(e1,n1),params,index,args)=let      fun handleArcCosine(y,e1,params,index,args,fieldset,flag)=let
200  val (e1',params',args',code)=  rewriteOp("powint", e1,params,index,[],args)          val (e1',params',args',code,fieldset)=  rewriteOp("ArcCosine", e1,params,index,[],args,fieldset,flag)
201  val body =E.PowInt(e1',n1)          val body =E.ArcCosine e1'
202  val einapp= rewriteOrig(y,body,params',index,[],args')  val einapp= rewriteOrig(y,body,params',index,[],args')
203  in  in
204  (einapp,code)              (einapp,code,fieldset)
205  end  end
206    
207        (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
208        * calls rewriteOp() lift  on ein_exp
209        *)
210        fun handleSine(y,e1,params,index,args,fieldset,flag)=let
211            val (e1',params',args',code,fieldset)=  rewriteOp("sine", e1,params,index,[],args,fieldset,flag)
212            val body =E.Sine e1'
213            val einapp= rewriteOrig(y,body,params',index,[],args')
214            in
215                (einapp,code,fieldset)
216        end
217    
218      (* handlePowReal:var*ein_exp *params*index*args-> (var*einap)*code      (* handleSine:var*ein_exp *params*index*args-> (var*einap)*code
219      * calls rewriteOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
220      *)      *)
221      fun handlePowReal(y,(e1,n1),params,index,args)=let      fun handleArcSine(y,e1,params,index,args,fieldset,flag)=let
222      val (e1',params',args',code)=  rewriteOp("powreal", e1,params,index,[],args)          val (e1',params',args',code,fieldset)=  rewriteOp("ArcSine", e1,params,index,[],args,fieldset,flag)
223      val body =E.PowReal(e1',n1)          val body =E.ArcSine e1'
224      val einapp= rewriteOrig(y,body,params',index,[],args')      val einapp= rewriteOrig(y,body,params',index,[],args')
225      in      in
226      (einapp,code)              (einapp,code,fieldset)
227      end      end
228    
229    
230     (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code     (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
231      * calls rewriteOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
232      *)      *)
233      fun handleSub(y,e1,e2,params,index,args)=let      fun handleSub(y,e1,e2,params,index,args,fieldset,flag)=let
234          val ([e1',e2'],params',args',code)=  rewriteOps("subt",[e1,e2],params,index,[],args)          val ([e1',e2'],params',args',code,fieldset)=  rewriteOps("subt",[e1,e2],params,index,[],args,fieldset,flag)
235          val body =E.Sub(e1',e2')          val body =E.Sub(e1',e2')
236          val einapp= rewriteOrig(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
237          in          in
238              (einapp,code)              (einapp,code,fieldset)
239          end          end
240    
241      (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code      (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
242      * calls rewriteOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
243      *)      *)
244      fun handleDiv(y,e1,e2,params,index,args)=let      fun handleDiv(y,e1,e2,params,index,args,fieldset,flag)=let
245          val (e1',params1',args1',code1')=rewriteOp("div-num",e1,params,index,[],args)          val (e1',params1',args1',code1',fieldset)=rewriteOp("div-num",e1,params,index,[],args,fieldset,flag)
246          val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',index,[],args1')          val (e2',params2',args2',code2',fieldset)=rewriteOp("div-denom",e2,params1',index,[],args1',fieldset,flag)
         (*val (e2',params2',args2',code2')=rewriteOp("div-denom",e2,params1',[],[],args1')*)  
247          val body =E.Div(e1',e2')          val body =E.Div(e1',e2')
248          val einapp= rewriteOrig(y,body,params2',index,[],args2')          val einapp= rewriteOrig(y,body,params2',index,[],args2')
249          in          in
250                  (einapp,code1'@code2')                  (einapp,code1'@code2',fieldset)
251          end          end
252    
253      (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code      (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
254      * calls rewriteOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
255      *)      *)
256      fun handleAdd(y,e1,params,index,args)=let  fun   handleAdd(y,e1 as [_,_,_,_],params,index,args,fieldset,flag)=let
257    
258    val (e1',params',args',code,fieldset)=  rewriteOps("add",e1,params,index,[],args,fieldset,flag)
259    fun pb es=String.concatWith "\n\n\t-*-" (List.map P.printbody es)
260    (*)val _ =print("\n****Inside Add:"^Int.toString(length index)^"\n -"^ pb e1 ^"----- newbies\n-"^ pb e1')*)
261    
262    val body =E.Add e1'
263    val einapp= rewriteOrig(y,body,params',index,[],args')
264    in
265    (einapp,code,fieldset)
266    end
267        | handleAdd(y,e1,params,index,args,fieldset,flag)=let
268    
269            val (e1',params',args',code,fieldset)=  rewriteOps("add",e1,params,index,[],args,fieldset,flag)
270            fun pb es=String.concatWith "\n-" (List.map P.printbody es)
271    
272    
         val (e1',params',args',code)=  rewriteOps("add",e1,params,index,[],args)  
273          val body =E.Add e1'          val body =E.Add e1'
274          val einapp= rewriteOrig(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
275          in          in
276              (einapp,code)              (einapp,code,fieldset)
277          end          end
278    
279      (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code      (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
280       * calls rewriteOps() lift  on ein_exp       * calls rewriteOps() lift  on ein_exp
281       *)       *)
282      fun handleProd(y,e1,params,index,args)=let      fun handleProd(y,e1,params,index,args,fieldset,flag)=let
283          val (e1',params',args',code)=  rewriteOps("prod",e1,params,index,[],args)          val (e1',params',args',code,fieldset)=  rewriteOps("prod",e1,params,index,[],args,fieldset,flag)
284          val body =E.Prod e1'          val body =E.Prod e1'
285          val einapp= rewriteOrig(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
286          in          in
287              (einapp,code)              (einapp,code,fieldset)
288          end          end
289    
290     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
291      * calls rewriteOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
292      *)      *)
293      fun handleSumProd(y,e1,params,index,sx,args)=let      fun handleSumProd(y,e1,params,index,sx,args,fieldset,flag)=let
294          val (e1',params',args',code)=  rewriteOps("sumprod",e1,params,index,sx,args)          val (e1',params',args',code,fieldset)=  rewriteOps("sumprod",e1,params,index,sx,args,fieldset,flag)
295          val body= E.Sum(sx,E.Prod e1')          val body= E.Sum(sx,E.Prod e1')
296          val einapp= rewriteOrig(y,body,params',index,sx,args')          val einapp= rewriteOrig(y,body,params',index,sx,args')
297          in          in
298              (einapp,code)              (einapp,code,fieldset)
299          end          end
300    
301      (* split:var*ein_app-> (var*einap)*code      (* split:var*ein_app-> (var*einap)*code
302      * split ein expression into smaller pieces      * split ein expression into smaller pieces
303        note we leave summation around probe exp        note we leave summation around probe exp
304      *)      *)
305      fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let      fun split((y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args)),fieldset,flag) =let
306          val zero=   (setEinZero y,[])          val x= ((y,einapp),fieldset,flag)
307          val default=((y,einapp),[])          val zero=   (setEinZero y,[],fieldset)
308            val default=((y,einapp),[],fieldset)
309          val sumIndex=ref []          val sumIndex=ref []
310          val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)          val str="Poorly formed EIN operator. Argument needs to be applied in High-IL"^(P.printbody body)
311          val _=testp["\n\nStarting split",P.printbody body]          val _=testp["\n\nStarting split",P.printbody body]
# Line 257  Line 313 
313              of E.Probe (E.Conv _,_)   => default              of E.Probe (E.Conv _,_)   => default
314              | E.Probe(E.Field _,_)    => raise Fail str              | E.Probe(E.Field _,_)    => raise Fail str
315              | E.Probe _               => raise Fail str              | E.Probe _               => raise Fail str
316              | E.Conv _                => zero              | E.Conv _                => raise Fail "should have been swept"
317              | E.Field _               => zero              | E.Field _               => raise Fail "should have been swept"
318              | E.Apply _               => zero              | E.Apply _               => raise Fail "should have been swept"
319              | E.Lift e                => zero              | E.Lift e                => raise Fail "should have been swept"
320              | E.Delta _               => default              | E.Delta _               => default
321              | E.Epsilon _             => default              | E.Epsilon _             => default
322              | E.Eps2 _                => default              | E.Eps2 _                => default
323              | E.Tensor _              => default              | E.Tensor _              => default
324              | E.Const _               => default              | E.Const _               => default
325              | E.ConstR _              => default              | E.ConstR _              => default
326              | E.Neg e1                => handleNeg(y,e1,params,index,args)              | E.Neg e1                => handleNeg(e1,x)
327              | E.Sqrt e1               => handleSqrt(y,e1,params,index,args)              | E.Sqrt e1               => handleSqrt(y,e1,params,index,args,fieldset,flag)
328              | E.PowInt e1             => handlePowInt(y,e1,params,index,args)              | E.Cosine e1             => handleCosine(y,e1,params,index,args,fieldset,flag)
329              | E.PowReal e1            => handlePowReal(y,e1,params,index,args)              | E.ArcCosine e1          => handleArcCosine(y,e1,params,index,args,fieldset,flag)
330              | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args)              | E.Sine e1               => handleSine(y,e1,params,index,args,fieldset,flag)
331              | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args)              | E.ArcSine e1            => handleArcSine(y,e1,params,index,args,fieldset,flag)
332              | E.Sum(sx,E.Tensor(id,[]))=> rewrite (E.Tensor(id,[]))              | E.PowInt e1             => err(" PowInt unsupported")
333              | E.Sum(sx,E.Const c)      =>rewrite ( E.Const c )              | E.PowReal e1            => err(" PowReal unsupported")
334              | E.Sum(sx,E.ConstR r)    => rewrite (E.ConstR r)              | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args,fieldset,flag)
335              | E.Sum(sx,E.Neg n)       => rewrite (E.Neg(E.Sum(sx,n)))              | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args,fieldset,flag)
336              | E.Sum(sx,E.Add a)       => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))  (*
             | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))  
             | E.Sum(sx,E.Div(e1,e2))  => rewrite (E.Sum(sx,E.Prod[e1,E.Div(E.Const 1,e2)]))  
             | E.Sum(sx,E.Lift e )     => rewrite (E.Lift(E.Sum(sx,e)))  
             | E.Sum(sx,E.PowReal(e,n1)) => rewrite (E.PowReal(E.Sum(sx,e),n1))  
             | E.Sum(sx,E.Sqrt e)        => rewrite (E.Sqrt(E.Sum(sx,e)))  
             | E.Sum(c1,E.Sum (c2,e))    => rewrite (E.Sum (c1@c2,e))  
337              | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_)  ])      => default              | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_)  ])      => default
338              | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_)  ])      => default              | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_)  ])      => default
339    *)
340              | E.Sum(_,E.Probe(E.Conv _,_))    => default              | E.Sum(_,E.Probe(E.Conv _,_))    => default
341              | E.Sum(_,E.Conv _)       => zero              | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args,fieldset,flag)
342              | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args)              | E.Sum(sx,E.Delta d)     => handleSumProd(y,[E.Delta d],params,index,sx,args,fieldset,flag)
343              | E.Sum(sx,_)             => default              | E.Sum(sx,E.Tensor _)    => default
344              | E.Add e1                => handleAdd(y,e1,params,index,args)              | E.Sum(sx,_)             => err(" summation not distributed:"^str)
345              | E.Prod e1               => handleProd(y,e1,params,index,args)              | E.Add e1                => handleAdd(y,e1,params,index,args,fieldset,flag)
346                | E.Prod[E.Tensor(id0,[]),E.Tensor(id1,[i]),E.Tensor(id2,[])]=>
347                        rewrite (E.Prod[E.Prod[E.Tensor(id0,[]),E.Tensor(id2,[])],E.Tensor(id1,[i])])
348                | E.Prod e1               => handleProd(y,e1,params,index,args,fieldset,flag)
349              | E.Partial _             => err(" Partial used after normalize")              | E.Partial _             => err(" Partial used after normalize")
350              | E.Krn _                 => err("Krn used before expand")              | E.Krn _                 => err("Krn used before expand")
351              | E.Value _               => err("Value used before expand")              | E.Value _               => err("Value used before expand")
352              | E.Img _                 => err("Probe used before expand")              | E.Img _                 => err("Probe used before expand")
353              (*end case *))              (*end case *))
354          val (einapp2,newbies) =rewrite body          val (einapp2,newbies,fieldset) =rewrite body
         in  
             (einapp2,newbies)  
         end  
         |split(y,app) =((y,app),[])  
   
   
     (*Distribute summation if needed*)  
     fun distributeSummation(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let  
         fun rewrite b=(case b  
         of E.Sum(sx,E.Tensor(id,[]))    => E.Tensor(id,[])  
         | E.Sum(sx,E.Const c)           => E.Const c  
         | E.Sum(sx,E.ConstR r)          => E.ConstR r  
         | E.Sum(sx,E.Neg n)             => rewrite(E.Neg(E.Sum(sx,n)))  
         | E.Sum(sx,E.Add a)             => rewrite(E.Add(List.map (fn e=> E.Sum(sx,e)) a))  
         | E.Sum(sx,E.Sub (e1,e2))       => rewrite(E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))  
        (* | E.Sum(sx,E.Div(e1,e2))        => rewrite( E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))*)  
         | E.Sum(sx,E.Lift e )           => rewrite (E.Lift(E.Sum(sx,e)))  
         | E.Sum(sx,E.PowReal(e,n1))     => rewrite(E.PowReal(E.Sum(sx,e),n1))  
         | E.Sum(sx,E.Sqrt e)            => rewrite(E.Sqrt(E.Sum(sx,e)))  
         | E.Sum(sx,E.Sum (c2,e))        => rewrite (E.Sum (sx@c2,e))  
         | E.Sum(sx,E.Prod p)            => let  
                 val (c,e)=filterSca(sx,p)  
                 in e end  
         | E.Div(e1,e2)                  => E.Div(rewrite e1, rewrite e2)  
         | E.Sub(e1,E.Const 0)           => rewrite e1  
         | E.Sub(e1,e2)                  => E.Sub(rewrite e1, rewrite e2)  
         | E.Add es                      => E.Add(List.map rewrite es)  
         | E.Prod es                     => E.Prod(List.map rewrite es)  
         | E.Neg e                       => E.Neg(rewrite e)  
         | E.Sqrt e                      => E.Sqrt(rewrite e)  
         | _                             => b  
     (*end case*))  
     val body =rewrite body  
     val _ =testp["\nAfter distributeSummation \n",P.printbody body]  
     val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=body})  
     val einapp2= (y,DstIL.EINAPP(ein,args))  
355      in      in
356          split(einapp2)              ((einapp2,newbies),fieldset)
357      end      end
358      |distributeSummation(y,app) =((y,app),[])          |split((y,app),fieldset,_) =(((y,app),[]),fieldset)
   
359    
360    
361      (* iterMultiple:code*code=> (code*code)      fun iterMultiple(einapp2,newbies2,fieldset)=let
      * recursively split ein expression into smaller pieces  
     *)  
     fun iterMultiple(einapp2,newbies2)=let  
362          fun itercode([],rest,code,_)=(rest,code)          fun itercode([],rest,code,_)=(rest,code)
363          | itercode(e1::newbies,rest,code,cnt)=let          | itercode(e1::newbies,rest,code,cnt)=let
364              val _ =testp["\n\n******* split term **",Int.toString cnt," *****","\n \n",printEINAPP(e1),"\n=>\n"]                  val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
             val (einapp3,code3) = distributeSummation e1  
             val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP code3))]  
365              val (rest4,code4)=itercode(code3,[],[],cnt+1)              val (rest4,code4)=itercode(code3,[],[],cnt+1)
366              in itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2)          in
367                itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
368              end              end
369          val(rest,code)= itercode(newbies2,[],[],1)          val(rest,code)= itercode(newbies2,[],[],1)
370          in          in
371              (einapp2,code@rest)              ((code)@rest@[einapp2])
372          end          end
373    
     fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let  
         val bodysweep=handleE.sweep body  
         val _=testp["\nPresweep\n",P.printbody body,"\n\n Sweep\n",P.printbody bodysweep,"\n"]  
         val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=bodysweep})  
         val _=testp["\n\n Clean Summation\n",P.printbody(Ein.body ein),"\n"]  
         val einapp2=(y,DstIL.EINAPP(ein, args))  
         val _ =testp["\n\n******* split term **",Int.toString (0)," ***** \n \t==>\n",printEINAPP(einapp2)]  
         val (einapp3,newbies2)=distributeSummation einapp2  
         val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP newbies2))]  
374    
375        fun iterAll(einapp2,fieldset)=let
376            fun itercode([],rest,code,_)=(rest,code)
377            | itercode(e1::newbies,rest,code,cnt)=let
378                val ((einapp3,code3),_) = split(e1,fieldset,numFlag)
379                val (rest4,code4)=itercode(code3,[],[],cnt+1)
380                    val _ =testp [toStringBind(e1),"\n\t===>\n",toStringBind(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind (code4@rest4)))]
381      in      in
382          iterMultiple(einapp3,newbies2)                  itercode(newbies,rest@[einapp3],code4@( rest4)@code,cnt+2)
383                end
384            val(rest,code)= itercode(einapp2,[],[],0)
385            in
386                (code@rest)
387      end      end
388    
389      (* gettest:code*code=> (code*code)      fun splitEinApp einapp3= let
390      * print results for splitting einapp          val fieldset= einSet.EinSet.empty
391    
392            (* **** split in parts **** *)
393            (*
394            val ((einapp4,newbies4),fieldset)=split(einapp3,fieldset,0)
395            val _ =testp["\n\t===>\n",toStringBind(einapp4),"\nand\n",(String.concatWith",\n\t"(List.map toStringBind newbies4))]
396            val (newbies5)= iterMultiple(einapp4,newbies4,fieldset)
397      *)      *)
398      fun gettest einapp=(case testing  
399          of 0=>iterSplit(einapp)          (* **** split all at once **** *)
400          | _=>let          val (newbies5)= iterAll([einapp3],fieldset)
401              val star="\n************* SPLIT INITIAL********\n"  
             val _ =testp[star,"\n","start get test",printEINAPP einapp]  
             val (einapp2,newbies)=iterSplit(einapp)  
             val _ =testp["\n\n Returning \n\n =>",printEINAPP einapp2,  
                     " newbies\n\t",String.concatWith",\n\t"(List.map printEINAPP newbies), "\n",star]  
402              in              in
403                  (einapp2,newbies)              newbies5
404              end              end
405          (*end case*))  
406    
407    end; (* local *)    end; (* local *)
408    

Legend:
Removed from v.3030  
changed lines
  Added in v.3316

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0