Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3092, Tue Mar 17 20:02:38 2015 UTC revision 3195, Wed May 20 21:02:12 2015 UTC
# Line 38  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41      val testlift=1      val testlift=0
42    
43      val cnt = ref 0      val cnt = ref 0
44    
45      fun printEINAPP e=MidToString.printEINAPP e      fun printEINAPP e=MidToString.printEINAPP e
46      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
47      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
48         fun printEINAPP e=MidToString.printEINAPP e
49    
50      fun transitionToString(testreplace,a,b)=(case testreplace      fun transitionToString(testreplace,a,b)=(case testreplace
51          of 0=> 1          of 0=> 1
52          | _ => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)          | 2 => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)
53            |_ =>(print(String.concat["\nReplaced:",P.printbody a]);1)
54          (*end case*))          (*end case*))
55      fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}      fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}
56      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
# Line 61  Line 64 
64          (*end case*))          (*end case*))
65      fun  einapptostring (body,a,b)=(case testlift      fun  einapptostring (body,a,b)=(case testlift
66          of 0=>1          of 0=>1
67          | _=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)          | 2=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)
68            |_ =>(print(String.concat["\nLifted",P.printbody body]);1)
69          (*end case*))          (*end case*))
70    
71    
# Line 168  Line 172 
172      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
173        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
174    
175    
176        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
177          | multiMergePs e=multiPs e
178    
179    
180      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
181              -> ein_exp* *code              -> ein_exp* *code
182      * Transforms position to world space      * Transforms position to world space
# Line 182  Line 191 
191          val originalb=Ein.body e          val originalb=Ein.body e
192          val params=Ein.params e          val params=Ein.params e
193          val index=Ein.index e          val index=Ein.index e
194            val _= ("\n"^P.printbody originalb)
195    
196          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
197          val fid=length(params)          val fid=length(params)
# Line 209  Line 218 
218              code@[einapp]              code@[einapp]
219          end          end
220    
221        val tsplitvar=true
222      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
223          val Pid=0          val Pid=0
224          val tid=1          val tid=1
# Line 225  Line 234 
234          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
235          val tshape=alpha@newdx          val tshape=alpha@newdx
236          val t=E.Tensor(tid,tshape)          val t=E.Tensor(tid,tshape)
237          val exp = multiPs(Ps,newsx,t)          val (splitvar,body)=(case originalb
238          val body=(case originalb              of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
239              of E.Sum(sx, E.Probe _)              => E.Sum(sx,exp)              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
240              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,exp])              | _                                  => (case tsplitvar
241              | _                                  => exp                  of false => (false,multiPs(Ps,newsx,t))
242                    | true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
243                    (*end case*))
244              (*end case*))              (*end case*))
245    
246          val ein0=mkEin(params,index,body)          val ein0=mkEin(params,index,body)
247          in          in
248              (ein0,sizes,dx)              (splitvar,ein0,sizes,dx)
249          end          end
250    
251      fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let
252          val originalb=Ein.body e          val originalb=Ein.body e
253          val params=Ein.params e          val params=Ein.params e
254          val index=Ein.index e          val index=Ein.index e
255            val _=("\n"^P.printbody originalb)
256    
257          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
258          val fid=length(params)          val fid=length(params)
# Line 251  Line 263 
263    
264    
265          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
266          val (ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)          val (splitvar,ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
267          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
268          val einApp0=mkEinApp(ein0,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
269          val rtn0=(y,einApp0)          val rtn0=(case splitvar
270                of false => [(y,einApp0)]
271                | _      => Split.splitEinApp (y,einApp0)
272                (*end case*))
273    
274          (*lifted probe*)          (*lifted probe*)
275          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
# Line 262  Line 277 
277          val ein1=mkEin(params',sizes,body')          val ein1=mkEin(params',sizes,body')
278          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
279          val rtn1=(FArg,einApp1)          val rtn1=(FArg,einApp1)
280          val rtn=code@[rtn1,rtn0]          val rtn=code@[rtn1]@rtn0
281          val _= einapptostring (p,rtn1,rtn0)          (*val _= einapptostring (originalb,rtn1,rtn0)*)
282          in          in
283              rtn              rtn
284          end          end
# Line 274  Line 289 
289      *Looks to see if the expression has a probe. If so, replaces it.      *Looks to see if the expression has a probe. If so, replaces it.
290      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
291      *)      *)
292      fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let  (*    fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let*)
293                fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
294    
295    
296          fun checkConst ([],a) = liftProbe a          fun checkConst ([],a) = liftProbe a
297          | checkConst ((E.C _::_),a) =replaceProbe a          | checkConst ((E.C _::_),a) =(("\n \n constant field"^(printEINAPP e));replaceProbe a)
298          | checkConst ((_ ::es),a)=checkConst(es,a)          | checkConst ((_ ::es),a)=checkConst(es,a)
299          fun rewriteBody b=(case b          fun rewriteBody b=(case b
300              of E.Probe(E.Conv(_,_,_,[]),_)              of E.Probe(E.Conv(_,_,_,[]),_)
301                  => replaceProbe(1,e,b, [])                  => replaceProbe(0,e,b, [])
302              | E.Probe(E.Conv (_,alpha,_,dx),_)              | E.Probe(E.Conv (_,alpha,_,dx),_)
303                  => checkConst(alpha@dx,(0,e,b,[]))                  => checkConst(alpha@dx,(0,e,b,[]))
304              | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))              | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
305                  => replaceProbe(1,e,p, sx)                  => replaceProbe(0,e,p, sx)
306              | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))              | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))
307                  => checkConst(dx,(0,e,p,sx))                  => checkConst(dx,(0,e,p,sx))
308                (*| E.Sum(sx as [(v,_,_)],p as (E.Probe((E.Conv(_,alpha,_,dx),_))))=>(case
309                        (List.find (fn x => x = v) dx)
310                        of NONE=>  checkConst(alpha@dx,(1,e,p,sx))
311                        (*need to push summation to lifted exp rather than transform exp.*)
312                        | SOME _=> replaceProbe(1,e, p, sx)
313                        (*end case*))*)
314              | E.Sum(sx,E.Probe p)              | E.Sum(sx,E.Probe p)
315                  => replaceProbe(1,e,E.Probe p, sx)                  => replaceProbe(0,e,E.Probe p, sx)
316              | E.Sum(sx,E.Prod[eps,E.Probe p])              | E.Sum(sx,E.Prod[eps,E.Probe p])
317                  => replaceProbe(1,e,E.Probe p,sx)                  => replaceProbe(0,e,E.Probe p,sx)
318              | _ => [e]              | _ => [e]
319              (* end case *))              (* end case *))
320          in  
321              rewriteBody (Ein.body ein)          val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
322            in  (case var
323            of NONE=> (("\n \n not replacing"^(printEINAPP e));(rewriteBody(Ein.body ein),fieldset))
324            | SOME v=> (("\n replacing"^(P.printerE ein));( [(y,DstIL.VAR v)] , fieldset))
325                (*end case*))
326    
327            (*val code=rewriteBody(Ein.body ein)
328            in (code,fieldset)*)
329          end          end
330    
331    end; (* local *)    end; (* local *)

Legend:
Removed from v.3092  
changed lines
  Added in v.3195

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0