Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3441 - (view) (download)

1 : cchiw 2845 (* Expands probe ein
2 : cchiw 2606 *
3 : jhr 3349 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 : cchiw 2606 * All rights reserved.
7 :     *)
8 :    
9 :     structure ProbeEin = struct
10 :    
11 :     local
12 :    
13 :     structure E = Ein
14 :     structure DstIL = MidIL
15 :     structure DstOp = MidOps
16 : jhr 3060 structure P = Printer
17 :     structure T = TransformEin
18 :     structure MidToS = MidToString
19 : cchiw 2976 structure DstV = DstIL.Var
20 :     structure DstTy = MidILTypes
21 :    
22 : cchiw 2606 in
23 :    
24 : cchiw 2870 (* This file expands probed fields
25 :     * Take a look at ProbeEin tex file for examples
26 :     *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
27 :     * Param_ids are used to note the placement of the argument in the midIL.var list
28 :     * Index_ids keep track of the shape of an Image or differentiation.
29 :     * Mu bind Index_id
30 :     * Generally, we will refer to the following
31 :     *dim:dimension of field V
32 :     * s: support of kernel H
33 :     * alpha: The alpha in <V_alpha * H^(deltas)>
34 :     * deltas: The deltas in <V_alpha * H^(deltas)>
35 :     * Vid:param_id for V
36 :     * hid:param_id for H
37 :     * nid: integer position param_id
38 :     * fid :fractional position param_id
39 : cchiw 3166 * img-imginfo about V
40 : cchiw 2870 *)
41 : cchiw 3033
42 : cchiw 2923 val testing=0
43 : cchiw 3325 val testlift=1
44 : cchiw 3441 val detflag =true
45 : cchiw 3362 val fieldliftflag=true
46 : cchiw 3325 val valnumflag=true
47 : cchiw 3307
48 :    
49 : cchiw 2845 val cnt = ref 0
50 :     fun transformToIndexSpace e=T.transformToIndexSpace e
51 :     fun transformToImgSpace e=T.transformToImgSpace e
52 : cchiw 3268 fun toStringBind e=(MidToString.toStringBind e)
53 : cchiw 3260 fun mkEin e=Ein.mkEin e
54 : cchiw 3033 fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
55 : cchiw 3441 fun setConst e = E.setConst e
56 :     fun setNeg e = E.setNeg e
57 :     fun setExp e = E.setExp e
58 :     fun setDiv e= E.setDiv e
59 :     fun setSub e= E.setSub e
60 :     fun setProd e= E.setProd e
61 :     fun setAdd e= E.setAdd e
62 :    
63 : cchiw 2845 fun testp n=(case testing
64 :     of 0=> 1
65 : cchiw 3383 | _ =>((String.concat n);1)
66 : cchiw 2845 (*end case*))
67 : cchiw 3260
68 :    
69 : cchiw 2845 fun getRHSDst x = (case DstIL.Var.binding x
70 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
71 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
72 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
73 :     (* end case *))
74 : cchiw 2838
75 : cchiw 2606
76 : cchiw 2845 (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
77 :     uses the Param_ids for the image, kernel,
78 :     and position tensor to get the Mid-IL arguments
79 :     returns the support of ther kernel, and image
80 :     *)
81 : jhr 3060 fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
82 :     of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
83 : cchiw 2845 in
84 : jhr 3060 ((Kernel.support h) ,img,ImageInfo.dim img)
85 : cchiw 2845 end
86 : cchiw 3441 | ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
87 : cchiw 2845 (*end case*))
88 : cchiw 2606
89 :    
90 : cchiw 2845 (*handleArgs():int*int*int*Mid IL.Var list
91 :     ->int*Mid.ILVars list* code*int* low-il-var
92 :     * uses the Param_ids for the image, kernel, and tensor
93 :     * and gets the mid-IL vars for each.
94 :     *Transforms the position to index space
95 :     *P is the mid-il var for the (transformation matrix)transpose
96 :     *)
97 :     fun handleArgs(Vid,hid,tid,args)=let
98 :     val imgArg=List.nth(args,Vid)
99 :     val hArg=List.nth(args,hid)
100 :     val newposArg=List.nth(args,tid)
101 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
102 :     val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
103 : cchiw 2606 in
104 : cchiw 2845 (dim,args@argsT,code, s,P)
105 : cchiw 2606 end
106 : cchiw 2838
107 : cchiw 2845 (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
108 :     * expands the body for the probed field
109 :     *)
110 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
111 :     (*1-d fields*)
112 :     fun createKRND1 ()=let
113 :     val sum=sx
114 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
115 : cchiw 3441 val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
116 :     val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
117 : cchiw 2845 in
118 : cchiw 3441 setProd[E.Img(Vid,alpha,pos),rest]
119 : cchiw 2843 end
120 : cchiw 2845 (*createKRN Image field and kernels *)
121 : cchiw 3441 fun createKRN(0,imgpos,rest)=setProd ([E.Img(Vid,alpha,imgpos)] @rest)
122 : cchiw 2845 | createKRN(dim,imgpos,rest)=let
123 :     val dim'=dim-1
124 :     val sum=sx+dim'
125 :     val dels=List.map (fn e=>(E.C dim',e)) deltas
126 : cchiw 3441 val pos=[setAdd[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
127 :     val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
128 : cchiw 2845 in
129 :     createKRN(dim',pos@imgpos,[rest']@rest)
130 :     end
131 :     val exp=(case dim
132 :     of 1 => createKRND1()
133 :     | _=> createKRN(dim, [],[])
134 :     (*end case*))
135 :     (*sumIndex creating summaiton Index for body*)
136 :     val slb=1-s
137 : cchiw 3383 val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
138 : cchiw 2845 val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
139 : cchiw 2843 in
140 : cchiw 2845 E.Sum(esum, exp)
141 : cchiw 2606 end
142 :    
143 : cchiw 2845 (*getsumshift:sum_indexid list* int list-> int
144 :     *get fresh/unused index_id, returns int
145 :     *)
146 : cchiw 3383 fun getsumshift(sx,n) =let
147 : cchiw 2845 val nsumshift= (case sx
148 : cchiw 3383 of []=> n
149 : cchiw 2845 | _=>let
150 :     val (E.V v,_,_)=List.hd(List.rev sx)
151 :     in v+1
152 :     end
153 :     (* end case *))
154 : cchiw 3383
155 : cchiw 2845 val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
156 : cchiw 3383 val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
157 :     "\n\t Index length:",Int.toString n,
158 :     "\n\t Freshindex: ", Int.toString nsumshift])
159 : cchiw 2845 in
160 :     nsumshift
161 :     end
162 : cchiw 2611
163 : cchiw 2845 (*formBody:ein_exp->ein_exp
164 :     *just does a quick rewrite
165 :     *)
166 :     fun formBody(E.Sum([],e))=formBody e
167 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
168 : cchiw 3441 | formBody(E.Opn(E.Prod, [e]))=e
169 : cchiw 2845 | formBody e=e
170 : cchiw 2606
171 : cchiw 2976 (* silly change in order of the product to match vis branch WorldtoSpace functions*)
172 : cchiw 3441 fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
173 : cchiw 3353 (*
174 : cchiw 3441 | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
175 : cchiw 3353 *)
176 : cchiw 3441 | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
177 :     | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
178 : cchiw 3195
179 : cchiw 2976
180 : cchiw 3441 fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
181 : cchiw 3195 | multiMergePs e=multiPs e
182 :    
183 :    
184 : cchiw 2845 (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
185 :     -> ein_exp* *code
186 :     * Transforms position to world space
187 :     * transforms result back to index_space
188 :     * rewrites body
189 :     * replace probe with expanded version
190 :     *)
191 : cchiw 3048 (* fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
192 :    
193 :     fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
194 :     =let
195 :     val originalb=Ein.body e
196 :     val params=Ein.params e
197 :     val index=Ein.index e
198 : cchiw 3267 val _ = testp["\n***************** \n Replace ************ \n"]
199 : cchiw 3260 val _= toStringBind (y, DstIL.EINAPP(e,args))
200 : cchiw 3048
201 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
202 : cchiw 2845 val fid=length(params)
203 :     val nid=fid+1
204 :     val Pid=nid+1
205 :     val nshift=length(dx)
206 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
207 : cchiw 3383 val freshIndex=getsumshift(sx,length(index))
208 : cchiw 2845 val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
209 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
210 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
211 : cchiw 2976 val body' = multiPs(Ps,newsx1,body')
212 : cchiw 3033
213 :     val body'=(case originalb
214 :     of E.Sum(sx, E.Probe _) => E.Sum(sx,body')
215 : cchiw 3441 | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => E.Sum(sx,setProd[eps0,body'])
216 : cchiw 3033 | _ => body'
217 :     (*end case*))
218 : cchiw 3260
219 : cchiw 3033
220 : cchiw 2845 val args'=argsA@[PArg]
221 : cchiw 3033 val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
222 :     in
223 :     code@[einapp]
224 : cchiw 2845 end
225 : cchiw 2976
226 : cchiw 3195 val tsplitvar=true
227 : cchiw 3048 fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
228 : cchiw 2976 val Pid=0
229 :     val tid=1
230 : cchiw 3260
231 :     (*Assumes body is already clean*)
232 : cchiw 3048 val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
233 :    
234 :     (*need to rewrite dx*)
235 : cchiw 3260 val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
236 : cchiw 3048 of []=> ([],index,E.Conv(9,alpha,7,newdx))
237 : cchiw 3260 | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
238 : cchiw 3048 (*end case*))
239 :    
240 :     val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
241 : cchiw 3260 fun filterAlpha []=[]
242 :     | filterAlpha(E.C _::es)= filterAlpha es
243 :     | filterAlpha(e1::es)=[e1]@(filterAlpha es)
244 :    
245 :     val tshape=filterAlpha(alpha')@newdx
246 : cchiw 3033 val t=E.Tensor(tid,tshape)
247 : cchiw 3441
248 : cchiw 3195 val (splitvar,body)=(case originalb
249 : cchiw 3441 of E.Sum(sx, E.Probe _) => (true,multiPs(Ps,sx@newsx,t))
250 :     | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
251 : cchiw 3195 | _ => (case tsplitvar
252 : cchiw 3441 of(* true => (true,multiMergePs(Ps,newsx,t)) (*pushes summations in place*)
253 : cchiw 3259 | false*) _ => (true,multiPs(Ps,newsx,t))
254 : cchiw 3195 (*end case*))
255 : cchiw 3441 (*end case*))
256 : cchiw 3048
257 : cchiw 3324 val _ =(case splitvar
258 : cchiw 3325 of true=> (String.concat["splitvar is true", P.printbody body])
259 :     | _ => (String.concat["splitvar is false",P.printbody body])
260 : cchiw 3324 (*end case*))
261 :    
262 :    
263 : cchiw 3048 val ein0=mkEin(params,index,body)
264 : cchiw 2976 in
265 : cchiw 3260 (splitvar,ein0,sizes,dx,alpha')
266 : cchiw 2976 end
267 : cchiw 3048
268 : cchiw 3260 fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
269 : cchiw 3441 val _=testp["\n******* Lift Geneirc Probe ***\n"]
270 : cchiw 3048 val originalb=Ein.body e
271 :     val params=Ein.params e
272 : cchiw 3189 val index=Ein.index e
273 : cchiw 3383 val _ = (toStringBind (y, DstIL.EINAPP(e,args)))
274 : cchiw 2976
275 : cchiw 3048 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
276 : cchiw 2976 val fid=length(params)
277 :     val nid=fid+1
278 :     val nshift=length(dx)
279 : cchiw 3048 val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
280 : cchiw 3383 val freshIndex=getsumshift(sx,length(index))
281 : cchiw 2976
282 :     (*transform T*P*P..Ps*)
283 : cchiw 3260 val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
284 : cchiw 3441
285 : cchiw 3048 val FArg = DstV.new ("F", DstTy.TensorTy(sizes))
286 :     val einApp0=mkEinApp(ein0,[PArg,FArg])
287 : cchiw 3195 val rtn0=(case splitvar
288 : cchiw 3324 of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
289 :     | _ => let
290 :     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
291 : cchiw 3440 in Split.splitEinApp bind3
292 : cchiw 3324 end
293 : cchiw 3195 (*end case*))
294 : cchiw 2976
295 :     (*lifted probe*)
296 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
297 : cchiw 3383 val freshIndex'= length(sizes)
298 :    
299 :     val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
300 : cchiw 3033 val ein1=mkEin(params',sizes,body')
301 : cchiw 2976 val einApp1=mkEinApp(ein1,args')
302 : cchiw 3048 val rtn1=(FArg,einApp1)
303 : cchiw 3195 val rtn=code@[rtn1]@rtn0
304 : cchiw 3260 val _= List.map toStringBind ([rtn1]@rtn0)
305 : cchiw 3383 val _=(String.concat["\n* end Lift Geneirc Probe ******** \n"])
306 : cchiw 2976 in
307 :     rtn
308 :     end
309 :    
310 : cchiw 3383 fun searchFullField (fieldset,code1,body1,dx)=let
311 :     val (lhs,_)=code1
312 :     fun continueReconstruction ()=let
313 :     val _=print"Tash:don't replaced"
314 :     in (case dx
315 :     of []=> (lhs,replaceProbe(1,code1,body1,[]))
316 :     | _ =>(lhs,liftProbe(1,code1,body1,[]))
317 :     (*end case*))
318 :     end
319 :     in (case valnumflag
320 : cchiw 3441 of false => (fieldset,continueReconstruction())
321 :     | true => (case (einSet.rtnVarN(fieldset,code1))
322 : cchiw 3383 of (fieldset,NONE) => (fieldset,continueReconstruction())
323 :     | (fieldset,SOME m) =>(print"TASH:replaced"; (fieldset,(m,[])))
324 :     (*end case*))
325 :     (*end case*))
326 :     end
327 : cchiw 3324
328 :     fun liftFieldMat(newvx,e)=
329 :     let
330 : cchiw 3441 val _=testp[ "\n ***************************** start FieldMat\n"]
331 : cchiw 3324 val (y, DstIL.EINAPP(ein,args))=e
332 :     val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
333 :     val index0=Ein.index ein
334 :     val index1 = index0@[3]
335 :     val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
336 :     (* clean to get body indices in order *)
337 :     val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
338 :     val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
339 :    
340 :     val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
341 :     val ein1 = mkEin(Ein.params ein,index1,body1)
342 :     val code1= (lhs1,mkEinApp(ein1,args))
343 :     val codeAll= (case dx
344 :     of []=> replaceProbe(1,code1,body1,[])
345 :     | _ =>liftProbe(1,code1,body1,[])
346 :     (*end case*))
347 :    
348 :     (*Probe that tensor at a constant position c1*)
349 :     val param0 = [E.TEN(1,index1)]
350 :     val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
351 :     val body0 = E.Tensor(0,[c1]@nx)
352 :     val ein0 = mkEin(param0,index0,body0)
353 :     val einApp0 = mkEinApp(ein0,[lhs1])
354 :     val code0 = (y,einApp0)
355 :     val _= toStringBind code0
356 : cchiw 3441 val _=testp["\n end FieldMat *****************************\n "]
357 : cchiw 3324 in
358 :     codeAll@[code0]
359 :     end
360 :    
361 : cchiw 3383 fun liftFieldVec(newvx,e,fieldset)=
362 :     let
363 : cchiw 3441 val _=testp[ "\n ***************************** start FieldVec\n"]
364 : cchiw 3383 val (y, DstIL.EINAPP(ein,args))=e
365 :     val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein
366 :     val index0=Ein.index ein
367 :     val index1 = index0@[3]
368 :     val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
369 :     (* clean to get body indices in order *)
370 :     val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
371 :    
372 :    
373 :     val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
374 :     val ein1 = mkEin(Ein.params ein,index1,body1)
375 :     val code1= (lhs1,mkEinApp(ein1,args))
376 :     val (fieldset,(lhs0,codeAll))=searchFullField (fieldset,code1,body1,dx)
377 :    
378 :     (*Probe that tensor at a constant position c1*)
379 :     val param0 = [E.TEN(1,index1)]
380 :     val nx=List.tabulate(length(dx),fn n=>E.V n)
381 :     val body0 = E.Tensor(0,[c1]@nx)
382 :     val ein0 = mkEin(param0,index0,body0)
383 :     val einApp0 = mkEinApp(ein0,[lhs0])
384 :     val code0 = (y,einApp0)
385 :    
386 : cchiw 3441 val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
387 :     val _ = (toStringBind code0)
388 :     val _ = testp[ "\n end FieldVec *****************************\n "]
389 : cchiw 3383 in
390 :     codeAll@[code0]
391 :     end
392 :    
393 :    
394 :    
395 : cchiw 3324 fun liftFieldSum e =
396 :     let
397 : cchiw 3441 val _=testp[ "\n************************************* Start Lift Field Sum\n"]
398 : cchiw 3324 val (y, DstIL.EINAPP(ein,args))=e
399 :     val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
400 :     val index0=Ein.index ein
401 :     val index1 = index0@[3]@[3]
402 :     val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
403 :     val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
404 :    
405 :    
406 :     val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
407 :     val ein1 = mkEin(Ein.params ein,index1,body1)
408 :     val code1= (lhs1,mkEinApp(ein1,args))
409 :     val codeAll= (case dx
410 : cchiw 3441 of [] => replaceProbe(1,code1,body1,[])
411 :     | _ =>liftProbe(1,code1,body1,[])
412 : cchiw 3383 (*end case*))
413 : cchiw 3324
414 :     (*Probe that tensor at a constant position c1*)
415 :     val param0 = [E.TEN(1,index1)]
416 :     val nx=List.tabulate(length(dx),fn n=>E.V n)
417 :     val body0 = E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
418 :     val ein0 = mkEin(param0,index0,body0)
419 :     val einApp0 = mkEinApp(ein0,[lhs1])
420 :     val code0 = (y,einApp0)
421 : cchiw 3374 val _ = toStringBind e
422 :     val _ = toStringBind code0
423 : cchiw 3383 val _ = (String.concat ["\norig",P.printbody(Ein.body ein),"\n replace i ",P.printbody body1,"\nfreshtensor",P.printbody body0])
424 :     val _ =((List.map toStringBind (codeAll@[code0])))
425 : cchiw 3441 val _ = testp["\n*** end Field Sum*************************************\n"]
426 : cchiw 3324 in
427 :     codeAll@[code0]
428 :     end
429 :    
430 :    
431 : cchiw 2845 (* expandEinOp: code-> code list
432 : cchiw 3259 * A this point we only have simple ein ops
433 :     * Looks to see if the expression has a probe. If so, replaces it.
434 : cchiw 2845 * Note how we keeps eps expressions so only generate pieces that are used
435 :     *)
436 : cchiw 3229 fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
437 : cchiw 3261
438 : cchiw 3441 fun checkConst ([],a) =
439 : cchiw 3314 (case fieldliftflag
440 :     of true => liftProbe a
441 :     | _ => replaceProbe a
442 :     (*end case*))
443 : cchiw 3260 | checkConst ((E.C _::_),a) = replaceProbe a
444 : cchiw 3166 | checkConst ((_ ::es),a)= checkConst(es,a)
445 : cchiw 3307 fun rewriteBody b=(case (detflag,b)
446 :     of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
447 : cchiw 3324 => liftFieldMat (1,e)
448 : cchiw 3307 | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
449 : cchiw 3324 => liftFieldMat (2,e)
450 : cchiw 3307 | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
451 : cchiw 3324 => liftFieldMat (3,e)
452 :     | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))
453 :     => liftFieldSum e
454 :     | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))
455 :     => liftFieldSum e
456 :     | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))
457 :     => liftFieldSum e
458 : cchiw 3383 | (true,E.Probe(E.Conv(_,[E.C _ ],_,[]),pos))
459 :     => liftFieldVec (0,e,fieldset)
460 :     | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0]),pos))
461 :     => liftFieldVec (1,e,fieldset)
462 :     | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1] ),pos))
463 :     => liftFieldVec (2,e,fieldset)
464 :     | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1,E.V 2] ),pos))
465 :     => liftFieldVec (3,e,fieldset)
466 : cchiw 3307 | (_,E.Probe(E.Conv(_,_,_,[]),_))
467 : cchiw 3262 => replaceProbe(0,e,b,[])
468 : cchiw 3307 | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
469 : cchiw 3259 => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
470 : cchiw 3307 | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
471 : cchiw 3260 => replaceProbe(0,e,p, sx) (*no dx*)
472 : cchiw 3307 | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
473 : cchiw 3260 => checkConst(dx,(0,e,p,sx)) (*scalar field*)
474 : cchiw 3307 | (_,E.Sum(sx,E.Probe p))
475 : cchiw 3094 => replaceProbe(0,e,E.Probe p, sx)
476 : cchiw 3441 | (_, E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))
477 : cchiw 3094 => replaceProbe(0,e,E.Probe p,sx)
478 : cchiw 3307 | (_,_) => [e]
479 : cchiw 2845 (* end case *))
480 : cchiw 3174
481 : cchiw 3314 val (fieldset,var) = (case valnumflag
482 :     of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
483 :     | _ => (fieldset,NONE)
484 :     (*end case*))
485 : cchiw 3271
486 :     fun matchField b=(case b
487 :     of E.Probe _ => 1
488 :     | E.Sum (_, E.Probe _)=>1
489 : cchiw 3441 | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1
490 : cchiw 3271 | _ =>0
491 :     (*end case*))
492 : cchiw 3327 fun toStrField b=(case b
493 : cchiw 3383 of E.Probe _ => print("\n"^(P.printbody b))
494 : cchiw 3369 | E.Sum (_, E.Probe _)=>print("\n"^ (P.printbody b))
495 : cchiw 3441 | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>print("\n"^ (P.printbody b))
496 : cchiw 3383 | _ => print""
497 : cchiw 3327 (*end case*))
498 : cchiw 3353 val b=Ein.body ein
499 : cchiw 3369
500 : cchiw 3174 in (case var
501 : cchiw 3415 of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
502 : cchiw 3344 | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
503 : cchiw 3174 (*end case*))
504 : cchiw 2845 end
505 : cchiw 2843
506 : cchiw 2606 end; (* local *)
507 :    
508 : cchiw 2845 end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0