Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 2870, Wed Feb 25 21:47:43 2015 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3312, Fri Oct 16 20:11:52 2015 UTC
# Line 14  Line 14 
14      structure P=Printer      structure P=Printer
15      structure T=TransformEin      structure T=TransformEin
16      structure MidToS=MidToString      structure MidToS=MidToString
17        structure DstV = DstIL.Var
18        structure DstTy = MidILTypes
19    
20      in      in
21    
22      (* This file expands probed fields      (* This file expands probed fields
# Line 35  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41      val cnt = ref 0      val testlift=0
42        val detflag =true
43    
44    
45        val cnt = ref 0
46      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
47      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
48        fun toStringBind e=(MidToString.toStringBind e)
49        fun mkEin e=Ein.mkEin e
50        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
51    
52      fun testp n=(case testing      fun testp n=(case testing
53          of 0=> 1          of 0=> 1
54          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
55          (*end case*))          (*end case*))
56    
57    
58      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
59          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
60          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
# Line 56  Line 68 
68      returns the support of ther kernel, and image      returns the support of ther kernel, and image
69      *)      *)
70      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
71          of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
72              in              in
73                  ((Kernel.support h) ,img,ImageInfo.dim img)                  ((Kernel.support h) ,img,ImageInfo.dim img)
74              end              end
# Line 142  Line 154 
154      | formBody(E.Prod [e])=e      | formBody(E.Prod [e])=e
155      | formBody e=e      | formBody e=e
156    
157        (* silly change in order of the product to match vis branch WorldtoSpace functions*)
158        fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
159          | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
160    
161    
162        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
163          | multiMergePs e=multiPs e
164    
165    
166      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
167              -> ein_exp* *code              -> ein_exp* *code
168      * Transforms position to world space      * Transforms position to world space
# Line 149  Line 170 
170      * rewrites body      * rewrites body
171      * replace probe with expanded version      * replace probe with expanded version
172      *)      *)
173      fun replaceProbe(b,params,args,index, sx)=let  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
174          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b  
175         fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
176            =let
177            val originalb=Ein.body e
178            val params=Ein.params e
179            val index=Ein.index e
180            val _ = testp["\n***************** \n Replace ************ \n"]
181            val _=  toStringBind (y, DstIL.EINAPP(e,args))
182    
183            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
184          val fid=length(params)          val fid=length(params)
185          val nid=fid+1          val nid=fid+1
186          val Pid=nid+1          val Pid=nid+1
# Line 160  Line 190 
190          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
191          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
192          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
193            val body' = multiPs(Ps,newsx1,body')
194    
195          (*silly change in order of product to match vis branch WorldtoSpace functions*)          val body'=(case originalb
196          val body' =(case Ps              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
197              of [_,_,_]=>        formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
198              | _ =>  formBody(E.Sum(newsx1, E.Prod([body']@Ps)))              | _                                  => body'
199              (*end case*))              (*end case*))
200    
201    
202          val args'=argsA@[PArg]          val args'=argsA@[PArg]
203            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
204          in          in
205              (body',params',args' ,code)              code@[einapp]
206          end          end
207    
208      (* expandEinOp: code->  code list      val tsplitvar=true
209      *Looks to see if the expression has a probe. If so, replaces it.      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
210      * Note how we keeps eps expressions so only generate pieces that are used          val Pid=0
211      *)          val tid=1
     fun expandEinOp( e as (y, DstIL.EINAPP(ein as Ein.EIN{params, index, body}, args))) = let  
         fun rewriteBody b=(case b  
             of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator. Argument needs to be applied in High-IL"  
             | E.Probe e =>let  
212    
213                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])          (*Assumes body is already clean*)
214                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
                 val code=newbies@[einapp]  
215    
216                  in          (*need to rewrite dx*)
217                      code          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
218                  end              of []=> ([],index,E.Conv(9,alpha,7,newdx))
219              | E.Sum(sx,E.Probe e)  =>let              | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
220                (*end case*))
221    
222                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
223                  val  body'=E.Sum(sx,body')          fun filterAlpha []=[]
224                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))            | filterAlpha(E.C _::es)= filterAlpha es
225                  val code=newbies@[einapp]            | filterAlpha(e1::es)=[e1]@(filterAlpha es)
226    
227            val tshape=filterAlpha(alpha')@newdx
228            val t=E.Tensor(tid,tshape)
229            val (splitvar,body)=(case originalb
230                of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
231                | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
232                | _                                  => (case tsplitvar
233                    of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
234                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
235                    (*end case*))
236                (*end case*))
237    
238            val ein0=mkEin(params,index,body)
239                  in                  in
240                      code              (splitvar,ein0,sizes,dx,alpha')
241                  end                  end
             | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let  
242    
243                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
244                  val  body'=E.Sum(sx,E.Prod[eps,body'])          val _=testp["\n******* Lift ******** \n"]
245                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val originalb=Ein.body e
246                  val code=newbies@[einapp]          val params=Ein.params e
247            val index=Ein.index e
248            val _=  toStringBind (y, DstIL.EINAPP(e,args))
249    
250            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
251            val fid=length(params)
252            val nid=fid+1
253            val nshift=length(dx)
254            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
255            val freshIndex=getsumshift(sx,index)
256    
257            (*transform T*P*P..Ps*)
258            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
259            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
260            val einApp0=mkEinApp(ein0,[PArg,FArg])
261            val rtn0=(case splitvar
262                of false => [(y,einApp0)]
263                | _      => Split.splitEinApp (y,einApp0)
264                (*end case*))
265    
266            (*lifted probe*)
267            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
268            val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
269            val ein1=mkEin(params',sizes,body')
270            val einApp1=mkEinApp(ein1,args')
271            val rtn1=(FArg,einApp1)
272            val rtn=code@[rtn1]@rtn0
273            val _= List.map toStringBind ([rtn1]@rtn0)
274    
275                  in                  in
276                      code              rtn
277                  end                  end
278              | _=> [e]  
279        (* expandEinOp: code->  code list
280        * A this point we only have simple ein ops
281        * Looks to see if the expression has a probe. If so, replaces it.
282        * Note how we keeps eps expressions so only generate pieces that are used
283        *)
284       fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
285    
286            fun checkConst ([],a) = liftProbe a
287            | checkConst ((E.C _::_),a) = replaceProbe a
288            | checkConst ((_ ::es),a)= checkConst(es,a)
289    
290            fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=
291                let
292    
293                    val _= toStringBind e
294                    val index0=Ein.index ein
295                    val index1 = index0@[3]
296                    val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)
297                     (* clean to get body indices in order *)
298                    val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
299                    val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
300    
301                    val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
302                    val ein1 = mkEin(Ein.params ein,index1,body1)
303                    val code1= (lhs1,mkEinApp(ein1,args))
304                    val codeAll= (case dx
305                        of []=> replaceProbe(1,code1,body1,[])
306                        | _ =>liftProbe(1,code1,body1,[])
307                    (*end case*))
308    
309                    (*Probe that tensor at a constant position E.C c1*)
310                    val param0 = [E.TEN(1,index1)]
311                    val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
312                    val body0 =  E.Tensor(0,[E.C c1]@nx)
313                    val ein0 = mkEin(param0,index0,body0)
314                    val einApp0 = mkEinApp(ein0,[lhs1])
315                    val code0 = (y,einApp0)
316                    val _= toStringBind code0
317                in
318                    codeAll@[code0]
319                end
320    
321            fun rewriteBody b=(case (detflag,b)
322                of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
323                    => liftFieldMat (1,b)
324                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
325                    => liftFieldMat (2,b)
326                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
327                    => liftFieldMat (3,b)
328                | (_,E.Probe(E.Conv(_,_,_,[]),_))
329                    => replaceProbe(0,e,b,[])
330                | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
331                    => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
332                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
333                    => replaceProbe(0,e,p, sx)  (*no dx*)
334                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
335                    => checkConst(dx,(0,e,p,sx)) (*scalar field*)
336                | (_,E.Sum(sx,E.Probe p))
337                    => replaceProbe(0,e,E.Probe p, sx)
338                | (_,E.Sum(sx,E.Prod[eps,E.Probe p]))
339                    => replaceProbe(0,e,E.Probe p,sx)
340                | (_,_) => [e]
341                (* end case *))
342    
343            val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
344    
345            fun matchField b=(case b
346                of E.Probe _ => 1
347                | E.Sum (_, E.Probe _)=>1
348                | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
349                | _ =>0
350                (*end case*))
351    
352            in  (case var
353                of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
354                | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
355              (* end case *))              (* end case *))
         in  
             rewriteBody body  
356          end          end
357    
358    end; (* local *)    end; (* local *)

Legend:
Removed from v.2870  
changed lines
  Added in v.3312

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0