Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2838, Tue Nov 25 03:40:24 2014 UTC revision 2843, Mon Dec 8 01:27:25 2014 UTC
# Line 4  Line 4 
4   * All rights reserved.   * All rights reserved.
5   *)   *)
6    
   
 (*  
 A couple of different approaches.  
 One approach is to find all the Probe(Conv). Gerenerate exp for it  
 Then use Subst function to sub in. That takes care for index matching and  
   
 *)  
   
 (*This approach creates probe expanded terms, and adds params to the end. *)  
   
   
7  structure ProbeEin = struct  structure ProbeEin = struct
8    
9      local      local
# Line 35  Line 24 
24      structure F=Filter      structure F=Filter
25      structure T=TransformEin      structure T=TransformEin
26      structure split=Split      structure split=Split
27        structure cleanI=cleanIndex
28    
29    
30      val testing=0      val testing=1
31    
32    
33      in      in
34    
35    
36    (* This file expands probed fields
37    *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
38    * Param_ids are used to note the placement of the argument in the midIL.var list
39    * Index_ids bind the shape of an Image or differentiation.
40    * Generally, we will refer to the following
41    *dim:dimension of field V
42    * s: support of kernel H
43    * alpha: The alpha in <V_alpha * H^(deltas)>
44    * deltas: The deltas in <V_alpha * H^(deltas)>
45    * Vid:param_id for V
46    * hid:param_id for H
47    * nid: integer position param_id
48    * fid :fractional position param_id
49    *img-imginfo about V
50    *)
51    
52    
53  val cnt = ref 0  val cnt = ref 0
54  fun genName prefix = let  fun genName prefix = let
55  val n = !cnt  val n = !cnt
# Line 50  Line 59 
59  end  end
60    
61    
62    fun iterSx e=F.iterSx e
63    fun transformToIndexSpace e=T.transformToIndexSpace e
64    fun transformToImgSpace  e=T.transformToImgSpace  e
65  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
66  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
67  fun testp n=(case testing  fun testp n=(case testing
68      of 0=> 1      of 0=> 1
69      | _ =>(print(String.concat n);1)      | _ =>(print(String.concat n);1)
70      (*end case*))      (*end case*))
   
 (*transform image-space position x to world space position*)  
   
 fun getTys 1= (DstTy.intTy,[],[])  
  | getTys dim = (DstTy.iVecTy dim,[dim],[dim,dim])  
   
   
 fun WorldToImagespace(dim,v,posx,imgArgDst)=let  
         val translate=DstOp.Translate v  
         val transform=DstOp.Transform v  
         val (_ ,fty,pty)=getTys dim  
         val mty=DstTy.TensorTy  pty  
         val rty=DstTy.TensorTy fty  
   
         val M  = DstV.new (genName "M", mty)   (*transform dim by dim?*)  
         val T  = DstV.new (genName "T", rty)  
         val x  = DstV.new (genName "x", rty)            (*Image-Space position*)  
         val x0  = DstV.new (genName "x0", rty)  
         val (PosToImgSpaceA,PosToImgSpaceB)=(case dim  
             of 1=>(mk.prodScalar,mk.addScalar)  
             | _ => (mk.transformA(dim,dim) ,mk.transformB(dim))  
             (*end case*))  
         val code=[  
             assign(M, transform, [imgArgDst]),  
             assign(T, translate, [imgArgDst]),  
             assignEin(x0, PosToImgSpaceA,[M,posx]) , (*xo=MX*)  
             assignEin(x, PosToImgSpaceB,[x0,T])  (*x=x0+T*)  
         ]  
     in (M,x,code)  
         end  
   
   
 (*Create fractional, and integer position vectors*)  
 fun transformToImgSpace  (dim,v,posx,imgArgDst)=let  
     val (ity,fty,pty)=getTys dim  
     val mty=DstTy.TensorTy  pty  
     val rty=DstTy.TensorTy fty  
   
     val f  = DstV.new ("f", rty)            (*fractional*)  
     val nd = DstV.new ("nd",  rty)           (*real position*)  
     val n  = DstV.new ("n", ity)           (*integer position*)  
     val P  = DstV.new ("P",mty)   (*transform dim by dim?*)  
   
     val (M,x,code1)=WorldToImagespace(dim,v,posx,imgArgDst)  
     val (P,PCode)=(case dim  
         of 1=>(M,[])  
         | _ =>(P,[assignEin(P, mk.transpose(pty), [M])])  
         (*end case*))  
     val code=[  
         assign(nd, DstOp.Floor dim, [x]),   (*nd *)  
         assignEin(f, mk.subTen(fty),[x,nd]),           (*fractional*)  
         assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)  
         ]  
     in ([n,f],P,code1@PCode@code)  
     end  
   
   
71   fun getRHSDst x  = (case DstIL.Var.binding x   fun getRHSDst x  = (case DstIL.Var.binding x
72      of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)      of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
73      | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'      | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
# Line 120  Line 75 
75   (* end case *))   (* end case *))
76    
77    
78   (*Get Img, and Kern Args*)  (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
79   fun getArgsDst(hid,hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)      uses the Param_ids for the image, kernel, and position tensor to get the Mid-IL arguments
80      returns the support of ther kernel, and image
81    *)
82     fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
83      of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let      of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
84      in      in
85          ((Kernel.support h) ,img)          ((Kernel.support h) ,img,ImageInfo.dim img)
86      end      end
87   |  _ => raise Fail "Expected Image and kernel argument"   |  _ => raise Fail "Expected Image and kernel arguments"
88   (*end case*))   (*end case*))
89    
90    
91    (*handleArgs():int*int*int*Mid IL.Var list ->int*Mid.ILVars list* code*int* low-il-var
92  fun handleArgs(V,hid,t,args)=let  * uses the Param_ids for the image, kernel, and tensor and gets the mid-IL vars for each
93    *Transforms the position to index space
94    *P-mid-il var for the (transformation matrix)transpose
95    *)
96    fun handleArgs(Vid,hid,tid,args)=let
97        val imgArg=List.nth(args,Vid)
98      val hArg=List.nth(args,hid)      val hArg=List.nth(args,hid)
99      val imgArg=List.nth(args,V)      val newposArg=List.nth(args,tid)
100      val newposArg=List.nth(args,t)      val (s,img,dim) =getArgsDst(hArg,imgArg,args)
     val (s,img) =getArgsDst(hid,hArg,imgArg,args)  
     val dim=ImageInfo.dim img  
101      val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)      val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
102      in (dim,args@argsT,code, s,P)      in (dim,args@argsT,code, s,P)
103      end      end
104    
105    
106  (*Created new body for probe*)  (*createBody:int*int*int, index_id list, param_id, param_id, param_id, param_id
107  fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let  * expands the body for the probed field
108    *)
109      (*sumIndex creating summaiton Index for body*)  fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
     fun sumIndex 0=[]  
     |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]  
   
110    
111        (*1-d fields*)
112      fun createKRND1 ()=let      fun createKRND1 ()=let
113          val sum=sx          val sum=sx
114          val dels=List.map (fn e=>(E.C 0,e)) deltas          val dels=List.map (fn e=>(E.C 0,e)) deltas
115          val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]          val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
116          val rest= E.Krn(h,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))          val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
117          in          in
118              E.Prod [E.Img(V,shape,pos),rest]              E.Prod [E.Img(Vid,alpha,pos),rest]
119    
120          end          end
   
121      (*createKRN Image field and kernels *)      (*createKRN Image field and kernels *)
122      fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)      fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
123      | createKRN(dim,imgpos,rest)=let      | createKRN(dim,imgpos,rest)=let
124          val dim'=dim-1          val dim'=dim-1
125          val sum=sx+dim'          val sum=sx+dim'
126          val dels=List.map (fn e=>(E.C dim',e)) deltas          val dels=List.map (fn e=>(E.C dim',e)) deltas
127          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
128          val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))          val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
129          in          in
130              createKRN(dim',pos@imgpos,[rest']@rest)              createKRN(dim',pos@imgpos,[rest']@rest)
131          end          end
   
132      val exp=(case dim      val exp=(case dim
133          of 1 => createKRND1()          of 1 => createKRND1()
134          | _=> createKRN(dim, [],[])          | _=> createKRN(dim, [],[])
135          (*end case*))          (*end case*))
136    
137      val esum=sumIndex (dim)      (*sumIndex creating summaiton Index for body*)
138      in E.Sum(esum, exp)      val slb=1-s
139        val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
140    in
141        E.Sum(esum, exp)
142      end      end
143    
144    (*getsumshift:sum_index_id list* index_id list-> int
145  fun ShapeConv([],n)=[]  *get fresh/unused index_id, returns int
146      | ShapeConv(E.C c::es, n)=ShapeConv(es, n)  *)
147      | ShapeConv(E.V v::es, n)=  fun getsumshift(sx,index) =let
148          if(n>v) then [E.V v] @ ShapeConv(es, n)      val nsumshift= (case sx
149          else ShapeConv(es,n)          of []=> length(index)
   
   
 fun mapIndex([],_)=[]  
     | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)  
     | mapIndex(E.C c::es,index) = mapIndex(es,index)  
   
   
   
  (* Expand probe in place eplaceProbe(b,params,args, index, sx,args)*)  
  fun replaceProbe(b,params,args,index, sumIndex)=let  
   
     val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b  
     val fid=length(params)  
     val nid=fid+1  
     val n=length(index)  
   
     val nshift=length(dx)  
     val nsumshift =(case sumIndex  
         of []=> n  
150          | _=>let          | _=>let
151              val (E.V v,_,_)=List.hd(List.rev sumIndex)              val (E.V v,_,_)=List.hd(List.rev sx)
152              in v+1              in v+1
153              end              end
154      (* end case *))      (* end case *))
155        val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
     val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sumIndex  
156      val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]      val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]
157        in
158            nsumshift
159        end
160    
161      (*Outer Index-id Of Probe*)  (*formBody:ein_exp->ein_exp
162      val VShape=ShapeConv(alpha, n)  *just does a quick rewrite
163      val HShape=ShapeConv(dx, n)  *)
164      val shape=VShape@HShape  fun formBody(E.Sum([],e))=formBody e
165      (* Bindings for Shape*)  | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
166      val shapebind= mapIndex(shape,index)  | formBody(E.Prod [e])=e
167      val Vshapebind= mapIndex(VShape,index)  | formBody e=e
168    
169    
170      val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,args)  (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code
171      val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)  * Transforms position to world space
172    * transforms result back to index_space
173    * rewrites body
174    * replace probe with expanded version
175    *)
176     fun replaceProbe(b,params,args,index, sx)=let
177    
178        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
179        val fid=length(params)
180        val nid=fid+1
181        val Pid=nid+1
182        val nshift=length(dx)
183        val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
184        val freshIndex=getsumshift(sx,index)
185        val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
186      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
187      val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)      val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
188      val body' =(case nshift      val body' =formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))
         of 0=> body''  
         | _ => E.Sum(sxT, E.Prod(restT@[body'']))  
         (*end case*))  
189      val args'=argsA@[PArg]      val args'=argsA@[PArg]
190      in      in
191          (body',params',args' ,code)          (body',params',args' ,code)
192      end      end
193    
194    
195   (* sx-[] then move out, otherwise keep in *)  (* expandEinOp: code->  code list
196    *Looks to see if the expression has a probe. If so, replaces it.
197    * Note how we keeps eps type expressions so we have less time in mid-to-low-il stage
198    *)
199  fun expandEinOp( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let  fun expandEinOp( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
200        fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",
201        (String.concatWith",\t"(List.map split.printEINAPP code))]
202    
203        fun rewriteBody b=(case b
204      (*b-current body, info-original ein op, data-new assigments*)          of  E.Probe e =>let
205      fun rewriteBody b= let              val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])
         in (case b  
             of E.Probe(E.Conv _, E.Tensor _) =>let  
             val (body',params',args',newbies)=replaceProbe(b, params,args,index, [])  
206              val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))              val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
207              val code=newbies@[einapp]              val code=newbies@[einapp]
208              in              in
209                  (1,code)                  code
210              end              end
211          | E.Sum(sx,E.Probe e) =>let          | E.Sum(sx,E.Probe e) =>let
212              val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)              val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
# Line 259  Line 214 
214              val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))              val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
215              val code=newbies@[einapp]              val code=newbies@[einapp]
216              in              in
217                  (1,code)                  code
218              end              end
219          | _=> (0,[e])          | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
220          (* end case *))              val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
221                val  body'=E.Sum(sx,E.Prod[eps,body'])
222                val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
223                val code=newbies@[einapp]
224                in
225                    code
226          end          end
227            | _=> [e]
      val empty =fn key =>NONE  
   
     val (c,code)=rewriteBody body  
     val b=String.concatWith",\t"(List.map split.printEINAPP code)  
     val _ =(case c  
         of 1 =>print(String.concat["\nbody",split.printEINAPP e, "\n=>\n",b ])  
         | _ =>print(String.concat[""])  
228          (*end case*))          (*end case*))
229      in      in
230          code          rewriteBody body
231      end      end
232    
233    
234    
235    end; (* local *)    end; (* local *)
236    
237  end (* local *)  end (* local *)

Legend:
Removed from v.2838  
changed lines
  Added in v.2843

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0