Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2612, Wed May 7 02:58:55 2014 UTC revision 2829, Wed Nov 12 23:24:38 2014 UTC
# Line 37  Line 37 
37      structure F=Filter      structure F=Filter
38      structure T=TransformEin      structure T=TransformEin
39    
40  val testing=1      val testing=0
41    
42    
 datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int  
 datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int  
43      in      in
44    
45    
46  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
47  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
48    fun testp n=(case testing
49  fun getRHS x  = (case SrcIL.Var.binding x      of 0=> 1
50      of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)      | _ =>(print(String.concat n);1)
     | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'  
     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)  
     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)  
     | SrcIL.VB_NONE=>(S2 2,[])  
     | vb => raise Fail(concat[  
     "expected rhs operator for ", SrcIL.Var.toString x,  
     "but found ", SrcIL.vbToString vb])  
51      (* end case *))      (* end case *))
52    
   
   
53  (*Create fractional, and integer position vectors*)  (*Create fractional, and integer position vectors*)
54  fun transformToImgSpace  (dim,v,posx)=let  fun transformToImgSpace  (dim,v,posx,imgArgDst)=let
   
55      val translate=DstOp.Translate v      val translate=DstOp.Translate v
56      val transform=DstOp.Transform v      val transform=DstOp.Transform v
57      val M  = DstV.new ("M", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)      val M  = DstV.new ("M", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)
58        val T  = DstV.new ("T", DstTy.tensorTy [dim])   (*translate*)
     val T  = DstV.new ("T", DstTy.tensorTy [dim,dim])   (*translate*)  
59      val x  = DstV.new ("x", DstTy.vecTy dim)            (*Image-Space position*)      val x  = DstV.new ("x", DstTy.vecTy dim)            (*Image-Space position*)
60        val x0  = DstV.new ("x0", DstTy.vecTy dim)
61        val x1  = DstV.new ("x1", DstTy.vecTy dim)
62      val f  = DstV.new ("f", DstTy.vecTy dim)            (*fractional*)      val f  = DstV.new ("f", DstTy.vecTy dim)            (*fractional*)
63      val nd = DstV.new ("nd", DstTy.vecTy dim)           (*real position*)      val nd = DstV.new ("nd", DstTy.vecTy dim)           (*real position*)
64      val n  = DstV.new ("n", DstTy.iVecTy dim)           (*integer position*)      val n  = DstV.new ("n", DstTy.iVecTy dim)           (*integer position*)
65      val PosToImgSpace=mk.transform(dim,dim)      val PosToImgSpace=mk.transform(dim,dim)
66        val PosToImgSpaceA=mk.transformA(dim,dim)
67        val PosToImgSpaceB=mk.transformB dim
68      val P  = DstV.new ("P", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)      val P  = DstV.new ("P", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)
   
69      val code=[      val code=[
70          assign(M, transform, []),          assign(M, transform, [imgArgDst]),
71          assign(T, translate, []),          assign(T, translate, [imgArgDst]),
72          assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)          (*assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)*)
73            assignEin(x0, PosToImgSpaceA,[M,posx]) ,
74            assignEin(x, PosToImgSpaceB,[x0,T]) ,
75    
76          assign(nd, DstOp.Floor dim, [x]),   (*nd *)          assign(nd, DstOp.Floor dim, [x]),   (*nd *)
77          assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)          assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)
78          assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)          assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)
# Line 87  Line 81 
81      in ([n,f],P,code)      in ([n,f],P,code)
82      end      end
83    
84    fun getRHS x  = (case SrcIL.Var.binding x
85  fun replaceH(kvar, place,args)=let      of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (rator, args)
86      val l1=List.take(args, place)      | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'
87      val l2=List.drop(args,place+1)      | vb => raise Fail(concat[ "expected rhs operator for ", SrcIL.Var.toString x, "but found ", SrcIL.vbToString vb])
88      in l1@[kvar]@l2 end      (* end case *))
   
89    
90  (*Get Img, and Kern Args*)  (*Get Img, and Kern Args*)
91  fun getArgs(hid,hArg,V,imgArg,args,lift)=case (getRHS hArg,getRHS imgArg)  fun getArgs(hid,hArg,V,imgArg,args,lift,varI)=(case (getRHS hArg,getRHS imgArg)
92      of ((O2(SrcOp.Kernel(h, i)),argK),(O2(SrcOp.LoadImage img),_))=> let      of ((SrcOp.Kernel(h, i), _ ),(SrcOp.LoadImage img, _ ))=> let
93          val hvar=DstV.new ("KNL", DstTy.KernelTy)          val hvar=DstV.new ("KNL", DstTy.KernelTy)
94          val imgvar=DstV.new ("IMG", DstTy.ImageTy img)          val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
95          val argsVK= (case lift          val argsVK= (case lift
96              of 0=> let              of 0=> let
97                  val argsN=replaceH(hvar, hid,args)                  val _=print "non lift"
98                  in replaceH(imgvar, V,argsN) end                  val l1=List.take(args, hid)
99              | _ => [imgvar, hvar]                  val l2=List.drop(args,hid+1)
100                    in
101                        l1@[hvar]@l2
102                    end
103                | _ => [varI, hvar]
104          (* end case *))          (* end case *))
         val assigments=[assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])]  
105    
106            val assigments=[assign (hvar, DstOp.Kernel(h, i), [])]
107          in          in
108              (Kernel.support h ,img, assigments,argsVK)              ((Kernel.support h) ,img, assigments,argsVK)
109          end          end
110      | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"      |  _ => raise Fail "Expected Image and kernel argument"
111      |  _ => raise Fail "Not a kernel argument"      (*end case*))
112    
113    
114  fun handleArgs(V,h,t,(params,args),origargs,lift)=let  fun handleArgs(V,h,t,(params,args),origargs,lift,dstargs)=let
115      val E.IMG(dim)=List.nth(params,V)      val E.IMG(dim)=List.nth(params,V)
116      val kArg=List.nth(origargs,h)      val kArg=List.nth(origargs,h)
117      val imgArg=List.nth(origargs,V)      val imgArg=List.nth(origargs,V)
118      val newposArg=List.nth(args, t)      val newposArg=List.nth(args, t)
119      val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)      val imgArgDst=List.nth(dstargs,V)
120      val (argsT,P,code')=transformToImgSpace(dim,img,newposArg)      val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift,imgArgDst)
121        val (argsT,P,code')=transformToImgSpace(dim,img,newposArg,imgArgDst)
122      in (dim,argsVH@argsT,argcode@code', s,P)      in (dim,argsVH@argsT,argcode@code', s,P)
123      end      end
124    
# Line 166  Line 164 
164      | mapIndex(E.C c::es,index) = mapIndex(es,index)      | mapIndex(E.C c::es,index) = mapIndex(es,index)
165    
166    
167    (*
168  (*Lift probe and Multiply by P*)  (*Lift probe and Multiply by P*)
169  fun liftProbe(E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_)),(params,args),index, sumIndex,origargs)=let  fun liftProbe(E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_)),(params,args),index, sumIndex,origargs)=let
170      val _ =print "Lift Probe"      val _ =print "Lift Probe"
# Line 194  Line 193 
193    
194      (*Look at Args and get dim, mid-il ops, support, and Arg for transformation matrix P*)      (*Look at Args and get dim, mid-il ops, support, and Arg for transformation matrix P*)
195      val (dim,args',code,support,PArg) = handleArgs(V,H,t,(params,args), origargs,1)      val (dim,args',code,support,PArg) = handleArgs(V,H,t,(params,args), origargs,1)
196        val _ =print("\nSupport is "^Int.toString support)
197    
198      (*New transformations:params, sx, rest, will be empty if no transformation is made*)      (*New transformations:params, sx, rest, will be empty if no transformation is made*)
199      val (oldArg,newArg,dx, paramsT,sxT,restT,ixT,dataT) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,ns,4)      val (oldArg,newArg,dx, paramsT,sxT,restT,ixT,dataT) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,ns,4)
# Line 223  Line 223 
223      in      in
224          (E.Tensor(np,shape), (params@[E.TEN(1,shapebind)],args@[newArg]),code@[data]@dataT)          (E.Tensor(np,shape), (params@[E.TEN(1,shapebind)],args@[newArg]),code@[data]@dataT)
225      end      end
226     |liftProbe _ =raise Fail"Incorrect body for Probe"
227    
228    *)
229    
230    
 (*Does not yet do transformation*)  
231   (* Expand probe in place *)   (* Expand probe in place *)
232   fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let   fun replaceProbe(b,(params,args),index, sumIndex,origargs,dstargs)=let
233    
234      val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b      val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
235      val fid=length(params)      val fid=length(params)
# Line 253  Line 254 
254      val Vshapebind= mapIndex(VShape,index)      val Vshapebind= mapIndex(VShape,index)
255    
256    
257      val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,(params,args), origargs,0)      val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,(params,args), origargs,0,dstargs)
258        val _ =testp["\nSupport is ",Int.toString s]
259      val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)      val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)
260    
261      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
# Line 263  Line 265 
265          | _ => E.Sum(sxT, E.Prod(restT@[body'']))          | _ => E.Sum(sxT, E.Prod(restT@[body'']))
266          (*end case*))          (*end case*))
267      val args'=argsA@[PArg]      val args'=argsA@[PArg]
     val _ =(case testing  
         of 0=> 1  
         | _ =>  let  
268              val subexp=Ein.EIN{params=params', index=index, body=body'}              val subexp=Ein.EIN{params=params', index=index, body=body'}
269              val _= print(String.concat["\n Don't replace probe  \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"])      val _ = testp["\n Don't replace probe  \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"]
             in 1 end  
         (* end case *))  
   
270    
271      in (body',(params',args') ,code)      in (body',(params',args') ,code)
272      end      end
273    
274    (*
275  (*Checks if (1) Summation variable occurs just once (2) it matches n.  (*Checks if (1) Summation variable occurs just once (2) it matches n.
276  Then we lift otherwise expand in place *)  Then we lift otherwise expand in place *)
277  fun checkSum(sx,b,info,index,origargs)=(case sx  fun checkSum(sx,b,info,index,origargs)=(case sx
# Line 293  Line 290 
290          end          end
291      | _ =>replaceProbe(b, info,index, sx,origargs)      | _ =>replaceProbe(b, info,index, sx,origargs)
292      (*end case*))      (*end case*))
293    *)
294    
295  fun flatten []=[]  fun flatten []=[]
296      | flatten(e1::es)=e1@(flatten es)      | flatten(e1::es)=e1@(flatten es)
# Line 322  Line 319 
319                  (sumIndex:=tl(s);(e',infoK,dataK))                  (sumIndex:=tl(s);(e',infoK,dataK))
320              end              end
321    
322    
323            (*Nothing liftProbe and checkSum are commented out.
324                Some mistake underestimating size of dimension*)
325    
326          fun filter es=let          fun filter es=let
327              fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)              fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)
328              | filterApply(B::es, doneA, infoA,dataA)= let              | filterApply(B::es, doneA, infoA,dataA)= let
# Line 335  Line 336 
336              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let
337                  val ref sx=sumIndex                  val ref sx=sumIndex
338                  in (case sx                  in (case sx
339                      of   [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)                      of  (* [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
340                        | [i]=> checkSum(i,b, info,index,origargs)                        | [i]=> checkSum(i,b, info,index,origargs)
341                        | _ => let                        |*) _ => let
342                          val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)                          val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs,args)
343                          in (E.Sum(c,b),m,code)                          in (E.Sum(c,b),m,code)
344                          end                          end
345                  (* end case*))                  (* end case*))
# Line 346  Line 347 
347          | E.Probe(E.Conv _, E.Tensor _) =>let          | E.Probe(E.Conv _, E.Tensor _) =>let
348              val ref sx=sumIndex              val ref sx=sumIndex
349              in (case sx              in (case sx
350                  of []=> liftProbe(b, info,index, [],origargs)                  of (* []=> liftProbe(b, info,index, [],origargs)
351                  | [i]=> checkSum(i,b, info,index,origargs)                  | [i]=> checkSum(i,b, info,index,origargs)
352                  | _ => replaceProbe(b, info,index, flatten sx,origargs)                  |*) _ => replaceProbe(b, info,index, flatten sx,origargs,args)
353               (* end case*))               (* end case*))
354              end              end
355          | E.Probe _=> (dummy,info,[])          | E.Probe _=> (dummy,info,[])

Legend:
Removed from v.2612  
changed lines
  Added in v.2829

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0