Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2870, Wed Feb 25 21:47:43 2015 UTC revision 3261, Wed Sep 23 20:13:23 2015 UTC
# Line 14  Line 14 
14      structure P=Printer      structure P=Printer
15      structure T=TransformEin      structure T=TransformEin
16      structure MidToS=MidToString      structure MidToS=MidToString
17        structure DstV = DstIL.Var
18        structure DstTy = MidILTypes
19    
20      in      in
21    
22      (* This file expands probed fields      (* This file expands probed fields
# Line 35  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41      val cnt = ref 0      val testlift=0
42    
43        val cnt = ref 0
44      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
45      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
46        fun toStringBind e=print(MidToString.toStringBind e)
47        fun mkEin e=Ein.mkEin e
48        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
49    
50      fun testp n=(case testing      fun testp n=(case testing
51          of 0=> 1          of 0=> 1
52          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
53          (*end case*))          (*end case*))
54    
55    
56      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
57          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
58          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
# Line 56  Line 66 
66      returns the support of ther kernel, and image      returns the support of ther kernel, and image
67      *)      *)
68      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
69          of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
70              in              in
71                  ((Kernel.support h) ,img,ImageInfo.dim img)                  ((Kernel.support h) ,img,ImageInfo.dim img)
72              end              end
# Line 128  Line 138 
138                  end                  end
139              (* end case *))              (* end case *))
140          val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx          val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
141          val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),          val _ =print(String.concat["\n", "SumIndex" ,(String.concatWith"," aa),
142              "\nThink nshift is ", Int.toString nsumshift]              "\nThink nshift is ", Int.toString nsumshift])
143          in          in
144              nsumshift              nsumshift
145          end          end
# Line 142  Line 152 
152      | formBody(E.Prod [e])=e      | formBody(E.Prod [e])=e
153      | formBody e=e      | formBody e=e
154    
155        (* silly change in order of the product to match vis branch WorldtoSpace functions*)
156        fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
157          | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
158    
159    
160        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
161          | multiMergePs e=multiPs e
162    
163    
164      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
165              -> ein_exp* *code              -> ein_exp* *code
166      * Transforms position to world space      * Transforms position to world space
# Line 149  Line 168 
168      * rewrites body      * rewrites body
169      * replace probe with expanded version      * replace probe with expanded version
170      *)      *)
171      fun replaceProbe(b,params,args,index, sx)=let  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
172          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b  
173         fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
174            =let
175            val originalb=Ein.body e
176            val params=Ein.params e
177            val index=Ein.index e
178            val _ = print("\n***************** \n Replace ************ \n")
179            val _=  toStringBind (y, DstIL.EINAPP(e,args))
180    
181            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
182          val fid=length(params)          val fid=length(params)
183          val nid=fid+1          val nid=fid+1
184          val Pid=nid+1          val Pid=nid+1
# Line 160  Line 188 
188          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
189          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
190          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
191            val body' = multiPs(Ps,newsx1,body')
192    
193          (*silly change in order of product to match vis branch WorldtoSpace functions*)          val body'=(case originalb
194          val body' =(case Ps              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
195              of [_,_,_]=>        formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
196              | _ =>  formBody(E.Sum(newsx1, E.Prod([body']@Ps)))              | _                                  => body'
197              (*end case*))              (*end case*))
198    
199    
200          val args'=argsA@[PArg]          val args'=argsA@[PArg]
201            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
202          in          in
203              (body',params',args' ,code)              code@[einapp]
204          end          end
205    
206      (* expandEinOp: code->  code list      val tsplitvar=true
207      *Looks to see if the expression has a probe. If so, replaces it.      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
208      * Note how we keeps eps expressions so only generate pieces that are used          val Pid=0
209      *)          val tid=1
210      fun expandEinOp( e as (y, DstIL.EINAPP(ein as Ein.EIN{params, index, body}, args))) = let  
211          fun rewriteBody b=(case b          (*Assumes body is already clean*)
212              of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator. Argument needs to be applied in High-IL"          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
213              | E.Probe e =>let  
214            (*need to rewrite dx*)
215            val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
216                of []=> ([],index,E.Conv(9,alpha,7,newdx))
217                | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
218                (*end case*))
219    
220                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
221                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          fun filterAlpha []=[]
222                  val code=newbies@[einapp]            | filterAlpha(E.C _::es)= filterAlpha es
223              | filterAlpha(e1::es)=[e1]@(filterAlpha es)
224    
225            val tshape=filterAlpha(alpha')@newdx
226            val t=E.Tensor(tid,tshape)
227            val (splitvar,body)=(case originalb
228                of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
229                | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
230                | _                                  => (case tsplitvar
231                    of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
232                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
233                    (*end case*))
234                (*end case*))
235    
236            val ein0=mkEin(params,index,body)
237                  in                  in
238                      code              (splitvar,ein0,sizes,dx,alpha')
239                  end                  end
             | E.Sum(sx,E.Probe e)  =>let  
240    
241                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
242                  val  body'=E.Sum(sx,body')          val _=print("\n******* Lift ******** \n")
243                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val originalb=Ein.body e
244                  val code=newbies@[einapp]          val params=Ein.params e
245            val index=Ein.index e
246            val _=  toStringBind (y, DstIL.EINAPP(e,args))
247    
248            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
249            val fid=length(params)
250            val nid=fid+1
251            val nshift=length(dx)
252            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
253            val freshIndex=getsumshift(sx,index)
254    
255            (*transform T*P*P..Ps*)
256            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
257            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
258            val einApp0=mkEinApp(ein0,[PArg,FArg])
259            val rtn0=(case splitvar
260                of false => [(y,einApp0)]
261                | _      => Split.splitEinApp (y,einApp0)
262                (*end case*))
263    
264            (*lifted probe*)
265            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
266            val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
267            val ein1=mkEin(params',sizes,body')
268            val einApp1=mkEinApp(ein1,args')
269            val rtn1=(FArg,einApp1)
270            val rtn=code@[rtn1]@rtn0
271            val _= List.map toStringBind ([rtn1]@rtn0)
272    
273                  in                  in
274                      code              rtn
275                  end                  end
             | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let  
276    
277                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)      (* expandEinOp: code->  code list
278                  val  body'=E.Sum(sx,E.Prod[eps,body'])      * A this point we only have simple ein ops
279                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))      * Looks to see if the expression has a probe. If so, replaces it.
280                  val code=newbies@[einapp]      * Note how we keeps eps expressions so only generate pieces that are used
281        *)
282       fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
283    
284            fun checkConst ([],a) = liftProbe a
285            | checkConst ((E.C _::_),a) = replaceProbe a
286            | checkConst ((_ ::es),a)= checkConst(es,a)
287    
288            fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=
289                let
290    
291                    val _= toStringBind e
292                    val index0=Ein.index ein
293                    val index1 = index0@[3]
294                    val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)
295                     (* clean to get body indices in order *)
296                    val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
297                    val _=print(String.concat["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1])
298    
299                    val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
300                    val ein1 = mkEin(Ein.params ein,index1,body1)
301                    val code1= (lhs1,mkEinApp(ein1,args))
302                    val codeAll= liftProbe(1,code1,body1,[])
303    
304    
305                    (*Probe that tensor at a constant position E.C c1*)
306                    val param0 = [E.TEN(1,index1)]
307                    val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
308                    val body0 =  E.Tensor(0,[E.C c1]@nx)
309                    val ein0 = mkEin(param0,index0,body0)
310                    val einApp0 = mkEinApp(ein0,[lhs1])
311                    val code0 = (y,einApp0)
312                    val _= toStringBind code0
313                  in                  in
314                      code                  codeAll@[code0]
315                  end                  end
316    
317            fun rewriteBody b=(case b
318                of E.Probe(E.Conv(_,_,_,[]),_)
319                    => replaceProbe(0,e,b,[])
320                | E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos)
321                    => liftFieldMat (2,b)
322                | E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos)
323                    => liftFieldMat (3,b)
324                | E.Probe(E.Conv (_,alpha,_,dx),_)
325                    => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
326                | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
327                    => replaceProbe(0,e,p, sx)  (*no dx*)
328                | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))
329                    => checkConst(dx,(0,e,p,sx)) (*scalar field*)
330                | E.Sum(sx,E.Probe p)
331                    => replaceProbe(0,e,E.Probe p, sx)
332                | E.Sum(sx,E.Prod[eps,E.Probe p])
333                    => replaceProbe(0,e,E.Probe p,sx)
334              | _=> [e]              | _=> [e]
335              (* end case *))              (* end case *))
336          in  
337              rewriteBody body          val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
338            in  (case var
339            of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset))
340                | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)] , fieldset))
341                (*end case*))
342          end          end
343    
344    end; (* local *)    end; (* local *)

Legend:
Removed from v.2870  
changed lines
  Added in v.3261

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0