Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 3092, Tue Mar 17 20:02:38 2015 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3278, Tue Oct 13 03:59:10 2015 UTC
# Line 38  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41      val testlift=1      val testlift=0
     val cnt = ref 0  
42    
43      fun printEINAPP e=MidToString.printEINAPP e      val cnt = ref 0
44      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
45      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
46        fun toStringBind e=(MidToString.toStringBind e)
47      fun transitionToString(testreplace,a,b)=(case testreplace      fun mkEin e=Ein.mkEin e
         of 0=> 1  
         | _ => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)  
         (*end case*))  
     fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}  
48      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
     fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body  
     fun setBody(body',(y,DstIL.EINAPP(E.EIN{params,index,body},args)))=  
             (y,DstIL.EINAPP(E.EIN{params=params,index=index,body=body'},args))  
49    
50      fun testp n=(case testing      fun testp n=(case testing
51          of 0=> 1          of 0=> 1
52          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
53          (*end case*))          (*end case*))
     fun  einapptostring (body,a,b)=(case testlift  
         of 0=>1  
         | _=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)  
         (*end case*))  
54    
55    
56      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
# Line 168  Line 156 
156      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
157        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
158    
159    
160        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
161          | multiMergePs e=multiPs e
162    
163    
164      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
165              -> ein_exp* *code              -> ein_exp* *code
166      * Transforms position to world space      * Transforms position to world space
# Line 182  Line 175 
175          val originalb=Ein.body e          val originalb=Ein.body e
176          val params=Ein.params e          val params=Ein.params e
177          val index=Ein.index e          val index=Ein.index e
178            val _ = testp["\n***************** \n Replace ************ \n"]
179            val _=  toStringBind (y, DstIL.EINAPP(e,args))
180    
181          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
182          val fid=length(params)          val fid=length(params)
# Line 201  Line 195 
195              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
196              | _                                  => body'              | _                                  => body'
197              (*end case*))              (*end case*))
198          val _=transitionToString(testN,originalb,body')  
199    
200          val args'=argsA@[PArg]          val args'=argsA@[PArg]
201          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
# Line 209  Line 203 
203              code@[einapp]              code@[einapp]
204          end          end
205    
206        val tsplitvar=true
207      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
208          val Pid=0          val Pid=0
209          val tid=1          val tid=1
210    
211            (*Assumes body is already clean*)
212          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
213    
214          (*need to rewrite dx*)          (*need to rewrite dx*)
215          val (_,sizes,E.Conv(_,_,_,dx))=(case sx@newsx          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
216              of []=> ([],index,E.Conv(9,alpha,7,newdx))              of []=> ([],index,E.Conv(9,alpha,7,newdx))
217              | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)              | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
218              (*end case*))              (*end case*))
219    
220          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
221          val tshape=alpha@newdx          fun filterAlpha []=[]
222              | filterAlpha(E.C _::es)= filterAlpha es
223              | filterAlpha(e1::es)=[e1]@(filterAlpha es)
224    
225            val tshape=filterAlpha(alpha')@newdx
226          val t=E.Tensor(tid,tshape)          val t=E.Tensor(tid,tshape)
227          val exp = multiPs(Ps,newsx,t)          val (splitvar,body)=(case originalb
228          val body=(case originalb              of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
229              of E.Sum(sx, E.Probe _)              => E.Sum(sx,exp)              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
230              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,exp])              | _                                  => (case tsplitvar
231              | _                                  => exp                  of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
232                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
233                    (*end case*))
234              (*end case*))              (*end case*))
235    
236          val ein0=mkEin(params,index,body)          val ein0=mkEin(params,index,body)
237          in          in
238              (ein0,sizes,dx)              (splitvar,ein0,sizes,dx,alpha')
239          end          end
240    
241      fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
242            val _=testp["\n******* Lift ******** \n"]
243          val originalb=Ein.body e          val originalb=Ein.body e
244          val params=Ein.params e          val params=Ein.params e
245          val index=Ein.index e          val index=Ein.index e
246            val _=  toStringBind (y, DstIL.EINAPP(e,args))
247    
248          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
249          val fid=length(params)          val fid=length(params)
# Line 249  Line 252 
252          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
253          val freshIndex=getsumshift(sx,index)          val freshIndex=getsumshift(sx,index)
254    
   
255          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
256          val (ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
257          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
258          val einApp0=mkEinApp(ein0,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
259          val rtn0=(y,einApp0)          val rtn0=(case splitvar
260                of false => [(y,einApp0)]
261                | _      => Split.splitEinApp (y,einApp0)
262                (*end case*))
263    
264          (*lifted probe*)          (*lifted probe*)
265          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
266          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
267          val ein1=mkEin(params',sizes,body')          val ein1=mkEin(params',sizes,body')
268          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
269          val rtn1=(FArg,einApp1)          val rtn1=(FArg,einApp1)
270          val rtn=code@[rtn1,rtn0]          val rtn=code@[rtn1]@rtn0
271          val _= einapptostring (p,rtn1,rtn0)          val _= List.map toStringBind ([rtn1]@rtn0)
272    
273          in          in
274              rtn              rtn
275          end          end
276    
   
277      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
278      *A this point we only have simple ein ops      *A this point we only have simple ein ops
279      *Looks to see if the expression has a probe. If so, replaces it.      *Looks to see if the expression has a probe. If so, replaces it.
280      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
281      *)      *)
282      fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
283    
284          fun checkConst ([],a) = liftProbe a          fun checkConst ([],a) = liftProbe a
285          | checkConst ((E.C _::_),a) =replaceProbe a          | checkConst ((E.C _::_),a) =replaceProbe a
286          | checkConst ((_ ::es),a)=checkConst(es,a)          | checkConst ((_ ::es),a)=checkConst(es,a)
287    
288            fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=
289                let
290    
291                    val _= toStringBind e
292                    val index0=Ein.index ein
293                    val index1 = index0@[3]
294                    val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)
295                     (* clean to get body indices in order *)
296                    val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
297                    val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
298    
299                    val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
300                    val ein1 = mkEin(Ein.params ein,index1,body1)
301                    val code1= (lhs1,mkEinApp(ein1,args))
302                    val codeAll= (case dx
303                        of []=> replaceProbe(1,code1,body1,[])
304                        | _ =>liftProbe(1,code1,body1,[])
305                    (*end case*))
306    
307                    (*Probe that tensor at a constant position E.C c1*)
308                    val param0 = [E.TEN(1,index1)]
309                    val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
310                    val body0 =  E.Tensor(0,[E.C c1]@nx)
311                    val ein0 = mkEin(param0,index0,body0)
312                    val einApp0 = mkEinApp(ein0,[lhs1])
313                    val code0 = (y,einApp0)
314                    val _= toStringBind code0
315                in
316                    codeAll@[code0]
317                end
318    
319          fun rewriteBody b=(case b          fun rewriteBody b=(case b
320              of E.Probe(E.Conv(_,_,_,[]),_)              of E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos)
321                  => replaceProbe(1,e,b, [])                  => liftFieldMat (1,b)
322                | E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos)
323                    => liftFieldMat (2,b)
324                | E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos)
325                    => liftFieldMat (3,b)
326                | E.Probe(E.Conv(_,_,_,[]),_)
327                    => replaceProbe(0,e,b,[])
328              | E.Probe(E.Conv (_,alpha,_,dx),_)              | E.Probe(E.Conv (_,alpha,_,dx),_)
329                  => checkConst(alpha@dx,(0,e,b,[]))                  => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
330              | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))              | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
331                  => replaceProbe(1,e,p, sx)                  => replaceProbe(0,e,p, sx)  (*no dx*)
332              | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))              | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))
333                  => checkConst(dx,(0,e,p,sx))                  => checkConst(dx,(0,e,p,sx)) (*scalar field*)
334              | E.Sum(sx,E.Probe p)              | E.Sum(sx,E.Probe p)
335                  => replaceProbe(1,e,E.Probe p, sx)                  => replaceProbe(0,e,E.Probe p, sx)
336              | E.Sum(sx,E.Prod[eps,E.Probe p])              | E.Sum(sx,E.Prod[eps,E.Probe p])
337                  => replaceProbe(1,e,E.Probe p,sx)                  => replaceProbe(0,e,E.Probe p,sx)
338              | _ => [e]              | _ => [e]
339              (* end case *))              (* end case *))
340          in  
341              rewriteBody (Ein.body ein)          val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
342    
343            fun matchField b=(case b
344                of E.Probe _ => 1
345                | E.Sum (_, E.Probe _)=>1
346                | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
347                | _ =>0
348                (*end case*))
349    
350            in  (case var
351                of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
352                | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
353                (*end case*))
354          end          end
355    
356    end; (* local *)    end; (* local *)

Legend:
Removed from v.3092  
changed lines
  Added in v.3278

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0