SCM Repository
Annotation of /branches/charisee/src/compiler/high-il/normalize-ein.sml
Parent Directory
|
Revision Log
Revision 2795 - (view) (download)
1 : | cchiw | 2397 | |
2 : | structure NormalizeEin = struct | ||
3 : | |||
4 : | local | ||
5 : | cchiw | 2445 | |
6 : | cchiw | 2397 | structure E = Ein |
7 : | cchiw | 2603 | structure P=Printer |
8 : | structure F=Filter | ||
9 : | structure G=EpsHelpers | ||
10 : | cchiw | 2605 | |
11 : | cchiw | 2397 | in |
12 : | |||
13 : | cchiw | 2603 | fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str]) |
14 : | cchiw | 2795 | val testing=1 |
15 : | cchiw | 2603 | |
16 : | cchiw | 2608 | fun flatProd [e]=e |
17 : | | flatProd e=E.Prod e | ||
18 : | |||
19 : | cchiw | 2515 | |
20 : | cchiw | 2605 | fun prodAppPartial(es,p1)=(case es |
21 : | of [] => raise Fail "Empty App Partial" | ||
22 : | | [e1] => E.Apply(E.Partial p1,e1) | ||
23 : | | (e1::e2) => let | ||
24 : | val l= prodAppPartial(e2,p1) | ||
25 : | val (_,e2')= F.mkProd[e1,l] | ||
26 : | val (_,e1')=F.mkProd(e2@ [E.Apply(E.Partial p1, e1)]) | ||
27 : | in | ||
28 : | E.Add[e1',e2'] | ||
29 : | end | ||
30 : | (* end case *)) | ||
31 : | cchiw | 2494 | |
32 : | cchiw | 2603 | (*rewritten Sum*) |
33 : | fun mkSum(c1,e1)=(case e1 | ||
34 : | of E.Conv _ => (0,E.Sum(c1,e1)) | ||
35 : | | E.Field _ => (0,E.Sum(c1,e1)) | ||
36 : | | E.Probe _ => (0,E.Sum(c1,e1)) | ||
37 : | | E.Apply _ => (0,E.Sum(c1,e1)) | ||
38 : | | E.Delta _ => (0,E.Sum(c1,e1)) | ||
39 : | | E.Epsilon _ => (0,E.Sum(c1,e1)) | ||
40 : | | E.Tensor _ => (0,E.Sum(c1,e1)) | ||
41 : | | E.Neg e2 => (1,E.Neg(E.Sum(c1,e2))) | ||
42 : | cchiw | 2608 | | E.Sub (a,b) => (1,E.Sub(E.Sum(c1,a),E.Sum(c1,b))) |
43 : | | E.Add e => (1,E.Add (List.map (fn(a)=>E.Sum(c1,a)) e)) | ||
44 : | | E.Div (a,b) => (1,E.Div(E.Sum(c1,a),E.Sum(c1,b))) | ||
45 : | cchiw | 2603 | | E.Lift e => (1,E.Lift(E.Sum(c1,e))) |
46 : | | E.Sum(c2,e2)=> (1,E.Sum(c1@c2,e2)) | ||
47 : | | E.Prod p =>F.filterSca(c1,p) | ||
48 : | | E.Const _ => err("Sum of Const") | ||
49 : | | E.Partial _ => err("Sum of Partial") | ||
50 : | | E.Krn _ => err("Krn used before expand") | ||
51 : | | E.Value _ => err("Value used before expand") | ||
52 : | | E.Img _ => err("Probe used before expand") | ||
53 : | (*end case*)) | ||
54 : | cchiw | 2452 | |
55 : | cchiw | 2603 | (*rewritten Apply*) |
56 : | fun mkapply(d1,e1)=(case e1 | ||
57 : | of E.Lift e => (1,E.Const 0) | ||
58 : | | E.Prod [] => err("Apply of empty product") | ||
59 : | | E.Add [] => err("Apply of empty Addition") | ||
60 : | | E.Conv(v, alpha, h, d2) =>let | ||
61 : | val E.Partial d3=d1 | ||
62 : | in (1,E.Conv(v,alpha,h,d2@d3)) end | ||
63 : | | E.Field _ => (0,E.Apply(d1,e1)) | ||
64 : | cchiw | 2789 | | E.Probe _ => (0,E.Apply(d1,e1)) (*FIX ME, Should be error actually apply of a tensor result*) |
65 : | cchiw | 2603 | | E.Apply(E.Partial d2,e2) => let |
66 : | val E.Partial d3=d1 | ||
67 : | in (1,E.Apply(E.Partial(d3@d2),e2)) end | ||
68 : | cchiw | 2608 | | E.Apply _ => err" Apply of non-Partial expression" |
69 : | | E.Sum(c2,e2)=> (1,E.Sum(c2,E.Apply(d1,e2))) | ||
70 : | cchiw | 2603 | | E.Neg e2 => (1,E.Neg(E.Apply(d1,e2))) |
71 : | | E.Add e => (1,E.Add (List.map (fn(a)=>E.Apply(d1,a)) e)) | ||
72 : | | E.Sub (a,b) => (1,E.Sub(E.Apply(d1,a),E.Apply(d1,b))) | ||
73 : | cchiw | 2608 | | E.Div (g,b) => let |
74 : | in | ||
75 : | (case F.filterField[b] | ||
76 : | of (_,[]) => (1,E.Div(E.Apply(d1,g),b)) (*Division by a real*) | ||
77 : | | (pre,h) => let | ||
78 : | val g'=E.Apply(d1,g) | ||
79 : | val h'=E.Apply(d1,flatProd(h)) | ||
80 : | val num=E.Sub(E.Prod([g']@h),E.Prod[g,h']) | ||
81 : | val denom=E.Prod(pre@h@h) | ||
82 : | in (1,E.Div(num,denom)) | ||
83 : | end | ||
84 : | (*end case*)) | ||
85 : | end | ||
86 : | |||
87 : | cchiw | 2603 | | E.Prod p =>let |
88 : | val (pre, post)= F.filterField p | ||
89 : | val E.Partial d3=d1 | ||
90 : | in F.mkProd(pre@[prodAppPartial(post,d3)]) | ||
91 : | cchiw | 2515 | end |
92 : | cchiw | 2603 | | E.Const _ => err("Const without Lift") |
93 : | | E.Tensor _ => err("Tensor without Lift") | ||
94 : | | E.Delta _ => err("Apply of Delta") | ||
95 : | | E.Epsilon _ => err("Apply of Eps") | ||
96 : | | E.Partial _ => err("Apply of Partial") | ||
97 : | | E.Krn _ => err("Krn used before expand") | ||
98 : | | E.Value _ => err("Value used before expand") | ||
99 : | | E.Img _ => err("Probe used before expand") | ||
100 : | (*end case*)) | ||
101 : | |||
102 : | cchiw | 2515 | |
103 : | cchiw | 2603 | (*rewritten probe*) |
104 : | fun mkprobe(e1,x)=(case e1 | ||
105 : | of E.Lift e => (1,e) | ||
106 : | | E.Prod [] => err("Probe of empty product") | ||
107 : | | E.Prod p => (1,E.Prod (List.map (fn(a)=>E.Probe(a,x)) p)) | ||
108 : | | E.Apply _ => (0,E.Probe(e1,x)) | ||
109 : | | E.Conv _ => (0,E.Probe(e1,x)) | ||
110 : | | E.Field _ => (0,E.Probe(e1,x)) | ||
111 : | cchiw | 2608 | | E.Sum(c,e') => (1,E.Sum(c,E.Probe(e',x))) |
112 : | cchiw | 2603 | | E.Add e => (1,E.Add (List.map (fn(a)=>E.Probe(a,x)) e)) |
113 : | | E.Sub (a,b) => (1,E.Sub(E.Probe(a,x),E.Probe(b,x))) | ||
114 : | | E.Neg e' => (1,E.Neg(E.Probe(e',x))) | ||
115 : | cchiw | 2608 | | E.Div (a,b) => (1,E.Div(E.Probe(a,x),E.Probe(b,x))) |
116 : | cchiw | 2603 | | E.Const _ => err("Const without Lift") |
117 : | | E.Tensor _ => err("Tensor without Lift") | ||
118 : | | E.Delta _ => err("Probe of Delta") | ||
119 : | | E.Epsilon _ => err("Probe of Eps") | ||
120 : | | E.Partial _ => err("Probe Partial") | ||
121 : | | E.Probe _ => err("Probe of a Probe") | ||
122 : | | E.Krn _ => err("Krn used before expand") | ||
123 : | | E.Value _ => err("Value used before expand") | ||
124 : | | E.Img _ => err("Probe used before expand") | ||
125 : | (*end case*)) | ||
126 : | cchiw | 2515 | |
127 : | cchiw | 2499 | |
128 : | cchiw | 2397 | |
129 : | cchiw | 2525 | |
130 : | cchiw | 2397 | (*Apply normalize to each term in product list |
131 : | or Apply normalize to tail of each list*) | ||
132 : | fun normalize (Ein.EIN{params, index, body}) = let | ||
133 : | val changed = ref false | ||
134 : | cchiw | 2611 | |
135 : | cchiw | 2605 | fun rewriteBody body =(case body |
136 : | of E.Const _ => body | ||
137 : | | E.Tensor _ => body | ||
138 : | | E.Field _ => body | ||
139 : | | E.Delta _ => body | ||
140 : | | E.Epsilon _ => body | ||
141 : | | E.Conv _ => body | ||
142 : | | E.Partial _ => body | ||
143 : | | E.Krn _ => raise Fail"Krn before Expand" | ||
144 : | | E.Img _ => raise Fail"Img before Expand" | ||
145 : | | E.Value _ => raise Fail"Value before Expand" | ||
146 : | cchiw | 2603 | |
147 : | (*************Algebraic Rewrites **************) | ||
148 : | cchiw | 2605 | | E.Neg(E.Neg e) => rewriteBody e |
149 : | | E.Neg e => E.Neg(rewriteBody e) | ||
150 : | | E.Lift e => E.Lift(rewriteBody e) | ||
151 : | | E.Add es => let val (change,body')= F.mkAdd(List.map rewriteBody es) | ||
152 : | cchiw | 2452 | in if (change=1) then ( changed:=true;body') else body' end |
153 : | cchiw | 2605 | | E.Sub(a, E.Field f)=> (changed:=true;E.Add[a, E.Neg(E.Field(f))]) |
154 : | | E.Sub(E.Sub(a,b),E.Sub(c,d)) => rewriteBody(E.Sub(E.Add[a,d],E.Add[b,c])) | ||
155 : | cchiw | 2611 | | E.Sub(E.Sub(a,b),e2) => rewriteBody (E.Sub(a,E.Add[b,e2])) |
156 : | | E.Sub(e1,E.Sub(c,d)) => rewriteBody(E.Add([E.Sub(e1,c),d])) | ||
157 : | | E.Sub (a,b) => E.Sub(rewriteBody a, rewriteBody b) | ||
158 : | cchiw | 2605 | | E.Div(E.Div(a,b),E.Div(c,d)) => rewriteBody(E.Div(E.Prod[a,d],E.Prod[b,c])) |
159 : | | E.Div(E.Div(a,b),c) => rewriteBody (E.Div(a, E.Prod[b,c])) | ||
160 : | cchiw | 2611 | | E.Div(a,E.Div(b,c)) => rewriteBody (E.Div(E.Prod[a,c],b)) |
161 : | cchiw | 2605 | | E.Div (a, b) => E.Div(rewriteBody a, rewriteBody b) |
162 : | cchiw | 2603 | |
163 : | (**************Apply, Sum, Probe**************) | ||
164 : | cchiw | 2611 | | E.Apply(E.Partial [],e) => e |
165 : | | E.Apply(E.Partial d1, e1) => | ||
166 : | let | ||
167 : | cchiw | 2603 | val e2 = rewriteBody e1 |
168 : | val (c,e3)=mkapply(E.Partial d1,e2) | ||
169 : | in (case c of 1=>(changed:=true;e3)| _ =>e3 (*end case*)) | ||
170 : | end | ||
171 : | cchiw | 2611 | | E.Apply _ => raise Fail" Not well-formed Apply expression" |
172 : | | E.Sum([],e) => (changed:=true;rewriteBody e) | ||
173 : | | E.Sum(c,e) => let | ||
174 : | val (c,e')=mkSum(c,rewriteBody e) | ||
175 : | in (case c of 0 => e'|_ => (changed:=true;e')) | ||
176 : | cchiw | 2603 | end |
177 : | cchiw | 2611 | | E.Probe(u,v) => |
178 : | let | ||
179 : | cchiw | 2603 | val (c',b')=mkprobe(rewriteBody u,rewriteBody v) |
180 : | in (case c' | ||
181 : | of 1=> (changed:=true;b') | ||
182 : | |_=> b' | ||
183 : | (*end case*)) | ||
184 : | end | ||
185 : | cchiw | 2494 | (*************Product**************) |
186 : | cchiw | 2605 | | E.Prod [] => raise Fail"missing elements in product" |
187 : | cchiw | 2449 | | E.Prod [e1] => rewriteBody e1 |
188 : | cchiw | 2496 | | E.Prod((E.Add(e2))::e3)=> |
189 : | (changed := true; E.Add(List.map (fn e=> E.Prod([e]@e3)) e2)) | ||
190 : | cchiw | 2608 | | E.Prod((E.Sub(e2,e3))::e4)=> |
191 : | (changed :=true; E.Sub(E.Prod([e2]@e4), E.Prod([e3]@e4 ))) | ||
192 : | | E.Prod((E.Div(e2,e3))::e4)=> (changed :=true; E.Div(E.Prod([e2]@e4), e3 )) | ||
193 : | cchiw | 2496 | | E.Prod(e1::E.Add(e2)::e3)=> |
194 : | (changed := true; E.Add(List.map (fn e=> E.Prod([e1,e]@e3)) e2)) | ||
195 : | | E.Prod(e1::E.Sub(e2,e3)::e4)=> | ||
196 : | (changed :=true; E.Sub(E.Prod([e1,e2]@e4), E.Prod([e1,e3]@e4 ))) | ||
197 : | cchiw | 2603 | |
198 : | cchiw | 2510 | |
199 : | (*************Product EPS **************) | ||
200 : | cchiw | 2605 | |
201 : | cchiw | 2506 | | E.Prod(E.Epsilon(i,j,k)::E.Apply(E.Partial d,e)::es)=>let |
202 : | cchiw | 2603 | val change= G.matchEps(0,d,[],[i,j,k]) |
203 : | cchiw | 2506 | in case (change,es) |
204 : | cchiw | 2611 | of (1,_) =>(changed:=true; E.Const 0) |
205 : | cchiw | 2510 | | (_,[]) =>E.Prod[E.Epsilon(i,j,k),rewriteBody (E.Apply(E.Partial d,e))] |
206 : | cchiw | 2506 | |(_,_)=> let |
207 : | val a=rewriteBody(E.Prod([E.Apply(E.Partial d,e)]@ es)) | ||
208 : | cchiw | 2603 | val (_,b)=F.mkProd [E.Epsilon(i,j,k),a] |
209 : | cchiw | 2611 | in b end |
210 : | cchiw | 2506 | end |
211 : | cchiw | 2510 | | E.Prod(E.Epsilon(i,j,k)::E.Conv(V,alpha, h, d)::es)=>let |
212 : | cchiw | 2611 | val change= G.matchEps(0,d,[],[i,j,k]) |
213 : | in case (change,es) | ||
214 : | of (1,_) =>(changed:=true; E.Const 0) | ||
215 : | | (_,[]) =>E.Prod[E.Epsilon(i,j,k),E.Conv(V,alpha, h, d)] | ||
216 : | | (_,_) =>let | ||
217 : | val a=rewriteBody(E.Prod([E.Conv(V,alpha, h, d)]@ es)) | ||
218 : | val (_,b) = F.mkProd [E.Epsilon(i,j,k),a] | ||
219 : | in b end | ||
220 : | end | ||
221 : | |||
222 : | cchiw | 2452 | | E.Prod[(E.Epsilon(e1,e2,e3)), E.Tensor(_,[E.V i1,E.V i2])]=> |
223 : | cchiw | 2553 | if(e2=i1 andalso e3=i2) then (changed :=true;E.Const(0)) |
224 : | cchiw | 2452 | else body |
225 : | cchiw | 2510 | |
226 : | cchiw | 2611 | | E.Prod(E.Epsilon eps1::ps)=> (case (G.epsToDels(E.Epsilon eps1::ps)) |
227 : | of (1,e,[],_,_) =>(changed:=true;e)(* Changed to Deltas *) | ||
228 : | | (1,e,sx,_,_) =>(changed:=true;E.Sum(sx,e))(* Changed to Deltas *) | ||
229 : | cchiw | 2615 | | (_,_,_,_,[]) => body |
230 : | | (_,_,_,epsAll,rest) => let | ||
231 : | cchiw | 2515 | val p'=rewriteBody(E.Prod rest) |
232 : | cchiw | 2611 | val(_,b)= F.mkProd(epsAll@[p']) |
233 : | cchiw | 2515 | in b end |
234 : | cchiw | 2611 | (*end case*)) |
235 : | |||
236 : | | E.Prod(E.Sum(c1,E.Prod(E.Epsilon e1::es1))::E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es) => | ||
237 : | (case G.epsToDels([E.Epsilon e1, E.Epsilon e2]@es1@es2@es) | ||
238 : | of (1,e,sx,_,_)=> (changed:=true; E.Sum(c1@c2@sx,e)) | ||
239 : | cchiw | 2615 | | (_,_,_,_,_)=>let |
240 : | cchiw | 2611 | val eA=rewriteBody(E.Sum(c1,E.Prod(E.Epsilon e1::es1))) |
241 : | val eB=rewriteBody(E.Prod(E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es)) | ||
242 : | val (_,e)=F.mkProd([eA,eB]) | ||
243 : | in e | ||
244 : | end | ||
245 : | (*end case*)) | ||
246 : | |||
247 : | cchiw | 2515 | | E.Prod(E.Delta d::es)=>let |
248 : | cchiw | 2603 | val (pre',eps, dels,post)= F.filterGreek(E.Delta d::es) |
249 : | cchiw | 2611 | val (change,a)=G.reduceDelta(eps, dels, post) |
250 : | cchiw | 2515 | in (case (change,a) |
251 : | of (0, _)=> E.Prod [E.Delta d,rewriteBody(E.Prod es)] | ||
252 : | | (_, E.Prod p)=>let | ||
253 : | cchiw | 2603 | val (_, p') = F.mkProd p |
254 : | cchiw | 2611 | in (changed:=true;p') end |
255 : | | _ => (changed:=true;a ) | ||
256 : | cchiw | 2515 | (*end case*)) |
257 : | end | ||
258 : | |||
259 : | cchiw | 2603 | | E.Prod[e1,e2]=> let val (_,b)=F.mkProd[rewriteBody e1, rewriteBody e2] in b end |
260 : | cchiw | 2449 | | E.Prod(e::es)=>let |
261 : | val e'=rewriteBody e | ||
262 : | val e2=rewriteBody(E.Prod es) | ||
263 : | cchiw | 2494 | val(_,b)=(case e2 |
264 : | cchiw | 2603 | of E.Prod p'=> F.mkProd([e']@p') |
265 : | |_=>F.mkProd [e',e2]) | ||
266 : | cchiw | 2496 | in b |
267 : | cchiw | 2515 | end |
268 : | cchiw | 2584 | |
269 : | cchiw | 2397 | (*end case*)) |
270 : | cchiw | 2584 | |
271 : | cchiw | 2510 | fun loop(body ,count) = let |
272 : | val body' = rewriteBody body | ||
273 : | cchiw | 2515 | |
274 : | in | ||
275 : | cchiw | 2397 | if !changed |
276 : | cchiw | 2605 | then let |
277 : | val _= (case testing | ||
278 : | of 1=> (print(String.concat["\nN =>",Int.toString(count),"--",P.printbody(body')]);1) | ||
279 : | | _=> 1) | ||
280 : | in | ||
281 : | cchiw | 2611 | (changed := false ;loop(body',count+1)) |
282 : | cchiw | 2605 | end |
283 : | cchiw | 2502 | else (body',count) |
284 : | cchiw | 2397 | end |
285 : | cchiw | 2502 | |
286 : | val (b,count) = loop(body,0) | ||
287 : | cchiw | 2605 | val _ =(case testing |
288 : | of 1 => (print(String.concat["\n out of normalize \n",P.printbody(b),"\n Final CounterXX:",Int.toString(count),"\n\n"]);1) | ||
289 : | | _=> 1 | ||
290 : | (*end case*)) | ||
291 : | cchiw | 2397 | in |
292 : | cchiw | 2502 | (Ein.EIN{params=params, index=index, body=b},count) |
293 : | cchiw | 2397 | end |
294 : | cchiw | 2463 | end |
295 : | |||
296 : | cchiw | 2397 | |
297 : | |||
298 : | end (* local *) |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |