Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/ein16/src/compiler/high-il/normalize-ein.sml
ViewVC logotype

Annotation of /branches/ein16/src/compiler/high-il/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3682 - (view) (download)

1 : cchiw 2397
2 :     structure NormalizeEin = struct
3 :    
4 :     local
5 : cchiw 2445
6 : cchiw 2397 structure E = Ein
7 : cchiw 2603 structure P=Printer
8 :     structure F=Filter
9 :     structure G=EpsHelpers
10 : cchiw 2870 structure Eq=EqualEin
11 :     structure R=RationalEin
12 : cchiw 2605
13 : cchiw 2397 in
14 :    
15 : cchiw 3153 val testing=0
16 : cchiw 2845 fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])
17 :     fun mkProd e= F.mkProd e
18 :     fun filterSca e=F.filterSca e
19 :     fun mkAdd e=F.mkAdd e
20 :     fun filterGreek e=F.filterGreek e
21 : cchiw 3138 fun mkapply e= derivativeEin.mkapply e
22 : cchiw 2845 fun testp n=(case testing
23 :     of 0=> 1
24 :     | _ =>(print(String.concat n);1)
25 : cchiw 3441 (*end case*))
26 : cchiw 2608
27 : cchiw 3441 val zero=E.B(E.Const 0)
28 :     fun setConst e = E.setConst e
29 :     fun setNeg e = E.setNeg e
30 :     fun setExp e = E.setExp e
31 :     fun setDiv e= E.setDiv e
32 :     fun setSub e= E.setSub e
33 :     fun setProd e= E.setProd e
34 :     fun setAdd e= E.setAdd e
35 : cchiw 3138
36 : cchiw 2845 (*mkSum:sum_indexid list * ein_exp->int *ein_exp
37 :     *distribute summation expression
38 :     *)
39 : cchiw 3441 fun mkSum(sx1,b)=(case b
40 :     of E.Lift e => (1,E.Lift(E.Sum(sx1,e)))
41 : cchiw 3440 | E.Tensor(_,[]) => (1,b)
42 : cchiw 3441 | E.B _ => (1,b)
43 :     | E.Opn(E.Prod, es) => filterSca(sx1,es)
44 :     | _ => (0,E.Sum(sx1,b))
45 : cchiw 2922 (*end case*))
46 : cchiw 2923
47 : cchiw 2845 (*mkprobe:ein_exp* ein_exp-> int ein_exp
48 :     *rewritten probe
49 :     *)
50 : cchiw 3440 fun mkprobe(b,x)=let
51 :     val (c,rtn)=(case b
52 : cchiw 3441 of (E.B _) => (0,b)
53 :     | E.Tensor _ => err("Tensor without Lift")
54 :     | E.G _ => (0,b)
55 :     | E.Field _ => (0,E.Probe(b,x))
56 :     | E.Lift e1 => (1,e1)
57 : cchiw 3440 | E.Conv _ => (0,E.Probe(b,x))
58 :     | E.Partial _ => err("Probe Partial")
59 : cchiw 3441 | E.Apply _ => (0,E.Probe(b,x))
60 : cchiw 3440 | E.Probe _ => err("Probe of a Probe")
61 :     | E.Value _ => err("Value used before expand")
62 :     | E.Img _ => err("Probe used before expand")
63 : cchiw 3441 | E.Krn _ => err("Krn used before expand")
64 :     | E.Sum(sx1,e1) => (1,E.Sum(sx1,E.Probe(e1,x)))
65 :     | E.Op1(op1, e1) => (1,E.Op1(op1, E.Probe(e1,x)))
66 :     | E.Op2(op2, e1,e2) => (1,E.Op2(op2, E.Probe(e1,x), E.Probe(e2,x)))
67 :     | E.Opn(opn, []) => err("Probe of empty operator")
68 :     | E.Opn(opn, es) => (1,E.Opn(opn, List.map(fn e1=> E.Probe(e1,x)) es))
69 : cchiw 2976 (*end case*))
70 :     in
71 :     (c,rtn)
72 :     end
73 : cchiw 2603
74 : cchiw 2845 (* normalize: EIN->EIN
75 :     * rewrite body of EIN
76 :     * note "c" keeps track if ein_exp is changed
77 :     *)
78 : cchiw 2870 fun normalize (ee as Ein.EIN{params, index, body},args) = let
79 : cchiw 2397 val changed = ref false
80 : cchiw 3604 fun rewrite body =let
81 :    
82 :     fun prod2(e1, e2,[]) =
83 :     (case (rewrite e1, rewrite e2)
84 :     of (E.B(E.Const 0), e2') => (changed:=true;e2')
85 :     | (e1', E.B(E.Const 0)) => (changed:=true;e1')
86 :     | (e1', e2') => E.Opn(E.Prod,[e1',e2']))
87 :     | prod2(e1, e2,es) = let
88 :     val e2= E.Opn(E.Prod,e2::es)
89 :     in
90 :     (case (rewrite e1, rewrite e2)
91 :     of (E.B(E.Const 0), _) => (changed:=true;E.B(E.Const 0))
92 :     | (_ , E.B(E.Const 0)) => (changed:=true;E.B(E.Const 0))
93 :     | (e1', E.Opn(E.Prod, ps')) => E.Opn(E.Prod, e1'::ps')
94 :     | (e1', e2') =>(changed:=true; E.Opn(E.Prod,[e1',e2']))
95 :     (*end case*))
96 :     end
97 :    
98 :     in (case body
99 : cchiw 3441 of E.B _ => body
100 : cchiw 3440 | E.Tensor _ => body
101 : cchiw 3441 | E.G _ => body
102 :     (************** Field Terms **************)
103 : cchiw 3440 | E.Field _ => body
104 : cchiw 3441 | E.Lift e1 => E.Lift(rewrite e1)
105 : cchiw 3440 | E.Conv _ => body
106 :     | E.Partial _ => body
107 : cchiw 3604 | E.Probe(e1,e2) =>
108 :     let
109 :     val (c',b')=mkprobe(rewrite e1,rewrite e2)
110 :     in (case c'
111 :     of 1=> (changed:=true;b')
112 :     | _ => b'
113 :     (*end case*))
114 :     end
115 : cchiw 3441 | E.Apply(E.Partial [],e1) => e1
116 :     | E.Apply(E.Partial d1, e1) =>
117 : cchiw 2845 let
118 : cchiw 3441 val e2 = rewrite e1
119 :     val (c,e3)=mkapply(E.Partial d1,e2)
120 :     val _= testp["\nafter apply:",P.printbody body,"-->",P.printbody e3]
121 :     in
122 :     (case c of 1=>(changed:=true;e3)| _ =>e3 (*end case*))
123 :     end
124 :     | E.Apply _ => err " Not well-formed Apply expression"
125 : cchiw 3604
126 : cchiw 3441 (************** Field Terms **************)
127 :     | E.Value _ => err "Value before Expand"
128 :     | E.Img _ => err "Img before Expand"
129 :     | E.Krn _ => err "Krn before Expand"
130 :     (************** Sum **************)
131 :    
132 :     | E.Sum([],e1) => (changed:=true;rewrite e1)
133 :     | E.Sum(sx1,e1) => let
134 :     val e2=rewrite e1
135 :     val (c,e')=mkSum(sx1,e2)
136 :     val _= testp["\nafter mksum:\n\t",P.printbody body,"\n\t-->",P.printbody e2,"\n\t-->",P.printbody e']
137 :     in
138 :     (case c of 0 => e'|_ => (changed:=true;e'))
139 :     end
140 :     (*************Algebraic Rewrites Op1 **************)
141 :     | E.Op1(E.Neg,e1) => (case e1
142 :     of E.Op1(E.Neg,e2) => rewrite e2
143 :     | E.B(E.Const 0) =>(changed:=true;zero)
144 :     | _ => E.Op1(E.Neg,rewrite e1)
145 :     (*end case*))
146 :     | E.Op1(op1,e1) => E.Op1(op1,rewrite e1)
147 :     (*************Algebraic Rewrites Op2 **************)
148 :     | E.Op2(E.Sub,e1,e2) => (case (e1,e2)
149 :     of (E.B(E.Const 0),_) => (changed:=true;setNeg(rewrite e2))
150 :     | (_,E.B(E.Const 0)) => (changed:=true;rewrite e1)
151 :     | _ => setSub(rewrite e1, rewrite e2)
152 :     (*end case*))
153 :     | E.Op2(E.Div,e1,e2) =>(case (e1,e2)
154 :     of (E.B(E.Const 0),_) => (changed:=true;zero)
155 :     |(E.Op2(E.Div,a,b), E.Op2(E.Div,c,d)) => rewrite(setDiv(setProd[a,d],setProd[b,c]))
156 :     |(E.Op2(E.Div,a,b), c) => rewrite (setDiv(a, setProd[b,c]))
157 :     | (a,E.Op2(E.Div,b,c)) => rewrite (setDiv(setProd[a,c],b))
158 :     | _ => setDiv(rewrite e1, rewrite e2)
159 :     (*end case*))
160 :     (*************Algebraic Rewrites Opn **************)
161 : cchiw 3678 (*
162 :     | E.Opn(E.Add,[E.Sum([(E.V 4,0,2)],E.Opn(E.Prod,[E.Tensor(0,[E.V 4,E.V 0]), E.Tensor(1,[E.V 4,E.V 1])])),E.Opn(E.Prod,[E.Tensor(2,[]),
163 :     E.Sum([(E.V 8,0,2)],E.Opn(E.Prod,[E.Tensor(0,[E.V 8,E.V 0]), E.Tensor(4,[E.V 8,E.V 1])]))])]) => let
164 :     val a = E.Tensor(0,[E.V 2,E.V 0])
165 :     val b = E.Tensor(1,[E.V 2,E.V 1])
166 :     val c = E.Tensor(4,[E.V 2,E.V 1])
167 :     val s = E.Tensor(2,[])
168 :     val add = E.Opn(E.Add,[b,E.Opn(E.Prod,[s,c])])
169 :     val prod= E.Sum([(E.V 2,0,2)],E.Opn(E.Prod,[a,add]))
170 :     val _ =print(String.concat["\nMatched"])
171 :     in prod end
172 :     *)
173 : cchiw 3441 | E.Opn(E.Add,es) => let
174 :     val (change,body')= mkAdd(List.map rewrite es)
175 :     in if (change=1) then ( changed:=true;body') else body' end
176 :    
177 : cchiw 2845 (*************Product**************)
178 : cchiw 3441 | E.Opn(E.Prod,[]) => err "missing elements in product"
179 :     | E.Opn(E.Prod,[e1]) => rewrite e1
180 : cchiw 3604 | E.Opn(E.Prod,[e1 as E.Op1(E.Sqrt,s1),e2 as E.Op1(E.Sqrt,s2)])=>
181 : cchiw 3166 if(Eq.isBodyEq(s1,s2)) then (changed :=true;s1)
182 : cchiw 3604 else (*let
183 : cchiw 3441 val a=rewrite (E.Op1(E.Sqrt,s1))
184 :     val b=rewrite (E.Op1(E.Sqrt,s2))
185 : cchiw 2955 val (_,d)=mkProd ([a,b])
186 :     in d
187 : cchiw 3604 end*) prod2(e1, e2,[])
188 : cchiw 2845 (*************Product EPS **************)
189 : cchiw 3441 | E.Opn(E.Prod,(E.G(E.Epsilon e1)::ps))=> let
190 :     val E.G(E.Epsilon(i,j,k))=E.G(E.Epsilon e1)
191 :     val eps1=E.G(E.Epsilon(i,j,k))
192 :     val p1=List.hd(ps)
193 :     in (case ps
194 :     of (E.Apply(E.Partial d,e)::es)=>let
195 :     val change= G.matchEps(0,d,[],[i,j,k])
196 :     in case (change,es)
197 :     of (1,_) => (changed:=true; zero)
198 : cchiw 3604 | _ => prod2(eps1, p1, es)
199 : cchiw 3441 end
200 :     | (E.Conv(V,alpha, h, d)::es)=>let
201 :     val change= G.matchEps(0,d,[],[i,j,k])
202 :     in case (change,es)
203 :     of (1,_) => (changed:=true; E.Lift zero )
204 : cchiw 3604 | (_,_) => prod2(eps1, p1 ,es)
205 : cchiw 3441 end
206 :     | _ => (case (G.epsToDels(eps1::ps))
207 :     of (1,e,[],_,_) => (changed:=true;e)(* Changed to Deltas*)
208 :     | (1,e,sx,_,_) => (changed:=true;E.Sum(sx,e))
209 :     | (_,_,_,_,[]) => body
210 : cchiw 3604 | (_,_,_,epsAll,[r]) => E.Opn(E.Prod,epsAll@[rewrite r])
211 :     | (_,_,_,epsAll,rest) => (case (rewrite(E.Opn(E.Prod, rest)))
212 :     of E.Opn(E.Prod, ps')=> E.Opn(E.Prod, epsAll@ ps')
213 :     | t => (changed:=true; E.Opn(E.Prod,epsAll@[t])))
214 : cchiw 3441 (*end case*))
215 :     (*end case*))
216 : cchiw 2845 end
217 : cchiw 3441 | E.Opn(E.Prod,E.Sum(c1,e1)::E.Sum(c2,e2)::es)=>(case (e1,e2,es)
218 :     of (E.Opn(E.Prod,E.G(E.Epsilon e1)::es1),E.Opn(E.Prod,E.G(E.Epsilon e2)::es2),_) =>
219 :     (case G.epsToDels([E.G(E.Epsilon e1), E.G(E.Epsilon e2)]@es1@es2@es)
220 :     of (1,e,sx,_,_)=> (changed:=true; E.Sum(c1@c2@sx,e))
221 : cchiw 3604 | (_,_,_,_,_)=>
222 :     let
223 :     val eA= E.Sum(c1,setProd(E.G(E.Epsilon e1)::es1))
224 :     val eB= E.Sum(c2,setProd(E.G(E.Epsilon e2)::es2))
225 :     in prod2(eA, eB, es) end
226 :    
227 : cchiw 3441 (*end case*))
228 : cchiw 3604 | _ => prod2(E.Sum(c1,e1),E.Sum(c2,e2),es)
229 : cchiw 2845 (*end case*))
230 : cchiw 3604 | E.Opn(E.Prod,E.G(E.Delta d)::es) => let
231 :     val (pre',eps, dels,post)= filterGreek(E.G(E.Delta d)::es)
232 :     val _= testp["\n\n Reduce delta--",P.printbody(body)]
233 :     val (change,a)=G.reduceDelta(eps, dels, post)
234 :     val _= testp["\n\n ---delta moved--",P.printbody(a)]
235 :     in (case (change,a)
236 :     of (0, _)=> prod2(E.G(E.Delta d), List.hd(es) , List.tl(es))
237 :     | (_, E.Opn(E.Prod, p))=>let
238 :     val (_, p') = mkProd p
239 :     in (changed:=true;p') end
240 :     | _ => raise Fail"impossible"
241 :     (*end case*))
242 :     end
243 : cchiw 3441 | E.Opn(E.Prod,[e1,e2])=> let
244 :     val (_,b)=mkProd[rewrite e1, rewrite e2]
245 : cchiw 2845 in b end
246 : cchiw 3441 | E.Opn(E.Prod,e1::es)=>let
247 :     val e'=rewrite e1
248 :     val e2=rewrite(setProd es)
249 : cchiw 2845 val(_,b)=(case e2
250 : cchiw 3441 of E.Opn(Prod, p')=> mkProd([e']@p')
251 : cchiw 2845 |_=>mkProd [e',e2])
252 : cchiw 3441 in b end
253 : cchiw 2845 (*end case*))
254 : cchiw 3604 end
255 : cchiw 2584
256 : cchiw 3441
257 : cchiw 3267 val _=testp["\n******** Start Normalize: \n",P.printerE ee,"\n*****\n"]
258 : cchiw 2845 fun loop(body ,count) = let
259 : cchiw 3557 val _= (concat["\n N =>",Int.toString(count)])
260 : cchiw 3441 val body' = rewrite body
261 : cchiw 3557 val _=(EqualEin.boolToString(EqualEin.isBodyEq(body,body')))
262 : cchiw 2845 in
263 :     if !changed
264 :     then (changed := false ;loop(body',count+1))
265 :     else (body',count)
266 :     end
267 : cchiw 3267
268 : cchiw 2845 val (b,count) = loop(body,0)
269 : cchiw 3138 val _ =testp["\n Out of normalize \n",P.printbody(b),
270 : cchiw 2845 "\n Final CounterXX:",Int.toString(count),"\n\n"]
271 :     in
272 :     (Ein.EIN{params=params, index=index, body=b},count)
273 :     end
274 : cchiw 2463 end
275 :    
276 : cchiw 2397
277 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0