Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2608, Fri May 2 18:04:54 2014 UTC revision 2611, Mon May 5 21:21:12 2014 UTC
# Line 35  Line 35 
35      structure shift=ShiftEin      structure shift=ShiftEin
36      structure split=SplitEin      structure split=SplitEin
37      structure F=Filter      structure F=Filter
38        structure T=TransformEin
39    
40  val testing=1  val testing=1
41    
# Line 57  Line 58 
58      "but found ", SrcIL.vbToString vb])      "but found ", SrcIL.vbToString vb])
59      (* end case *))      (* end case *))
60    
 (*Transform differentiation index to world-space*)  
 (*Returns new deltas, summations, and List of Tensor Products*)  
 fun ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)= let  
   
     val dim'=dim-1  
     fun setMatrix(imgix,wrdix) = E.Tensor(1,[ E.V imgix, E.V wrdix])  
   
     fun changeDel([],n,newdels,sx,rest)=(newdels,sx,rest)  
      | changeDel((E.C _)::es,_,_,_,_)= raise Fail "unsure what to do with constant differentiation"  
      | changeDel((E.V v)::es,n,newdels,sx,rest)= let  
         val P=setMatrix(v, n)  
         val n'=E.V n  
         in  
             changeDel(es,n+1,newdels@[n'],sx@[(n',0,dim')],rest@[P])  
         end  
   
     val n=length(outerShape)  
     val (newdels,sx,rest)=changeDel(dels,n,[],[],[])  
     in  
         (newdels,sx,rest)  
     end  
   
 (*Transformation is Lifted*)  
 fun ImgtoWorldSpaceLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg)= let  
     val (newdels,sx,rest)=ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)  
     val tshape=List.tabulate((length(alpha)),fn v=> E.V v)  
   
     val newbie'=Ein.EIN{  
             params=[E.TEN(1,outerShape), E.TEN(1,[dim,dim])],  
             index=outerShape,  
             body=E.Sum(sx,E.Prod([E.Tensor(0,tshape@newdels)]@rest))  
         }  
   
     val _ = print(String.concat["\n Transform \n ",(split.printA (newArg,newbie',[oldArg,PArg])) ,"\n"])  
     val data=assignEin (newArg, newbie', [oldArg,PArg])  
     val ix=List.tabulate((length(dels)),fn _=> dim)  
     in  
         (newdels,ix,[data],[],[],[])  
     end  
   
   
 fun ImgtoWorldSpaceNoLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg,id)= let  
     val (newdels,sx,rest)=ImgtoWorldSpace(dels, outerShape,alpha,dim,oldArg,newArg,PArg)  
     val params=[E.TEN(id,[dim,dim])]  
     in  
         (newdels,[],[],params,sx,rest)  
     end  
   
   
 (*E.Sum(sx,rest) *)  
   
   
   
 (*returns final argumentVar, new dels, and assignments*)  
 fun decideIfTransform(dels,outerShape,alpha,dim,PArg,Pid,lift)=let  
     val oldArg = DstV.new ("ProbeResult", DstTy.tensorTy outerShape)  
     in (case (dels,lift)  
         of ([],_) => (oldArg,oldArg,dels,[],[],[],[],[])  
         | (_,0) => let  (*need to transform to world-space*)  
             val newArg = DstV.new ("IMG-Space", DstTy.tensorTy outerShape)  
             val (dels',ix',assigments,_,_,_)=ImgtoWorldSpaceLift(dels,outerShape,alpha,dim,oldArg,newArg,PArg)  
             in  
                 (oldArg,newArg,dels', ix',assigments,[],[],[])  
             end  
         | ( _,_) =>  let  
                 val newArg = DstV.new ("IMG-Space", DstTy.tensorTy outerShape)  
                val (dels',_,_,params,sx,rest)=ImgtoWorldSpaceNoLift(dels, outerShape,alpha,dim,oldArg,newArg,PArg,Pid)  
                 in (oldArg, oldArg, dels',[],[],params,sx,rest)  
                 end  
         (*end case*))  
     end  
61    
62    
63  (*Create fractional, and integer position vectors*)  (*Create fractional, and integer position vectors*)
# Line 236  Line 166 
166      | mapIndex(E.C c::es,index) = mapIndex(es,index)      | mapIndex(E.C c::es,index) = mapIndex(es,index)
167    
168    
169    (*Lift probe and Multiply by P*)
170    fun liftProbe(E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_)),(params,args),index, sumIndex,origargs)=let
171        val _ =print "Lift Probe"
172    
 (*Currently have three different Functions. One that Lifts Field and Transormation. One that just one lift, and the other does it all in place*)  
 (*Lift probe and does transformation Lifted*)  
 fun liftProbeTester(b,(params,args),index, sumIndex,origargs)=let  
   
     val E.Probe(E.Conv(V,alpha,H,dels),E.Tensor(t,_))=b  
     val newId=length(params)  
173      val n=length(index)      val n=length(index)
   
     (*Create new tensor replacement*)  
     val shape=ShapeConv(alpha@dels, n)  
     val newB=E.Tensor(newId,shape)  
   
     (* Create new Param*)  
     (*  val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape*)  
     val shape'= mapIndex(shape,index)  
     val newP= E.TEN(1,shape')  
   
   
     val A1=ShapeConv(alpha, n)  
     val alpha1= mapIndex(A1,index)  
   
     (*Expand Probe*)  
174      val ns=length sumIndex      val ns=length sumIndex
175        val nshift=length(dx)
176      val (dim,args',code,s,PArg) = handleArgs(V,H,t,(params,args), origargs,1)      val np=length(params)
177      val (oldArg,newArg,dx, ix,assigments,paramsA,sxA,restA) = decideIfTransform(dels,shape',alpha1,dim,PArg,n+1,ns)      val nsumshift =(case ns
178            of 0=>   n
179      val body' =(case ns          |_=>let  val (E.V v,_,_)=List.nth(sumIndex, ns-1)
180          of 0=>    createBody(dim, s,n+length(dx),alpha,dx,0, 1, 3, 2)              in
181          |_=>let                  v+1
             val (E.V v,_,_)=List.nth(sumIndex, ns-1)  
             val v'=v+length(dx) (*Shifted because of the swapped indices *)  
             val body'=createBody(dim, s,v'+2,alpha,dx,0, 1, 3, 2)  
             in  E.Sum(sumIndex ,body')  
182              end              end
183      (* end case *))      (* end case *))
184    
     val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]  
     val (p',i',b',a')=shift.clean(params'@paramsA, index@ix,body', args')  
     val newbie'=Ein.EIN{params=p', index=i', body=b'}  
   
   
     val data=assignEin (oldArg, newbie', a')  
   
     val _ = (case testing  
         of 0 => 1  
         | _ => (print(String.concat["\n Lift Probe\n", split.printA(oldArg, newbie', a'),"\n"]);1)  
             (*end case *))  
     in (newB, (params@[newP],args@[newArg]) ,code@[data]@assigments)  
     end  
   
   
 (*Lift probe and Multiply by P*)  
 fun liftProbe(b,(params,args),index, sumIndex,origargs)=let  
   
 val E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_))=b  
 val newId=length(params)  
 val n=length(index)  
   
 (*Create new tensor replacement*)  
 val shape=ShapeConv(alpha@dx, n)  
 val newB=E.Tensor(newId,shape)  
   
 (* Create new Param*)  
 (*  val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape*)  
 val shape'= mapIndex(shape,index)  
 val newP= E.TEN(1,shape')  
185    
186  val newArg = DstV.new ("PC", DstTy.tensorTy shape')      (*Outer Index-id Of Probe*)
187        val VShape=ShapeConv(alpha, n)
188        val HShape=ShapeConv(dx, n)
189        val shape=VShape@HShape
190    
191  (*Expand Probe*)      (* Bindings for Shape*)
192  val ns=length sumIndex      val shapebind= mapIndex(shape,index)
193        val Vshapebind= mapIndex(VShape,index)
194    
195  val (dim,args',code,s,PArg) = handleArgs(V,H,t,(params,args), origargs,1)      (*Look at Args and get dim, mid-il ops, support, and Arg for transformation matrix P*)
196        val (dim,args',code,support,PArg) = handleArgs(V,H,t,(params,args), origargs,1)
197    
198        (*New transformations:params, sx, rest, will be empty if no transformation is made*)
199        val (oldArg,newArg,dx, paramsT,sxT,restT,ixT,dataT) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,ns,4)
200    
201  val body' =(case ns      (*rewriteBody*)
202  of 0=>    createBody(dim, s,n,alpha,dx,0, 1, 3, 2)      val bodyExpanded = createBody(dim, support,nsumshift+nshift,alpha,dx,0, 1, 3, 2)
 |_=>let  
 val (E.V v,_,_)=List.nth(sumIndex, ns-1)  
203    
204  val body'=createBody(dim, s,v+1,alpha,dx,0, 1, 3, 2)      val sx=sumIndex@sxT
205  in  E.Sum(sumIndex ,body')      val body'=(case sx
206  end          of [] =>E.Prod(restT@[bodyExpanded])
207            | _ => E.Sum(sx, E.Prod(restT@[bodyExpanded]))
208  (* end case *))  (* end case *))
209    
210  val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]      (*create new EIN OPerator*)
211  val (p',i',b',a')=shift.clean(params', index,body', args')      val _ =print("Found this many args ")
212  val newbie'=Ein.EIN{params=p', index=i', body=b'}      val _ =print(Int.toString(length(args')))
   
213    
214  val data=assignEin (newArg, newbie', a')      val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]@paramsT
215        val (p',i',b',a')=shift.clean(params', index@ixT, body', args'@[PArg])
216        val newbie'=Ein.EIN{params=p', index=i', body=b'}
217        val data=assignEin (oldArg, newbie', a')
218    
219  val _ = (case testing  val _ = (case testing
220  of 0 => 1  of 0 => 1
221  | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)  | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)
222  (*end case *))  (*end case *))
223  in (newB, (params@[newP],args@[newArg]) ,code@[data])      in
224            (E.Tensor(np,shape), (params@[E.TEN(1,shapebind)],args@[newArg]),code@[data]@dataT)
225  end  end
226    
227    
# Line 345  Line 232 
232    
233      val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b      val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
234      val fid=length(params)      val fid=length(params)
235        val nid=fid+1
236      val n=length(index)      val n=length(index)
   
     (*Expand Probe*)  
237      val ns=length sumIndex      val ns=length sumIndex
238      val (dim,args',code,s,P) = handleArgs(V,h,t,(params,args), origargs,0)      val nshift=length(dx)
239      val nid=fid+1      val nsumshift =(case ns
240      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          of 0=> n
     val body' =(case ns  
         of 0=> createBody(dim, s,n,alpha,dx,V, h, nid, fid)  
241          |_=>let          |_=>let
242              val (E.V v,_,_)=List.nth(sumIndex, ns-1)              val (E.V v,_,_)=List.nth(sumIndex, ns-1)
243              in createBody(dim, s,v+1,alpha,dx,V, h, nid, fid)              in v+1
244              end              end
245          (* end case *))          (* end case *))
246    
247        (*Outer Index-id Of Probe*)
248        val VShape=ShapeConv(alpha, n)
249        val HShape=ShapeConv(dx, n)
250        val shape=VShape@HShape
251        (* Bindings for Shape*)
252        val shapebind= mapIndex(shape,index)
253        val Vshapebind= mapIndex(VShape,index)
254    
255    
256        val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,(params,args), origargs,0)
257        val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)
258    
259        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
260        val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)
261        val body' =(case nshift
262            of 0=> body''
263            | _ => E.Sum(sxT, E.Prod(restT@[body'']))
264            (*end case*))
265        val args'=argsA@[PArg]
266      val _ =(case testing      val _ =(case testing
267          of 0=> 1          of 0=> 1
268          | _ =>  let          | _ =>  let
# Line 367  Line 271 
271              in 1 end              in 1 end
272          (* end case *))          (* end case *))
273    
274    
275      in (body',(params',args') ,code)      in (body',(params',args') ,code)
276      end      end
277    
278    (*Checks if vairable occurs just once. If it does then we can lift *)
279    fun checkSum([(E.V i,lb,ub)],E.Probe(E.Conv(V,alpha,h,dx), E.Tensor(id,beta)),info,index,origargs)=let
280        val n=length(index)
281        val _=print "in check Sum\n "
282        val _=print(P.printbody(E.Sum([(E.V i,lb,ub)],E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)))))
283        in
284         if (i=n) then (case F.findOcc(i,alpha@dx)
285                of  1 => liftProbe(E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)),info,index@[ub], [],origargs)
286                | _ =>  replaceProbe(E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)), info,index, [(E.V i,lb,ub)],origargs)
287                (*end case*))
288        else replaceProbe(E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(id,beta)), info,index, [(E.V i,lb,ub)],origargs)
289        end
290    
291  fun flatten []=[]  fun flatten []=[]
292      | flatten(e1::es)=e1@(flatten es)      | flatten(e1::es)=e1@(flatten es)
# Line 398  Line 315 
315                  (sumIndex:=tl(s);(e',infoK,dataK))                  (sumIndex:=tl(s);(e',infoK,dataK))
316              end              end
317    
318            fun filter es=let
319                fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)
320                | filterApply(B::es, doneA, infoA,dataA)= let
321                    val (bodyB, infoB,dataB)= rewriteBody(B,infoA)
322                    in
323                        filterApply(es, doneA@[bodyB], infoB,dataA@dataB)
324                    end
325                in filterApply(es, [],info,[])
326                end
327          in (case b          in (case b
328              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let
329                  val ref sx=sumIndex                  val ref sx=sumIndex
330                  in (case sx                  in (case sx
331                      of [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)                      of [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
332                      | _ =>  replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)                        | [i]=> checkSum(i,b, info,index,origargs)
333                         | _ => let
334                            val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)
335                            in (E.Sum(c,b),m,code)
336                            end
337                  (* end case*))                  (* end case*))
338              end              end
339          | E.Probe(E.Conv _, E.Tensor _) =>let          | E.Probe(E.Conv _, E.Tensor _) =>let
340              val ref sx=sumIndex              val ref sx=sumIndex
341              in (case sx              in (case sx
342                  of []=> liftProbe(b, info,index, [],origargs)                  of []=> liftProbe(b, info,index, [],origargs)
343                    | [i]=> checkSum(i,b, info,index,origargs)
344                  | _=> replaceProbe(b, info,index, flatten sx,origargs)                  | _=> replaceProbe(b, info,index, flatten sx,origargs)
345               (* end case*))               (* end case*))
346              end              end
# Line 434  Line 365 
365              val (bodyB, infoB,dataB)= rewriteBody(b,infoA)              val (bodyB, infoB,dataB)= rewriteBody(b,infoA)
366              in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end              in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end
367          | E.Add es=> let          | E.Add es=> let
368              fun filter([], done, info', data)= let              val (done, info',data')= filter es
369                      val (_, e)=F.mkAdd done                      val (_, e)=F.mkAdd done
370                      in (e, info',data)              in (e, info',data')
371                      end                      end
                 | filter(e::es, done, info',data)= let  
                     val (body', info'',data')= rewriteBody(e,info')  
                     in filter(es, done@[body'], info'',data@data') end  
             in filter(es, [],info,[]) end  
   
372          | E.Prod es=> let          | E.Prod es=> let
373              fun filter([], done, info',data)= let              val (done, info',data')= filter es
374                      val (_, e)=F.mkProd done                      val (_, e)=F.mkProd done
375                      in  (e,info', data)              in (e, info',data')
376                      end                      end
                 | filter(e::es, done, info',data)= let  
                     val (body', info'',data')= rewriteBody(e, info')  
                     in filter(es, done@[body'], info'',data@data') end  
                 in filter(es, [],info,[]) end  
377          | _=>  (b,info,[])          | _=>  (b,info,[])
378          (* end case *))          (* end case *))
379          end          end

Legend:
Removed from v.2608  
changed lines
  Added in v.2611

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0