Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3033, Tue Mar 10 15:17:25 2015 UTC revision 3048, Wed Mar 11 20:00:27 2015 UTC
# Line 38  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41        val testlift=0
42      val cnt = ref 0      val cnt = ref 0
43    
44      fun printEINAPP e=MidToString.printEINAPP e      fun printEINAPP e=MidToString.printEINAPP e
45      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
46      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
47    
48      fun transitionToString(a,b)=      fun transitionToString(testreplace,a,b)=(case testreplace
49          print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b])          of 0=> 1
50            | _ => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)
51            (*end case*))
52      fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}      fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}
53      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
54      fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body      fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body
# Line 57  Line 59 
59          of 0=> 1          of 0=> 1
60          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
61          (*end case*))          (*end case*))
62        fun  einapptostring (body,a,b)=(case testlift
63            of 0=>1
64            | _=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)
65            (*end case*))
66    
67    
68      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
69          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
70          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
# Line 167  Line 175 
175      * rewrites body      * rewrites body
176      * replace probe with expanded version      * replace probe with expanded version
177      *)      *)
178      fun replaceProbe(y,originalb,b,params,args,index, sx)=let  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
179          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b  
180         fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
181            =let
182            val originalb=Ein.body e
183            val params=Ein.params e
184            val index=Ein.index e
185    
186    
187            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
188          val fid=length(params)          val fid=length(params)
189          val nid=fid+1          val nid=fid+1
190          val Pid=nid+1          val Pid=nid+1
# Line 185  Line 201 
201              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
202              | _                                  => body'              | _                                  => body'
203              (*end case*))              (*end case*))
204          val _=transitionToString(originalb,body')          val _=transitionToString(testN,originalb,body')
205    
206          val args'=argsA@[PArg]          val args'=argsA@[PArg]
207          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
# Line 194  Line 210 
210          end          end
211    
212    
213      fun shapeoflifted (tshape,index,sx)=let      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
         val _ =List.map (fn (E.V e)=>print(Int.toString(e)^"-")) tshape  
         val sizeMapp= cleanIndex.mkSizeMapp(index,sx)  
         val sizes= List.map(fn E.V e1=> cleanIndex.lkupIX(e1,sizeMapp,"Could not find Size of")) tshape  
         in sizes end  
   
     fun createEinApp(alpha,index,freshIndex,dim,dx,sx)= let  
214          val Pid=0          val Pid=0
215          val tid=1          val tid=1
216    
217          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
218          val params=[E.TEN(1,[dim,dim]),E.TEN(1,index)]  
219            (*need to rewrite dx*)
220            val (_,sizes,E.Conv(_,_,_,dx))=(case sx@newsx
221                of []=> ([],index,E.Conv(9,alpha,7,newdx))
222                | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
223                (*end case*))
224    
225            val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
226          val tshape=alpha@newdx          val tshape=alpha@newdx
227          val t=E.Tensor(tid,tshape)          val t=E.Tensor(tid,tshape)
228          val body = multiPs(Ps,newsx,t)          val exp = multiPs(Ps,newsx,t)
229          val rator=mkEin(params,index,body)          val body=(case originalb
230          val sizes= (case sx              of E.Sum(sx, E.Probe _)              => E.Sum(sx,exp)
231              of []=> index              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,exp])
232              | _=>(*shapeoflifted (tshape,index,sx@newsx)*) index              | _                                  => exp
233              )              (*end case*))
234    
235            val ein0=mkEin(params,index,body)
236          in          in
237              (rator,sizes)              (ein0,sizes,dx)
238          end          end
239    
240      fun  einapptostring (body,a,b)=      fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let
241          print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b])          val originalb=Ein.body e
242            val params=Ein.params e
243            val index=Ein.index e
244    
245      fun liftProbe(y,b,params,args,index, sx)=let          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
         val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b  
246          val fid=length(params)          val fid=length(params)
247          val nid=fid+1          val nid=fid+1
248          val nshift=length(dx)          val nshift=length(dx)
249          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
250          val freshIndex=getsumshift(sx,index)          val freshIndex=getsumshift(sx,index)
251    
252    
253          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
254          val FArg  = DstV.new ("F", DstTy.TensorTy(index))          val (ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
255          val (rator,sizes)=createEinApp(alpha,index,freshIndex,dim,dx,sx)          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
256          val einApp0=mkEinApp(rator,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
257            val rtn0=(y,einApp0)
258    
259          (*lifted probe*)          (*lifted probe*)
260          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
261          val dxinner=(*List.tabulate(length(dx), (fn e=>E.V e))*) dx          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
         val body' = createBody(dim, s,freshIndex+nshift,alpha,dxinner,Vid, hid, nid, fid)  
         val args'=argsA  
   
262          val ein1=mkEin(params',sizes,body')          val ein1=mkEin(params',sizes,body')
263          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
264            val rtn1=(FArg,einApp1)
265          val rtn=(code,(FArg,einApp1),(y,einApp0))          val rtn=code@[rtn1,rtn0]
266            val _= einapptostring (p,rtn1,rtn0)
267          in          in
268              rtn              rtn
269          end          end
270    
271    
   
272      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
273        *A this point we only have simple ein ops
274      *Looks to see if the expression has a probe. If so, replaces it.      *Looks to see if the expression has a probe. If so, replaces it.
275      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
276      *)      *)
277      fun expandEinOp( e as (y, DstIL.EINAPP(ein as Ein.EIN{params, index, body}, args))) = let      fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let
278          fun rewriteBody b=(case b          fun rewriteBody b=(case b
279              of E.Probe(E.Conv(_,_,_,[]),_)              of E.Probe(E.Conv(_,_,_,[]),_)
280                  => replaceProbe(y,b,b,params,args, index, [])                  => replaceProbe(0,e,b, [])
281              | E.Probe(E.Conv _,_) =>let              | E.Probe(E.Conv _,_)
282                  val (a0,a1,a2)=liftProbe(y,b,params,args, index, [])                  => liftProbe(0,e,b,[])
                 val _= einapptostring (b,a1,a2)  
                 in  
                     a0@[a1,a2]  
                 end  
283             | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))             | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
284                  => replaceProbe(y,b,p,params,args, index, sx)                  => replaceProbe(0,e,p, sx)
285                | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,_),_))
286          (*                  => liftProbe(0,e,p,sx)
287              | E.Sum(sx,b as E.Probe(E.Conv(_,[],_,_),_))  =>let              | E.Sum(sx,E.Probe p)
288                  val _=print(String.concat["\n\n\n lift probe:\n",P.printerE ein,"\n=>"])                  => replaceProbe(1,e,E.Probe p, sx)
289                | E.Sum(sx,E.Prod[eps,E.Probe p])
290                  val (a0,a1,a2)=liftProbe(y,b,params,args, index,sx)                  => replaceProbe(1,e,E.Probe p,sx)
291                  val (y,DstIL.EINAPP(ein,args))=a2              | _ => [e]
                 val  body'=E.Sum(sx,Ein.body ein)  
                 val einapp=(y,DstIL.EINAPP(Ein.EIN{params=Ein.params ein, index=Ein.index ein, body=body'},args))  
                 val _=einapptostring (b,a1,einapp)  
                 in a0@[a1,einapp]  
                 end  
 *)  
             | E.Sum(sx,E.Probe e)  
                 => replaceProbe(y,b,E.Probe e,params,args, index, sx)  
             | E.Sum(sx,E.Prod[eps,E.Probe e])  
                 => replaceProbe(y,b,E.Probe e,params,args, index, sx)  
             | _=> [(y, DstIL.EINAPP(ein,args))]  
292              (* end case *))              (* end case *))
293          in          in
294              rewriteBody body              rewriteBody (Ein.body ein)
295          end          end
296    
297    end; (* local *)    end; (* local *)

Legend:
Removed from v.3033  
changed lines
  Added in v.3048

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0