Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3030, Tue Mar 10 01:24:41 2015 UTC revision 3033, Tue Mar 10 15:17:25 2015 UTC
# Line 27  Line 27 
27      structure P=Printer      structure P=Printer
28      structure cleanP=cleanParams      structure cleanP=cleanParams
29      structure cleanI=cleanIndex      structure cleanI=cleanIndex
30      structure handleE=handleEin  
31    
32      in      in
33    
34      val testing=1      val testing=0
35      fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}      fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
36      fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))      fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
37      val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])      val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
# Line 39  Line 39 
39      fun cleanParams e =cleanP.cleanParams e      fun cleanParams e =cleanP.cleanParams e
40      fun cleanIndex e =cleanI.cleanIndex e      fun cleanIndex e =cleanI.cleanIndex e
41      fun printEINAPP e=MidToString.printEINAPP e      fun printEINAPP e=MidToString.printEINAPP e
42      fun isZero e=handleE.isZero e  
43      fun sweep e=handleE.sweep e  
44    
45      fun itos i =Int.toString i      fun itos i =Int.toString i
46      fun filterSca e=Filter.filterSca e      fun filterSca e=Filter.filterSca e
47      fun err str=raise Fail str      fun err str=raise Fail str
# Line 121  Line 122 
122          | m(e1::es,rest,params,args,code)=let          | m(e1::es,rest,params,args,code)=let
123    
124              val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)              val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)
125  val _ =print("rewriteOP:\n"^P.printbody e1^"\n\t=>"^ P.printbody e1')              (*val _ =testp["rewriteOP:\n",P.printbody e1,"\n\t=>", P.printbody e1']*)
126              in              in
127                  m(es,rest@[e1'],params',args',code@code')                  m(es,rest@[e1'],params',args',code@code')
128              end              end
# Line 133  Line 134 
134             When the operation is zero then we return a real.             When the operation is zero then we return a real.
135          -Moved is Zero to before split.          -Moved is Zero to before split.
136      *)      *)
137      fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body)      fun rewriteOrig(y,body,params,index,sx,args) =cleanParams(y,body,params,index,args)
         of 1=>  setEinZero y  
         | _ => cleanParams(y,body,params,index,args)  
         (*end case*))  
138    
139      (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code      (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
140      * calls rewriteOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
# Line 273  Line 271 
271              | E.PowReal e1            => handlePowReal(y,e1,params,index,args)              | E.PowReal e1            => handlePowReal(y,e1,params,index,args)
272              | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args)              | E.Sub (e1,e2)           => handleSub(y,e1,e2,params,index,args)
273              | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args)              | E.Div (e1,e2)           => handleDiv(y,e1,e2,params,index,args)
             | E.Sum(sx,E.Tensor(id,[]))=> rewrite (E.Tensor(id,[]))  
             | E.Sum(sx,E.Const c)      =>rewrite ( E.Const c )  
             | E.Sum(sx,E.ConstR r)    => rewrite (E.ConstR r)  
             | E.Sum(sx,E.Neg n)       => rewrite (E.Neg(E.Sum(sx,n)))  
             | E.Sum(sx,E.Add a)       => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))  
             | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))  
             | E.Sum(sx,E.Div(e1,e2))  => rewrite (E.Sum(sx,E.Prod[e1,E.Div(E.Const 1,e2)]))  
             | E.Sum(sx,E.Lift e )     => rewrite (E.Lift(E.Sum(sx,e)))  
             | E.Sum(sx,E.PowReal(e,n1)) => rewrite (E.PowReal(E.Sum(sx,e),n1))  
             | E.Sum(sx,E.Sqrt e)        => rewrite (E.Sqrt(E.Sum(sx,e)))  
             | E.Sum(c1,E.Sum (c2,e))    => rewrite (E.Sum (c1@c2,e))  
274              | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_)  ])      => default              | E.Sum(_,E.Prod[E.Eps2 _, E.Probe(E.Conv _,_)  ])      => default
275              | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_)  ])      => default              | E.Sum(_,E.Prod[E.Epsilon _, E.Probe(E.Conv _,_)  ])      => default
276              | E.Sum(_,E.Probe(E.Conv _,_))    => default              | E.Sum(_,E.Probe(E.Conv _,_))    => default
             | E.Sum(_,E.Conv _)       => zero  
277              | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args)              | E.Sum(sx,E.Prod e1)     => handleSumProd(y,e1,params,index,sx,args)
278              | E.Sum(sx,_)             => default              | E.Sum(sx,E.Add [a0,a1,a2])             => err(" \n summation not distributed:add3\n found:\n\t"^P.printbody a0^"\nin \n"^str)
279    
280                | E.Sum(sx,E.Add n)             => err(" \n summation not distributed:add\n"^Int.toString(length n)^"found:\n\t"^P.printbody b^"\nin \n"^str)
281                | E.Sum(sx,E.Sub _)             => err(" summation not distributed:sub"^str)
282                | E.Sum(sx,E.Sqrt _)             => err(" summation not distributed:sqrt"^str)
283                | E.Sum(sx,E.Neg _)             => err(" summation not distributed:neg"^str)
284                | E.Sum(sx,_)             => err(" summation not distributed:"^str)
285              | E.Add e1                => handleAdd(y,e1,params,index,args)              | E.Add e1                => handleAdd(y,e1,params,index,args)
286              | E.Prod e1               => handleProd(y,e1,params,index,args)              | E.Prod e1               => handleProd(y,e1,params,index,args)
287              | E.Partial _             => err(" Partial used after normalize")              | E.Partial _             => err(" Partial used after normalize")
# Line 303  Line 295 
295          end          end
296          |split(y,app) =((y,app),[])          |split(y,app) =((y,app),[])
297    
   
     (*Distribute summation if needed*)  
     fun distributeSummation(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let  
         fun rewrite b=(case b  
         of E.Sum(sx,E.Tensor(id,[]))    => E.Tensor(id,[])  
         | E.Sum(sx,E.Const c)           => E.Const c  
         | E.Sum(sx,E.ConstR r)          => E.ConstR r  
         | E.Sum(sx,E.Neg n)             => rewrite(E.Neg(E.Sum(sx,n)))  
         | E.Sum(sx,E.Add a)             => rewrite(E.Add(List.map (fn e=> E.Sum(sx,e)) a))  
         | E.Sum(sx,E.Sub (e1,e2))       => rewrite(E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))  
        (* | E.Sum(sx,E.Div(e1,e2))        => rewrite( E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))*)  
         | E.Sum(sx,E.Lift e )           => rewrite (E.Lift(E.Sum(sx,e)))  
         | E.Sum(sx,E.PowReal(e,n1))     => rewrite(E.PowReal(E.Sum(sx,e),n1))  
         | E.Sum(sx,E.Sqrt e)            => rewrite(E.Sqrt(E.Sum(sx,e)))  
         | E.Sum(sx,E.Sum (c2,e))        => rewrite (E.Sum (sx@c2,e))  
         | E.Sum(sx,E.Prod p)            => let  
                 val (c,e)=filterSca(sx,p)  
                 in e end  
         | E.Div(e1,e2)                  => E.Div(rewrite e1, rewrite e2)  
         | E.Sub(e1,E.Const 0)           => rewrite e1  
         | E.Sub(e1,e2)                  => E.Sub(rewrite e1, rewrite e2)  
         | E.Add es                      => E.Add(List.map rewrite es)  
         | E.Prod es                     => E.Prod(List.map rewrite es)  
         | E.Neg e                       => E.Neg(rewrite e)  
         | E.Sqrt e                      => E.Sqrt(rewrite e)  
         | _                             => b  
     (*end case*))  
     val body =rewrite body  
     val _ =testp["\nAfter distributeSummation \n",P.printbody body]  
     val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=body})  
     val einapp2= (y,DstIL.EINAPP(ein,args))  
     in  
         split(einapp2)  
     end  
     |distributeSummation(y,app) =((y,app),[])  
   
   
   
298      (* iterMultiple:code*code=> (code*code)      (* iterMultiple:code*code=> (code*code)
299       * recursively split ein expression into smaller pieces       * recursively split ein expression into smaller pieces
300      *)      *)
# Line 348  Line 302 
302          fun itercode([],rest,code,_)=(rest,code)          fun itercode([],rest,code,_)=(rest,code)
303          | itercode(e1::newbies,rest,code,cnt)=let          | itercode(e1::newbies,rest,code,cnt)=let
304              val _ =testp["\n\n******* split term **",Int.toString cnt," *****","\n \n",printEINAPP(e1),"\n=>\n"]              val _ =testp["\n\n******* split term **",Int.toString cnt," *****","\n \n",printEINAPP(e1),"\n=>\n"]
305              val (einapp3,code3) = distributeSummation e1               val (einapp3,code3) = split e1
306              val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP code3))]              val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP code3))]
307              val (rest4,code4)=itercode(code3,[],[],cnt+1)              val (rest4,code4)=itercode(code3,[],[],cnt+1)
308              in itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2)              in itercode(newbies,rest@[einapp3],code4@rest4@code,cnt+2)
# Line 358  Line 312 
312              (einapp2,code@rest)              (einapp2,code@rest)
313          end          end
314    
     fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let  
         val bodysweep=handleE.sweep body  
         val _=testp["\nPresweep\n",P.printbody body,"\n\n Sweep\n",P.printbody bodysweep,"\n"]  
         val ein=SummationEin.cleanSummation(Ein.EIN{params=params, index=index, body=bodysweep})  
         val _=testp["\n\n Clean Summation\n",P.printbody(Ein.body ein),"\n"]  
         val einapp2=(y,DstIL.EINAPP(ein, args))  
         val _ =testp["\n\n******* split term **",Int.toString (0)," ***** \n \t==>\n",printEINAPP(einapp2)]  
         val (einapp3,newbies2)=distributeSummation einapp2  
         val _ =testp["\n\t===>\n",printEINAPP(einapp3),"\nand\n",(String.concatWith",\n\t"(List.map printEINAPP newbies2))]  
   
     in  
         iterMultiple(einapp3,newbies2)  
     end  
   
     (* gettest:code*code=> (code*code)  
     * print results for splitting einapp  
     *)  
     fun gettest einapp=(case testing  
         of 0=>iterSplit(einapp)  
         | _=>let  
             val star="\n************* SPLIT INITIAL********\n"  
             val _ =testp[star,"\n","start get test",printEINAPP einapp]  
             val (einapp2,newbies)=iterSplit(einapp)  
             val _ =testp["\n\n Returning \n\n =>",printEINAPP einapp2,  
                     " newbies\n\t",String.concatWith",\n\t"(List.map printEINAPP newbies), "\n",star]  
             in  
                 (einapp2,newbies)  
             end  
         (*end case*))  
315    
316    end; (* local *)    end; (* local *)
317    

Legend:
Removed from v.3030  
changed lines
  Added in v.3033

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0