Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3030, Tue Mar 10 01:24:41 2015 UTC revision 3033, Tue Mar 10 15:17:25 2015 UTC
# Line 43  Line 43 
43      fun printEINAPP e=MidToString.printEINAPP e      fun printEINAPP e=MidToString.printEINAPP e
44      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
45      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
46    
47        fun transitionToString(a,b)=
48            print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b])
49    
50        fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}
51        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
52        fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body
53        fun setBody(body',(y,DstIL.EINAPP(E.EIN{params,index,body},args)))=
54                (y,DstIL.EINAPP(E.EIN{params=params,index=index,body=body'},args))
55    
56      fun testp n=(case testing      fun testp n=(case testing
57          of 0=> 1          of 0=> 1
58          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
# Line 157  Line 167 
167      * rewrites body      * rewrites body
168      * replace probe with expanded version      * replace probe with expanded version
169      *)      *)
170      fun replaceProbe(b,params,args,index, sx)=let      fun replaceProbe(y,originalb,b,params,args,index, sx)=let
171          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
172          val fid=length(params)          val fid=length(params)
173          val nid=fid+1          val nid=fid+1
# Line 169  Line 179 
179          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
180          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
181          val body' = multiPs(Ps,newsx1,body')          val body' = multiPs(Ps,newsx1,body')
182    
183            val body'=(case originalb
184                of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
185                | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
186                | _                                  => body'
187                (*end case*))
188            val _=transitionToString(originalb,body')
189    
190          val args'=argsA@[PArg]          val args'=argsA@[PArg]
191            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
192          in          in
193              (body',params',args' ,code)              code@[einapp]
194          end          end
195    
     fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}  
     fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)  
196    
197      fun createEinApp(alpha,index,freshIndex,dim,dx)= let      fun shapeoflifted (tshape,index,sx)=let
198            val _ =List.map (fn (E.V e)=>print(Int.toString(e)^"-")) tshape
199            val sizeMapp= cleanIndex.mkSizeMapp(index,sx)
200            val sizes= List.map(fn E.V e1=> cleanIndex.lkupIX(e1,sizeMapp,"Could not find Size of")) tshape
201            in sizes end
202    
203        fun createEinApp(alpha,index,freshIndex,dim,dx,sx)= let
204          val Pid=0          val Pid=0
205          val tid=1          val tid=1
206    
207          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
208          val params=[E.TEN(1,[dim,dim]),E.TEN(1,index)]          val params=[E.TEN(1,[dim,dim]),E.TEN(1,index)]
209          val t=E.Tensor(tid,alpha@newdx)          val tshape=alpha@newdx
210            val t=E.Tensor(tid,tshape)
211          val body = multiPs(Ps,newsx,t)          val body = multiPs(Ps,newsx,t)
212          val rator=mkEin(params,index,body)          val rator=mkEin(params,index,body)
213            val sizes= (case sx
214                of []=> index
215                | _=>(*shapeoflifted (tshape,index,sx@newsx)*) index
216                )
217          in          in
218              rator              (rator,sizes)
219          end          end
220    
221        fun  einapptostring (body,a,b)=
222            print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b])
223    
224    
225      fun liftProbe(y,b,params,args,index, sx)=let      fun liftProbe(y,b,params,args,index, sx)=let
# Line 203  Line 233 
233    
234          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
235          val FArg  = DstV.new ("F", DstTy.TensorTy(index))          val FArg  = DstV.new ("F", DstTy.TensorTy(index))
236          val rator=createEinApp(alpha,index,freshIndex,dim,dx)          val (rator,sizes)=createEinApp(alpha,index,freshIndex,dim,dx,sx)
237          val einApp0=mkEinApp(rator,[PArg,FArg])          val einApp0=mkEinApp(rator,[PArg,FArg])
238    
239    
240          (*lifted probe*)          (*lifted probe*)
241          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
242          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val dxinner=(*List.tabulate(length(dx), (fn e=>E.V e))*) dx
243            val body' = createBody(dim, s,freshIndex+nshift,alpha,dxinner,Vid, hid, nid, fid)
244          val args'=argsA          val args'=argsA
245          val ein1=mkEin(params', index,body')  
246            val ein1=mkEin(params',sizes,body')
247          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
248    
249          val rtn=code@[(FArg,einApp1),(y,einApp0)]          val rtn=(code,(FArg,einApp1),(y,einApp0))
250    
          val _=print(String.concat["\n lift probe:\n",P.printbody b,"\n\t=>\n\t", printEINAPP (FArg,einApp1),  "&\n\t", printEINAPP (y,einApp0)])  
251          in          in
252              rtn              rtn
253          end          end
254    
255    
256    
257      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
258      *Looks to see if the expression has a probe. If so, replaces it.      *Looks to see if the expression has a probe. If so, replaces it.
259      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
260      *)      *)
261      fun expandEinOp( e as (y, DstIL.EINAPP(ein as Ein.EIN{params, index, body}, args))) = let      fun expandEinOp( e as (y, DstIL.EINAPP(ein as Ein.EIN{params, index, body}, args))) = let
262          fun rewriteBody b=(case b          fun rewriteBody b=(case b
263              of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator."              of E.Probe(E.Conv(_,_,_,[]),_)
264              | E.Probe(E.Conv(_,_,_,[]),_) =>let                  => replaceProbe(y,b,b,params,args, index, [])
   
                 val (body',params',args',newbies)=replaceProbe(b,params,args, index, [])  
                 val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))  
                 val code=newbies@[einapp]  
                val _=print(String.concat["\n Replace probe:\n",P.printbody b,"\n=>",P.printbody body'])  
                 in  
                     code  
                 end  
265              | E.Probe(E.Conv _,_) =>let              | E.Probe(E.Conv _,_) =>let
266                    val _=print(String.concat["\n Lift  probe:\n",P.printerE ein])                  val (a0,a1,a2)=liftProbe(y,b,params,args, index, [])
267                  in  liftProbe(y,b,params,args, index, [])                  val _= einapptostring (b,a1,a2)
                 end  
             | E.Sum(sx,E.Probe e)  =>let  
                 val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)  
                 val  body'=E.Sum(sx,body')  
                 val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))  
                 val code=newbies@[einapp]  
                 val _=print(String.concat["\n Replace probe:\n",P.printerE ein,"\n=>",P.printbody body'])  
   
                 in  
                     code  
                 end  
             | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let  
                 val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)  
                 val  body'=E.Sum(sx,E.Prod[eps,body'])  
                 val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))  
                 val code=newbies@[einapp]  
                 val _=print(String.concat["\n Replace probe:\n",P.printerE ein,"\n=>",P.printbody body'])  
268                  in                  in
269                      code                      a0@[a1,a2]
270                  end                  end
271               | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
272                    => replaceProbe(y,b,p,params,args, index, sx)
273    
274            (*
275                | E.Sum(sx,b as E.Probe(E.Conv(_,[],_,_),_))  =>let
276                    val _=print(String.concat["\n\n\n lift probe:\n",P.printerE ein,"\n=>"])
277    
278                    val (a0,a1,a2)=liftProbe(y,b,params,args, index,sx)
279                    val (y,DstIL.EINAPP(ein,args))=a2
280                    val  body'=E.Sum(sx,Ein.body ein)
281                    val einapp=(y,DstIL.EINAPP(Ein.EIN{params=Ein.params ein, index=Ein.index ein, body=body'},args))
282                    val _=einapptostring (b,a1,einapp)
283                    in a0@[a1,einapp]
284                    end
285    *)
286                | E.Sum(sx,E.Probe e)
287                    => replaceProbe(y,b,E.Probe e,params,args, index, sx)
288                | E.Sum(sx,E.Prod[eps,E.Probe e])
289                    => replaceProbe(y,b,E.Probe e,params,args, index, sx)
290              | _=> [(y, DstIL.EINAPP(ein,args))]              | _=> [(y, DstIL.EINAPP(ein,args))]
291              (* end case *))              (* end case *))
292          in          in

Legend:
Removed from v.3030  
changed lines
  Added in v.3033

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0