Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

branches/charisee/src/compiler/high-to-mid/ProbeEin.sml revision 2843, Mon Dec 8 01:27:25 2014 UTC branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml revision 3503, Thu Dec 17 23:13:57 2015 UTC
# Line 1  Line 1 
1  (* Currently under construction  (* Expands probe ein
2   *   *
3   * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)   * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4     *
5     * COPYRIGHT (c) 2015 The University of Chicago
6   * All rights reserved.   * All rights reserved.
7   *)   *)
8    
# Line 9  Line 11 
11      local      local
12    
13      structure E = Ein      structure E = Ein
     structure mk= mkOperators  
     structure SrcIL = HighIL  
     structure SrcTy = HighILTypes  
     structure SrcOp = HighOps  
     structure SrcSV = SrcIL.StateVar  
     structure VTbl = SrcIL.Var.Tbl  
14      structure DstIL = MidIL      structure DstIL = MidIL
     structure DstTy = MidILTypes  
15      structure DstOp = MidOps      structure DstOp = MidOps
     structure DstV = DstIL.Var  
     structure SrcV = SrcIL.Var  
16      structure P=Printer      structure P=Printer
     structure F=Filter  
17      structure T=TransformEin      structure T=TransformEin
18      structure split=Split      structure MidToS = MidToString
19      structure cleanI=cleanIndex      structure DstV = DstIL.Var
20        structure DstTy = MidILTypes
   
     val testing=1  
   
21    
22      in      in
23    
   
24  (* This file expands probed fields  (* This file expands probed fields
25        * Take a look at ProbeEin tex file for examples
26  *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )  *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
27  * Param_ids are used to note the placement of the argument in the midIL.var list  * Param_ids are used to note the placement of the argument in the midIL.var list
28  * Index_ids bind the shape of an Image or differentiation.      * Index_ids  keep track of the shape of an Image or differentiation.
29        * Mu  bind Index_id
30  * Generally, we will refer to the following  * Generally, we will refer to the following
31  *dim:dimension of field V  *dim:dimension of field V
32  * s: support of kernel H  * s: support of kernel H
# Line 49  Line 39 
39  *img-imginfo about V  *img-imginfo about V
40  *)  *)
41    
42        val testing=0
43        val valnumflag=true
44        val tsplitvar=true
45        val fieldliftflag=true
46        val constflag=true
47        val detflag =true
48        val detsumflag=true
49        fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
50        fun decUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
51    
52  val cnt = ref 0  val cnt = ref 0
 fun genName prefix = let  
 val n = !cnt  
 in  
 cnt := n+1;  
 String.concat[prefix, "_", Int.toString n]  
 end  
   
   
 fun iterSx e=F.iterSx e  
53  fun transformToIndexSpace e=T.transformToIndexSpace e  fun transformToIndexSpace e=T.transformToIndexSpace e
54  fun transformToImgSpace  e=T.transformToImgSpace  e  fun transformToImgSpace  e=T.transformToImgSpace  e
55  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))      fun toStringBind e=(MidToString.toStringBind e)
56  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))      fun mkEin e=Ein.mkEin e
57        fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
58        fun setConst e = E.setConst e
59        fun setNeg e  =  E.setNeg e
60        fun setExp e  =  E.setExp e
61        fun setDiv e= E.setDiv e
62        fun setSub e= E.setSub e
63        fun setProd e= E.setProd e
64        fun setAdd e= E.setAdd e
65        fun mkCx es =List.map (fn c => E.C (c,true)) es
66        fun mkCxSingle c = E.C (c,true)
67    
68  fun testp n=(case testing  fun testp n=(case testing
69      of 0=> 1      of 0=> 1
70      | _ =>(print(String.concat n);1)      | _ =>(print(String.concat n);1)
71      (*end case*))      (*end case*))
72    
73    
74  fun getRHSDst x  = (case DstIL.Var.binding x  fun getRHSDst x  = (case DstIL.Var.binding x
75      of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)      of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
76      | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'      | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
# Line 76  Line 79 
79    
80    
81  (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int  (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
82      uses the Param_ids for the image, kernel, and position tensor to get the Mid-IL arguments          uses the Param_ids for the image, kernel,
83            and position tensor to get the Mid-IL arguments
84    returns the support of ther kernel, and image    returns the support of ther kernel, and image
85  *)  *)
86   fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)   fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
87      of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let          of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
88      in      in
89          ((Kernel.support h) ,img,ImageInfo.dim img)          ((Kernel.support h) ,img,ImageInfo.dim img)
90      end      end
91   |  _ => raise Fail "Expected Image and kernel arguments"              |  ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
92   (*end case*))   (*end case*))
93    
94    
95  (*handleArgs():int*int*int*Mid IL.Var list ->int*Mid.ILVars list* code*int* low-il-var      (*handleArgs():int*int*int*Mid IL.Var list
96  * uses the Param_ids for the image, kernel, and tensor and gets the mid-IL vars for each          ->int*Mid.ILVars list* code*int* low-il-var
97            * uses the Param_ids for the image, kernel, and tensor
98            * and gets the mid-IL vars for each.
99  *Transforms the position to index space  *Transforms the position to index space
100  *P-mid-il var for the (transformation matrix)transpose          *P is the mid-il var for the (transformation matrix)transpose
101  *)  *)
102  fun handleArgs(Vid,hid,tid,args)=let  fun handleArgs(Vid,hid,tid,args)=let
103      val imgArg=List.nth(args,Vid)      val imgArg=List.nth(args,Vid)
# Line 99  Line 105 
105      val newposArg=List.nth(args,tid)      val newposArg=List.nth(args,tid)
106      val (s,img,dim) =getArgsDst(hArg,imgArg,args)      val (s,img,dim) =getArgsDst(hArg,imgArg,args)
107      val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)      val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
108      in (dim,args@argsT,code, s,P)          in
109                (dim,args@argsT,code, s,P)
110      end      end
111    
112        (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
 (*createBody:int*int*int, index_id list, param_id, param_id, param_id, param_id  
113  * expands the body for the probed field  * expands the body for the probed field
114  *)  *)
115  fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let  fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
   
116      (*1-d fields*)      (*1-d fields*)
117      fun createKRND1 ()=let      fun createKRND1 ()=let
118          val sum=sx          val sum=sx
119          val dels=List.map (fn e=>(E.C 0,e)) deltas              val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
120          val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
121          val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))              val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
122          in          in
123              E.Prod [E.Img(Vid,alpha,pos),rest]                 setProd[E.Img(Vid,alpha,pos),rest]
   
124          end          end
125      (*createKRN Image field and kernels *)      (*createKRN Image field and kernels *)
126      fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)          fun createKRN(0,imgpos,rest)=setProd ([E.Img(Vid,alpha,imgpos)] @rest)
127      | createKRN(dim,imgpos,rest)=let      | createKRN(dim,imgpos,rest)=let
128          val dim'=dim-1          val dim'=dim-1
129          val sum=sx+dim'          val sum=sx+dim'
130          val dels=List.map (fn e=>(E.C dim',e)) deltas              val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
131          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]              val pos=[setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(sum)]]
132          val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))              val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[mkCxSingle dim']),E.Value(sum)))
133          in          in
134              createKRN(dim',pos@imgpos,[rest']@rest)              createKRN(dim',pos@imgpos,[rest']@rest)
135          end          end
# Line 133  Line 137 
137          of 1 => createKRND1()          of 1 => createKRND1()
138          | _=> createKRN(dim, [],[])          | _=> createKRN(dim, [],[])
139          (*end case*))          (*end case*))
   
140      (*sumIndex creating summaiton Index for body*)      (*sumIndex creating summaiton Index for body*)
141      val slb=1-s      val slb=1-s
142            val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
143      val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))      val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
144  in  in
145      E.Sum(esum, exp)      E.Sum(esum, exp)
146  end  end
147    
148  (*getsumshift:sum_index_id list* index_id list-> int      (*getsumshift:sum_indexid list* int list-> int
149  *get fresh/unused index_id, returns int  *get fresh/unused index_id, returns int
150  *)  *)
151  fun getsumshift(sx,index) =let      fun getsumshift(sx,n) =let
152      val nsumshift= (case sx      val nsumshift= (case sx
153          of []=> length(index)              of []=> n
154          | _=>let          | _=>let
155              val (E.V v,_,_)=List.hd(List.rev sx)              val (E.V v,_,_)=List.hd(List.rev sx)
156              in v+1              in v+1
157              end              end
158          (* end case *))          (* end case *))
159    
160      val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx      val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
161      val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]          val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
162            "\n\t Index length:",Int.toString n,
163            "\n\t Freshindex: ", Int.toString nsumshift])
164      in      in
165          nsumshift          nsumshift
166      end      end
# Line 163  Line 170 
170  *)  *)
171  fun formBody(E.Sum([],e))=formBody e  fun formBody(E.Sum([],e))=formBody e
172  | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)  | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
173  | formBody(E.Prod [e])=e      | formBody(E.Opn(E.Prod, [e]))=e
174  | formBody e=e  | formBody e=e
175    
176        (* silly change in order of the product to match vis branch WorldtoSpace functions*)
177        fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
178        (*
179          | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
180          *)
181          | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
182          | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
183    
184  (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code  
185        fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
186          | multiMergePs e=multiPs e
187    
188    
189        (* *******************************************  Replace probe *******************************************  *)
190        (* replaceProbe
191  * Transforms position to world space  * Transforms position to world space
192  * transforms result back to index_space  * transforms result back to index_space
193  * rewrites body  * rewrites body
194  * replace probe with expanded version  * replace probe with expanded version
195  *)  *)
196   fun replaceProbe(b,params,args,index, sx)=let       fun replaceProbe((y, DstIL.EINAPP(e,args)),p ,sx)
197            =let
198            val originalb=Ein.body e
199            val params=Ein.params e
200            val index=Ein.index e
201            val _ = testp["\n***************** \n Replace ************ \n"]
202            val _=  toStringBind (y, DstIL.EINAPP(e,args))
203    
204      val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
205      val fid=length(params)      val fid=length(params)
206      val nid=fid+1      val nid=fid+1
207      val Pid=nid+1      val Pid=nid+1
208      val nshift=length(dx)      val nshift=length(dx)
209      val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)      val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
210      val freshIndex=getsumshift(sx,index)          val freshIndex=getsumshift(sx,length(index))
211      val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)      val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
212      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]      val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
213      val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)      val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
214      val body' =formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))          val body' = multiPs(Ps,newsx1,body')
215    
216            val body'=(case originalb
217                of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
218                | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
219                | _                                  => body'
220                (*end case*))
221    
222      val args'=argsA@[PArg]      val args'=argsA@[PArg]
223            val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
224      in      in
225          (body',params',args' ,code)              code@[einapp]
226      end      end
227    
228        (* ******************************************* Lift probe *******************************************  *)
229        fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
230            val Pid=0
231            val tid=1
232    
233  (* expandEinOp: code->  code list          (*Assumes body is already clean*)
234  *Looks to see if the expression has a probe. If so, replaces it.          val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
 * Note how we keeps eps type expressions so we have less time in mid-to-low-il stage  
 *)  
 fun expandEinOp( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let  
     fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",  
     (String.concatWith",\t"(List.map split.printEINAPP code))]  
235    
236      fun rewriteBody b=(case b          (*need to rewrite dx*)
237          of  E.Probe e =>let          val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
238              val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])              of []=> ([],index,E.Conv(9,alpha,7,newdx))
239              val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))              | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
240              val code=newbies@[einapp]              (*end case*))
241              in  
242                  code          val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
243              end          fun filterAlpha []=[]
244          | E.Sum(sx,E.Probe e)  =>let            | filterAlpha(E.C _::es)= filterAlpha es
245              val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)            | filterAlpha(e1::es)=[e1]@(filterAlpha es)
246              val  body'=E.Sum(sx,body')  
247              val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val tshape=filterAlpha(alpha')@newdx
248              val code=newbies@[einapp]          val t=E.Tensor(tid,tshape)
249    
250            val (splitvar,body)=(case originalb
251                of E.Sum(sx, E.Probe _)              => (true,multiPs(Ps,sx@newsx,t))
252                | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
253                | _                                  => (case tsplitvar
254                  of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
255                    | false*) _ =>   (true,multiPs(Ps,newsx,t))
256                    (*end case*))
257            (*end case*))
258    
259            val _ =(case splitvar
260            of true=> (String.concat["splitvar is true", P.printbody body])
261            | _ => (String.concat["splitvar is false",P.printbody body])
262            (*end case*))
263    
264    
265            val ein0=mkEin(params,index,body)
266              in              in
267                  code              (splitvar,ein0,sizes,dx,alpha')
268              end              end
269          | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let  
270              val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)      fun liftProbe((y, DstIL.EINAPP(e,args)),p ,sx)=let
271              val  body'=E.Sum(sx,E.Prod[eps,body'])          val _=testp["\n******* Lift Geneirc Probe ***\n"]
272              val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))          val originalb=Ein.body e
273              val code=newbies@[einapp]          val params=Ein.params e
274              in          val index=Ein.index e
275                  code          val _ =  (toStringBind (y, DstIL.EINAPP(e,args)))
276    
277            val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
278            val fid=length(params)
279            val nid=fid+1
280            val nshift=length(dx)
281            val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
282            val freshIndex=getsumshift(sx,length(index))
283    
284            (*transform T*P*P..Ps*)
285            val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
286    
287            val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
288            val einApp0=mkEinApp(ein0,[PArg,FArg])
289            val rtn0=(case splitvar
290                of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
291                | _      => let
292                     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
293                     in Split.splitEinApp bind3
294              end              end
         | _=> [e]  
295          (* end case *))          (* end case *))
296    
297            (*lifted probe*)
298            val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
299            val freshIndex'= length(sizes)
300    
301            val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
302            val ein1=mkEin(params',sizes,body')
303            val einApp1=mkEinApp(ein1,args')
304            val rtn1=(FArg,einApp1)
305            val rtn=code@[rtn1]@rtn0
306            val _= List.map toStringBind ([rtn1]@rtn0)
307             val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])
308            in
309                rtn
310            end
311    
312        (* ******************************************* Reconstruction -> Lift|Replace probe *******************************************  *)
313        (* scans dx for contant
314         * arg:(1,code1, body1,[])
315         *)
316        fun reconstruction([],arg)= replaceProbe arg
317         | reconstruction(dx,arg)=(case (constflag,fieldliftflag)
318            of (true,true) => liftProbe arg
319            | (_,false)    => replaceProbe arg
320            | _ => let
321                fun fConst [] = liftProbe arg
322                | fConst (E.C _::_) = replaceProbe arg
323                | fConst (_ ::es)= fConst es
324                in fConst dx end
325            (* end case*))
326    
327        (* **************************************************** Index Tensor **************************************************** *)
328        (*Push constant indices to tensor replacement*)
329        fun getF (e,fieldset,dim,newvx)= let
330            val (y, DstIL.EINAPP(ein,args))=e
331            val index0=Ein.index ein
332            val index1 = index0@dim
333            val b=Ein.body ein
334    
335            val (c1,dx,body1)=(case b
336                of  E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=>let
337                    val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
338                    val b=E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
339                    in (c1,dx,b) end
340                | E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=> let
341                    val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
342                    (* clean to get body indices in order *)
343                    val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
344                    in (c1,dx,body1) end
345                |  E.Probe(E.Conv(V,[c1],h,dx),pos)=> let
346                   val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
347                   val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
348                   in (c1,dx,body1) end
349                (*end case*))
350    
351            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
352            val ein1 = mkEin(Ein.params ein,index1,body1)
353            val code1= (lhs1,mkEinApp(ein1,args))
354    
355            val (_,(lhs0,codeAll))= (case valnumflag
356                of false    => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
357                | true      => (case  (einVarSet.rtnVarN(fieldset,code1))
358                    of (fieldset,NONE)     => (fieldset,(lhs1, reconstruction(dx,(code1,body1,[]))))
359                    | (fieldset,SOME m)   =>  (fieldset,(m,[]))
360                    (*end case*))
361                (*end case*))
362    
363            (*Probe that tensor at a constant position  c1*)
364            val param0 = [E.TEN(1,index1)]
365            val nx=List.tabulate(newvx,fn n=>E.V n)
366            val body0 =  (case b
367                of E.Sum([(vsum,0,n)],_)=>  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
368                | _ => E.Tensor(0,[c1]@nx)
369                (*end case*))
370            val ein0 = mkEin(param0,index0,body0)
371            val einApp0 = mkEinApp(ein0,[lhs0])
372            val code0 = (y,einApp0)
373            val _= toStringBind code0
374      in      in
375          rewriteBody body              codeAll@[code0]
376      end      end
377        (* **************************************************** General Fn **************************************************** *)
378        (* expandEinOp: code->  code list
379        * A this point we only have simple ein ops
380        * Looks to see if the expression has a probe. If so, replaces it.
381        * Note how we keeps eps expressions so only generate pieces that are used
382        *)
383        fun expandEinOp( e0 as (y, DstIL.EINAPP(ein,args)),fieldset,varset)=let
384            fun rewriteBody(e,p as E.Probe(E.Conv(_,alpha,_,dx),_))= (case (detflag,alpha,dx)
385                of (true,[E.C(_,true), E.V 0],[])            => getF(e,fieldset,[3],1)
386                | (true,[E.C(_,true), E.V 0],[E.V 1])        => getF(e,fieldset,[3],2)
387                | (true,[E.C(_,true), E.V 0],[E.V 1,E.V 2])  => getF(e,fieldset,[3],3)
388                | (true,[E.C(_,true)],[])                    => getF(e,fieldset,[3],0)
389                | (true,[E.C(_,true)],[E.V 0])               => getF(e,fieldset,[3],1)
390                | (true,[E.C(_,true)],[E.V 0,E.V 1])         => getF(e,fieldset,[3],2)
391                | (true,[E.C(_,true)],[E.V 0,E.V 1,E.V 2])   => getF(e,fieldset,[3],3)
392                | _                                          => reconstruction(dx,(e,p,[]))
393                (*end case*))
394            | rewriteBody(e,E.Sum(sx,p as E.Probe(E.Conv(_,alpha,_,dx),_)))= (case (detsumflag,sx,alpha,dx)
395                of (true,[(E.V 0,0,_)],[E.V 0 ,E.V 0],[])              => getF(e,fieldset,[3,3],0)
396                | (true,[(E.V 1,0,_)],[E.V 1 ,E.V 1],[E.V 0])          => getF(e,fieldset,[3,3],1)
397                | (true,[(E.V 2,0,_)],[E.V 2 ,E.V 2],[E.V 0,E.V 1])    => getF(e,fieldset,[3,3],2)
398                | (_,_,_,[])                                => replaceProbe(e,p, sx)  (*no dx*)
399                | (_,_,[],_)                                => reconstruction(dx,(e,p,sx))
400                | _                                         => replaceProbe(e,p, sx)
401                (* end case *))
402            | rewriteBody(e,E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))     = replaceProbe(e,E.Probe p,sx)
403            | rewriteBody (e,_)  = [e]
404    
405            val b=Ein.body ein
406            fun matchField()=(case b
407                of E.Probe _ => 1
408                | E.Sum (_, E.Probe _)=>1
409                | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1
410                | _ =>0
411                (*end case*))
412            val (fieldset,varset,code,flag) = (case valnumflag
413                of true => (case (einVarSet.rtnVarN(fieldset,e0))
414                   of  (fldset,NONE)     => (fldset,varset,rewriteBody(e0,b),0)
415                  | (fldset,SOME v)    => (fldset,varset,[(y,DstIL.VAR v)],1)
416                     (*of (fldset, NONE)      => (fldset,varset,rewriteBody((y,DstIL.EINAPP(ein,List.map (fn a=>einVarSet.replaceArg(varset,a)) args)), b),0)
417                     | (fldset,SOME v)    => (fldset,einVarSet.VarSet.add(varset,einVarSet.VAR(v,y)),[],1)*)
418                    (*end case*))
419                | _     => (fieldset,varset,rewriteBody(e0, b),0)
420                (*end case*))
421            val m=matchField()
422            in  (code,fieldset,varset,m,flag) end
423    
424    end; (* local *)    end; (* local *)
425    

Legend:
Removed from v.2843  
changed lines
  Added in v.3503

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0