Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3602 - (view) (download)

1 : cchiw 2845 (* Expands probe ein
2 : cchiw 2606 *
3 : jhr 3349 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 : cchiw 2606 * All rights reserved.
7 :     *)
8 :    
9 :     structure ProbeEin = struct
10 :    
11 :     local
12 :    
13 :     structure E = Ein
14 :     structure DstIL = MidIL
15 :     structure DstOp = MidOps
16 : jhr 3060 structure P = Printer
17 :     structure T = TransformEin
18 :     structure MidToS = MidToString
19 : cchiw 2976 structure DstV = DstIL.Var
20 :     structure DstTy = MidILTypes
21 :    
22 : cchiw 2606 in
23 :    
24 : cchiw 2870 (* This file expands probed fields
25 :     * Take a look at ProbeEin tex file for examples
26 :     *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
27 :     * Param_ids are used to note the placement of the argument in the midIL.var list
28 :     * Index_ids keep track of the shape of an Image or differentiation.
29 :     * Mu bind Index_id
30 :     * Generally, we will refer to the following
31 :     *dim:dimension of field V
32 :     * s: support of kernel H
33 :     * alpha: The alpha in <V_alpha * H^(deltas)>
34 :     * deltas: The deltas in <V_alpha * H^(deltas)>
35 :     * Vid:param_id for V
36 :     * hid:param_id for H
37 :     * nid: integer position param_id
38 :     * fid :fractional position param_id
39 : cchiw 3166 * img-imginfo about V
40 : cchiw 2870 *)
41 : cchiw 3033
42 : cchiw 2923 val testing=0
43 : cchiw 3448 val valnumflag=true
44 :     val tsplitvar=true
45 : cchiw 3444 val fieldliftflag=true
46 : cchiw 3448 val constflag=false
47 :     val detflag =true
48 :    
49 : cchiw 3444
50 :    
51 : cchiw 2845 val cnt = ref 0
52 :     fun transformToIndexSpace e=T.transformToIndexSpace e
53 :     fun transformToImgSpace e=T.transformToImgSpace e
54 : cchiw 3268 fun toStringBind e=(MidToString.toStringBind e)
55 : cchiw 3260 fun mkEin e=Ein.mkEin e
56 : cchiw 3033 fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
57 : cchiw 3448 fun setConst e = E.setConst e
58 :     fun setNeg e = E.setNeg e
59 :     fun setExp e = E.setExp e
60 :     fun setDiv e= E.setDiv e
61 :     fun setSub e= E.setSub e
62 :     fun setProd e= E.setProd e
63 :     fun setAdd e= E.setAdd e
64 :    
65 : cchiw 2845 fun testp n=(case testing
66 :     of 0=> 1
67 : cchiw 3444 | _ =>((String.concat n);1)
68 : cchiw 2845 (*end case*))
69 : cchiw 3260
70 :    
71 : cchiw 2845 fun getRHSDst x = (case DstIL.Var.binding x
72 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
73 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
74 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
75 :     (* end case *))
76 : cchiw 2838
77 : cchiw 2606
78 : cchiw 2845 (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
79 :     uses the Param_ids for the image, kernel,
80 :     and position tensor to get the Mid-IL arguments
81 :     returns the support of ther kernel, and image
82 :     *)
83 : jhr 3060 fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
84 :     of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
85 : cchiw 2845 in
86 : jhr 3060 ((Kernel.support h) ,img,ImageInfo.dim img)
87 : cchiw 2845 end
88 : cchiw 3448 | ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
89 : cchiw 2845 (*end case*))
90 : cchiw 2606
91 :    
92 : cchiw 2845 (*handleArgs():int*int*int*Mid IL.Var list
93 :     ->int*Mid.ILVars list* code*int* low-il-var
94 :     * uses the Param_ids for the image, kernel, and tensor
95 :     * and gets the mid-IL vars for each.
96 :     *Transforms the position to index space
97 :     *P is the mid-il var for the (transformation matrix)transpose
98 :     *)
99 :     fun handleArgs(Vid,hid,tid,args)=let
100 :     val imgArg=List.nth(args,Vid)
101 :     val hArg=List.nth(args,hid)
102 :     val newposArg=List.nth(args,tid)
103 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
104 :     val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
105 : cchiw 2606 in
106 : cchiw 2845 (dim,args@argsT,code, s,P)
107 : cchiw 2606 end
108 : cchiw 2838
109 : cchiw 2845 (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
110 :     * expands the body for the probed field
111 :     *)
112 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
113 :     (*1-d fields*)
114 :     fun createKRND1 ()=let
115 :     val sum=sx
116 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
117 : cchiw 3448 val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
118 :     val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
119 : cchiw 2845 in
120 : cchiw 3448 setProd[E.Img(Vid,alpha,pos),rest]
121 : cchiw 2843 end
122 : cchiw 2845 (*createKRN Image field and kernels *)
123 : cchiw 3448 fun createKRN(0,imgpos,rest)=setProd ([E.Img(Vid,alpha,imgpos)] @rest)
124 : cchiw 2845 | createKRN(dim,imgpos,rest)=let
125 :     val dim'=dim-1
126 :     val sum=sx+dim'
127 :     val dels=List.map (fn e=>(E.C dim',e)) deltas
128 : cchiw 3448 val pos=[setAdd[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
129 :     val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
130 : cchiw 2845 in
131 :     createKRN(dim',pos@imgpos,[rest']@rest)
132 :     end
133 :     val exp=(case dim
134 :     of 1 => createKRND1()
135 :     | _=> createKRN(dim, [],[])
136 :     (*end case*))
137 :     (*sumIndex creating summaiton Index for body*)
138 :     val slb=1-s
139 : cchiw 3444 val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
140 : cchiw 2845 val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
141 : cchiw 2843 in
142 : cchiw 2845 E.Sum(esum, exp)
143 : cchiw 2606 end
144 :    
145 : cchiw 2845 (*getsumshift:sum_indexid list* int list-> int
146 :     *get fresh/unused index_id, returns int
147 :     *)
148 : cchiw 3444 fun getsumshift(sx,n) =let
149 : cchiw 2845 val nsumshift= (case sx
150 : cchiw 3444 of []=> n
151 : cchiw 2845 | _=>let
152 :     val (E.V v,_,_)=List.hd(List.rev sx)
153 :     in v+1
154 :     end
155 :     (* end case *))
156 : cchiw 3444
157 : cchiw 2845 val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
158 : cchiw 3444 val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
159 :     "\n\t Index length:",Int.toString n,
160 :     "\n\t Freshindex: ", Int.toString nsumshift])
161 : cchiw 2845 in
162 :     nsumshift
163 :     end
164 : cchiw 2611
165 : cchiw 2845 (*formBody:ein_exp->ein_exp
166 :     *just does a quick rewrite
167 :     *)
168 :     fun formBody(E.Sum([],e))=formBody e
169 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
170 : cchiw 3448 | formBody(E.Opn(E.Prod, [e]))=e
171 : cchiw 2845 | formBody e=e
172 : cchiw 2606
173 : cchiw 2976 (* silly change in order of the product to match vis branch WorldtoSpace functions*)
174 : cchiw 3448 fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
175 : cchiw 3444 (*
176 : cchiw 3448 | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
177 : cchiw 3444 *)
178 : cchiw 3448 | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
179 :     | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
180 : cchiw 3195
181 : cchiw 2976
182 : cchiw 3448 fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
183 : cchiw 3195 | multiMergePs e=multiPs e
184 :    
185 : cchiw 3602
186 :    
187 :     fun arrangeBodyLift(body, Ps, newsx, probe')=(case body
188 :     of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,probe'))
189 :     | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false, E.Sum(sx, setProd[eps0, multiPs(Ps, newsx,probe')]))
190 :     | E.Probe _ => (true, multiPs(Ps, newsx, probe'))
191 :     | _ => raise Fail "impossible"
192 :     (* end case *))
193 :    
194 :    
195 : cchiw 2845 (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
196 :     -> ein_exp* *code
197 :     * Transforms position to world space
198 :     * transforms result back to index_space
199 :     * rewrites body
200 :     * replace probe with expanded version
201 :     *)
202 : cchiw 3448 (* fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
203 :    
204 : cchiw 3048 fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
205 :     =let
206 :     val originalb=Ein.body e
207 :     val params=Ein.params e
208 :     val index=Ein.index e
209 : cchiw 3267 val _ = testp["\n***************** \n Replace ************ \n"]
210 : cchiw 3260 val _= toStringBind (y, DstIL.EINAPP(e,args))
211 : cchiw 3048
212 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
213 : cchiw 2845 val fid=length(params)
214 :     val nid=fid+1
215 :     val Pid=nid+1
216 :     val nshift=length(dx)
217 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
218 : cchiw 3444 val freshIndex=getsumshift(sx,length(index))
219 : cchiw 2845 val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
220 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
221 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
222 : cchiw 3602 (*
223 : cchiw 2976 val body' = multiPs(Ps,newsx1,body')
224 : cchiw 3033
225 :     val body'=(case originalb
226 :     of E.Sum(sx, E.Probe _) => E.Sum(sx,body')
227 : cchiw 3448 | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => E.Sum(sx,setProd[eps0,body'])
228 : cchiw 3033 | _ => body'
229 :     (*end case*))
230 : cchiw 3260
231 : cchiw 3602 *)
232 :     val (_,body')= arrangeBodyLift(originalb, Ps, newsx1, body')
233 : cchiw 3033
234 : cchiw 2845 val args'=argsA@[PArg]
235 : cchiw 3033 val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
236 :     in
237 :     code@[einapp]
238 : cchiw 2845 end
239 : cchiw 2976
240 : cchiw 3448
241 : cchiw 3048 fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
242 : cchiw 2976 val Pid=0
243 :     val tid=1
244 : cchiw 3260
245 :     (*Assumes body is already clean*)
246 : cchiw 3048 val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
247 :    
248 :     (*need to rewrite dx*)
249 : cchiw 3260 val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
250 : cchiw 3048 of []=> ([],index,E.Conv(9,alpha,7,newdx))
251 : cchiw 3260 | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
252 : cchiw 3048 (*end case*))
253 :    
254 :     val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
255 : cchiw 3260 fun filterAlpha []=[]
256 :     | filterAlpha(E.C _::es)= filterAlpha es
257 :     | filterAlpha(e1::es)=[e1]@(filterAlpha es)
258 :    
259 :     val tshape=filterAlpha(alpha')@newdx
260 : cchiw 3033 val t=E.Tensor(tid,tshape)
261 : cchiw 3602 (*
262 : cchiw 3195 val (splitvar,body)=(case originalb
263 : cchiw 3448 of E.Sum(sx, E.Probe _) => (true,multiPs(Ps,sx@newsx,t))
264 :     | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
265 : cchiw 3195 | _ => (case tsplitvar
266 : cchiw 3448 of(* true => (true,multiMergePs(Ps,newsx,t)) (*pushes summations in place*)
267 : cchiw 3259 | false*) _ => (true,multiPs(Ps,newsx,t))
268 : cchiw 3195 (*end case*))
269 : cchiw 3448 (*end case*))
270 : cchiw 3602 *)
271 :     val (splitvar,body)= arrangeBodyLift(originalb, Ps, newsx, t)
272 :    
273 : cchiw 3444 val _ =(case splitvar
274 :     of true=> (String.concat["splitvar is true", P.printbody body])
275 :     | _ => (String.concat["splitvar is false",P.printbody body])
276 :     (*end case*))
277 :    
278 :    
279 : cchiw 3048 val ein0=mkEin(params,index,body)
280 : cchiw 2976 in
281 : cchiw 3260 (splitvar,ein0,sizes,dx,alpha')
282 : cchiw 2976 end
283 : cchiw 3048
284 : cchiw 3260 fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
285 : cchiw 3448 val _=testp["\n******* Lift Geneirc Probe ***\n"]
286 : cchiw 3048 val originalb=Ein.body e
287 :     val params=Ein.params e
288 : cchiw 3189 val index=Ein.index e
289 : cchiw 3444 val _ = (toStringBind (y, DstIL.EINAPP(e,args)))
290 : cchiw 2976
291 : cchiw 3048 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
292 : cchiw 2976 val fid=length(params)
293 :     val nid=fid+1
294 :     val nshift=length(dx)
295 : cchiw 3048 val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
296 : cchiw 3444 val freshIndex=getsumshift(sx,length(index))
297 : cchiw 2976
298 :     (*transform T*P*P..Ps*)
299 : cchiw 3260 val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
300 : cchiw 3448
301 : cchiw 3048 val FArg = DstV.new ("F", DstTy.TensorTy(sizes))
302 :     val einApp0=mkEinApp(ein0,[PArg,FArg])
303 : cchiw 3195 val rtn0=(case splitvar
304 : cchiw 3444 of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
305 :     | _ => let
306 :     val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
307 :     in Split.splitEinApp bind3
308 :     end
309 : cchiw 3195 (*end case*))
310 : cchiw 2976
311 :     (*lifted probe*)
312 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
313 : cchiw 3444 val freshIndex'= length(sizes)
314 :    
315 :     val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
316 : cchiw 3033 val ein1=mkEin(params',sizes,body')
317 : cchiw 2976 val einApp1=mkEinApp(ein1,args')
318 : cchiw 3048 val rtn1=(FArg,einApp1)
319 : cchiw 3195 val rtn=code@[rtn1]@rtn0
320 : cchiw 3260 val _= List.map toStringBind ([rtn1]@rtn0)
321 : cchiw 3444 val _=(String.concat["\n* end Lift Geneirc Probe ******** \n"])
322 : cchiw 2976 in
323 :     rtn
324 :     end
325 :    
326 : cchiw 3444 fun searchFullField (fieldset,code1,body1,dx)=let
327 :     val (lhs,_)=code1
328 :     fun continueReconstruction ()=let
329 :     val _=print"Tash:don't replaced"
330 :     in (case dx
331 :     of []=> (lhs,replaceProbe(1,code1,body1,[]))
332 :     | _ =>(lhs,liftProbe(1,code1,body1,[]))
333 :     (*end case*))
334 :     end
335 :     in (case valnumflag
336 : cchiw 3448 of false => (fieldset,continueReconstruction())
337 :     | true => (case (einSet.rtnVarN(fieldset,code1))
338 : cchiw 3444 of (fieldset,NONE) => (fieldset,continueReconstruction())
339 :     | (fieldset,SOME m) =>(print"TASH:replaced"; (fieldset,(m,[])))
340 :     (*end case*))
341 :     (*end case*))
342 :     end
343 :    
344 : cchiw 3448 fun liftFieldMat(newvx,e)=
345 :     let
346 :     val _=testp[ "\n ***************************** start FieldMat\n"]
347 :     val (y, DstIL.EINAPP(ein,args))=e
348 :     val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
349 :     val index0=Ein.index ein
350 :     val index1 = index0@[3]
351 :     val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
352 :     (* clean to get body indices in order *)
353 :     val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
354 :     val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
355 : cchiw 3444
356 : cchiw 3448 val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
357 :     val ein1 = mkEin(Ein.params ein,index1,body1)
358 :     val code1= (lhs1,mkEinApp(ein1,args))
359 :     val codeAll= (case dx
360 :     of []=> replaceProbe(1,code1,body1,[])
361 :     | _ =>liftProbe(1,code1,body1,[])
362 :     (*end case*))
363 :    
364 :     (*Probe that tensor at a constant position c1*)
365 :     val param0 = [E.TEN(1,index1)]
366 :     val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
367 :     val body0 = E.Tensor(0,[c1]@nx)
368 :     val ein0 = mkEin(param0,index0,body0)
369 :     val einApp0 = mkEinApp(ein0,[lhs1])
370 :     val code0 = (y,einApp0)
371 :     val _= toStringBind code0
372 :     val _=testp["\n end FieldMat *****************************\n "]
373 :     in
374 :     codeAll@[code0]
375 :     end
376 :    
377 :     fun liftFieldVec(newvx,e,fieldset)=
378 :     let
379 :     val _=testp[ "\n ***************************** start FieldVec\n"]
380 :     val (y, DstIL.EINAPP(ein,args))=e
381 :     val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein
382 :     val index0=Ein.index ein
383 :     val index1 = index0@[3]
384 :     val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
385 :     (* clean to get body indices in order *)
386 :     val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
387 :    
388 :    
389 :     val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
390 :     val ein1 = mkEin(Ein.params ein,index1,body1)
391 :     val code1= (lhs1,mkEinApp(ein1,args))
392 :     val (fieldset,(lhs0,codeAll))=searchFullField (fieldset,code1,body1,dx)
393 :    
394 :     (*Probe that tensor at a constant position c1*)
395 :     val param0 = [E.TEN(1,index1)]
396 :     val nx=List.tabulate(length(dx),fn n=>E.V n)
397 :     val body0 = E.Tensor(0,[c1]@nx)
398 :     val ein0 = mkEin(param0,index0,body0)
399 :     val einApp0 = mkEinApp(ein0,[lhs0])
400 :     val code0 = (y,einApp0)
401 :    
402 :     val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]
403 :     val _ = (toStringBind code0)
404 :     val _ = testp[ "\n end FieldVec *****************************\n "]
405 :     in
406 :     codeAll@[code0]
407 :     end
408 :    
409 :    
410 :    
411 :     fun liftFieldSum e =
412 :     let
413 :     val _=testp[ "\n************************************* Start Lift Field Sum\n"]
414 :     val (y, DstIL.EINAPP(ein,args))=e
415 :     val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
416 :     val index0=Ein.index ein
417 :     val index1 = index0@[3]@[3]
418 :     val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
419 :     val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
420 :    
421 :    
422 :     val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
423 :     val ein1 = mkEin(Ein.params ein,index1,body1)
424 :     val code1= (lhs1,mkEinApp(ein1,args))
425 :     val codeAll= (case dx
426 :     of [] => replaceProbe(1,code1,body1,[])
427 :     | _ =>liftProbe(1,code1,body1,[])
428 :     (*end case*))
429 :    
430 :     (*Probe that tensor at a constant position c1*)
431 :     val param0 = [E.TEN(1,index1)]
432 :     val nx=List.tabulate(length(dx),fn n=>E.V n)
433 :     val body0 = E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
434 :     val ein0 = mkEin(param0,index0,body0)
435 :     val einApp0 = mkEinApp(ein0,[lhs1])
436 :     val code0 = (y,einApp0)
437 :     val _ = toStringBind e
438 :     val _ = toStringBind code0
439 :     val _ = (String.concat ["\norig",P.printbody(Ein.body ein),"\n replace i ",P.printbody body1,"\nfreshtensor",P.printbody body0])
440 :     val _ =((List.map toStringBind (codeAll@[code0])))
441 :     val _ = testp["\n*** end Field Sum*************************************\n"]
442 :     in
443 :     codeAll@[code0]
444 :     end
445 :    
446 :    
447 : cchiw 2845 (* expandEinOp: code-> code list
448 : cchiw 3259 * A this point we only have simple ein ops
449 :     * Looks to see if the expression has a probe. If so, replaces it.
450 : cchiw 2845 * Note how we keeps eps expressions so only generate pieces that are used
451 :     *)
452 : cchiw 3229 fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let
453 : cchiw 3448
454 :     fun checkConst(es,a)=(case constflag
455 :     of true => liftProbe a
456 :     | _ => let
457 :     fun fConst ([],a) =
458 :     (case fieldliftflag
459 :     of true => liftProbe a
460 :     | _ => replaceProbe a
461 :     (*end case*))
462 :     | fConst ((E.C _::_),a) = replaceProbe a
463 :     | fConst ((_ ::es),a)= checkConst(es,a)
464 :     in fConst(es,a) end
465 :     (* end case*))
466 :     fun rewriteBodyB b=(case b
467 :     of (E.Probe(E.Conv(_,_,_,[]),_))
468 : cchiw 3262 => replaceProbe(0,e,b,[])
469 : cchiw 3448 | (E.Probe(E.Conv (_,alpha,_,dx),_))
470 : cchiw 3259 => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
471 : cchiw 3448 | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
472 : cchiw 3260 => replaceProbe(0,e,p, sx) (*no dx*)
473 : cchiw 3448 | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
474 : cchiw 3260 => checkConst(dx,(0,e,p,sx)) (*scalar field*)
475 : cchiw 3448 | (E.Sum(sx,E.Probe p))
476 : cchiw 3094 => replaceProbe(0,e,E.Probe p, sx)
477 : cchiw 3448 | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))
478 : cchiw 3094 => replaceProbe(0,e,E.Probe p,sx)
479 : cchiw 3448 | _ => [e]
480 : cchiw 2845 (* end case *))
481 : cchiw 3448 fun rewriteBody b=(case detflag
482 :     of true => (case b
483 :     of (E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
484 :     => liftFieldMat (1,e)
485 :     | (E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
486 :     => liftFieldMat (2,e)
487 :     | (E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2]),pos))
488 :     => liftFieldMat (3,e)
489 :     | _ => rewriteBodyB b
490 :     (* end case *))
491 :     | _ => rewriteBodyB b
492 :     (* end case *))
493 : cchiw 3444 val (fieldset,var) = (case valnumflag
494 :     of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
495 :     | _ => (fieldset,NONE)
496 :     (*end case*))
497 :    
498 :     fun matchField b=(case b
499 :     of E.Probe _ => 1
500 :     | E.Sum (_, E.Probe _)=>1
501 : cchiw 3448 | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=>1
502 : cchiw 3444 | _ =>0
503 :     (*end case*))
504 : cchiw 3448 val b=Ein.body ein
505 : cchiw 3444
506 : cchiw 3174 in (case var
507 : cchiw 3448 of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
508 : cchiw 3444 | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
509 : cchiw 3174 (*end case*))
510 : cchiw 2845 end
511 : cchiw 2843
512 : cchiw 2606 end; (* local *)
513 :    
514 : cchiw 2845 end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0