Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2606, Wed Apr 30 16:05:25 2014 UTC revision 2608, Fri May 2 18:04:54 2014 UTC
# Line 1  Line 1 
1  (* examples.sml  (* Currently under construction
2   *   *
3   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4   * All rights reserved.   * All rights reserved.
# Line 36  Line 36 
36      structure split=SplitEin      structure split=SplitEin
37      structure F=Filter      structure F=Filter
38    
39  val testing=0  val testing=1
40    
41  datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int  datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
42  datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int  datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int
# Line 57  Line 57 
57      "but found ", SrcIL.vbToString vb])      "but found ", SrcIL.vbToString vb])
58      (* end case *))      (* end case *))
59    
60    (*Transform differentiation index to world-space*)
61    (*Returns new deltas, summations, and List of Tensor Products*)
62    fun ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)= let
63    
64        val dim'=dim-1
65        fun setMatrix(imgix,wrdix) = E.Tensor(1,[ E.V imgix, E.V wrdix])
66    
67        fun changeDel([],n,newdels,sx,rest)=(newdels,sx,rest)
68         | changeDel((E.C _)::es,_,_,_,_)= raise Fail "unsure what to do with constant differentiation"
69         | changeDel((E.V v)::es,n,newdels,sx,rest)= let
70            val P=setMatrix(v, n)
71            val n'=E.V n
72            in
73                changeDel(es,n+1,newdels@[n'],sx@[(n',0,dim')],rest@[P])
74            end
75    
76        val n=length(outerShape)
77        val (newdels,sx,rest)=changeDel(dels,n,[],[],[])
78        in
79            (newdels,sx,rest)
80        end
81    
82    (*Transformation is Lifted*)
83    fun ImgtoWorldSpaceLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg)= let
84        val (newdels,sx,rest)=ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)
85        val tshape=List.tabulate((length(alpha)),fn v=> E.V v)
86    
87        val newbie'=Ein.EIN{
88                params=[E.TEN(1,outerShape), E.TEN(1,[dim,dim])],
89                index=outerShape,
90                body=E.Sum(sx,E.Prod([E.Tensor(0,tshape@newdels)]@rest))
91            }
92    
93        val _ = print(String.concat["\n Transform \n ",(split.printA (newArg,newbie',[oldArg,PArg])) ,"\n"])
94        val data=assignEin (newArg, newbie', [oldArg,PArg])
95        val ix=List.tabulate((length(dels)),fn _=> dim)
96        in
97            (newdels,ix,[data],[],[],[])
98        end
99    
100    
101    fun ImgtoWorldSpaceNoLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg,id)= let
102        val (newdels,sx,rest)=ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)
103        val params=[E.TEN(id,[dim,dim])]
104        in
105            (newdels,[],[],params,sx,rest)
106        end
107    
108    
109    (*E.Sum(sx,rest) *)
110    
111    
112    
113    (*returns final argumentVar, new dels, and assignments*)
114    fun decideIfTransform(dels,outerShape,alpha,dim,PArg,Pid,lift)=let
115        val oldArg = DstV.new ("ProbeResult", DstTy.tensorTy outerShape)
116        in (case (dels,lift)
117            of ([],_) => (oldArg,oldArg,dels,[],[],[],[],[])
118            | (_,0) => let  (*need to transform to world-space*)
119                val newArg = DstV.new ("IMG-Space", DstTy.tensorTy outerShape)
120                val (dels',ix',assigments,_,_,_)=ImgtoWorldSpaceLift(dels,outerShape,alpha,dim,oldArg,newArg,PArg)
121                in
122                    (oldArg,newArg,dels', ix',assigments,[],[],[])
123                end
124            | ( _,_) =>  let
125                    val newArg = DstV.new ("IMG-Space", DstTy.tensorTy outerShape)
126                   val (dels',_,_,params,sx,rest)=ImgtoWorldSpaceNoLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg,Pid)
127                    in (oldArg, oldArg, dels',[],[],params,sx,rest)
128                    end
129            (*end case*))
130        end
131    
132    
133  (*Create fractional, and integer position vectors*)  (*Create fractional, and integer position vectors*)
134  fun transformToImgSpace  (dim,v,posx)=let  fun transformToImgSpace  (dim,v,posx)=let
135    
136      val translate=DstOp.Translate v      val translate=DstOp.Translate v
137      val transform=DstOp.Transform v      val transform=DstOp.Transform v
138      val M  = DstV.new ("M", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)      val M  = DstV.new ("M", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)
139    
140      val T  = DstV.new ("T", DstTy.tensorTy [dim,dim])   (*translate*)      val T  = DstV.new ("T", DstTy.tensorTy [dim,dim])   (*translate*)
141      val x  = DstV.new ("x", DstTy.vecTy dim)            (*Image-Space position*)      val x  = DstV.new ("x", DstTy.vecTy dim)            (*Image-Space position*)
142      val f  = DstV.new ("f", DstTy.vecTy dim)            (*fractional*)      val f  = DstV.new ("f", DstTy.vecTy dim)            (*fractional*)
143      val nd = DstV.new ("nd", DstTy.vecTy dim)           (*real position*)      val nd = DstV.new ("nd", DstTy.vecTy dim)           (*real position*)
144      val n  = DstV.new ("n", DstTy.iVecTy dim)           (*integer position*)      val n  = DstV.new ("n", DstTy.iVecTy dim)           (*integer position*)
145      val PosToImgSpace=mk.transform(dim,dim)      val PosToImgSpace=mk.transform(dim,dim)
146        val P  = DstV.new ("P", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)
147    
148      val code=[      val code=[
149          assign(M, transform, []),          assign(M, transform, []),
# Line 76  Line 151 
151          assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)          assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)
152          assign(nd, DstOp.Floor dim, [x]),   (*nd *)          assign(nd, DstOp.Floor dim, [x]),   (*nd *)
153          assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)          assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)
154          assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)          assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)
155            assignEin(P, mk.transpose([dim,dim]), [M])
156      ]      ]
157      in ([n,f],code)      in ([n,f],P,code)
158      end      end
159    
160    
# Line 108  Line 184 
184      |  _ => raise Fail "Not a kernel argument"      |  _ => raise Fail "Not a kernel argument"
185    
186    
187  fun handleArgs(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),origargs,lift)=let  fun handleArgs(V,h,t,(params,args),origargs,lift)=let
188      val E.IMG(dim)=List.nth(params,V)      val E.IMG(dim)=List.nth(params,V)
189      val kArg=List.nth(origargs,h)      val kArg=List.nth(origargs,h)
190      val imgArg=List.nth(origargs,V)      val imgArg=List.nth(origargs,V)
191      val newposArg=List.nth(args, t)      val newposArg=List.nth(args, t)
192      val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)      val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)
193      val (argsT,code')=transformToImgSpace(dim,img,newposArg)      val (argsT,P,code')=transformToImgSpace(dim,img,newposArg)
194      in (dim,argsVH@argsT,argcode@code', s)      in (dim,argsVH@argsT,argcode@code', s,P)
195      end      end
196      | handleArgs _ =raise Fail"Expression is wrong for handleArgs"  
197    
198  (*createDels=> creates the kronecker deltas for each Kernel*)  (*createDels=> creates the kronecker deltas for each Kernel*)
199  fun createDels([],_)= []  fun createDels([],_)= []
# Line 161  Line 237 
237    
238    
239    
240  (*Lift probe*)  (*Currently have three different Functions. One that Lifts Field and Transormation. One that just one lift, and the other does it all in place*)
241    (*Lift probe and does transformation Lifted*)
242    fun liftProbeTester(b,(params,args),index, sumIndex,origargs)=let
243    
244        val E.Probe(E.Conv(V,alpha,H,dels),E.Tensor(t,_))=b
245        val newId=length(params)
246        val n=length(index)
247    
248        (*Create new tensor replacement*)
249        val shape=ShapeConv(alpha@dels, n)
250        val newB=E.Tensor(newId,shape)
251    
252        (* Create new Param*)
253        (*  val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape*)
254        val shape'= mapIndex(shape,index)
255        val newP= E.TEN(1,shape')
256    
257    
258        val A1=ShapeConv(alpha, n)
259        val alpha1= mapIndex(A1,index)
260    
261        (*Expand Probe*)
262        val ns=length sumIndex
263    
264        val (dim,args',code,s,PArg) = handleArgs(V,H,t,(params,args), origargs,1)
265        val (oldArg,newArg,dx, ix,assigments,paramsA,sxA,restA) = decideIfTransform(dels,shape',alpha1,dim,PArg,n+1,ns)
266    
267        val body' =(case ns
268            of 0=>    createBody(dim, s,n+length(dx),alpha,dx,0, 1, 3, 2)
269            |_=>let
270                val (E.V v,_,_)=List.nth(sumIndex, ns-1)
271                val v'=v+length(dx) (*Shifted because of the swapped indices *)
272                val body'=createBody(dim, s,v'+2,alpha,dx,0, 1, 3, 2)
273                in  E.Sum(sumIndex ,body')
274                end
275        (* end case *))
276    
277        val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]
278        val (p',i',b',a')=shift.clean(params'@paramsA, index@ix,body', args')
279        val newbie'=Ein.EIN{params=p', index=i', body=b'}
280    
281    
282        val data=assignEin (oldArg, newbie', a')
283    
284        val _ = (case testing
285            of 0 => 1
286            | _ => (print(String.concat["\n Lift Probe\n", split.printA(oldArg, newbie', a'),"\n"]);1)
287                (*end case *))
288        in (newB, (params@[newP],args@[newArg]) ,code@[data]@assigments)
289        end
290    
291    
292    (*Lift probe and Multiply by P*)
293  fun liftProbe(b,(params,args),index, sumIndex,origargs)=let  fun liftProbe(b,(params,args),index, sumIndex,origargs)=let
294    
295      val E.Probe(E.Conv(_,alpha,_,dx),pos)=b  val E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_))=b
296      val newId=length(params)      val newId=length(params)
297      val n=length(index)      val n=length(index)
298    
# Line 177  Line 305 
305      val shape'= mapIndex(shape,index)      val shape'= mapIndex(shape,index)
306      val newP= E.TEN(1,shape')      val newP= E.TEN(1,shape')
307    
     (*Create new Arg*)  
308      val newArg = DstV.new ("PC", DstTy.tensorTy shape')      val newArg = DstV.new ("PC", DstTy.tensorTy shape')
309    
310      (*Expand Probe*)      (*Expand Probe*)
311      val ns=length sumIndex      val ns=length sumIndex
312    
313      val (dim,args',code,s) = handleArgs(b,(params,args), origargs,1)  val (dim,args',code,s,PArg) = handleArgs(V,H,t,(params,args), origargs,1)
314    
315    
316      val body' =(case ns      val body' =(case ns
317          of 0=>    createBody(dim, s,n,alpha,dx,0, 1, 3, 2)          of 0=>    createBody(dim, s,n,alpha,dx,0, 1, 3, 2)
318          |_=>let          |_=>let
319              val (E.V v,_,_)=List.nth(sumIndex, ns-1)              val (E.V v,_,_)=List.nth(sumIndex, ns-1)
320    
321              val body'=createBody(dim, s,v+1,alpha,dx,0, 1, 3, 2)              val body'=createBody(dim, s,v+1,alpha,dx,0, 1, 3, 2)
322              in  E.Sum(sumIndex ,body')              in  E.Sum(sumIndex ,body')
323              end              end
# Line 197  Line 326 
326      val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]      val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]
327      val (p',i',b',a')=shift.clean(params', index,body', args')      val (p',i',b',a')=shift.clean(params', index,body', args')
328      val newbie'=Ein.EIN{params=p', index=i', body=b'}      val newbie'=Ein.EIN{params=p', index=i', body=b'}
329    
330    
331      val data=assignEin (newArg, newbie', a')      val data=assignEin (newArg, newbie', a')
332    
333      val _ = (case testing      val _ = (case testing
# Line 207  Line 338 
338      end      end
339    
340    
341    
342    (*Does not yet do transformation*)
343   (* Expand probe in place *)   (* Expand probe in place *)
344   fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let   fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let
345    
346      val E.Probe(E.Conv(V,alpha,h,dx),pos)=b      val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
347      val fid=length(params)      val fid=length(params)
348      val n=length(index)      val n=length(index)
349    
350      (*Expand Probe*)      (*Expand Probe*)
351      val ns=length sumIndex      val ns=length sumIndex
352      val (dim,args',code,s) = handleArgs(b,(params,args), origargs,0)      val (dim,args',code,s,P) = handleArgs(V,h,t,(params,args), origargs,0)
353      val nid=fid+1      val nid=fid+1
354      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
355      val body' =(case ns      val body' =(case ns
# Line 331  Line 464 
464    
465      val (body',(params',args'),newbies)=rewriteBody(body,(params,args))      val (body',(params',args'),newbies)=rewriteBody(body,(params,args))
466      val e'=Ein.EIN{params=params', index=index, body=body'}      val e'=Ein.EIN{params=params', index=index, body=body'}
467      val _ =(case testing      (*val _ =(case testing
468          of 0 => 1          of 0 => 1
469          | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)          | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)
470          (*end case*))          (*end case*))*)
471      in      in
472          ((e',args'),newbies)          ((e',args'),newbies)
473      end      end

Legend:
Removed from v.2606  
changed lines
  Added in v.2608

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0