Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3259, Mon Sep 21 15:14:52 2015 UTC revision 3260, Wed Sep 23 16:09:21 2015 UTC
# Line 41  Line 41 
41      val testlift=0      val testlift=0
42    
43      val cnt = ref 0      val cnt = ref 0
   
     fun printEINAPP e=MidToString.printEINAPP e  
44      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
45      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
46      fun printEINAPP e=MidToString.printEINAPP e      fun toStringBind e=print(MidToString.toStringBind e)
47        fun mkEin e=Ein.mkEin e
     fun transitionToString(testreplace,a,b)=(case testreplace  
         of 0=> 1  
         | 2 => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)  
         |_ =>(print(String.concat["\nReplaced:",P.printbody a]);1)  
         (*end case*))  
     fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}  
48      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
     fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body  
     fun setBody(body',(y,DstIL.EINAPP(E.EIN{params,index,body},args)))=  
             (y,DstIL.EINAPP(E.EIN{params=params,index=index,body=body'},args))  
     fun setBodye(body',E.EIN{params,index,body})=  
            E.EIN{params=params,index=index,body=body'}  
49    
50      fun testp n=(case testing      fun testp n=(case testing
51          of 0=> 1          of 0=> 1
52          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
53          (*end case*))          (*end case*))
     fun  einapptostring (body,a,b)=(case testlift  
         of 0=>1  
         | 2=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)  
         |_ =>(print(String.concat["\nLifted",P.printbody body]);1)  
         (*end case*))  
54    
55    
56      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
# Line 156  Line 138 
138                  end                  end
139              (* end case *))              (* end case *))
140          val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx          val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
141          val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),          val _ =print(String.concat["\n", "SumIndex" ,(String.concatWith"," aa),
142              "\nThink nshift is ", Int.toString nsumshift]              "\nThink nshift is ", Int.toString nsumshift])
143          in          in
144              nsumshift              nsumshift
145          end          end
# Line 193  Line 175 
175          val originalb=Ein.body e          val originalb=Ein.body e
176          val params=Ein.params e          val params=Ein.params e
177          val index=Ein.index e          val index=Ein.index e
178          val _= ("\n"^P.printbody originalb)          val _ = print("\n***************** \n Replace ************ \n")
179            val _=  toStringBind (y, DstIL.EINAPP(e,args))
180    
181          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
182          val fid=length(params)          val fid=length(params)
# Line 212  Line 195 
195              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
196              | _                                  => body'              | _                                  => body'
197              (*end case*))              (*end case*))
198          val _=transitionToString(testN,originalb,body')  
199    
200          val args'=argsA@[PArg]          val args'=argsA@[PArg]
201          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
# Line 225  Line 208 
208          val Pid=0          val Pid=0
209          val tid=1          val tid=1
210    
211            (*Assumes body is already clean*)
212          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
213    
214          (*need to rewrite dx*)          (*need to rewrite dx*)
215          val (_,sizes,E.Conv(_,_,_,dx))=(case sx@newsx          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
216              of []=> ([],index,E.Conv(9,alpha,7,newdx))              of []=> ([],index,E.Conv(9,alpha,7,newdx))
217              | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)              | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
218              (*end case*))              (*end case*))
219    
220          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
221          val tshape=alpha@newdx          fun filterAlpha []=[]
222              | filterAlpha(E.C _::es)= filterAlpha es
223              | filterAlpha(e1::es)=[e1]@(filterAlpha es)
224    
225            val tshape=filterAlpha(alpha')@newdx
226          val t=E.Tensor(tid,tshape)          val t=E.Tensor(tid,tshape)
227          val (splitvar,body)=(case originalb          val (splitvar,body)=(case originalb
228              of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))              of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
# Line 247  Line 235 
235    
236          val ein0=mkEin(params,index,body)          val ein0=mkEin(params,index,body)
237          in          in
238              (splitvar,ein0,sizes,dx)              (splitvar,ein0,sizes,dx,alpha')
239          end          end
240    
241      fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
242            val _=print("\n******* Lift ******** \n")
243          val originalb=Ein.body e          val originalb=Ein.body e
244          val params=Ein.params e          val params=Ein.params e
245          val index=Ein.index e          val index=Ein.index e
246          val _=print("\n Inside Lift: "^P.printbody originalb)          val _=  toStringBind (y, DstIL.EINAPP(e,args))
247    
248          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
249          val fid=length(params)          val fid=length(params)
# Line 263  Line 252 
252          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
253          val freshIndex=getsumshift(sx,index)          val freshIndex=getsumshift(sx,index)
254    
   
255          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
256          val (splitvar,ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
257          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
258          val einApp0=mkEinApp(ein0,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
259          val rtn0=(case splitvar          val rtn0=(case splitvar
# Line 275  Line 263 
263    
264          (*lifted probe*)          (*lifted probe*)
265          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
266          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
267          val ein1=mkEin(params',sizes,body')          val ein1=mkEin(params',sizes,body')
268          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
269          val rtn1=(FArg,einApp1)          val rtn1=(FArg,einApp1)
270          val rtn=code@[rtn1]@rtn0          val rtn=code@[rtn1]@rtn0
271            val _= List.map toStringBind ([rtn1]@rtn0)
272    
273          in          in
274              rtn              rtn
# Line 292  Line 281 
281      *)      *)
282     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
283    
284          fun checkConst ([],a) = liftProbe a          fun checkConst ([],a) = (*liftProbe a*) replaceProbe a
285          | checkConst ((E.C _::_),a) =(print("\n \n constant field"^(printEINAPP e));replaceProbe a)          | checkConst ((E.C _::_),a) = replaceProbe a
286          | checkConst ((_ ::es),a)= checkConst(es,a)          | checkConst ((_ ::es),a)= checkConst(es,a)
287    
288            fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=
289                let
290    
291                    val _= toStringBind e
292                    val index0=Ein.index ein
293                    val index1 = index0@[3]
294                    val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)
295                     (* clean to get body indices in order *)
296                    val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
297                    val _=print(String.concat["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1])
298    
299                    val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
300                    val ein1 = mkEin(Ein.params ein,index1,body1)
301                    val code1= (lhs1,mkEinApp(ein1,args))
302                    val codeAll= liftProbe(1,code1,body1,[])
303    
304    
305                    (*Probe that tensor at a constant position E.C c1*)
306                    val param0 = [E.TEN(1,index1)]
307                    val body0 =  E.Tensor(0,[E.C c1]@dx)
308                    val ein0 = mkEin(param0,index0,body0)
309                    val einApp0 = mkEinApp(ein0,[lhs1])
310                    val code0 = (y,einApp0)
311                    val _= toStringBind code0
312                in
313                    codeAll@[code0]
314                end
315    
316          fun rewriteBody b=(case b          fun rewriteBody b=(case b
317              of E.Probe(E.Conv(_,_,_,[]),_)              of E.Probe(E.Conv(_,_,_,[]),_)
318                  => replaceProbe(0,e,b, [])                  => replaceProbe(0,e,b, [])
319    
320                | E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos)
321                    => liftFieldMat (2,b)
322                | E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos)
323                    => liftFieldMat (3,b)
324    
325              | E.Probe(E.Conv (_,alpha,_,dx),_)              | E.Probe(E.Conv (_,alpha,_,dx),_)
326                  => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)                  => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
327              | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))              | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
328                  => replaceProbe(0,e,p, sx)                  => replaceProbe(0,e,p, sx)  (*no dx*)
329              | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))              | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))
330                  => checkConst(dx,(0,e,p,sx))                  => checkConst(dx,(0,e,p,sx)) (*scalar field*)
             | E.Sum(sx as [(v,_,_)],p as (E.Probe((E.Conv(_,alpha,_,dx),_))))=>  
                 (print(P.printbody(b));case (List.find (fn x => x = v) dx)  
                     of NONE=>  checkConst(dx,(1,e,p,sx))  
                         (*need to push summation to lifted exp rather than transform exp.*)  
                     | SOME _=> replaceProbe(1,e, p, sx)  
                 (*end case*))  
             | E.Sum(sx as [(v1,_,_),(v2,_,_)],p as (E.Probe((E.Conv(_,alpha,_,dx),_))))=>  
                 (print(P.printbody(b));case ((List.find (fn x => x = v1) dx),(List.find (fn x => x = v2) dx))  
                     of (NONE,NONE)=>  checkConst(dx,(1,e,p,sx))  
                     |  ( _ , NONE)=>  checkConst(dx,(1,e,p,sx))  
                     |  _=> (*replaceProbe(1,e, p, sx)*) checkConst(dx,(1,e,p,sx))  
                 (*end case*))  
331              | E.Sum(sx,E.Probe p)              | E.Sum(sx,E.Probe p)
332                  => replaceProbe(0,e,E.Probe p, sx)                  => replaceProbe(0,e,E.Probe p, sx)
333              | E.Sum(sx,E.Prod[eps,E.Probe p])              | E.Sum(sx,E.Prod[eps,E.Probe p])
# Line 326  Line 337 
337    
338          val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))          val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
339          in  (case var          in  (case var
340              of NONE=> (("\n \n not replacing"^(printEINAPP e));(rewriteBody(Ein.body ein),fieldset))          of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset))
341              | SOME v=> (("\n replacing"^(P.printerE ein));( [(y,DstIL.VAR v)] , fieldset))              | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)] , fieldset))
342              (*end case*))              (*end case*))
343          end          end
344    

Legend:
Removed from v.3259  
changed lines
  Added in v.3260

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0