Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2843, Mon Dec 8 01:27:25 2014 UTC revision 2845, Fri Dec 12 06:46:23 2014 UTC
# Line 14  Line 14 
14   (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.   (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
15    
16   (1c) Call cleanParams.sml to clean the params in the subexpression.\\   (1c) Call cleanParams.sml to clean the params in the subexpression.\\
  (2) All the lifted subexpressions in the original EIN operator are replaced with tensors and non-probed fields with zeros. Call isZero() to determine if the body is zero. If so, needs to return 0. Otherwise clean the EIN operator.  
   
17   *)   *)
18    
19  structure Split = struct  structure Split = struct
# Line 23  Line 21 
21      local      local
22    
23      structure E = Ein      structure E = Ein
     structure mk= mkOperators  
     structure SrcIL = HighIL  
     structure SrcTy = HighILTypes  
     structure SrcOp = HighOps  
     structure SrcSV = SrcIL.StateVar  
     structure VTbl = SrcIL.Var.Tbl  
24      structure DstIL = MidIL      structure DstIL = MidIL
25      structure DstTy = MidILTypes      structure DstTy = MidILTypes
     structure DstOp = MidOps  
26      structure DstV = DstIL.Var      structure DstV = DstIL.Var
     structure SrcV = SrcIL.Var  
27      structure P=Printer      structure P=Printer
     structure F=Filter  
     structure T=TransformEin  
     structure Var = MidIL.Var  
28      structure cleanP=cleanParams      structure cleanP=cleanParams
29      structure cleanI=cleanIndex      structure cleanI=cleanIndex
30        structure handleE=handleEin
31    
     val testing=1  
32      in      in
33    
34        val testing=1
35      fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}      fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
36      fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))      fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
37      val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])      val einappzero=DstIL.EINAPP(setEin([],[],E.Const 0),[])
38      fun setEinZero y=  (y,einappzero)      fun setEinZero y=  (y,einappzero)
39      fun cleanParams e =cleanP.cleanParams e      fun cleanParams e =cleanP.cleanParams e
40      fun cleanIndex e =cleanI.cleanIndex e      fun cleanIndex e =cleanI.cleanIndex e
41        fun printEINAPP e=MidToString.printEINAPP e
42        fun isZero e=handleE.isZero e
43      fun itos i =Int.toString i      fun itos i =Int.toString i
44      fun err str=raise Fail str      fun err str=raise Fail str
45      val cnt = ref 0      val cnt = ref 0
# Line 65  Line 54 
54          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
55          (*end case*))          (*end case*))
56    
     fun printEINAPP(id, DstIL.EINAPP(rator, args))=let  
         val a=String.concatWith " , " (List.map Var.toString args)  
         in  
             String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> ==",P.printerE rator, a,"\n"])  
         end  
       | printEINAPP(id, DstIL.OP(rator, args))=let  
           val a=String.concatWith " , " (List.map Var.toString args)  
          in  
             String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> =",DstOp.toString rator,a,"\n"])  
           end  
   
        | printEINAPP(id,_)= String.concat([Var.toString id,"<",(DstTy.toString (Var.ty id)),"> non-einapp\n"])  
   
57    
58      (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)      (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
59      *lifts expression and returns replacement tensor      *lifts expression and returns replacement tensor
60      * cleans the index and params of subexpression      * cleans the index and params of subexpression
61      *creates new param and replacement tensor for the original ein_exp      *creates new param and replacement tensor for the original ein_exp
62      *)      *)
63      fun lift(e,params,index,sx,args)=let      fun lift(name,e,params,index,sx,args)=let
64          val (tshape,sizes,body)=cleanIndex(e,index,sx)          val (tshape,sizes,body)=cleanIndex(e,index,sx)
65          val id=length(params)          val id=length(params)
66          val Rparams=params@[E.TEN(1,sizes)]          val Rparams=params@[E.TEN(1,sizes)]
67          val Re=E.Tensor(id,tshape)          val Re=E.Tensor(id,tshape)
68          val M  = DstV.new (genName ("TLifted_"^itos id), DstTy.TensorTy sizes)          val M  = DstV.new (genName (name^"_l_"^itos id), DstTy.TensorTy sizes)
69          val Rargs=args@[M]          val Rargs=args@[M]
70          val einapp=cleanParams(M,body,Rparams,sizes,Rargs)          val einapp=cleanParams(M,body,Rparams,sizes,Rargs)
71    
# Line 97  Line 73 
73          (Re,Rparams,Rargs,[einapp])          (Re,Rparams,Rargs,[einapp])
74      end      end
75    
   
76      (* isOp: ein->int      (* isOp: ein->int
77       * checks to see if this sub-expression is pulled out or split form original       * checks to see if this sub-expression is pulled out or split form original
78       * 0-becomes zero,1-remains the same, 2-operator       * 0-becomes zero,1-remains the same, 2-operator
# Line 121  Line 96 
96          | _           => 2          | _           => 2
97      (*end case*))      (*end case*))
98    
   
99      (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code      (* rewriteOp:ein_exp*params*index*args-> ein_exp*params*args*code
100       * If e1 an op then call lift() to replace it       * If e1 an op then call lift() to replace it
101       * Otherwise rewrite to 0 or it remains the same       * Otherwise rewrite to 0 or it remains the same
102       *)       *)
103      fun rewriteOp(e1,params,index,sx,args)=(case (isOp e1)      fun rewriteOp(name,e1,params,index,sx,args)=(case (isOp e1)
104          of  0   => (E.Const 0,params,args,[])          of  0   => (E.Const 0,params,args,[])
105          | 2     => (e1,params,args,[])          | 2     => (e1,params,args,[])
106          | _     => lift(e1,params,index,sx,args)          | _     => lift(name,e1,params,index,sx,args)
107          (*end*))          (*end*))
108    
   
109      (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars      (* rewriteOps:ein_exp list*params*index*sum_id list*mid-il vars
110             -> ein_exp list*params*args*code             -> ein_exp list*params*args*code
111       * calls rewriteOp on ein_exp list       * calls rewriteOp on ein_exp list
112       *)       *)
113      fun rewriteOps(list1,params,index,sx,args)=let      fun rewriteOps(name,list1,params,index,sx,args)=let
114          fun m([],rest,params,args,code)=(rest,params,args,code)          fun m([],rest,params,args,code)=(rest,params,args,code)
115          | m(e1::es,rest,params,args,code)=let          | m(e1::es,rest,params,args,code)=let
116              val (e1',params',args',code')= rewriteOp(e1,params,index,sx,args)              val (e1',params',args',code')= rewriteOp(name,e1,params,index,sx,args)
117              in              in
118                  m(es,rest@[e1'],params',args',code@code')                  m(es,rest@[e1'],params',args',code@code')
119              end              end
# Line 148  Line 121 
121              m(list1,[],params,args,[])              m(list1,[],params,args,[])
122          end          end
123    
124      (*isZero: var* ein_exp* params*index list*mid-il vars      (*rewriteOrig: var* ein_exp* params*index list*mid-il vars
125             When the operation is zero then we return a real.             When the operation is zero then we return a real.
126            -Moved is Zero to before split.
127      *)      *)
128      fun isZero(y,body,params,index,sx,args) =(case (cleanP.isZero body)      fun rewriteOrig(y,body,params,index,sx,args) =(case (isZero body)
129          of 1=>  setEinZero y          of 1=>  setEinZero y
130          | _ => cleanParams(y,body,params,index,args)          | _ => cleanParams(y,body,params,index,args)
131      (*end case*))      (*end case*))
# Line 160  Line 134 
134      * calls rewriteOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
135      *)      *)
136      fun handleNeg(y,e1,params,index,args)=let      fun handleNeg(y,e1,params,index,args)=let
137          val (e1',params',args',code)=  rewriteOp(e1,params,index,[],args)          val (e1',params',args',code)=  rewriteOp(DstV.name y, e1,params,index,[],args)
138          val body =E.Neg e1'          val body =E.Neg e1'
139          val einapp= isZero(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
140      in      in
141          (einapp,code)          (einapp,code)
142      end      end
# Line 171  Line 145 
145      * calls rewriteOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
146      *)      *)
147      fun handleSub(y,e1,e2,params,index,args)=let      fun handleSub(y,e1,e2,params,index,args)=let
148          val ([e1',e2'],params',args',code)=  rewriteOps([e1,e2],params,index,[],args)          val ([e1',e2'],params',args',code)=  rewriteOps(DstV.name y,[e1,e2],params,index,[],args)
149          val body =E.Sub(e1',e2')          val body =E.Sub(e1',e2')
150          val einapp= isZero(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
151      in      in
152          (einapp,code)          (einapp,code)
153      end      end
# Line 182  Line 156 
156      * calls rewriteOp() lift  on ein_exp      * calls rewriteOp() lift  on ein_exp
157      *)      *)
158      fun handleDiv(y,e1,e2,params,index,args)=let      fun handleDiv(y,e1,e2,params,index,args)=let
159          val (e1',params1',args1',code1')=rewriteOp(e1,params,index,[],args)          val (e1',params1',args1',code1')=rewriteOp(DstV.name y,e1,params,index,[],args)
160          val (e2',params2',args2',code2')=rewriteOp(e2,params1',[],[],args1')          val (e2',params2',args2',code2')=rewriteOp(DstV.name y,e2,params1',[],[],args1')
161          val body =E.Div(e1',e2')          val body =E.Div(e1',e2')
162          val einapp= isZero(y,body,params2',index,[],args2')          val einapp= rewriteOrig(y,body,params2',index,[],args2')
163      in      in
164              (einapp,code1'@code2')              (einapp,code1'@code2')
165      end      end
166    
   
167      (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code      (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
168      * calls rewriteOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
169      *)      *)
170      fun handleAdd(y,e1,params,index,args)=let      fun handleAdd(y,e1,params,index,args)=let
171          val (e1',params',args',code)=  rewriteOps(e1,params,index,[],args)          val (e1',params',args',code)=  rewriteOps(DstV.name y,e1,params,index,[],args)
172          val body =E.Add e1'          val body =E.Add e1'
173          val einapp= isZero(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
174      in      in
175          (einapp,code)          (einapp,code)
176      end      end
# Line 206  Line 179 
179       * calls rewriteOps() lift  on ein_exp       * calls rewriteOps() lift  on ein_exp
180       *)       *)
181      fun handleProd(y,e1,params,index,args)=let      fun handleProd(y,e1,params,index,args)=let
182          val (e1',params',args',code)=  rewriteOps(e1,params,index,[],args)          val (e1',params',args',code)=  rewriteOps(DstV.name y,e1,params,index,[],args)
183          val body =E.Prod e1'          val body =E.Prod e1'
184          val einapp= isZero(y,body,params',index,[],args')          val einapp= rewriteOrig(y,body,params',index,[],args')
185      in      in
186          (einapp,code)          (einapp,code)
187      end      end
# Line 217  Line 190 
190      * calls rewriteOps() lift  on ein_exp      * calls rewriteOps() lift  on ein_exp
191      *)      *)
192      fun handleSumProd(y,e1,params,index,sx,args)=let      fun handleSumProd(y,e1,params,index,sx,args)=let
193          val (e1',params',args',code)=  rewriteOps(e1,params,index,sx,args)          val (e1',params',args',code)=  rewriteOps(DstV.name y,e1,params,index,sx,args)
194          val body= E.Sum(sx,E.Prod e1')          val body= E.Sum(sx,E.Prod e1')
195          val einapp= isZero(y,body,params',index,sx,args')          val einapp= rewriteOrig(y,body,params',index,sx,args')
196      in      in
197          (einapp,code)          (einapp,code)
198      end      end
# Line 264  Line 237 
237              | E.Value _               => err("Value used before expand")              | E.Value _               => err("Value used before expand")
238              | E.Img _                 => err("Probe used before expand")              | E.Img _                 => err("Probe used before expand")
239              (*end case *))              (*end case *))
   
   
240          val (einapp2,newbies) =rewrite body          val (einapp2,newbies) =rewrite body
241          in          in
242              (einapp2,newbies)              (einapp2,newbies)
243          end          end
244          |split(y,app) =((y,app),[])          |split(y,app) =((y,app),[])
245    
   
246      (* iterMultiple:code*code=> (code*code)      (* iterMultiple:code*code=> (code*code)
247       * recursively split ein expression into smaller pieces       * recursively split ein expression into smaller pieces
248      *)      *)
# Line 288  Line 258 
258              (einapp2,code@rest)              (einapp2,code@rest)
259          end          end
260    
   
261      fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let      fun iterSplit(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
262          val (_,_,body')=cleanIndex(body,index,[])          (*val (_,_,body')=cleanIndex(body,index,[])
263          val einapp1= assignEinApp(y,params,index,body',args)          val einapp1= assignEinApp(y,params,index,body',args)
264            *)
265            val (_,sizes,body')=cleanIndex(body,index,[])
266            val einapp1= assignEinApp(y,params,index,body',args)
267            val a=testp["\n rewriten einapp\n \t",printEINAPP einapp1]
268          val (einapp2,newbies2)=split einapp1          val (einapp2,newbies2)=split einapp1
269      in      in
270          iterMultiple(einapp2,newbies2)          iterMultiple(einapp2,newbies2)
271      end      end
272    
   
   
273      (* gettest:code*code=> (code*code)      (* gettest:code*code=> (code*code)
274      * print results for splitting einapp      * print results for splitting einapp
275      *)      *)
276      fun gettest(einapp)=(case testing      fun gettest einapp=(case testing
277          of 0=>iterSplit(einapp)          of 0=>iterSplit(einapp)
278          | _=>let          | _=>let
279              val star="\n*************\n"              val star="\n************* SPLIT********\n"
280              val _ =print(String.concat[star])              val _ =print(String.concat[star,"\n","start get test",printEINAPP einapp])
281              val (einapp2,newbies)=iterSplit(einapp)              val (einapp2,newbies)=iterSplit(einapp)
282              val a=printEINAPP einapp2              val a=printEINAPP einapp2
283              val b=String.concatWith",\n\t"(List.map printEINAPP newbies)              val b=String.concatWith",\n\t"(List.map printEINAPP newbies)
# Line 316  Line 287 
287              end              end
288          (*end case*))          (*end case*))
289    
   
290    end; (* local *)    end; (* local *)
291    
292  end (* local *)  end (* local *)

Legend:
Removed from v.2843  
changed lines
  Added in v.2845

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0