Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2870, Wed Feb 25 21:47:43 2015 UTC revision 3060, Fri Mar 13 22:12:44 2015 UTC
# Line 14  Line 14 
14      structure P=Printer      structure P=Printer
15      structure T=TransformEin      structure T=TransformEin
16      structure MidToS=MidToString      structure MidToS=MidToString
17        structure DstV = DstIL.Var
18        structure DstTy = MidILTypes
19    
20      in      in
21    
22      (* This file expands probed fields      (* This file expands probed fields
# Line 35  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41        val testlift=1
42      val cnt = ref 0      val cnt = ref 0
43    
44        fun printEINAPP e=MidToString.printEINAPP e
45      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
46      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
47    
48        fun transitionToString(testreplace,a,b)=(case testreplace
49            of 0=> 1
50            | _ => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)
51            (*end case*))
52        fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}
53        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
54        fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body
55        fun setBody(body',(y,DstIL.EINAPP(E.EIN{params,index,body},args)))=
56                (y,DstIL.EINAPP(E.EIN{params=params,index=index,body=body'},args))
57    
58      fun testp n=(case testing      fun testp n=(case testing
59          of 0=> 1          of 0=> 1
60          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
61          (*end case*))          (*end case*))
62        fun  einapptostring (body,a,b)=(case testlift
63            of 0=>1
64            | _=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)
65            (*end case*))
66    
67    
68      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
69          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
70          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
# Line 56  Line 78 
78      returns the support of ther kernel, and image      returns the support of ther kernel, and image
79      *)      *)
80      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
81          of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
82              in              in
83                  ((Kernel.support h) ,img,ImageInfo.dim img)                  ((Kernel.support h) ,img,ImageInfo.dim img)
84              end              end
# Line 142  Line 164 
164      | formBody(E.Prod [e])=e      | formBody(E.Prod [e])=e
165      | formBody e=e      | formBody e=e
166    
167        (* silly change in order of the product to match vis branch WorldtoSpace functions*)
168        fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
169          | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
170    
171      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
172              -> ein_exp* *code              -> ein_exp* *code
173      * Transforms position to world space      * Transforms position to world space
# Line 149  Line 175 
175      * rewrites body      * rewrites body
176      * replace probe with expanded version      * replace probe with expanded version
177      *)      *)
178      fun replaceProbe(b,params,args,index, sx)=let  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
179          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b  
180         fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
181            =let
182            val originalb=Ein.body e
183            val params=Ein.params e
184            val index=Ein.index e
185    
186    
187            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
188          val fid=length(params)          val fid=length(params)
189          val nid=fid+1          val nid=fid+1
190          val Pid=nid+1          val Pid=nid+1
# Line 160  Line 194 
194          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
195          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
196          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
197            val body' = multiPs(Ps,newsx1,body')
198    
199          (*silly change in order of product to match vis branch WorldtoSpace functions*)          val body'=(case originalb
200          val body' =(case Ps              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
201              of [_,_,_]=>        formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
202              | _ =>  formBody(E.Sum(newsx1, E.Prod([body']@Ps)))              | _                                  => body'
203              (*end case*))              (*end case*))
204            val _=transitionToString(testN,originalb,body')
205    
206          val args'=argsA@[PArg]          val args'=argsA@[PArg]
207            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
208          in          in
209              (body',params',args' ,code)              code@[einapp]
210          end          end
211    
     (* expandEinOp: code->  code list  
     *Looks to see if the expression has a probe. If so, replaces it.  
     * Note how we keeps eps expressions so only generate pieces that are used  
     *)  
     fun expandEinOp( e as (y, DstIL.EINAPP(ein as Ein.EIN{params, index, body}, args))) = let  
         fun rewriteBody b=(case b  
             of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator. Argument needs to be applied in High-IL"  
             | E.Probe e =>let  
212    
213                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
214                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val Pid=0
215                  val code=newbies@[einapp]          val tid=1
216    
217                  in          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
                     code  
                 end  
             | E.Sum(sx,E.Probe e)  =>let  
218    
219                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)          (*need to rewrite dx*)
220                  val  body'=E.Sum(sx,body')          val (_,sizes,E.Conv(_,_,_,dx))=(case sx@newsx
221                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))              of []=> ([],index,E.Conv(9,alpha,7,newdx))
222                  val code=newbies@[einapp]              | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
223                (*end case*))
224    
225            val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
226            val tshape=alpha@newdx
227            val t=E.Tensor(tid,tshape)
228            val exp = multiPs(Ps,newsx,t)
229            val body=(case originalb
230                of E.Sum(sx, E.Probe _)              => E.Sum(sx,exp)
231                | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,exp])
232                | _                                  => exp
233                (*end case*))
234    
235            val ein0=mkEin(params,index,body)
236                  in                  in
237                      code              (ein0,sizes,dx)
238                  end                  end
             | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let  
239    
240                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)      fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let
241                  val  body'=E.Sum(sx,E.Prod[eps,body'])          val originalb=Ein.body e
242                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val params=Ein.params e
243                  val code=newbies@[einapp]          val index=Ein.index e
244    
245            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
246            val fid=length(params)
247            val nid=fid+1
248            val nshift=length(dx)
249            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
250            val freshIndex=getsumshift(sx,index)
251    
252    
253            (*transform T*P*P..Ps*)
254            val (ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
255            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
256            val einApp0=mkEinApp(ein0,[PArg,FArg])
257            val rtn0=(y,einApp0)
258    
259            (*lifted probe*)
260            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
261            val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
262            val ein1=mkEin(params',sizes,body')
263            val einApp1=mkEinApp(ein1,args')
264            val rtn1=(FArg,einApp1)
265            val rtn=code@[rtn1,rtn0]
266            val _= einapptostring (p,rtn1,rtn0)
267                  in                  in
268                      code              rtn
269                  end                  end
270    
271    
272        (* expandEinOp: code->  code list
273        *A this point we only have simple ein ops
274        *Looks to see if the expression has a probe. If so, replaces it.
275        * Note how we keeps eps expressions so only generate pieces that are used
276        *)
277        fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let
278            fun rewriteBody b=(case b
279                of E.Probe(E.Conv(_,_,_,[]),_)
280                    => replaceProbe(0,e,b, [])
281                | E.Probe(E.Conv _,_)
282                    => liftProbe(0,e,b,[])
283                | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
284                    => replaceProbe(0,e,p, sx)
285                | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,_),_))
286                    => liftProbe(0,e,p,sx)
287                | E.Sum(sx,E.Probe p)
288                    => replaceProbe(1,e,E.Probe p, sx)
289                | E.Sum(sx,E.Prod[eps,E.Probe p])
290                    => replaceProbe(1,e,E.Probe p,sx)
291              | _=> [e]              | _=> [e]
292              (* end case *))              (* end case *))
293          in          in
294              rewriteBody body              rewriteBody (Ein.body ein)
295          end          end
296    
297    end; (* local *)    end; (* local *)

Legend:
Removed from v.2870  
changed lines
  Added in v.3060

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0