Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 3190, Sat Apr 4 00:18:24 2015 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3325, Tue Oct 20 17:04:54 2015 UTC
# Line 38  Line 38 
38      *)      *)
39    
40      val testing=0      val testing=0
41      val testlift=0      val testlift=1
42        val detflag =true
43        val fieldliftflag=true
44        val valnumflag=true
45    
     val cnt = ref 0  
46    
47      fun printEINAPP e=MidToString.printEINAPP e      val cnt = ref 0
48      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
49      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
50       fun printEINAPP e=MidToString.printEINAPP e      fun toStringBind e=(MidToString.toStringBind e)
51        fun mkEin e=Ein.mkEin e
     fun transitionToString(testreplace,a,b)=(case testreplace  
         of 0=> 1  
         | 2 => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)  
         |_ =>(print(String.concat["\nReplaced:",P.printbody a]);1)  
         (*end case*))  
     fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}  
52      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
     fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body  
     fun setBody(body',(y,DstIL.EINAPP(E.EIN{params,index,body},args)))=  
             (y,DstIL.EINAPP(E.EIN{params=params,index=index,body=body'},args))  
53    
54      fun testp n=(case testing      fun testp n=(case testing
55          of 0=> 1          of 0=> 1
56          | _ =>(print(String.concat n);1)          | _ =>(print(String.concat n);1)
57          (*end case*))          (*end case*))
     fun  einapptostring (body,a,b)=(case testlift  
         of 0=>1  
         | 2=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)  
         |_ =>(print(String.concat["\nLifted",P.printbody body]);1)  
         (*end case*))  
58    
59    
60      fun getRHSDst x  = (case DstIL.Var.binding x      fun getRHSDst x  = (case DstIL.Var.binding x
# Line 170  Line 158 
158    
159      (* silly change in order of the product to match vis branch WorldtoSpace functions*)      (* silly change in order of the product to match vis branch WorldtoSpace functions*)
160      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))      fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
161          | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, E.Prod([P0,body,P1])))
162        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))        | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
163    
164    
165        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
166          | multiMergePs e=multiPs e
167    
168    
169      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
170              -> ein_exp* *code              -> ein_exp* *code
171      * Transforms position to world space      * Transforms position to world space
# Line 186  Line 180 
180          val originalb=Ein.body e          val originalb=Ein.body e
181          val params=Ein.params e          val params=Ein.params e
182          val index=Ein.index e          val index=Ein.index e
183          val _= ("\n"^P.printbody originalb)          val _ = testp["\n***************** \n Replace ************ \n"]
184            val _=  toStringBind (y, DstIL.EINAPP(e,args))
185    
186          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
187          val fid=length(params)          val fid=length(params)
# Line 205  Line 200 
200              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
201              | _                                  => body'              | _                                  => body'
202              (*end case*))              (*end case*))
203          val _=transitionToString(testN,originalb,body')  
204    
205          val args'=argsA@[PArg]          val args'=argsA@[PArg]
206          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
# Line 213  Line 208 
208              code@[einapp]              code@[einapp]
209          end          end
210    
211        val tsplitvar=true
212      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
213          val Pid=0          val Pid=0
214          val tid=1          val tid=1
215    
216            (*Assumes body is already clean*)
217          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
218    
219          (*need to rewrite dx*)          (*need to rewrite dx*)
220          val (_,sizes,E.Conv(_,_,_,dx))=(case sx@newsx          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
221              of []=> ([],index,E.Conv(9,alpha,7,newdx))              of []=> ([],index,E.Conv(9,alpha,7,newdx))
222              | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)              | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
223              (*end case*))              (*end case*))
224    
225          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
226          val tshape=alpha@newdx          fun filterAlpha []=[]
227              | filterAlpha(E.C _::es)= filterAlpha es
228              | filterAlpha(e1::es)=[e1]@(filterAlpha es)
229    
230            val tshape=filterAlpha(alpha')@newdx
231          val t=E.Tensor(tid,tshape)          val t=E.Tensor(tid,tshape)
232          val exp = multiPs(Ps,newsx,t)          val (splitvar,body)=(case originalb
233          val body=(case originalb              of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
234              of E.Sum(sx, E.Probe _)              => E.Sum(sx,exp)              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
235              | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,exp])              | _                                  => (case tsplitvar
236              | _                                  => exp                  of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
237                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
238                    (*end case*))
239                (*end case*))
240    
241            val _ =(case splitvar
242            of true=> (String.concat["splitvar is true", P.printbody body])
243            | _ => (String.concat["splitvar is false",P.printbody body])
244              (*end case*))              (*end case*))
245    
246    
247          val ein0=mkEin(params,index,body)          val ein0=mkEin(params,index,body)
248          in          in
249              (ein0,sizes,dx)              (splitvar,ein0,sizes,dx,alpha')
250          end          end
251    
252      fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
253            val _=testp["\n******* Lift ******** \n"]
254          val originalb=Ein.body e          val originalb=Ein.body e
255          val params=Ein.params e          val params=Ein.params e
256          val index=Ein.index e          val index=Ein.index e
257          val _=("\n"^P.printbody originalb)          val _=  toStringBind (y, DstIL.EINAPP(e,args))
258    
259          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
260          val fid=length(params)          val fid=length(params)
# Line 254  Line 263 
263          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
264          val freshIndex=getsumshift(sx,index)          val freshIndex=getsumshift(sx,index)
265    
   
266          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
267          val (ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)          val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
268          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
269          val einApp0=mkEinApp(ein0,[PArg,FArg])          val einApp0=mkEinApp(ein0,[PArg,FArg])
270          val rtn0=(y,einApp0)          val rtn0=(case splitvar
271                of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
272                | _      => let
273                     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
274                     in Split.splitEinApp bind3
275                     end
276                (*end case*))
277    
278          (*lifted probe*)          (*lifted probe*)
279          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
280          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
281          val ein1=mkEin(params',sizes,body')          val ein1=mkEin(params',sizes,body')
282          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
283          val rtn1=(FArg,einApp1)          val rtn1=(FArg,einApp1)
284          val rtn=code@[rtn1,rtn0]          val rtn=code@[rtn1]@rtn0
285          val _= einapptostring (originalb,rtn1,rtn0)          val _= List.map toStringBind ([rtn1]@rtn0)
286    
287          in          in
288              rtn              rtn
289          end          end
290    
291    
292        fun liftFieldMat(newvx,e)=
293            let
294                val (y, DstIL.EINAPP(ein,args))=e
295                val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
296                val index0=Ein.index ein
297                val index1 = index0@[3]
298                val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
299                (* clean to get body indices in order *)
300                val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
301                val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
302    
303                val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
304                val ein1 = mkEin(Ein.params ein,index1,body1)
305                val code1= (lhs1,mkEinApp(ein1,args))
306                val codeAll= (case dx
307                of []=> replaceProbe(1,code1,body1,[])
308                | _ =>liftProbe(1,code1,body1,[])
309                (*end case*))
310    
311                (*Probe that tensor at a constant position  c1*)
312                val param0 = [E.TEN(1,index1)]
313                val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
314                val body0 =  E.Tensor(0,[c1]@nx)
315                val ein0 = mkEin(param0,index0,body0)
316                val einApp0 = mkEinApp(ein0,[lhs1])
317                val code0 = (y,einApp0)
318                val _= toStringBind code0
319            in
320                codeAll@[code0]
321        end
322    
323        fun liftFieldSum e =
324        let
325            val _=print"\n*************************************\n"
326            val (y, DstIL.EINAPP(ein,args))=e
327            val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
328            val index0=Ein.index ein
329            val index1 = index0@[3]@[3]
330            val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
331            val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
332    
333    
334            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
335            val ein1 = mkEin(Ein.params ein,index1,body1)
336            val code1= (lhs1,mkEinApp(ein1,args))
337            val codeAll= (case dx
338            of []=> replaceProbe(1,code1,body1,[])
339            | _ =>liftProbe(1,code1,body1,[])
340            (*end case*))
341    
342            (*Probe that tensor at a constant position  c1*)
343            val param0 = [E.TEN(1,index1)]
344            val nx=List.tabulate(length(dx),fn n=>E.V n)
345            val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
346            val ein0 = mkEin(param0,index0,body0)
347            val einApp0 = mkEinApp(ein0,[lhs1])
348            val code0 = (y,einApp0)
349            val _= toStringBind  e
350            val _ =toStringBind code0
351           val _ = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])
352           val _ =(String.concat(List.map toStringBind (codeAll@[code0])))
353                   val _=print"\n*************************************\n"
354            in
355            codeAll@[code0]
356        end
357    
358    
359      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
360      *A this point we only have simple ein ops      *A this point we only have simple ein ops
361      *Looks to see if the expression has a probe. If so, replaces it.      *Looks to see if the expression has a probe. If so, replaces it.
362      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
363      *)      *)
 (*    fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let*)  
364              fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let              fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
365    
366            fun checkConst ([],a) =
367          fun checkConst ([],a) = liftProbe a              (case fieldliftflag
368          | checkConst ((E.C _::_),a) =(("\n \n constant field"^(printEINAPP e));replaceProbe a)                  of true => liftProbe a
369                    | _ => replaceProbe a
370                (*end case*))
371            | checkConst ((E.C _::_),a) = replaceProbe a
372          | checkConst ((_ ::es),a)= checkConst(es,a)          | checkConst ((_ ::es),a)= checkConst(es,a)
373          fun rewriteBody b=(case b  
374              of E.Probe(E.Conv(_,_,_,[]),_)          fun rewriteBody b=(case (detflag,b)
375                of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
376                    => liftFieldMat (1,e)
377                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
378                    => liftFieldMat (2,e)
379                | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
380                    => liftFieldMat (3,e)
381                | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))
382                    => liftFieldSum e
383                | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))
384                    => liftFieldSum e
385                | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))
386                    => liftFieldSum e
387    
388    
389                | (_,E.Probe(E.Conv(_,_,_,[]),_))
390                  => replaceProbe(0,e,b, [])                  => replaceProbe(0,e,b, [])
391              | E.Probe(E.Conv (_,alpha,_,dx),_)              | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
392                  => checkConst(alpha@dx,(0,e,b,[]))                  => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
393              | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))              | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
394                  => replaceProbe(0,e,p, sx)                  => replaceProbe(0,e,p, sx)  (*no dx*)
395              | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))              | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
396                  => checkConst(dx,(0,e,p,sx))                  => checkConst(dx,(0,e,p,sx)) (*scalar field*)
397              (*| E.Sum(sx as [(v,_,_)],p as (E.Probe((E.Conv(_,alpha,_,dx),_))))=>(case              | (_,E.Sum(sx,E.Probe p))
                     (List.find (fn x => x = v) dx)  
                     of NONE=>  checkConst(alpha@dx,(1,e,p,sx))  
                     (*need to push summation to lifted exp rather than transform exp.*)  
                     | SOME _=> replaceProbe(1,e, p, sx)  
                     (*end case*))*)  
             | E.Sum(sx,E.Probe p)  
398                  => replaceProbe(0,e,E.Probe p, sx)                  => replaceProbe(0,e,E.Probe p, sx)
399              | E.Sum(sx,E.Prod[eps,E.Probe p])              | (_,E.Sum(sx,E.Prod[eps,E.Probe p]))
400                  => replaceProbe(0,e,E.Probe p,sx)                  => replaceProbe(0,e,E.Probe p,sx)
401              | _ => [e]              | (_,_) => [e]
402              (* end case *))              (* end case *))
403    
404          val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))          val (fieldset,var) = (case valnumflag
405          in  (case var              of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
406          of NONE=> (("\n \n not replacing"^(printEINAPP e));(rewriteBody(Ein.body ein),fieldset))              | _     => (fieldset,NONE)
407          | SOME v=> (("\n replacing"^(P.printerE ein));( [(y,DstIL.VAR v)] , fieldset))          (*end case*))
408    
409            fun matchField b=(case b
410                of E.Probe _ => 1
411                | E.Sum (_, E.Probe _)=>1
412                | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
413                | _ =>0
414              (*end case*))              (*end case*))
415    
416          (*val code=rewriteBody(Ein.body ein)          in  (case var
417          in (code,fieldset)*)              of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
418                | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
419                (*end case*))
420          end          end
421    
422    end; (* local *)    end; (* local *)

Legend:
Removed from v.3190  
changed lines
  Added in v.3325

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0