Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 2870, Wed Feb 25 21:47:43 2015 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3318, Sat Oct 17 04:50:28 2015 UTC
# Line 14  Line 14 
14      structure P=Printer      structure P=Printer
15      structure T=TransformEin      structure T=TransformEin
16      structure MidToS=MidToString      structure MidToS=MidToString
17        structure DstV = DstIL.Var
18        structure DstTy = MidILTypes
19    
20      in      in
21    
22      (* This file expands probed fields      (* This file expands probed fields
# Line 35  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41      val cnt = ref 0      val testlift=0
42        val detflag =true
43        val fieldliftflag=true
44        val valnumflag=true
45    
46    
47        val cnt = ref 0
48      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
49      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
50        fun toStringBind e=(MidToString.toStringBind e)
51        fun mkEin e=Ein.mkEin e
52        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
53    
54      fun testp n=(case testing      fun testp n=(case testing
55          of 0=> 1          of 0=> 1
56          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
57          (*end case*))          (*end case*))
58    
59    
60      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
61          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
62          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
# Line 56  Line 70 
70      returns the support of ther kernel, and image      returns the support of ther kernel, and image
71      *)      *)
72      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
73          of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
74              in              in
75                  ((Kernel.support h) ,img,ImageInfo.dim img)                  ((Kernel.support h) ,img,ImageInfo.dim img)
76              end              end
# Line 142  Line 156 
156      | formBody(E.Prod [e])=e      | formBody(E.Prod [e])=e
157      | formBody e=e      | formBody e=e
158    
159        (* silly change in order of the product to match vis branch WorldtoSpace functions*)
160        fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
161          | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
162    
163    
164        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
165          | multiMergePs e=multiPs e
166    
167    
168      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
169              -> ein_exp* *code              -> ein_exp* *code
170      * Transforms position to world space      * Transforms position to world space
# Line 149  Line 172 
172      * rewrites body      * rewrites body
173      * replace probe with expanded version      * replace probe with expanded version
174      *)      *)
175      fun replaceProbe(b,params,args,index, sx)=let  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
176          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b  
177         fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
178            =let
179            val originalb=Ein.body e
180            val params=Ein.params e
181            val index=Ein.index e
182            val _ = testp["\n***************** \n Replace ************ \n"]
183            val _=  toStringBind (y, DstIL.EINAPP(e,args))
184    
185            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
186          val fid=length(params)          val fid=length(params)
187          val nid=fid+1          val nid=fid+1
188          val Pid=nid+1          val Pid=nid+1
# Line 160  Line 192 
192          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
193          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
194          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
195            val body' = multiPs(Ps,newsx1,body')
196    
197          (*silly change in order of product to match vis branch WorldtoSpace functions*)          val body'=(case originalb
198          val body' =(case Ps              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
199              of [_,_,_]=>        formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
200              | _ =>  formBody(E.Sum(newsx1, E.Prod([body']@Ps)))              | _                                  => body'
201              (*end case*))              (*end case*))
202    
203    
204          val args'=argsA@[PArg]          val args'=argsA@[PArg]
205            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
206          in          in
207              (body',params',args' ,code)              code@[einapp]
208          end          end
209    
210      (* expandEinOp: code->  code list      val tsplitvar=true
211      *Looks to see if the expression has a probe. If so, replaces it.      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
212      * Note how we keeps eps expressions so only generate pieces that are used          val Pid=0
213      *)          val tid=1
     fun expandEinOp( e as (y, DstIL.EINAPP(ein as Ein.EIN{params, index, body}, args))) = let  
         fun rewriteBody b=(case b  
             of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator. Argument needs to be applied in High-IL"  
             | E.Probe e =>let  
214    
215                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])          (*Assumes body is already clean*)
216                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
                 val code=newbies@[einapp]  
217    
218                  in          (*need to rewrite dx*)
219                      code          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
220                  end              of []=> ([],index,E.Conv(9,alpha,7,newdx))
221              | E.Sum(sx,E.Probe e)  =>let              | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
222                (*end case*))
223    
224                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
225                  val  body'=E.Sum(sx,body')          fun filterAlpha []=[]
226                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))            | filterAlpha(E.C _::es)= filterAlpha es
227                  val code=newbies@[einapp]            | filterAlpha(e1::es)=[e1]@(filterAlpha es)
228    
229            val tshape=filterAlpha(alpha')@newdx
230            val t=E.Tensor(tid,tshape)
231            val (splitvar,body)=(case originalb
232                of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
233                | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
234                | _                                  => (case tsplitvar
235                    of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
236                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
237                    (*end case*))
238                (*end case*))
239    
240            val ein0=mkEin(params,index,body)
241                  in                  in
242                      code              (splitvar,ein0,sizes,dx,alpha')
243                  end                  end
             | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let  
244    
245                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
246                  val  body'=E.Sum(sx,E.Prod[eps,body'])          val _=testp["\n******* Lift ******** \n"]
247                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val originalb=Ein.body e
248                  val code=newbies@[einapp]          val params=Ein.params e
249            val index=Ein.index e
250            val _=  toStringBind (y, DstIL.EINAPP(e,args))
251    
252            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
253            val fid=length(params)
254            val nid=fid+1
255            val nshift=length(dx)
256            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
257            val freshIndex=getsumshift(sx,index)
258    
259            (*transform T*P*P..Ps*)
260            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
261            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
262            val einApp0=mkEinApp(ein0,[PArg,FArg])
263            val rtn0=(case splitvar
264                of false => [(y,einApp0)]
265                | _      => Split.splitEinApp (y,einApp0)
266                (*end case*))
267    
268            (*lifted probe*)
269            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
270            val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
271            val ein1=mkEin(params',sizes,body')
272            val einApp1=mkEinApp(ein1,args')
273            val rtn1=(FArg,einApp1)
274            val rtn=code@[rtn1]@rtn0
275            val _= List.map toStringBind ([rtn1]@rtn0)
276    
277                  in                  in
278                      code              rtn
279                  end                  end
280              | _=> [e]  
281        (* expandEinOp: code->  code list
282        * A this point we only have simple ein ops
283        * Looks to see if the expression has a probe. If so, replaces it.
284        * Note how we keeps eps expressions so only generate pieces that are used
285        *)
286       fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
287    
288            fun checkConst ([],a) =
289                (case fieldliftflag
290                    of true => liftProbe a
291                    | _ => replaceProbe a
292                (*end case*))
293            | checkConst ((E.C _::_),a) = replaceProbe a
294            | checkConst ((_ ::es),a)= checkConst(es,a)
295    
296            fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=
297                let
298    
299                    val _= toStringBind e
300                    val index0=Ein.index ein
301                    val index1 = index0@[3]
302                    val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)
303                     (* clean to get body indices in order *)
304                    val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
305                    val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
306    
307                    val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
308                    val ein1 = mkEin(Ein.params ein,index1,body1)
309                    val code1= (lhs1,mkEinApp(ein1,args))
310                    val codeAll= (case dx
311                        of []=> replaceProbe(1,code1,body1,[])
312                        | _ =>liftProbe(1,code1,body1,[])
313                    (*end case*))
314    
315                    (*Probe that tensor at a constant position E.C c1*)
316                    val param0 = [E.TEN(1,index1)]
317                    val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
318                    val body0 =  E.Tensor(0,[E.C c1]@nx)
319                    val ein0 = mkEin(param0,index0,body0)
320                    val einApp0 = mkEinApp(ein0,[lhs1])
321                    val code0 = (y,einApp0)
322                    val _= toStringBind code0
323                in
324                    codeAll@[code0]
325                end
326    
327            fun rewriteBody b=(case (detflag,b)
328                of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
329                    => liftFieldMat (1,b)
330                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
331                    => liftFieldMat (2,b)
332                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
333                    => liftFieldMat (3,b)
334                | (_,E.Probe(E.Conv(_,_,_,[]),_))
335                    => replaceProbe(0,e,b,[])
336                | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
337                    => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
338                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
339                    => replaceProbe(0,e,p, sx)  (*no dx*)
340                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
341                    => checkConst(dx,(0,e,p,sx)) (*scalar field*)
342                | (_,E.Sum(sx,E.Probe p))
343                    => replaceProbe(0,e,E.Probe p, sx)
344                | (_,E.Sum(sx,E.Prod[eps,E.Probe p]))
345                    => replaceProbe(0,e,E.Probe p,sx)
346                | (_,_) => [e]
347                (* end case *))
348    
349            val (fieldset,var) = (case valnumflag
350                of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
351                | _     => (fieldset,NONE)
352            (*end case*))
353    
354            fun matchField b=(case b
355                of E.Probe _ => 1
356                | E.Sum (_, E.Probe _)=>1
357                | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
358                | _ =>0
359                (*end case*))
360    
361            in  (case var
362                of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
363                | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
364              (* end case *))              (* end case *))
         in  
             rewriteBody body  
365          end          end
366    
367    end; (* local *)    end; (* local *)

Legend:
Removed from v.2870  
changed lines
  Added in v.3318

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0