Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2426 - (view) (download)
Original Path: branches/charisee/src/compiler/high-il/normalize-ein.sml

1 : cchiw 2397
2 :     structure NormalizeEin = struct
3 :    
4 :     local
5 :     structure G = GenericEin
6 :     structure E = Ein
7 :     structure S = Specialize
8 :     structure R = Rewrite
9 :    
10 :    
11 :    
12 :     in
13 :    
14 :    
15 :    
16 :     (*
17 :     If changed is true then I know the expression will run through the funciton again.
18 :     However, if not, then I want to make sure that every expression in the Product is examined, and not just individually but as a group.
19 :     Prod[t1,t2,(t3+t4)] indivually=> same
20 :     Prod[t1] @ Prod[t2,(t3+t4)]=> Notice rule here
21 :     Prod[t1] @ Add(Prod (t2, t3), Prod (t2, t4))
22 :     => Add( Prod[t1, Prod(t2,t3)]..)
23 :     => Add (Prod[t1,t2,t3]) Flattened
24 :    
25 :     *)
26 :    
27 :    
28 :    
29 :    
30 :    
31 :     (*Flattens Add constructor: change, expression *)
32 :     fun mkAdd [e]=(1,e)
33 :     | mkAdd(e)=let
34 :     fun flatten((i, (E.Add l)::l'))= flatten(1,l@l')
35 :     |flatten(i,((E.Const c):: l'))=
36 :     if (c>0.0 orelse c<0.0) then let
37 :     val(b,a)=flatten(i,l') in (b,[E.Const c]@a) end
38 :     else flatten(1,l')
39 :     | flatten(i,[])=(i,[])
40 :     | flatten (i,e::l') = let
41 :     val(b,a)=flatten(i,l') in (b,[e]@a) end
42 :    
43 :     val (b,a)=flatten(0,e)
44 :     in case a
45 :     of [] => (1,E.Const(1.0))
46 :     | [e] => (1,e)
47 :     | es => (b,E.Add es)
48 :     (* end case *)
49 :     end
50 :    
51 :     fun mkProd [e]=(1,e)
52 :     | mkProd(e)=let
53 :     fun flatten(i,((E.Prod l)::l'))= flatten(1,l@l')
54 :     |flatten(i,((E.Const c):: l'))=
55 :     if(c>0.0 orelse c<0.0) then
56 :     if (c>1.0 orelse c<1.0) then let
57 :     val(b,a)=flatten(i,l') in (b,[E.Const c]@a) end
58 :     else flatten(1,l')
59 :     else (3, [E.Const(0.0)])
60 :     | flatten(i,[])=(i,[])
61 :     | flatten (i,e::l') = let
62 :     val(b,a)=flatten(i,l') in (b,[e]@a) end
63 :     val ( b,a)=flatten(0,e)
64 :     in if(b=3) then (1,E.Const(0.0))
65 :     else case a
66 :     of [] => (1,E.Const(0.0))
67 :     | [e] => (1,e)
68 :     | es => (b, E.Prod es)
69 :     (* end case *)
70 :     end
71 :    
72 :    
73 :     fun mkEps(e)= (case e
74 :     of E.Apply(E.Partial [a], E.Prod( e2::m ))=> (0,e)
75 :     | E.Apply(E.Partial [a,b], E.Prod( (E.Epsilon(i,j,k))::m ))=>
76 :     (if(a=i andalso b=j) then (1,E.Const(0.0))
77 :     else if(a=i andalso b=k) then (1,E.Const(0.0))
78 :     else if(a=j andalso b=i) then (1,E.Const(0.0))
79 :     else if(a=j andalso b=k) then (1,E.Const(0.0))
80 :     else if(a=k andalso b=j) then (1,E.Const(0.0))
81 :     else if(a=k andalso b=i) then (1,E.Const(0.0))
82 :     else (0,e))
83 :     |_=> (0,e)
84 :     (*end case*))
85 :    
86 :     fun mkApply(E.Apply(d, e)) = (case e
87 :     of E.Tensor(a,[])=> (0,E.Const(0.0))
88 :     | E.Tensor _=> (0,E.Apply(d,e))
89 :     | E.Const _=> (1,E.Const(0.0))
90 :     | E.Add l => (1,E.Add(List.map (fn e => E.Apply(d, e)) l))
91 :     | E.Sub(e2, e3) =>(1, E.Sub(E.Apply(d, e2), E.Apply(d, e3)))
92 :     | E.Prod((E.Epsilon c)::e2)=> mkEps(E.Apply(d,e))
93 :     | E.Prod[E.Tensor(a,[]), e2]=> (0, E.Prod[ E.Tensor(a,[]), E.Apply(d, e2)] )
94 :     | E.Prod((E.Tensor(a,[]))::e2)=> (0, E.Prod[E.Tensor(a,[]), E.Apply(d, E.Prod e2)] )
95 :     | E.Prod es => (let
96 :     fun prod [e] = (E.Apply(d, e))
97 :     | prod(e1::e2)=(let val l= prod(e2) val m= E.Prod[e1,l]
98 :     val lr=e2 @[E.Apply(d,e1)] val(b,a) =mkProd lr
99 :     in ( E.Add[ a, m] )
100 :     end)
101 :     | prod _= (E.Const(1.0))
102 :     in (1,prod es)
103 :     end)
104 :     | _=> (0,E.Apply(d,e))
105 :     (*end case*))
106 :    
107 :     fun mkSumApply(E.Sum(c,E.Apply(d, e))) = (case e
108 :     of E.Tensor(a,[])=> (0,E.Const(0.0))
109 :     | E.Tensor _=> (0,E.Sum(c,E.Apply(d,e)))
110 :     | E.Field _ =>(0, E.Sum(c, E.Apply(d,e)))
111 :     | E.Const _=> (1,E.Const(0.0))
112 :     | E.Add l => (1,E.Add(List.map (fn e => E.Sum(c,E.Apply(d, e))) l))
113 :     | E.Sub(e2, e3) =>(1, E.Sub(E.Sum(c,E.Apply(d, e2)), E.Sum(c,E.Apply(d, e3))))
114 :     | E.Prod((E.Epsilon c)::e2)=> mkEps(E.Apply(d,e))
115 :     | E.Prod[E.Tensor(a,[]), e2]=> (0, E.Prod[ E.Tensor(a,[]), E.Sum(c,E.Apply(d, e2))] )
116 :     | E.Prod((E.Tensor(a,[]))::e2)=> (0, E.Prod[E.Tensor(a,[]), E.Sum(c,E.Apply(d, E.Prod e2))] )
117 :     | E.Prod es => (let
118 :     fun prod [e] = (E.Apply(d, e))
119 :     | prod(e1::e2)=(let val l= prod(e2) val m= E.Prod[e1,l]
120 :     val lr=e2 @[E.Apply(d,e1)] val(b,a) =mkProd lr
121 :     in ( E.Add[ a, m] ) end)
122 :     | prod _= (E.Const(1.0))
123 :     in (1, E.Sum(c,prod es)) end)
124 :     | _=> (0,E.Sum(c,E.Apply(d,e)))
125 :     (*end case*))
126 :    
127 :    
128 :    
129 :     (* Identity: (Epsilon ijk Epsilon ilm) e => (Delta jl Delta km - Delta jm Delta kl) e
130 :     The epsToDels Function searches for Epsilons in the expression, checks for this identity in all adjacent Epsilons and if needed, does the transformation.
131 :     The Function returns two separate list, 1 is the remaining list of Epsilons that have not be changed to deltas, and the second is the Product of the remaining expression.
132 :     Ex:(Epsilon_ijk Epsilon_ilm) Epsilon_stu e =>([Epsilon_stu], [Delta_jl,Delta_km,e -Delta_jm Delta_kl, e] )
133 :     This is useful since we can normalize the second list without having to normalize the epsilons again.*)
134 :    
135 :     fun epsToDels(E.Sum(count,E.Prod e))= let
136 :     fun doubleEps((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::es,e3)=
137 :     let
138 :     fun createDeltas(s,t,u,v, e3)=
139 :     (1, E.Sub(E.Sum(2,E.Prod([E.Delta(s,u), E.Delta(t,v)] @e3)),
140 :     E.Sum(2,E.Prod([E.Delta(s,v), E.Delta(t,u)]@e3))))
141 :     in if(a=d) then createDeltas(b,c,e,f, e3)
142 :     else if(a=e) then createDeltas(b,c,f,d, e3)
143 :     else if(a=f) then createDeltas(b,c,d,e, e3)
144 :     else if(b=d) then createDeltas(c,a,e,f, e3)
145 :     else if(b=e) then createDeltas(c,a,f,d,e3)
146 :     else if(b=f) then createDeltas(c,a,d,e,e3)
147 :     else if(c=d) then createDeltas(a,b,e,f,e3)
148 :     else if(c=e) then createDeltas(a,b,f,d,e3)
149 :     else if(c=f) then createDeltas(a,b,d,e,e3)
150 :     else (0,(E.Prod((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::e3)))
151 :     end
152 :     fun findeps(e,[])= (e,[])
153 :     | findeps(e,(E.Epsilon eps)::es)= findeps(e@[E.Epsilon eps],es)
154 :     | findeps(e,es)= (e, es)
155 :     fun distribute([], s)=(0, [],s)
156 :     | distribute([e1], s)=(0, [e1], s)
157 :     | distribute(e1::es, s)= let val(i, exp)=doubleEps(e1::es, s)
158 :     in if(i=1) then (1, tl(es), [exp])
159 :     else let val(a,b,c)= distribute(es, s)
160 :     in (a, [e1]@b, c) end
161 :     end
162 :     val (change, eps,rest)= distribute(findeps([], e))
163 :     in (change, eps,rest) end
164 :    
165 :    
166 :    
167 :    
168 :    
169 :    
170 :    
171 :     (*The Deltas then need to be distributed over to the tensors in the expression e.
172 :     Ex.:Delta ij ,Tensor_j, e=> Tensor_i,e. The mkDelts function compares every Delta in the expression to the tensors in the expressions while keeping the results in the correct order.
173 :     This also returns a list of deltas and a list of the remaining expression.
174 :     *)
175 :    
176 :     fun mkDel(e) = let
177 :     fun Del(i, [],x)= (i,[],x)
178 :     | Del(i, d,[])=(i, d,[])
179 :     | Del(i, (E.Delta(d1,d2))::d, (E.Tensor(id,[x]))::xs)=
180 :     if(x=d2) then (let
181 :     val(i',s,t)= Del(i+1,d, xs)
182 :     in Del(i',s, [E.Tensor(id, [d1])] @t) end)
183 :     else (let val (i',s,t)= Del(i,[E.Delta(d1,d2)],xs)
184 :     val(i2,s2,t2)= Del(i',d,[E.Tensor(id,[x])]@t)
185 :     in (i2,s@s2, t2) end )
186 :     | Del(i, (E.Delta(d1,d2))::d, (E.Field(id,[x]))::xs)=
187 :     if(x=d2) then (let
188 :     val(i',s,t)= Del(i+1,d, xs)
189 :     in Del(i',s, [E.Field(id, [d1])] @t) end)
190 :     else (let val (i',s,t)= Del(i,[E.Delta(d1,d2)],xs)
191 :     val(i2,s2,t2)= Del(i',d,[E.Field(id,[x])]@t)
192 :     in (i2,s@s2, t2) end )
193 :    
194 :     | Del(i, d, t)= (i,d,t)
195 :     fun findels(e,[])= (e,[])
196 :     | findels(e,es)= let val del1= hd(es)
197 :     in (case del1
198 :     of E.Delta _=> findels(e@[del1],tl(es))
199 :     |_=> (e, es))
200 :     end
201 :     val(a,b)= findels([], e)
202 :     in
203 :     Del(0, a, b)
204 :     end
205 :    
206 :    
207 :     (*The Deltas are distributed over to the tensors in the expression e.
208 :     This function checks for instances of the dotProduct.
209 :     Sum_2 (Delta_ij (A_i B_j D_k))=>Sum_1(A_i B_i) D_k
210 :     *)
211 :     fun checkDot(E.Sum(s,E.Prod e))= let
212 :     fun dot(i,d,r, (E.Tensor(ida,[a]))::(E.Tensor(idb,[b]))::ts)=
213 :     if (a=b) then
214 :     dot(i-1,d@[E.Sum(1,E.Prod[(E.Tensor(ida,[a])), (E.Tensor(idb,[b]))])], [],r@ts)
215 :     else dot(i,d, r@[E.Tensor(idb,[b])],(E.Tensor(ida,[a]))::ts)
216 :     |dot(i, d,r, [t])=dot(i,d@[t], [], r)
217 :     |dot(i,d, [],[])= (i,d, [],[])
218 :     |dot(i,d, r, [])= dot(i,d, [], r)
219 :     |dot(i, d, r, (E.Prod p)::t)= dot (i, d, r, p@t)
220 :     |dot(i,d, r, e)= (i,d@r@e, [], [])
221 :    
222 :     val(i,d,r,c)= dot(s,[],[], e)
223 :     val soln= (case d of [d1]=>d1
224 :     |_=> E.Prod d)
225 :     in E.Sum(i,soln) end
226 :     |checkDot(e)= (e)
227 :    
228 :    
229 :    
230 :    
231 :    
232 :    
233 :    
234 :     (*Apply normalize to each term in product list
235 :     or Apply normalize to tail of each list*)
236 :     fun normalize (Ein.EIN{params, index, body}) = let
237 :     val changed = ref false
238 :     fun rewriteBody body = (case body
239 :     of E.Const _=> body
240 :     | E.Tensor _ =>body
241 :     | E.Field _=> body
242 :     | E.Delta _ => body
243 :     | E.Epsilon _=>body
244 :     | E.Conv _=> body
245 :     | E.Partial _=>body
246 :     | E.Add es => let val (b,a)= mkAdd(List.map rewriteBody es)
247 :     in if (b=1) then ( changed:=true;a) else a end
248 : cchiw 2426 | E.Pair es=> E.Pair(List.map rewriteBody es)
249 :     | E.Value _ => body
250 : cchiw 2397 | E.Sub (a,b)=> E.Sub(rewriteBody a, rewriteBody b)
251 :     | E.Div (a, b) => E.Div(rewriteBody a, rewriteBody b)
252 :     | E.Probe(u,v)=> ( E.Probe(rewriteBody u, v))
253 :     | E.Sum(0, e)=>e
254 :     | E.Sum(_, (E.Const c))=> E.Const c
255 :     | E.Sum(c,(E.Add l))=> E.Add(List.map (fn e => E.Sum(c,e)) l)
256 :    
257 :     | E.Sum(c,E.Prod((E.Delta d)::es))=>(
258 :     let val (i,dels, e)= mkDel((E.Delta d)::es)
259 :     val rest=(case e of [e1]=> rewriteBody e1
260 :     |_=> rewriteBody(E.Prod(e)))
261 :     val soln= (case rest of E.Prod r=> E.Sum(c-i, E.Prod(dels@r))
262 :     |_=>E.Sum(c-i, E.Prod(dels@[rest])))
263 :     val q= checkDot(soln)
264 :     in if (i=0) then q
265 :     else (changed :=true;q)
266 :     end )
267 :    
268 :     | E.Sum(c,E.Prod((E.Epsilon e1 )::(E.Epsilon e2)::xs))=>
269 :     let val (i,eps, e)= epsToDels(body)
270 :     in
271 :     if (i=0) then let val e'=rewriteBody(E.Prod(e)) in (case e'
272 :     of E.Prod m=> let val (i2, p)= mkProd(eps @ m)
273 :     in E.Sum(c, p) end
274 :     |_=>E.Sum(c, E.Prod(eps@ [e']))) end
275 :     else(let val [list]=e
276 :     val ans=rewriteBody(list)
277 :     val soln=(case ans
278 :     of E.Sub (E.Sum(c1,(E.Prod s1)),E.Sum(c2,(E.Prod s2))) =>
279 :     E.Sum(c-3+c1, E.Sub(E.Prod(eps@s1),E.Prod(eps@s2)))
280 :     | E.Sub (E.Sum(c1,s1),E.Sum(c2,s2)) =>
281 :     E.Sum(c-3+c1, E.Prod(eps@ [E.Sub(s1,s2)]))
282 :     |_=> E.Prod(eps@ [ans]))
283 :     in (changed :=true;soln) end
284 :     ) end
285 :    
286 :    
287 :     | E.Sum(c, E.Apply(E.Partial p, E.Prod((E.Delta(i,j))::e3 )))=>
288 :    
289 :     let fun part([], e2, counter)=([], e2, counter)
290 :     | part(p1::ps, [E.Delta(i,j)],counter)=if (p1=j) then ([i]@ps,[],counter-1)
291 :     else (let val (a,b,counter)=part(ps, [E.Delta(i,j)],counter)
292 :     in ([p1]@a, b,counter ) end)
293 :     val (e1,e2,counter)= part(p, [E.Delta(i,j)],c)
294 :    
295 :     in (E.Sum(counter, E.Apply(E.Partial e1, E.Prod(e2@e3)))) end
296 :    
297 :     | E.Sum(c, E.Apply(p, e))=>let
298 :     val e'= rewriteBody(E.Sum(c, e))
299 :     val p'= rewriteBody p
300 :     val (i, e2)= (case e'
301 :     of E.Sum(c',exp)=> mkSumApply(E.Sum(c', E.Apply(p', exp)))
302 :     |_=>mkApply( E.Apply(p', e')))
303 :     in if(i=1) then (changed :=true;e2) else e2 end
304 :    
305 :    
306 :     | E.Sum(c, e)=> E.Sum(c, rewriteBody e)
307 :    
308 :     | E.Prod([e1])=>(rewriteBody e1 )
309 :     | E.Prod(e1::(E.Add(e2))::e3)=>
310 :     (changed := true;
311 :     E.Add(List.map (fn e=> E.Prod([e1, e]@e3)) e2))
312 :     | E.Prod(e1::(E.Sub(e2,e3))::e4)=>
313 :     ( changed :=true; E.Sub(E.Prod([e1, e2]@e4), E.Prod([e1,e3]@e4 )))
314 : cchiw 2414 | E.Prod[E.Partial r1,E. Conv(E.Field(id,[i]), deltas)]=>
315 :     (changed:=true; (
316 :     let
317 :     val j1= List.map (fn(x)=> (i,x)) r1
318 :     in E.Conv(E.Field(id,[i]), j1@deltas) end ))
319 : cchiw 2397 | E.Prod((E.Partial r1)::(E.Partial r2)::e) =>
320 :     (changed := true; E.Prod([E.Partial (r1@r2)] @ e) )
321 :     | E.Prod[(E.Epsilon(e1,e2,e3)), E.Tensor(_,[i1,i2])]=>
322 :     if(e2=i1 andalso e3=i2) then (changed :=true;E.Const(0.0))
323 :     else body
324 :     | E.Prod((E.Epsilon eps1)::es)=> (let
325 :     val rest=(case es of [e1] => rewriteBody e1
326 :     |_=> rewriteBody(E.Prod(es)))
327 :     val (i, solution)=(case rest
328 :     of E.Prod m=> mkProd ([E.Epsilon eps1] @m )
329 :     |_=> mkProd([E.Epsilon eps1]@ [rest]))
330 :     in if (i=1) then (changed:=true;solution)
331 :     else solution end)
332 :    
333 :     | E.Prod (e::es) => (let val r=rewriteBody(E.Prod es)
334 :     val (i,solution)= (case r of E.Prod m => mkProd([e]@m )
335 :     |_=> mkProd([e]@ [r]))
336 :     in if (i=1) then (changed:=true;solution)
337 :     else solution end)
338 :     | E.Apply(E.Const _,_) => (E.Const(0.0))
339 :     | E.Apply(E.Partial p, E.Prod((E.Delta(i,j))::e3))=>
340 :     let fun part([], e2)=([], e2)
341 :     | part(p1::ps, [E.Delta(i,j)])=if (p1=j) then ([i]@ps,[])
342 :     else (let val (a,b)=part(ps, [E.Delta(i,j)])
343 :     in ([p1]@a, b ) end)
344 :     val (e1,e2)= part(p, [E.Delta(i,j)])
345 :     in E.Apply(E.Partial e1, E.Prod(e2@e3)) end
346 :    
347 :     | E.Apply(d,e)=> ( let val (t1,t2)= mkApply(E.Apply(rewriteBody d, rewriteBody e))
348 :     in if (t1=1) then (changed :=true;t2) else t2 end )
349 :     |_=> body
350 :    
351 :     (*end case*))
352 :    
353 :     fun loop body = let
354 :     val body' = rewriteBody body
355 :     in
356 :     if !changed
357 :     then (changed := false; loop body')
358 :     else body'
359 :     end
360 :     val b = loop body
361 :     in
362 :     ((Ein.EIN{params=params, index=index, body=b}))
363 :     end
364 :     end
365 :    
366 :    
367 :    
368 :    
369 :    
370 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0