Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3459, Mon Nov 23 19:29:51 2015 UTC revision 3655, Thu Feb 4 04:12:40 2016 UTC
# Line 40  Line 40 
40      *)      *)
41    
42      val testing=0      val testing=0
   
   
   
43      val valnumflag=true      val valnumflag=true
44      val tsplitvar=true      val tsplitvar=true
45      val fieldliftflag=true      val fieldliftflag=true
46      val constflag=false      val constflag =  true
47      val detflag =true      val detflag =true
48      val detsumflag=true      val detsumflag=true
49        fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
50        fun decUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
51    
52        val liftimgflag =false
53        val pullKrn= false
54    
55      val cnt = ref 0      val cnt = ref 0
56      fun transformToIndexSpace e=T.transformToIndexSpace e      fun transformToIndexSpace e=T.transformToIndexSpace e
57      fun transformToImgSpace  e=T.transformToImgSpace  e      fun transformToImgSpace  e=T.transformToImgSpace  e
58        fun transformToImgSpaceF  e=T.transformToImgSpaceF  e
59      fun toStringBind e=(MidToString.toStringBind e)      fun toStringBind e=(MidToString.toStringBind e)
60        fun toStringBindp e=(MidToString.toStringBind e)
61      fun mkEin e=Ein.mkEin e      fun mkEin e=Ein.mkEin e
62      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)      fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
63      fun setConst e = E.setConst e      fun setConst e = E.setConst e
# Line 64  Line 67 
67      fun setSub e= E.setSub e      fun setSub e= E.setSub e
68      fun setProd e= E.setProd e      fun setProd e= E.setProd e
69      fun setAdd e= E.setAdd e      fun setAdd e= E.setAdd e
70        fun mkCx es =List.map (fn c => E.C (c,true)) es
71        fun mkCxSingle c = E.C (c,true)
72    
73      fun testp n=(case testing      fun testp n=(case testing
74          of 0=> 1          of 0=> 1
75          | _ =>((String.concat n);1)          | _ =>(print(String.concat n);1)
76          (*end case*))          (*end case*))
77    
78    
# Line 84  Line 89 
89      returns the support of ther kernel, and image      returns the support of ther kernel, and image
90      *)      *)
91      fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)      fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
92          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> ((Kernel.support h) ,img,ImageInfo.dim img)
             in  
               ((Kernel.support h) ,img,ImageInfo.dim img)  
             end  
93              |  ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])              |  ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
94          (*end case*))          (*end case*))
95    
# Line 109  Line 111 
111              (dim,args@argsT,code, s,P)              (dim,args@argsT,code, s,P)
112          end          end
113    
114        fun handleArgsF(fieldset,Vid,hid,tid,args)=let
115            val imgArg=List.nth(args,Vid)
116            val hArg=List.nth(args,hid)
117            val newposArg=List.nth(args,tid)
118            val (s,img,dim) =getArgsDst(hArg,imgArg,args)
119            val (fieldset,argsT,P,code)=transformToImgSpaceF(fieldset,dim,img,newposArg,imgArg)
120            in
121                (fieldset,dim,args@argsT,code, s,P)
122            end
123    
124    
125    
126    
127      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
128      * expands the body for the probed field      * expands the body for the probed field
129      *)      *)
# Line 116  Line 131 
131          (*1-d fields*)          (*1-d fields*)
132          fun createKRND1 ()=let          fun createKRND1 ()=let
133              val sum=sx              val sum=sx
134              val dels=List.map (fn e=>(E.C 0,e)) deltas              val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
135              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
136              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
137              in              in
138                 setProd[E.Img(Vid,alpha,pos),rest]                 setProd[E.Img(Vid,alpha,pos),rest]
139              end              end
140    
141            fun mkImg(imgpos)=E.Img(Vid,alpha,imgpos)
142    
143          (*createKRN Image field and kernels *)          (*createKRN Image field and kernels *)
144          fun createKRN(0,imgpos,rest)=setProd ([E.Img(Vid,alpha,imgpos)] @rest)          fun createKRN(0,imgpos,rest)=setProd ([mkImg(imgpos)] @rest)
145          | createKRN(dim,imgpos,rest)=let          | createKRN(dim,imgpos,rest)=let
146              val dim'=dim-1              val dim'=dim-1
147              val sum=sx+dim'              val sum=sx+dim'
148              val dels=List.map (fn e=>(E.C dim',e)) deltas              val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
149              val pos=[setAdd[E.Tensor(fid,[E.C dim']),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(sum)]]
150              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[E.C dim']),E.Value(sum)))              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[mkCxSingle dim']),E.Value(sum)))
151              in              in
152                  createKRN(dim',pos@imgpos,[rest']@rest)                  createKRN(dim',pos@imgpos,[rest']@rest)
153              end              end
# Line 145  Line 163 
163          E.Sum(esum, exp)          E.Sum(esum, exp)
164      end      end
165    
166    
167    
168        (* build position *)
169        fun buildPos (dir,dim,argsA,hid,nid,s) =let
170            val vA  = DstV.new ("kernel_pos", DstTy.TensorTy([]))
171            val p=[E.KRN,E.TEN(1,[dim])]
172            val pos=setSub(E.Tensor(1,[mkCxSingle dir]),E.Value(0))
173            val exp= E.BuildPos(s,pos)
174            (*val exp = E.Sum([(E.V 0, slb, s)],E.Krn(0,[],pos))*)
175            val a=[List.nth(argsA,hid),List.nth(argsA,nid)]
176            val A=(vA,mkEinApp(mkEin(p,[],exp),a))
177            in (vA,A) end
178    
179        (* apply differentiation *)
180        fun getKrn1Del(dx,dim,args,slb,s)= let
181            val n=Int.toString(dx)
182            val vA  = DstV.new ("kernel_del"^n, DstTy.TensorTy([]))
183            val p=[E.KRN,E.TEN(1,[dim])]
184            val exp = E.EvalKrn dx
185            val A = (vA,mkEinApp(mkEin(p,[],exp),args))
186            in (vA,A) end
187    
188        (*create holder expression*)
189        fun mkHolder(dim,args) =let
190            val n=List.length(args)
191            val vA  = DstV.new ("kernel_cons", DstTy.TensorTy([n]))
192            val p=[E.KRN,E.TEN(1,[dim])]
193            val A= (vA,mkEinApp(mkEin(p,[],E.Holder n),args))
194            in (vA,A) end
195    
196        (*lifted Kernel expressions*)
197        fun liftKrn(dx,dir,dim,argsA,hid,nid,slb,s)=let
198            val  (vA,A)=buildPos(dir,dim,argsA,hid,nid,s)
199            val args=[List.nth(argsA,hid),vA]
200            fun iter(0,vBs,Bs)=let
201                val (vA,A)=getKrn1Del(0,dim,args,slb,s)
202                in (vA::vBs,A::Bs) end
203              | iter (n,vBs,Bs)= let
204                val (vA,A)=getKrn1Del(n,dim,args,slb,s)
205                in iter(n-1,vA::vBs,A::Bs) end
206            val (vBs,Bs)=iter(length(dx),[],[])
207            val (vC,C)  =mkHolder(dim,vBs)
208            in (vC,(A::Bs)@[C]) end
209    
210    
211    
212        fun createBody2(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=let
213            (*1-d fields*)
214            val slb=1-s
215    
216    
217            (*making image*)
218            val tid=(case liftimgflag
219                of true => length(params)-1
220                | _     => length(params)-1
221            (*end case*))
222            fun mkImg imgpos =(case liftimgflag
223                of true=>(E.Tensor(Vid,alpha),SOME(E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim),slb,s))),E.Img(Vid,alpha,imgpos))))
224                |  _ =>let
225                    val imgpos= List.tabulate(dim,fn e=> setAdd[E.Tensor(fid,[mkCxSingle e]),E.Value(e+sx)])
226                    in (E.Img(Vid,alpha,imgpos),NONE) end
227            (*end case*))
228    
229            fun createKRND1 ()=let
230                val sum=sx
231                val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
232                val imgpos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
233                val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
234                val (talpha,iexp)= mkImg imgpos
235                in (setProd[talpha,rest],iexp,NONE,NONE)end
236    
237            (*createKRN Image field and kernels *)
238            fun createKRN(0,orig,imgpos,vAs,krnpos)= let
239                val (talpha,iexp)= mkImg imgpos
240                in  (setProd ([talpha]@orig),iexp,SOME vAs,SOME krnpos) end
241            | createKRN(d,orig,imgpos,vAs,krnpos)=let
242                val dim'=d-1
243                val sum=sx+dim'
244                val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
245                val ipos=setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(dim')]
246                val opos= E.Krn(hid,dels,E.Tensor(tid+d,[]))
247                val (vA,A)= liftKrn(dels,dim',dim,argsA,hid,nid,slb,s)
248                in
249                    createKRN(dim',[opos]@orig,[ipos]@imgpos,[vA]@vAs,A@krnpos)
250                end
251    
252            val (oexp,iexp,vAs,keinapp)=(case dim
253                of 1 => createKRND1()
254                | _=> createKRN(dim, [],[],[],[])
255            (*end case*))
256    
257            val oexp=E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s))), oexp)
258            in (oexp,iexp,vAs,keinapp) end
259    
260        fun createBody3(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)=
261                createBody2(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)
262        |   createBody3(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=
263                (createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid),NONE,NONE,NONE)
264    
265      (*getsumshift:sum_indexid list* int list-> int      (*getsumshift:sum_indexid list* int list-> int
266      *get fresh/unused index_id, returns int      *get fresh/unused index_id, returns int
267      *)      *)
# Line 185  Line 302 
302      fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])      fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
303        | multiMergePs e=multiPs e        | multiMergePs e=multiPs e
304    
305        (* *******************************************  setImage *******************************************  *)
306        fun replaceImgA(es,vid,newbie)=List.take(es,vid)@[newbie]@List.drop(es,vid+1)
307        fun setImage(params',argsA,code,vexp2,index,alpha,paraminstant,Vid,s)=
308            (case vexp2
309                of NONE =>(params',argsA,code)
310                | SOME vexp  =>    let
311                    val iArg  = DstV.new ("Img", DstTy.TensorTy([]))
312                    val alphax=List.map (fn (E.V i)=>List.nth(index,i)) alpha
313                    val ieinapp=(iArg,mkEinApp(mkEin(paraminstant,alphax,vexp),argsA))
314                    (*
315                     val _ =print(String.concat["\n****\n Image (",Int.toString(length(argsA)),")"])
316                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
317                    val _ =print(String.concat["\n replace at ",Int.toString Vid ," with " , DstIL.Var.toString iArg ,"\n"])*)
318                    val argsA=replaceImgA(argsA,Vid,iArg)
319                    val params'=replaceImgA(params',Vid,E.TEN(2,[(s-(1-s)+1)*(s-(1-s)+1),(s-(1-s)+1)]))
320                    val code=code@[ieinapp]
321                    (*
322                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
323                    val _ =print(String.concat["\n****\n Image(",Int.toString(length(argsA)),")"])*)
324                    in (params',argsA,code) end
325            (*end case*))
326    
327          (*kernels*)
328          fun setKernel(params',args',code,vAs2,keinapp2,dim)=
329            (case (vAs2,keinapp2)
330                of (NONE,NONE)=> (params',args',code)
331                | (SOME vAs,SOME keinapp) => let
332                (*
333                    val _ =print"\n****\n Kernels\n"
334                    val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])
335                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))*)
336                    val args'=   args'@vAs
337                    val params'= params'@(List.tabulate(dim,fn _=> E.TEN(2,[])))
338                    val code=code@keinapp
339                    (*
340                    val _ =print"\n"
341                    val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))
342                    val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])*)
343                    in (params',args',code) end
344            (*end case*))
345    
346      (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list        fun setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)=let
347              -> ein_exp* *code          val (params',args',code)=setImage(params',args',code,vexp2,index,alpha,paraminstant,Vid,s)
348            in setKernel(params',args',code,vAs2,keinapp2,dim) end
349    
350    
351        (* *******************************************  Replace probe *******************************************  *)
352        (* replaceProbe
353      * Transforms position to world space      * Transforms position to world space
354      * transforms result back to index_space      * transforms result back to index_space
355      * rewrites body      * rewrites body
356      * replace probe with expanded version      * replace probe with expanded version
357      *)      *)
358  (*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)       fun replaceProbe0(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)
   
      fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)  
359          =let          =let
360          val originalb=Ein.body e          val originalb=Ein.body e
361          val params=Ein.params e          val params=Ein.params e
362          val index=Ein.index e          val index=Ein.index e
363          val _ = testp["\n***************** \n Replace ************ \n"]          val _ = (String.concat["\n***************** \n Replace ************ \n"])
364          val _=  toStringBind (y, DstIL.EINAPP(e,args))          val _=  toStringBindp (y, DstIL.EINAPP(e,args))
365    
366          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
367          val fid=length(params)          val fid=length(params)
# Line 221  Line 381 
381              | _                                  => body'              | _                                  => body'
382              (*end case*))              (*end case*))
383    
384            val args'=argsA@[PArg]
385            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
386            val _= List.map toStringBindp(code@[einapp])
387            in
388                (fieldset,code@[einapp])
389            end
390    
391    
392    
393        fun replaceProbe3(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx) = let
394            val originalb=Ein.body e
395            val params=Ein.params e
396            val index=Ein.index e
397            val _ = testp["\n***************** \n Replace ************ \n"]
398            val _=  toStringBind (y, DstIL.EINAPP(e,args))
399    
400            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
401            val fid=length(params)
402            val nid=fid+1
403            val Pid=nid+1
404            val nshift=length(dx)
405            val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
406            val freshIndex=getsumshift(sx,length(index))
407            val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
408    
409            val paraminstant=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
410            val params'=paraminstant@[E.TEN(1,[dim,dim])]
411    
412    
413    
414            val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid,paraminstant,argsA)
415            val body' = multiPs(Ps,newsx1,body')
416    
417            val body'=(case originalb
418                of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
419                | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
420                | _                                  => body'
421                (*end case*))
422    
423            (*images and kernels*)
424            val (params',argsA,code)=setImageKernel(params',argsA,code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)
425    
426    
427            (*replace term*)
428          val args'=argsA@[PArg]          val args'=argsA@[PArg]
429          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))          val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
430            val _= List.map toStringBindp(code@[einapp])
431          in          in
432              code@[einapp]              (fieldset,code@[einapp])
433          end          end
434    
435    
436        (* ******************************************* Lift probe *******************************************  *)
437      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let      fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
438          val Pid=0          val Pid=0
439          val tid=1          val tid=1
# Line 270  Line 475 
475              (splitvar,ein0,sizes,dx,alpha')              (splitvar,ein0,sizes,dx,alpha')
476          end          end
477    
478      fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let      fun liftProbe0(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)=let
479            val _=(String.concat["\n******* Lift Geneirc Probe ***\n"])
480            val originalb=Ein.body e
481            val params=Ein.params e
482            val index=Ein.index e
483            val _ =  (toStringBindp (y, DstIL.EINAPP(e,args)))
484    
485            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
486            val fid=length(params)
487            val nid=fid+1
488            val nshift=length(dx)
489            val (fieldset,dim,args',code,s,PArg) = handleArgsF(fieldset,Vid,hid,tid,args)
490            val freshIndex=getsumshift(sx,length(index))
491    
492            (*transform T*P*P..Ps*)
493            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
494            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
495    
496    
497    
498            (*addedhere*)
499            val ein9=mkEin(params,sizes,E.Conv(Vid,alpha',hid,dx))
500            val einApp9=mkEinApp(ein9,args)
501            val rtn9=(FArg,einApp9)
502            val (fieldset,FArg,rtn1)= (case (einVarSet.rtnVarN(fieldset,rtn9))
503                of  (fieldset,SOME v)  => let
504                        val _ = (" \n did find"^toStringBind(rtn9))
505                        in (fieldset, v,[]) end
506                | (fieldset,NONE)  => let
507                        (*lifted probe*)
508                        val _ =(" \n did not find"^toStringBind(rtn9))
509                        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
510                        val freshIndex'= length(sizes)
511                        val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
512                        val ein1=mkEin(params',sizes,body')
513                        val einApp1=mkEinApp(ein1,args')
514                        val rtn1=(FArg,einApp1)
515                        in (fieldset,FArg ,[rtn1]) end
516                (*end case*))
517    
518            val einApp0=mkEinApp(ein0,[PArg,FArg])
519            val rtn0=(case splitvar
520                of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
521                | _      => let
522                val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
523                    in Split.splitEinApp bind3
524            end
525            (*end case*))
526    
527            val rtn=code@rtn1@rtn0
528            val _= List.map toStringBindp (code@rtn1)
529            val _ ="\n**** split code **\n"
530            val _= List.map toStringBindp rtn0
531            in
532                (fieldset,rtn)
533            end
534    
535        fun liftProbe3(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)=let
536          val _=testp["\n******* Lift Geneirc Probe ***\n"]          val _=testp["\n******* Lift Geneirc Probe ***\n"]
537          val originalb=Ein.body e          val originalb=Ein.body e
538          val params=Ein.params e          val params=Ein.params e
# Line 301  Line 563 
563          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
564          val freshIndex'= length(sizes)          val freshIndex'= length(sizes)
565    
566          val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)          (*val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)*)
567    
568            val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid,params',args')
569    
570            (*set image and kernel*)
571            val (params',args',code)=setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,params',Vid,s)
572    
573          val ein1=mkEin(params',sizes,body')          val ein1=mkEin(params',sizes,body')
574          val einApp1=mkEinApp(ein1,args')          val einApp1=mkEinApp(ein1,args')
575          val rtn1=(FArg,einApp1)          val rtn1=(FArg,einApp1)
576          val rtn=code@[rtn1]@rtn0          val rtn=code@[rtn1]@rtn0
577          val _= List.map toStringBind ([rtn1]@rtn0)          val _= List.map toStringBind ([rtn1]@rtn0)
578           val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])           val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])
579              val _= List.map toStringBindp(rtn)
580          in          in
581              rtn              (fieldset,rtn)
582          end          end
583    
584      fun searchFullField (fieldset,code1,body1,dx)=let      fun replaceProbe e= (case pullKrn
585          val (lhs,_)=code1          of true=>replaceProbe3 e
586          fun continueReconstruction ()=let          | false => replaceProbe0 e
             val _=print"Tash:don't replaced"  
             in (case dx  
                 of []=> (lhs,replaceProbe(1,code1,body1,[]))  
                 | _ =>(lhs,liftProbe(1,code1,body1,[]))  
587                  (*end case*))                  (*end case*))
588               end      fun liftProbe e=(case pullKrn
589          in  (case valnumflag          of true=>liftProbe3 e
590              of false    => (fieldset,continueReconstruction())          | false => liftProbe0 e
             | true      => (case  (einSet.rtnVarN(fieldset,code1))  
                 of (fieldset,NONE)     => (fieldset,continueReconstruction())  
                  | (fieldset,SOME m)   =>(print"TASH:replaced"; (fieldset,(m,[])))  
591                  (*end case*))                  (*end case*))
592    
593    
594        (* ******************************************* Reconstruction -> Lift|Replace probe *******************************************  *)
595        (* scans dx for contant
596         * arg:(1,code1, body1,[])
597         *)
598        fun reconstruction([],arg)= replaceProbe arg
599         | reconstruction(dx,arg)=(case (constflag,fieldliftflag)
600            of (true,true) => liftProbe arg
601            | (_,false)    => replaceProbe arg
602            | _ => let
603                fun fConst [] = liftProbe arg
604                | fConst (E.C _::_) = replaceProbe arg
605                | fConst (_ ::es)= fConst es
606                in fConst dx end
607              (*end case*))              (*end case*))
         end  
608    
609      fun liftFieldMat(newvx,e)=      (* **************************************************** Index Tensor **************************************************** *)
610          let      (*Push constant indices to tensor replacement*)
611              val _=testp[ "\n ***************************** start FieldMat\n"]      fun getF (e,fieldset,dim,newvx)= let
612              val (y, DstIL.EINAPP(ein,args))=e              val (y, DstIL.EINAPP(ein,args))=e
             val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein  
613              val index0=Ein.index ein              val index0=Ein.index ein
614              val index1 = index0@[3]          val index1 = index0@dim
615            val b=Ein.body ein
616    
617            val (c1,dx,body1)=(case b
618                of  E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=>let
619                    val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
620                    val b=E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
621                    in (c1,dx,b) end
622                | E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=> let
623              val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)              val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
624              (* clean to get body indices in order *)              (* clean to get body indices in order *)
625              val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])              val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
626              val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]                  in (c1,dx,body1) end
627                |  E.Probe(E.Conv(V,[c1],h,dx),pos)=> let
             val lhs1=DstV.new ("L", DstTy.TensorTy(index1))  
             val ein1 = mkEin(Ein.params ein,index1,body1)  
             val code1= (lhs1,mkEinApp(ein1,args))  
             val codeAll= (case dx  
             of []=> replaceProbe(1,code1,body1,[])  
             | _ =>liftProbe(1,code1,body1,[])  
             (*end case*))  
   
             (*Probe that tensor at a constant position  c1*)  
             val param0 = [E.TEN(1,index1)]  
             val nx=List.tabulate(length(dx)+1,fn n=>E.V n)  
             val body0 =  E.Tensor(0,[c1]@nx)  
             val ein0 = mkEin(param0,index0,body0)  
             val einApp0 = mkEinApp(ein0,[lhs1])  
             val code0 = (y,einApp0)  
             val _= toStringBind code0  
             val _=testp["\n end FieldMat *****************************\n "]  
         in  
             codeAll@[code0]  
     end  
   
     fun liftFieldVec(newvx,e,fieldset)=  
     let  
         val _=testp[ "\n ***************************** start FieldVec\n"]  
         val (y, DstIL.EINAPP(ein,args))=e  
         val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein  
         val index0=Ein.index ein  
         val index1 = index0@[3]  
628          val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)          val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
         (* clean to get body indices in order *)  
629          val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])          val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
630                   in (c1,dx,body1) end
631                (*end case*))
632    
633          val lhs1=DstV.new ("L", DstTy.TensorTy(index1))          val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
634          val ein1 = mkEin(Ein.params ein,index1,body1)          val ein1 = mkEin(Ein.params ein,index1,body1)
635          val code1= (lhs1,mkEinApp(ein1,args))          val code1= (lhs1,mkEinApp(ein1,args))
         val (fieldset,(lhs0,codeAll))=searchFullField (fieldset,code1,body1,dx)  
   
         (*Probe that tensor at a constant position  c1*)  
         val param0 = [E.TEN(1,index1)]  
         val nx=List.tabulate(length(dx),fn n=>E.V n)  
         val body0 =  E.Tensor(0,[c1]@nx)  
         val ein0 = mkEin(param0,index0,body0)  
         val einApp0 = mkEinApp(ein0,[lhs0])  
         val code0 = (y,einApp0)  
   
         val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]  
         val _ = (toStringBind code0)  
         val _ = testp[ "\n end FieldVec *****************************\n "]  
         in  
             codeAll@[code0]  
     end  
636    
637            val (lhs0,(fieldset,codeAll))= (case valnumflag
638                of false    => (lhs1, reconstruction(dx,(fieldset,code1,body1,[])))
639      fun liftFieldSum e =              | true      => (case  (einVarSet.rtnVarN(fieldset,code1))
640      let                  of (fieldset,NONE)     => (lhs1, reconstruction(dx,(fieldset,code1,body1,[])))
641          val _=testp[ "\n************************************* Start Lift Field Sum\n"]                  | (fieldset,SOME m)   =>  (m,(fieldset,[]))
642          val (y, DstIL.EINAPP(ein,args))=e                  (*end case*))
         val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein  
         val index0=Ein.index ein  
         val index1 = index0@[3]@[3]  
         val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))  
         val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)  
   
   
         val lhs1=DstV.new ("L", DstTy.TensorTy(index1))  
         val ein1 = mkEin(Ein.params ein,index1,body1)  
         val code1= (lhs1,mkEinApp(ein1,args))  
         val codeAll= (case dx  
             of []   => replaceProbe(1,code1,body1,[])  
             | _     =>liftProbe(1,code1,body1,[])  
643              (*end case*))              (*end case*))
644    
645          (*Probe that tensor at a constant position  c1*)          (*Probe that tensor at a constant position  c1*)
646          val param0 = [E.TEN(1,index1)]          val param0 = [E.TEN(1,index1)]
647          val nx=List.tabulate(length(dx),fn n=>E.V n)          val nx=List.tabulate(newvx,fn n=>E.V n)
648          val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))          val body0 =  (case b
649                of E.Sum([(vsum,0,n)],_)=>  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
650                | _ => E.Tensor(0,[c1]@nx)
651                (*end case*))
652          val ein0 = mkEin(param0,index0,body0)          val ein0 = mkEin(param0,index0,body0)
653          val einApp0 = mkEinApp(ein0,[lhs1])          val einApp0 = mkEinApp(ein0,[lhs0])
654          val code0 = (y,einApp0)          val code0 = (y,einApp0)
         val _ = toStringBind  e  
655          val _ = toStringBind code0          val _ = toStringBind code0
         val _  = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])  
         val _  =((List.map toStringBind (codeAll@[code0])))  
         val _ = testp["\n*** end Field Sum*************************************\n"]  
656          in          in
657          codeAll@[code0]              (fieldset,codeAll@[code0])
658      end      end
659        (* **************************************************** General Fn **************************************************** *)
   
660      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
661      * A this point we only have simple ein ops      * A this point we only have simple ein ops
662      * Looks to see if the expression has a probe. If so, replaces it.      * Looks to see if the expression has a probe. If so, replaces it.
663      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
664      *)      *)
665     fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let      fun expandEinOp(e0 as (y, DstIL.EINAPP(ein,args)),fieldset,varset)=let
666            fun rewriteBody(fieldset,e,p as E.Probe(E.Conv(_,alpha,_,dx),_))= (case (detflag,alpha,dx)
667                of (true,[E.C(_,true), E.V 0],[])            => getF(e,fieldset,[3],1)
668                | (true,[E.C(_,true), E.V 0],[E.V 1])        => getF(e,fieldset,[3],2)
669                | (true,[E.C(_,true), E.V 0],[E.V 1,E.V 2])  => getF(e,fieldset,[3],3)
670                | (true,[E.C(_,true)],[])                    => getF(e,fieldset,[3],0)
671                | (true,[E.C(_,true)],[E.V 0])               => getF(e,fieldset,[3],1)
672                | (true,[E.C(_,true)],[E.V 0,E.V 1])         => getF(e,fieldset,[3],2)
673                | (true,[E.C(_,true)],[E.V 0,E.V 1,E.V 2])   => getF(e,fieldset,[3],3)
674                | _                                          => reconstruction(dx,(fieldset,e,p,[]))
675                (*end case*))
676            | rewriteBody(fieldset,e,E.Sum(sx,p as E.Probe(E.Conv(_,alpha,_,dx),_)))= (case (detsumflag,sx,alpha,dx)
677                of (true,[(E.V 0,0,_)],[E.V 0 ,E.V 0],[])              => getF(e,fieldset,[3,3],0)
678                | (true,[(E.V 1,0,_)],[E.V 1 ,E.V 1],[E.V 0])          => getF(e,fieldset,[3,3],1)
679                | (true,[(E.V 2,0,_)],[E.V 2 ,E.V 2],[E.V 0,E.V 1])    => getF(e,fieldset,[3,3],2)
680                | (_,_,_,[])                                => replaceProbe(fieldset,e,p, sx)  (*no dx*)
681                | (_,_,[],_)                                => reconstruction(dx,(fieldset,e,p,sx))
682                | _                                         => replaceProbe(fieldset,e,p, sx)
683                (* end case *))
684            | rewriteBody(fieldset,e,E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))     = replaceProbe(fieldset,e,E.Probe p,sx)
685            | rewriteBody (fieldset,e,_)  = (fieldset,[e])
686    
687      fun checkConst(es,a)=(case constflag          val b=Ein.body ein
688          of true => liftProbe a          fun pf()=("\n **************************** starting  **************************** \n"^(P.printerE(ein)))
689          | _ => let          fun matchField()=(case b
690              fun fConst ([],a) =              of E.Probe _ =>  (pf();1)
691                  (case fieldliftflag              | E.Sum (_, E.Probe _)=> (pf();1)
692                      of true => liftProbe a              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=> (pf();1)
                     | _ => replaceProbe a  
                 (*end case*))  
                 | fConst ((E.C _::_),a) = replaceProbe a  
                     (*raise Fail (String.concat["\nFound it:",P.printerE(ein)])*)  
             | fConst ((_ ::es),a)= checkConst(es,a)  
             in fConst(es,a) end  
       (* end case*))  
   
   
   
     fun rewriteBodyB b=(case b  
         of  (E.Probe(E.Conv(_,_,_,[]),_))  
             => replaceProbe(0,e,b,[])  
         | (E.Probe(E.Conv (_,alpha,_,dx),_))  
             => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)  
         | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))  
             => replaceProbe(0,e,p, sx)  (*no dx*)  
         | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))  
             => checkConst(dx,(0,e,p,sx)) (*scalar field*)  
         | (E.Sum(sx,E.Probe p))  
             => replaceProbe(0,e,E.Probe p, sx)  
         | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))  
             => replaceProbe(0,e,E.Probe p,sx)  
         | _ => [e]  
         (* end case *))  
   
   
         fun rewriteBody b=(case detflag  
             of true => (case (detsumflag,b)  
                 of (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))  
                     => liftFieldMat (1,e)  
                 | (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))  
                     => liftFieldMat (2,e)  
                 | (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))  
                     => liftFieldMat (3,e)  
                 | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))  
                     => liftFieldSum e  
                 | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))  
                     => liftFieldSum e  
                 | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))  
                     => liftFieldSum e  
                 | (true,E.Probe(E.Conv(_,[E.C _ ],_,[]),pos))  
                     => liftFieldVec (0,e,fieldset)  
                 | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0]),pos))  
                     => liftFieldVec (1,e,fieldset)  
                 | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1] ),pos))  
                     => liftFieldVec (2,e,fieldset)  
                 | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1,E.V 2] ),pos))  
                         => liftFieldVec (3,e,fieldset)  
                 | _   => rewriteBodyB b  
                 (* end case *))  
             | _   => rewriteBodyB b  
             (* end case *))  
   
         val (fieldset,var) = (case valnumflag  
             of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))  
             | _     => (fieldset,NONE)  
         (*end case*))  
   
         fun matchField b=(case b  
             of E.Probe _ => 1  
             | E.Sum (_, E.Probe _)=>1  
             | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1  
693              | _ =>0              | _ =>0
694              (*end case*))              (*end case*))
695          fun toStrField b=(case b          val m=matchField()
696              of E.Probe _ => print("\n"^(P.printbody b))          val (fieldset,varset,code,flag) = (case valnumflag
697              | E.Sum (_, E.Probe _)=>print("\n"^ (P.printbody b))              of true => (case (einVarSet.rtnVarN(fieldset,e0))
698              | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>print("\n"^ (P.printbody b))                 of  (fieldset,NONE)     => let
699              | _ => print""                      val(fieldset,code)=rewriteBody(fieldset,e0,b)
700                        in (fieldset,varset,code,0) end
701                  | (fieldset,SOME v)    => (fieldset,varset,[(y,DstIL.VAR v)],1)
702              (*end case*))              (*end case*))
703          val b=Ein.body ein                  | _     => let
704                    val(fieldset,code)=rewriteBody(fieldset,e0, b)
705          in  (case var                  in (fieldset,varset,code,0) end
         of NONE=> (toStrField(b);(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))  
             | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))  
706              (*end case*))              (*end case*))
707          end  
708            in  (code,fieldset,varset,m,flag) end
709    
710    end; (* local *)    end; (* local *)
711    

Legend:
Removed from v.3459  
changed lines
  Added in v.3655

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0