Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3443 - (view) (download)

1 : cchiw 2845 (* Expands probe ein
2 : cchiw 2606 *
3 : jhr 3349 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 : cchiw 2606 * All rights reserved.
7 :     *)
8 :    
9 :     structure ProbeEin = struct
10 :    
11 :     local
12 :    
13 :     structure E = Ein
14 :     structure DstIL = MidIL
15 :     structure DstOp = MidOps
16 : jhr 3060 structure P = Printer
17 :     structure T = TransformEin
18 :     structure MidToS = MidToString
19 : cchiw 2976 structure DstV = DstIL.Var
20 :     structure DstTy = MidILTypes
21 :    
22 : cchiw 2606 in
23 :    
24 : cchiw 2870 (* This file expands probed fields
25 :     * Take a look at ProbeEin tex file for examples
26 :     *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
27 :     * Param_ids are used to note the placement of the argument in the midIL.var list
28 :     * Index_ids keep track of the shape of an Image or differentiation.
29 :     * Mu bind Index_id
30 :     * Generally, we will refer to the following
31 :     *dim:dimension of field V
32 :     * s: support of kernel H
33 :     * alpha: The alpha in <V_alpha * H^(deltas)>
34 :     * deltas: The deltas in <V_alpha * H^(deltas)>
35 :     * Vid:param_id for V
36 :     * hid:param_id for H
37 :     * nid: integer position param_id
38 :     * fid :fractional position param_id
39 : cchiw 3166 * img-imginfo about V
40 : cchiw 2870 *)
41 : cchiw 3033
42 : cchiw 2923 val testing=0
43 : cchiw 3443
44 :    
45 :    
46 : cchiw 3325 val valnumflag=true
47 : cchiw 3443 val tsplitvar=false
48 :     val fieldliftflag=false
49 :     val constflag=false
50 :     val detflag =false
51 :     val detsumflag=false
52 : cchiw 3307
53 :    
54 : cchiw 2845 val cnt = ref 0
55 :     fun transformToIndexSpace e=T.transformToIndexSpace e
56 :     fun transformToImgSpace e=T.transformToImgSpace e
57 : cchiw 3268 fun toStringBind e=(MidToString.toStringBind e)
58 : cchiw 3260 fun mkEin e=Ein.mkEin e
59 : cchiw 3033 fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
60 : cchiw 3441 fun setConst e = E.setConst e
61 :     fun setNeg e = E.setNeg e
62 :     fun setExp e = E.setExp e
63 :     fun setDiv e= E.setDiv e
64 :     fun setSub e= E.setSub e
65 :     fun setProd e= E.setProd e
66 :     fun setAdd e= E.setAdd e
67 :    
68 : cchiw 2845 fun testp n=(case testing
69 :     of 0=> 1
70 : cchiw 3383 | _ =>((String.concat n);1)
71 : cchiw 2845 (*end case*))
72 : cchiw 3260
73 :    
74 : cchiw 2845 fun getRHSDst x = (case DstIL.Var.binding x
75 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
76 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
77 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
78 :     (* end case *))
79 : cchiw 2838
80 : cchiw 2606
81 : cchiw 2845 (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
82 :     uses the Param_ids for the image, kernel,
83 :     and position tensor to get the Mid-IL arguments
84 :     returns the support of ther kernel, and image
85 :     *)
86 : jhr 3060 fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
87 :     of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
88 : cchiw 2845 in
89 : jhr 3060 ((Kernel.support h) ,img,ImageInfo.dim img)
90 : cchiw 2845 end
91 : cchiw 3441 | ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
92 : cchiw 2845 (*end case*))
93 : cchiw 2606
94 :    
95 : cchiw 2845 (*handleArgs():int*int*int*Mid IL.Var list
96 :     ->int*Mid.ILVars list* code*int* low-il-var
97 :     * uses the Param_ids for the image, kernel, and tensor
98 :     * and gets the mid-IL vars for each.
99 :     *Transforms the position to index space
100 :     *P is the mid-il var for the (transformation matrix)transpose
101 :     *)
102 :     fun handleArgs(Vid,hid,tid,args)=let
103 :     val imgArg=List.nth(args,Vid)
104 :     val hArg=List.nth(args,hid)
105 :     val newposArg=List.nth(args,tid)
106 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
107 :     val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
108 : cchiw 2606 in
109 : cchiw 2845 (dim,args@argsT,code, s,P)
110 : cchiw 2606 end
111 : cchiw 2838
112 : cchiw 2845 (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
113 :     * expands the body for the probed field
114 :     *)
115 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
116 :     (*1-d fields*)
117 :     fun createKRND1 ()=let
118 :     val sum=sx
119 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
120 : cchiw 3441 val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
121 :     val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
122 : cchiw 2845 in
123 : cchiw 3441 setProd[E.Img(Vid,alpha,pos),rest]
124 : cchiw 2843 end
125 : cchiw 2845 (*createKRN Image field and kernels *)
126 : cchiw 3441 fun createKRN(0,imgpos,rest)=setProd ([E.Img(Vid,alpha,imgpos)] @rest)
127 : cchiw 2845 | createKRN(dim,imgpos,rest)=let
128 :     val dim'=dim-1
129 :     val sum=sx+dim'
130 :     val dels=List.map (fn e=>(E.C dim',e)) deltas
131 : cchiw 3441 val pos=[setAdd[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
132 :     val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
133 : cchiw 2845 in
134 :     createKRN(dim',pos@imgpos,[rest']@rest)
135 :     end
136 :     val exp=(case dim
137 :     of 1 => createKRND1()
138 :     | _=> createKRN(dim, [],[])
139 :     (*end case*))
140 :     (*sumIndex creating summaiton Index for body*)
141 :     val slb=1-s
142 : cchiw 3383 val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
143 : cchiw 2845 val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
144 : cchiw 2843 in
145 : cchiw 2845 E.Sum(esum, exp)
146 : cchiw 2606 end
147 :    
148 : cchiw 2845 (*getsumshift:sum_indexid list* int list-> int
149 :     *get fresh/unused index_id, returns int
150 :     *)
151 : cchiw 3383 fun getsumshift(sx,n) =let
152 : cchiw 2845 val nsumshift= (case sx
153 : cchiw 3383 of []=> n
154 : cchiw 2845 | _=>let
155 :     val (E.V v,_,_)=List.hd(List.rev sx)
156 :     in v+1
157 :     end
158 :     (* end case *))
159 : cchiw 3383
160 : cchiw 2845 val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
161 : cchiw 3383 val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
162 :     "\n\t Index length:",Int.toString n,
163 :     "\n\t Freshindex: ", Int.toString nsumshift])
164 : cchiw 2845 in
165 :     nsumshift
166 :     end
167 : cchiw 2611
168 : cchiw 2845 (*formBody:ein_exp->ein_exp
169 :     *just does a quick rewrite
170 :     *)
171 :     fun formBody(E.Sum([],e))=formBody e
172 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
173 : cchiw 3441 | formBody(E.Opn(E.Prod, [e]))=e
174 : cchiw 2845 | formBody e=e
175 : cchiw 2606
176 : cchiw 2976 (* silly change in order of the product to match vis branch WorldtoSpace functions*)
177 : cchiw 3441 fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
178 : cchiw 3353 (*
179 : cchiw 3441 | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
180 : cchiw 3353 *)
181 : cchiw 3441 | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
182 :     | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
183 : cchiw 3195
184 : cchiw 2976
185 : cchiw 3441 fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
186 : cchiw 3195 | multiMergePs e=multiPs e
187 :    
188 :    
189 : cchiw 2845 (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
190 :     -> ein_exp* *code
191 :     * Transforms position to world space
192 :     * transforms result back to index_space
193 :     * rewrites body
194 :     * replace probe with expanded version
195 :     *)
196 : cchiw 3048 (* fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
197 :    
198 :     fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
199 :     =let
200 :     val originalb=Ein.body e
201 :     val params=Ein.params e
202 :     val index=Ein.index e
203 : cchiw 3267 val _ = testp["\n***************** \n Replace ************ \n"]
204 : cchiw 3260 val _= toStringBind (y, DstIL.EINAPP(e,args))
205 : cchiw 3048
206 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
207 : cchiw 2845 val fid=length(params)
208 :     val nid=fid+1
209 :     val Pid=nid+1
210 :     val nshift=length(dx)
211 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
212 : cchiw 3383 val freshIndex=getsumshift(sx,length(index))
213 : cchiw 2845 val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
214 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
215 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
216 : cchiw 2976 val body' = multiPs(Ps,newsx1,body')
217 : cchiw 3033
218 :     val body'=(case originalb
219 :     of E.Sum(sx, E.Probe _) => E.Sum(sx,body')
220 : cchiw 3441 | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => E.Sum(sx,setProd[eps0,body'])
221 : cchiw 3033 | _ => body'
222 :     (*end case*))
223 : cchiw 3260
224 : cchiw 3033
225 : cchiw 2845 val args'=argsA@[PArg]
226 : cchiw 3033 val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
227 :     in
228 :     code@[einapp]
229 : cchiw 2845 end
230 : cchiw 2976
231 : cchiw 3443
232 : cchiw 3048 fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
233 : cchiw 2976 val Pid=0
234 :     val tid=1
235 : cchiw 3260
236 :     (*Assumes body is already clean*)
237 : cchiw 3048 val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
238 :    
239 :     (*need to rewrite dx*)
240 : cchiw 3260 val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
241 : cchiw 3048 of []=> ([],index,E.Conv(9,alpha,7,newdx))
242 : cchiw 3260 | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
243 : cchiw 3048 (*end case*))
244 :    
245 :     val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
246 : cchiw 3260 fun filterAlpha []=[]
247 :     | filterAlpha(E.C _::es)= filterAlpha es
248 :     | filterAlpha(e1::es)=[e1]@(filterAlpha es)
249 :    
250 :     val tshape=filterAlpha(alpha')@newdx
251 : cchiw 3033 val t=E.Tensor(tid,tshape)
252 : cchiw 3441
253 : cchiw 3195 val (splitvar,body)=(case originalb
254 : cchiw 3441 of E.Sum(sx, E.Probe _) => (true,multiPs(Ps,sx@newsx,t))
255 :     | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
256 : cchiw 3195 | _ => (case tsplitvar
257 : cchiw 3441 of(* true => (true,multiMergePs(Ps,newsx,t)) (*pushes summations in place*)
258 : cchiw 3259 | false*) _ => (true,multiPs(Ps,newsx,t))
259 : cchiw 3195 (*end case*))
260 : cchiw 3441 (*end case*))
261 : cchiw 3048
262 : cchiw 3324 val _ =(case splitvar
263 : cchiw 3325 of true=> (String.concat["splitvar is true", P.printbody body])
264 :     | _ => (String.concat["splitvar is false",P.printbody body])
265 : cchiw 3324 (*end case*))
266 :    
267 :    
268 : cchiw 3048 val ein0=mkEin(params,index,body)
269 : cchiw 2976 in
270 : cchiw 3260 (splitvar,ein0,sizes,dx,alpha')
271 : cchiw 2976 end
272 : cchiw 3048
273 : cchiw 3260 fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
274 : cchiw 3441 val _=testp["\n******* Lift Geneirc Probe ***\n"]
275 : cchiw 3048 val originalb=Ein.body e
276 :     val params=Ein.params e
277 : cchiw 3189 val index=Ein.index e
278 : cchiw 3383 val _ = (toStringBind (y, DstIL.EINAPP(e,args)))
279 : cchiw 2976
280 : cchiw 3048 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
281 : cchiw 2976 val fid=length(params)
282 :     val nid=fid+1
283 :     val nshift=length(dx)
284 : cchiw 3048 val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
285 : cchiw 3383 val freshIndex=getsumshift(sx,length(index))
286 : cchiw 2976
287 :     (*transform T*P*P..Ps*)
288 : cchiw 3260 val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
289 : cchiw 3441
290 : cchiw 3048 val FArg = DstV.new ("F", DstTy.TensorTy(sizes))
291 :     val einApp0=mkEinApp(ein0,[PArg,FArg])
292 : cchiw 3195 val rtn0=(case splitvar
293 : cchiw 3324 of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
294 :     | _ => let
295 :     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
296 : cchiw 3440 in Split.splitEinApp bind3
297 : cchiw 3324 end
298 : cchiw 3195 (*end case*))
299 : cchiw 2976
300 :     (*lifted probe*)
301 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
302 : cchiw 3383 val freshIndex'= length(sizes)
303 :    
304 :     val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
305 : cchiw 3033 val ein1=mkEin(params',sizes,body')
306 : cchiw 2976 val einApp1=mkEinApp(ein1,args')
307 : cchiw 3048 val rtn1=(FArg,einApp1)
308 : cchiw 3195 val rtn=code@[rtn1]@rtn0
309 : cchiw 3260 val _= List.map toStringBind ([rtn1]@rtn0)
310 : cchiw 3383 val _=(String.concat["\n* end Lift Geneirc Probe ******** \n"])
311 : cchiw 2976 in
312 :     rtn
313 :     end
314 :    
315 : cchiw 3383 fun searchFullField (fieldset,code1,body1,dx)=let
316 :     val (lhs,_)=code1
317 :     fun continueReconstruction ()=let
318 :     val _=print"Tash:don't replaced"
319 :     in (case dx
320 :     of []=> (lhs,replaceProbe(1,code1,body1,[]))
321 :     | _ =>(lhs,liftProbe(1,code1,body1,[]))
322 :     (*end case*))
323 :     end
324 :     in (case valnumflag
325 : cchiw 3441 of false => (fieldset,continueReconstruction())
326 :     | true => (case (einSet.rtnVarN(fieldset,code1))
327 : cchiw 3383 of (fieldset,NONE) => (fieldset,continueReconstruction())
328 :     | (fieldset,SOME m) =>(print"TASH:replaced"; (fieldset,(m,[])))
329 :     (*end case*))
330 :     (*end case*))
331 :     end
332 : cchiw 3324
333 :     fun liftFieldMat(newvx,e)=
334 :     let
335 : cchiw 3441 val _=testp[ "\n ***************************** start FieldMat\n"]
336 : cchiw 3324 val (y, DstIL.EINAPP(ein,args))=e
337 :     val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
338 :     val index0=Ein.index ein
339 :     val index1 = index0@[3]
340 :     val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
341 :     (* clean to get body indices in order *)
342 :     val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
343 :     val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
344 :    
345 :     val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
346 :     val ein1 = mkEin(Ein.params ein,index1,body1)
347 :     val code1= (lhs1,mkEinApp(ein1,args))
348 :     val codeAll= (case dx
349 :     of []=> replaceProbe(1,code1,body1,[])
350 :     | _ =>liftProbe(1,code1,body1,[])
351 :     (*end case*))
352 :    
353 :     (*Probe that tensor at a constant position c1*)
354 :     val param0 = [E.TEN(1,index1)]
355 :     val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
356 :     val body0 = E.Tensor(0,[c1]@nx)
357 :     val ein0 = mkEin(param0,index0,body0)
358 :     val einApp0 = mkEinApp(ein0,[lhs1])
359 :     val code0 = (y,einApp0)
360 :     val _= toStringBind code0
361 : cchiw 3441 val _=testp["\n end FieldMat *****************************\n "]
362 : cchiw 3324 in
363 :     codeAll@[code0]
364 :     end
365 :    
366 : cchiw 3383 fun liftFieldVec(newvx,e,fieldset)=
367 :     let
368 : cchiw 3441 val _=testp[ "\n ***************************** start FieldVec\n"]
369 : cchiw 3383 val (y, DstIL.EINAPP(ein,args))=e
370 :     val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein
371 :     val index0=Ein.index ein
372 :     val index1 = index0@[3]
373 :     val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
374 :     (* clean to get body indices in order *)
375 :     val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
376 :    
377 :    
378 :     val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
379 :     val ein1 = mkEin(Ein.params ein,index1,body1)
380 :     val code1= (lhs1,mkEinApp(ein1,args))
381 :     val (fieldset,(lhs0,codeAll))=searchFullField (fieldset,code1,body1,dx)
382 :    
383 :     (*Probe that tensor at a constant position c1*)
384 :     val param0 = [E.TEN(1,index1)]
385 :     val nx=List.tabulate(length(dx),fn n=>E.V n)
386 :     val body0 = E.Tensor(0,[c1]@nx)
387 :     val ein0 = mkEin(param0,index0,body0)
388 :     val einApp0 = mkEinApp(ein0,[lhs0])
389 :     val code0 = (y,einApp0)
390 :    
391 : cchiw 3441 val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
392 :     val _ = (toStringBind code0)
393 :     val _ = testp[ "\n end FieldVec *****************************\n "]
394 : cchiw 3383 in
395 :     codeAll@[code0]
396 :     end
397 :    
398 :    
399 :    
400 : cchiw 3324 fun liftFieldSum e =
401 :     let
402 : cchiw 3441 val _=testp[ "\n************************************* Start Lift Field Sum\n"]
403 : cchiw 3324 val (y, DstIL.EINAPP(ein,args))=e
404 :     val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
405 :     val index0=Ein.index ein
406 :     val index1 = index0@[3]@[3]
407 :     val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
408 :     val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
409 :    
410 :    
411 :     val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
412 :     val ein1 = mkEin(Ein.params ein,index1,body1)
413 :     val code1= (lhs1,mkEinApp(ein1,args))
414 :     val codeAll= (case dx
415 : cchiw 3441 of [] => replaceProbe(1,code1,body1,[])
416 :     | _ =>liftProbe(1,code1,body1,[])
417 : cchiw 3383 (*end case*))
418 : cchiw 3324
419 :     (*Probe that tensor at a constant position c1*)
420 :     val param0 = [E.TEN(1,index1)]
421 :     val nx=List.tabulate(length(dx),fn n=>E.V n)
422 :     val body0 = E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
423 :     val ein0 = mkEin(param0,index0,body0)
424 :     val einApp0 = mkEinApp(ein0,[lhs1])
425 :     val code0 = (y,einApp0)
426 : cchiw 3374 val _ = toStringBind e
427 :     val _ = toStringBind code0
428 : cchiw 3383 val _ = (String.concat ["\norig",P.printbody(Ein.body ein),"\n replace i ",P.printbody body1,"\nfreshtensor",P.printbody body0])
429 :     val _ =((List.map toStringBind (codeAll@[code0])))
430 : cchiw 3441 val _ = testp["\n*** end Field Sum*************************************\n"]
431 : cchiw 3324 in
432 :     codeAll@[code0]
433 :     end
434 :    
435 :    
436 : cchiw 2845 (* expandEinOp: code-> code list
437 : cchiw 3259 * A this point we only have simple ein ops
438 :     * Looks to see if the expression has a probe. If so, replaces it.
439 : cchiw 2845 * Note how we keeps eps expressions so only generate pieces that are used
440 :     *)
441 : cchiw 3229 fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
442 : cchiw 3443
443 :     fun checkConst(es,a)=(case constflag
444 :     of true => liftProbe a
445 :     | _ => let
446 :     fun fConst ([],a) =
447 :     (case fieldliftflag
448 :     of true => liftProbe a
449 :     | _ => replaceProbe a
450 :     (*end case*))
451 :     | fConst ((E.C _::_),a) = replaceProbe a
452 :     | fConst ((_ ::es),a)= checkConst(es,a)
453 :     in fConst(es,a) end
454 :     (* end case*))
455 :    
456 : cchiw 3261
457 : cchiw 3443
458 :     fun rewriteBodyB b=(case b
459 :     of (E.Probe(E.Conv(_,_,_,[]),_))
460 :     => replaceProbe(0,e,b,[])
461 :     | (E.Probe(E.Conv (_,alpha,_,dx),_))
462 :     => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
463 :     | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
464 :     => replaceProbe(0,e,p, sx) (*no dx*)
465 :     | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
466 :     => checkConst(dx,(0,e,p,sx)) (*scalar field*)
467 :     | (E.Sum(sx,E.Probe p))
468 :     => replaceProbe(0,e,E.Probe p, sx)
469 :     | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))
470 :     => replaceProbe(0,e,E.Probe p,sx)
471 :     | _ => [e]
472 :     (* end case *))
473 :    
474 :    
475 :     fun rewriteBody b=(case detflag
476 :     of true => (case (detsumflag,b)
477 :     of (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
478 :     => liftFieldMat (1,e)
479 :     | (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
480 :     => liftFieldMat (2,e)
481 :     | (_,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
482 :     => liftFieldMat (3,e)
483 :     | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))
484 :     => liftFieldSum e
485 :     | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))
486 :     => liftFieldSum e
487 :     | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))
488 :     => liftFieldSum e
489 :     | (true,E.Probe(E.Conv(_,[E.C _ ],_,[]),pos))
490 :     => liftFieldVec (0,e,fieldset)
491 :     | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0]),pos))
492 :     => liftFieldVec (1,e,fieldset)
493 :     | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1] ),pos))
494 :     => liftFieldVec (2,e,fieldset)
495 :     | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1,E.V 2] ),pos))
496 :     => liftFieldVec (3,e,fieldset)
497 :     | _ => rewriteBodyB b
498 :     (* end case *))
499 :     | _ => rewriteBodyB b
500 : cchiw 2845 (* end case *))
501 : cchiw 3174
502 : cchiw 3314 val (fieldset,var) = (case valnumflag
503 :     of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
504 :     | _ => (fieldset,NONE)
505 :     (*end case*))
506 : cchiw 3271
507 :     fun matchField b=(case b
508 :     of E.Probe _ => 1
509 :     | E.Sum (_, E.Probe _)=>1
510 : cchiw 3441 | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1
511 : cchiw 3271 | _ =>0
512 :     (*end case*))
513 : cchiw 3327 fun toStrField b=(case b
514 : cchiw 3383 of E.Probe _ => print("\n"^(P.printbody b))
515 : cchiw 3369 | E.Sum (_, E.Probe _)=>print("\n"^ (P.printbody b))
516 : cchiw 3441 | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>print("\n"^ (P.printbody b))
517 : cchiw 3383 | _ => print""
518 : cchiw 3327 (*end case*))
519 : cchiw 3353 val b=Ein.body ein
520 : cchiw 3369
521 : cchiw 3174 in (case var
522 : cchiw 3415 of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
523 : cchiw 3344 | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
524 : cchiw 3174 (*end case*))
525 : cchiw 2845 end
526 : cchiw 2843
527 : cchiw 2606 end; (* local *)
528 :    
529 : cchiw 2845 end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0