Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 2922, Tue Mar 3 03:55:09 2015 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3362, Sun Nov 1 18:26:02 2015 UTC
# Line 1  Line 1 
1  (* Expands probe ein  (* Expands probe ein
2   *   *
3   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4     *
5     * COPYRIGHT (c) 2015 The University of Chicago
6   * All rights reserved.   * All rights reserved.
7   *)   *)
8    
# Line 14  Line 16 
16      structure P=Printer      structure P=Printer
17      structure T=TransformEin      structure T=TransformEin
18      structure MidToS=MidToString      structure MidToS=MidToString
19        structure DstV = DstIL.Var
20        structure DstTy = MidILTypes
21    
22      in      in
23    
24      (* This file expands probed fields      (* This file expands probed fields
# Line 34  Line 39 
39      *img-imginfo about V      *img-imginfo about V
40      *)      *)
41    
42      val testing=1      val testing=0
43      val cnt = ref 0      val testlift=1
44        val detflag =true
45        val fieldliftflag=true
46        val valnumflag=true
47    
48    
49        val cnt = ref 0
50      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
51      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
52        fun toStringBind e=(MidToString.toStringBind e)
53        fun mkEin e=Ein.mkEin e
54        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
55    
56      fun testp n=(case testing      fun testp n=(case testing
57          of 0=> 1          of 0=> 1
58          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
59          (*end case*))          (*end case*))
60    
61    
62      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
63          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)          of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
64          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'          | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
# Line 56  Line 72 
72      returns the support of ther kernel, and image      returns the support of ther kernel, and image
73      *)      *)
74      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)      fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
75          of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
76              in              in
77                  ((Kernel.support h) ,img,ImageInfo.dim img)                  ((Kernel.support h) ,img,ImageInfo.dim img)
78              end              end
# Line 142  Line 158 
158      | formBody(E.Prod [e])=e      | formBody(E.Prod [e])=e
159      | formBody e=e      | formBody e=e
160    
161        (* silly change in order of the product to match vis branch WorldtoSpace functions*)
162        fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
163        (*
164          | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, E.Prod([P0,body,P1])))
165          *)
166          | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
167    
168    
169        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
170          | multiMergePs e=multiPs e
171    
172    
173      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
174              -> ein_exp* *code              -> ein_exp* *code
175      * Transforms position to world space      * Transforms position to world space
# Line 149  Line 177 
177      * rewrites body      * rewrites body
178      * replace probe with expanded version      * replace probe with expanded version
179      *)      *)
180      fun replaceProbe(b,params,args,index, sx)=let  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
181          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b  
182         fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
183            =let
184            val originalb=Ein.body e
185            val params=Ein.params e
186            val index=Ein.index e
187            val _ = testp["\n***************** \n Replace ************ \n"]
188            val _=  toStringBind (y, DstIL.EINAPP(e,args))
189    
190            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
191          val fid=length(params)          val fid=length(params)
192          val nid=fid+1          val nid=fid+1
193          val Pid=nid+1          val Pid=nid+1
# Line 160  Line 197 
197          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
198          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
199          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
200            val body' = multiPs(Ps,newsx1,body')
201    
202          (*silly change in order of product to match vis branch WorldtoSpace functions*)          val body'=(case originalb
203          val body' =(case Ps              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
204              of [_,_,_]=>        formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
205              | _ =>  formBody(E.Sum(newsx1, E.Prod([body']@Ps)))              | _                                  => body'
206              (*end case*))              (*end case*))
207    
208    
209          val args'=argsA@[PArg]          val args'=argsA@[PArg]
210            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
211          in          in
212              (body',params',args' ,code)              code@[einapp]
213          end          end
214    
215      (* expandEinOp: code->  code list      val tsplitvar=true
216      *Looks to see if the expression has a probe. If so, replaces it.      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
217      * Note how we keeps eps expressions so only generate pieces that are used          val Pid=0
218      *)          val tid=1
219      fun expandEinOp( e as (y, DstIL.EINAPP(einorig, args))) = let  
220          val ein=SummationEin.cleanSummation(einorig)          (*Assumes body is already clean*)
221          val Ein.EIN{params, index, body}=ein          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
         fun rewriteBody b=(case b  
             of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator. Argument needs to be applied in High-IL"  
             | E.Probe e =>let  
222    
223                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])          (*need to rewrite dx*)
224                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
225                  val code=newbies@[einapp]              of []=> ([],index,E.Conv(9,alpha,7,newdx))
226                | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
227                (*end case*))
228    
229            val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
230            fun filterAlpha []=[]
231              | filterAlpha(E.C _::es)= filterAlpha es
232              | filterAlpha(e1::es)=[e1]@(filterAlpha es)
233    
234            val tshape=filterAlpha(alpha')@newdx
235            val t=E.Tensor(tid,tshape)
236            val (splitvar,body)=(case originalb
237                of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
238                | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
239                | _                                  => (case tsplitvar
240                    of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
241                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
242                    (*end case*))
243                (*end case*))
244    
245            val _ =(case splitvar
246            of true=> (String.concat["splitvar is true", P.printbody body])
247            | _ => (String.concat["splitvar is false",P.printbody body])
248            (*end case*))
249    
250    
251            val ein0=mkEin(params,index,body)
252                  in                  in
253                      code              (splitvar,ein0,sizes,dx,alpha')
254                  end                  end
             | E.Sum(sx,E.Probe e)  =>let  
255    
256                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
257                  val  body'=E.Sum(sx,body')          val _=testp["\n******* Lift ******** \n"]
258                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val originalb=Ein.body e
259                  val code=newbies@[einapp]          val params=Ein.params e
260            val index=Ein.index e
261            val _=  toStringBind (y, DstIL.EINAPP(e,args))
262    
263                  in          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
264                      code          val fid=length(params)
265            val nid=fid+1
266            val nshift=length(dx)
267            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
268            val freshIndex=getsumshift(sx,index)
269    
270            (*transform T*P*P..Ps*)
271            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
272            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
273            val einApp0=mkEinApp(ein0,[PArg,FArg])
274            val rtn0=(case splitvar
275                of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
276                | _      => let
277                     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
278                     in Split.splitEinApp(bind3,0)
279                  end                  end
280              | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let              (*end case*))
281    
282                  val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)          (*lifted probe*)
283                  val  body'=E.Sum(sx,E.Prod[eps,body'])          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
284                  val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
285                  val code=newbies@[einapp]          val ein1=mkEin(params',sizes,body')
286                  in          val einApp1=mkEinApp(ein1,args')
287                      code          val rtn1=(FArg,einApp1)
288            val rtn=code@[rtn1]@rtn0
289            val _= List.map toStringBind ([rtn1]@rtn0)
290    
291            in
292                rtn
293                  end                  end
294              | _=> [(y, DstIL.EINAPP(ein,args))]  
295    
296        fun liftFieldMat(newvx,e)=
297            let
298                val (y, DstIL.EINAPP(ein,args))=e
299                val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
300                val index0=Ein.index ein
301                val index1 = index0@[3]
302                val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
303                (* clean to get body indices in order *)
304                val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
305                val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
306    
307                val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
308                val ein1 = mkEin(Ein.params ein,index1,body1)
309                val code1= (lhs1,mkEinApp(ein1,args))
310                val codeAll= (case dx
311                of []=> replaceProbe(1,code1,body1,[])
312                | _ =>liftProbe(1,code1,body1,[])
313                (*end case*))
314    
315                (*Probe that tensor at a constant position  c1*)
316                val param0 = [E.TEN(1,index1)]
317                val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
318                val body0 =  E.Tensor(0,[c1]@nx)
319                val ein0 = mkEin(param0,index0,body0)
320                val einApp0 = mkEinApp(ein0,[lhs1])
321                val code0 = (y,einApp0)
322                val _= toStringBind code0
323            in
324                codeAll@[code0]
325        end
326    
327        fun liftFieldSum e =
328        let
329            val _=print"\n*************************************\n"
330            val (y, DstIL.EINAPP(ein,args))=e
331            val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
332            val index0=Ein.index ein
333            val index1 = index0@[3]@[3]
334            val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
335            val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
336    
337    
338            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
339            val ein1 = mkEin(Ein.params ein,index1,body1)
340            val code1= (lhs1,mkEinApp(ein1,args))
341            val codeAll= (case dx
342            of []=> replaceProbe(1,code1,body1,[])
343            | _ =>liftProbe(1,code1,body1,[])
344              (* end case *))              (* end case *))
345    
346            (*Probe that tensor at a constant position  c1*)
347            val param0 = [E.TEN(1,index1)]
348            val nx=List.tabulate(length(dx),fn n=>E.V n)
349            val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
350            val ein0 = mkEin(param0,index0,body0)
351            val einApp0 = mkEinApp(ein0,[lhs1])
352            val code0 = (y,einApp0)
353            val _= toStringBind  e
354            val _ =toStringBind code0
355           val _ = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])
356           val _ =(String.concat(List.map toStringBind (codeAll@[code0])))
357                   val _=print"\n*************************************\n"
358          in          in
359              rewriteBody body          codeAll@[code0]
360        end
361    
362    
363        (* expandEinOp: code->  code list
364        * A this point we only have simple ein ops
365        * Looks to see if the expression has a probe. If so, replaces it.
366        * Note how we keeps eps expressions so only generate pieces that are used
367        *)
368       fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
369    
370            fun checkConst ([],a) =
371                (case fieldliftflag
372                    of true => liftProbe a
373                    | _ => replaceProbe a
374                (*end case*))
375            | checkConst ((E.C _::_),a) = replaceProbe a
376            | checkConst ((_ ::es),a)= checkConst(es,a)
377    
378            fun rewriteBody b=(case (detflag,b)
379                of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
380                    => liftFieldMat (1,e)
381                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
382                    => liftFieldMat (2,e)
383                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
384                    => liftFieldMat (3,e)
385                | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))
386                    => liftFieldSum e
387                | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))
388                    => liftFieldSum e
389                | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))
390                    => liftFieldSum e
391    
392    
393                | (_,E.Probe(E.Conv(_,_,_,[]),_))
394                    => replaceProbe(0,e,b,[])
395                | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
396                    => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
397                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
398                    => replaceProbe(0,e,p, sx)  (*no dx*)
399                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
400                    => checkConst(dx,(0,e,p,sx)) (*scalar field*)
401                | (_,E.Sum(sx,E.Probe p))
402                    => replaceProbe(0,e,E.Probe p, sx)
403                | (_,E.Sum(sx,E.Prod[eps,E.Probe p]))
404                    => replaceProbe(0,e,E.Probe p,sx)
405                | (_,_) => [e]
406                (* end case *))
407    
408            val (fieldset,var) = (case valnumflag
409                of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
410                | _     => (fieldset,NONE)
411            (*end case*))
412    
413            fun matchField b=(case b
414                of E.Probe _ => 1
415                | E.Sum (_, E.Probe _)=>1
416                | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
417                | _ =>0
418                (*end case*))
419            fun toStrField b=(case b
420                of E.Probe _ => print (P.printbody b)
421                | E.Sum (_, E.Probe _)=>print (P.printbody b)
422                | E.Sum(_, E.Prod[ _ ,E.Probe _])=>print (P.printbody b)
423                | _ =>print ""
424                (*end case*))
425                val b=Ein.body ein
426    (*
427            val _=  toStrField b
428      *)
429            in  (case var
430                of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
431                | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
432                (*end case*))
433          end          end
434    
435    end; (* local *)    end; (* local *)

Legend:
Removed from v.2922  
changed lines
  Added in v.3362

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0