Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 3092, Tue Mar 17 20:02:38 2015 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3311, Fri Oct 16 20:09:14 2015 UTC
# Line 38  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41      val testlift=1      val testlift=0
42      val cnt = ref 0      val detflag =false
43    
44    
45      fun printEINAPP e=MidToString.printEINAPP e      val cnt = ref 0
46      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
47      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
48        fun toStringBind e=(MidToString.toStringBind e)
49      fun transitionToString(testreplace,a,b)=(case testreplace      fun mkEin e=Ein.mkEin e
         of 0=> 1  
         | _ => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)  
         (*end case*))  
     fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}  
50      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
     fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body  
     fun setBody(body',(y,DstIL.EINAPP(E.EIN{params,index,body},args)))=  
             (y,DstIL.EINAPP(E.EIN{params=params,index=index,body=body'},args))  
51    
52      fun testp n=(case testing      fun testp n=(case testing
53          of 0=> 1          of 0=> 1
54          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
55          (*end case*))          (*end case*))
     fun  einapptostring (body,a,b)=(case testlift  
         of 0=>1  
         | _=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)  
         (*end case*))  
56    
57    
58      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
# Line 168  Line 158 
158      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
159        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
160    
161    
162        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
163          | multiMergePs e=multiPs e
164    
165    
166      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
167              -> ein_exp* *code              -> ein_exp* *code
168      * Transforms position to world space      * Transforms position to world space
# Line 182  Line 177 
177          val originalb=Ein.body e          val originalb=Ein.body e
178          val params=Ein.params e          val params=Ein.params e
179          val index=Ein.index e          val index=Ein.index e
180            val _ = testp["\n***************** \n Replace ************ \n"]
181            val _=  toStringBind (y, DstIL.EINAPP(e,args))
182    
183          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
184          val fid=length(params)          val fid=length(params)
# Line 201  Line 197 
197              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
198              | _                                  => body'              | _                                  => body'
199              (*end case*))              (*end case*))
200          val _=transitionToString(testN,originalb,body')  
201    
202          val args'=argsA@[PArg]          val args'=argsA@[PArg]
203          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
# Line 209  Line 205 
205              code@[einapp]              code@[einapp]
206          end          end
207    
208        val tsplitvar=true
209      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
210          val Pid=0          val Pid=0
211          val tid=1          val tid=1
212    
213            (*Assumes body is already clean*)
214          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
215    
216          (*need to rewrite dx*)          (*need to rewrite dx*)
217          val (_,sizes,E.Conv(_,_,_,dx))=(case sx@newsx          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
218              of []=> ([],index,E.Conv(9,alpha,7,newdx))              of []=> ([],index,E.Conv(9,alpha,7,newdx))
219              | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)              | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
220              (*end case*))              (*end case*))
221    
222          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
223          val tshape=alpha@newdx          fun filterAlpha []=[]
224              | filterAlpha(E.C _::es)= filterAlpha es
225              | filterAlpha(e1::es)=[e1]@(filterAlpha es)
226    
227            val tshape=filterAlpha(alpha')@newdx
228          val t=E.Tensor(tid,tshape)          val t=E.Tensor(tid,tshape)
229          val exp = multiPs(Ps,newsx,t)          val (splitvar,body)=(case originalb
230          val body=(case originalb              of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
231              of E.Sum(sx, E.Probe _)              => E.Sum(sx,exp)              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
232              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,exp])              | _                                  => (case tsplitvar
233              | _                                  => exp                  of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
234                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
235                    (*end case*))
236              (*end case*))              (*end case*))
237    
238          val ein0=mkEin(params,index,body)          val ein0=mkEin(params,index,body)
239          in          in
240              (ein0,sizes,dx)              (splitvar,ein0,sizes,dx,alpha')
241          end          end
242    
243      fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
244            val _=testp["\n******* Lift ******** \n"]
245          val originalb=Ein.body e          val originalb=Ein.body e
246          val params=Ein.params e          val params=Ein.params e
247          val index=Ein.index e          val index=Ein.index e
248            val _=  toStringBind (y, DstIL.EINAPP(e,args))
249    
250          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
251          val fid=length(params)          val fid=length(params)
# Line 249  Line 254 
254          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
255          val freshIndex=getsumshift(sx,index)          val freshIndex=getsumshift(sx,index)
256    
   
257          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
258          val (ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
259          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
260          val einApp0=mkEinApp(ein0,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
261          val rtn0=(y,einApp0)          val rtn0=(case splitvar
262                of false => [(y,einApp0)]
263                | _      => Split.splitEinApp (y,einApp0)
264                (*end case*))
265    
266          (*lifted probe*)          (*lifted probe*)
267          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
268          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
269          val ein1=mkEin(params',sizes,body')          val ein1=mkEin(params',sizes,body')
270          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
271          val rtn1=(FArg,einApp1)          val rtn1=(FArg,einApp1)
272          val rtn=code@[rtn1,rtn0]          val rtn=code@[rtn1]@rtn0
273          val _= einapptostring (p,rtn1,rtn0)          val _= List.map toStringBind ([rtn1]@rtn0)
274    
275          in          in
276              rtn              rtn
277          end          end
278    
   
279      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
280      *A this point we only have simple ein ops      *A this point we only have simple ein ops
281      *Looks to see if the expression has a probe. If so, replaces it.      *Looks to see if the expression has a probe. If so, replaces it.
282      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
283      *)      *)
284      fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
285    
286          fun checkConst ([],a) = liftProbe a          fun checkConst ([],a) = liftProbe a
287          | checkConst ((E.C _::_),a) =replaceProbe a          | checkConst ((E.C _::_),a) =replaceProbe a
288          | checkConst ((_ ::es),a)=checkConst(es,a)          | checkConst ((_ ::es),a)=checkConst(es,a)
289          fun rewriteBody b=(case b  
290              of E.Probe(E.Conv(_,_,_,[]),_)          fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=
291                  => replaceProbe(1,e,b, [])              let
292              | E.Probe(E.Conv (_,alpha,_,dx),_)  
293                  => checkConst(alpha@dx,(0,e,b,[]))                  val _= toStringBind e
294              | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))                  val index0=Ein.index ein
295                  => replaceProbe(1,e,p, sx)                  val index1 = index0@[3]
296              | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))                  val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)
297                  => checkConst(dx,(0,e,p,sx))                   (* clean to get body indices in order *)
298              | E.Sum(sx,E.Probe p)                  val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
299                  => replaceProbe(1,e,E.Probe p, sx)                  val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
300              | E.Sum(sx,E.Prod[eps,E.Probe p])  
301                  => replaceProbe(1,e,E.Probe p,sx)                  val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
302              | _ => [e]                  val ein1 = mkEin(Ein.params ein,index1,body1)
303              (* end case *))                  val code1= (lhs1,mkEinApp(ein1,args))
304                    val codeAll= (case dx
305                        of []=> replaceProbe(1,code1,body1,[])
306                        | _ =>liftProbe(1,code1,body1,[])
307                    (*end case*))
308    
309                    (*Probe that tensor at a constant position E.C c1*)
310                    val param0 = [E.TEN(1,index1)]
311                    val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
312                    val body0 =  E.Tensor(0,[E.C c1]@nx)
313                    val ein0 = mkEin(param0,index0,body0)
314                    val einApp0 = mkEinApp(ein0,[lhs1])
315                    val code0 = (y,einApp0)
316                    val _= toStringBind code0
317          in          in
318              rewriteBody (Ein.body ein)                  codeAll@[code0]
319                end
320    
321            fun rewriteBody b=(case (detflag,b)
322                of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
323                    => liftFieldMat (1,b)
324                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
325                    => liftFieldMat (2,b)
326                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
327                    => liftFieldMat (3,b)
328                | (_,E.Probe(E.Conv(_,_,_,[]),_))
329                    => replaceProbe(0,e,b,[])
330                | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
331                    => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
332                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
333                    => replaceProbe(0,e,p, sx)  (*no dx*)
334                | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
335                    => checkConst(dx,(0,e,p,sx)) (*scalar field*)
336                | (_,E.Sum(sx,E.Probe p))
337                    => replaceProbe(0,e,E.Probe p, sx)
338                | (_,E.Sum(sx,E.Prod[eps,E.Probe p]))
339                    => replaceProbe(0,e,E.Probe p,sx)
340                | (_,_) => [e]
341                (* end case *))
342    
343            val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
344    
345            fun matchField b=(case b
346                of E.Probe _ => 1
347                | E.Sum (_, E.Probe _)=>1
348                | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
349                | _ =>0
350                (*end case*))
351    
352            in  (case var
353                of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
354                | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
355                (*end case*))
356          end          end
357    
358    end; (* local *)    end; (* local *)

Legend:
Removed from v.3092  
changed lines
  Added in v.3311

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0