Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3557 - (view) (download)

1 : cchiw 2397
2 :     structure NormalizeEin = struct
3 :    
4 :     local
5 : cchiw 2445
6 : cchiw 2397 structure E = Ein
7 : cchiw 2603 structure P=Printer
8 :     structure F=Filter
9 :     structure G=EpsHelpers
10 : cchiw 2870 structure Eq=EqualEin
11 :     structure R=RationalEin
12 : cchiw 2605
13 : cchiw 2397 in
14 :    
15 : cchiw 3153 val testing=0
16 : cchiw 2845 fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])
17 :     fun mkProd e= F.mkProd e
18 :     fun filterSca e=F.filterSca e
19 :     fun mkAdd e=F.mkAdd e
20 :     fun filterGreek e=F.filterGreek e
21 : cchiw 3138 fun mkapply e= derivativeEin.mkapply e
22 : cchiw 2845 fun testp n=(case testing
23 :     of 0=> 1
24 :     | _ =>(print(String.concat n);1)
25 : cchiw 3441 (*end case*))
26 : cchiw 2608
27 : cchiw 3441 val zero=E.B(E.Const 0)
28 :     fun setConst e = E.setConst e
29 :     fun setNeg e = E.setNeg e
30 :     fun setExp e = E.setExp e
31 :     fun setDiv e= E.setDiv e
32 :     fun setSub e= E.setSub e
33 :     fun setProd e= E.setProd e
34 :     fun setAdd e= E.setAdd e
35 : cchiw 3138
36 : cchiw 2845 (*mkSum:sum_indexid list * ein_exp->int *ein_exp
37 :     *distribute summation expression
38 :     *)
39 : cchiw 3441 fun mkSum(sx1,b)=(case b
40 :     of E.Lift e => (1,E.Lift(E.Sum(sx1,e)))
41 : cchiw 3440 | E.Tensor(_,[]) => (1,b)
42 : cchiw 3441 | E.B _ => (1,b)
43 :     | E.Opn(E.Prod, es) => filterSca(sx1,es)
44 :     | _ => (0,E.Sum(sx1,b))
45 : cchiw 2922 (*end case*))
46 : cchiw 2923
47 : cchiw 2845 (*mkprobe:ein_exp* ein_exp-> int ein_exp
48 :     *rewritten probe
49 :     *)
50 : cchiw 3440 fun mkprobe(b,x)=let
51 :     val (c,rtn)=(case b
52 : cchiw 3441 of (E.B _) => (0,b)
53 :     | E.Tensor _ => err("Tensor without Lift")
54 :     | E.G _ => (0,b)
55 :     | E.Field _ => (0,E.Probe(b,x))
56 :     | E.Lift e1 => (1,e1)
57 : cchiw 3440 | E.Conv _ => (0,E.Probe(b,x))
58 :     | E.Partial _ => err("Probe Partial")
59 : cchiw 3441 | E.Apply _ => (0,E.Probe(b,x))
60 : cchiw 3440 | E.Probe _ => err("Probe of a Probe")
61 :     | E.Value _ => err("Value used before expand")
62 :     | E.Img _ => err("Probe used before expand")
63 : cchiw 3441 | E.Krn _ => err("Krn used before expand")
64 :     | E.Sum(sx1,e1) => (1,E.Sum(sx1,E.Probe(e1,x)))
65 :     | E.Op1(op1, e1) => (1,E.Op1(op1, E.Probe(e1,x)))
66 :     | E.Op2(op2, e1,e2) => (1,E.Op2(op2, E.Probe(e1,x), E.Probe(e2,x)))
67 :     | E.Opn(opn, []) => err("Probe of empty operator")
68 :     | E.Opn(opn, es) => (1,E.Opn(opn, List.map(fn e1=> E.Probe(e1,x)) es))
69 : cchiw 2976 (*end case*))
70 :     in
71 :     (c,rtn)
72 :     end
73 : cchiw 2603
74 : cchiw 2845 (* normalize: EIN->EIN
75 :     * rewrite body of EIN
76 :     * note "c" keeps track if ein_exp is changed
77 :     *)
78 : cchiw 2870 fun normalize (ee as Ein.EIN{params, index, body},args) = let
79 : cchiw 2397 val changed = ref false
80 : cchiw 3441 fun rewrite body =(case body
81 :     of E.B _ => body
82 : cchiw 3440 | E.Tensor _ => body
83 : cchiw 3441 | E.G _ => body
84 :     (************** Field Terms **************)
85 : cchiw 3440 | E.Field _ => body
86 : cchiw 3441 | E.Lift e1 => E.Lift(rewrite e1)
87 : cchiw 3440 | E.Conv _ => body
88 :     | E.Partial _ => body
89 : cchiw 3441 | E.Apply(E.Partial [],e1) => e1
90 :     | E.Apply(E.Partial d1, e1) =>
91 : cchiw 2845 let
92 : cchiw 3441 val e2 = rewrite e1
93 :     val (c,e3)=mkapply(E.Partial d1,e2)
94 :     val _= testp["\nafter apply:",P.printbody body,"-->",P.printbody e3]
95 :     in
96 :     (case c of 1=>(changed:=true;e3)| _ =>e3 (*end case*))
97 :     end
98 :     | E.Apply _ => err " Not well-formed Apply expression"
99 :     | E.Probe(e1,e2) =>
100 : cchiw 2845 let
101 : cchiw 3441 val (c',b')=mkprobe(rewrite e1,rewrite e2)
102 : cchiw 2845 in (case c'
103 :     of 1=> (changed:=true;b')
104 : cchiw 3441 | _ => b'
105 : cchiw 2845 (*end case*))
106 :     end
107 : cchiw 3441 (************** Field Terms **************)
108 :     | E.Value _ => err "Value before Expand"
109 :     | E.Img _ => err "Img before Expand"
110 :     | E.Krn _ => err "Krn before Expand"
111 :     (************** Sum **************)
112 :    
113 :     | E.Sum([],e1) => (changed:=true;rewrite e1)
114 :     | E.Sum(sx1,e1) => let
115 :     val e2=rewrite e1
116 :     val (c,e')=mkSum(sx1,e2)
117 :     val _= testp["\nafter mksum:\n\t",P.printbody body,"\n\t-->",P.printbody e2,"\n\t-->",P.printbody e']
118 :     in
119 :     (case c of 0 => e'|_ => (changed:=true;e'))
120 :     end
121 :     (*************Algebraic Rewrites Op1 **************)
122 :    
123 :     | E.Op1(E.Neg,e1) => (case e1
124 :     of E.Op1(E.Neg,e2) => rewrite e2
125 :     | E.B(E.Const 0) =>(changed:=true;zero)
126 :     | _ => E.Op1(E.Neg,rewrite e1)
127 :     (*end case*))
128 :     | E.Op1(op1,e1) => E.Op1(op1,rewrite e1)
129 :     (*************Algebraic Rewrites Op2 **************)
130 :     | E.Op2(E.Sub,e1,e2) => (case (e1,e2)
131 :     of (E.B(E.Const 0),_) => (changed:=true;setNeg(rewrite e2))
132 :     | (_,E.B(E.Const 0)) => (changed:=true;rewrite e1)
133 :     | _ => setSub(rewrite e1, rewrite e2)
134 :     (*end case*))
135 :     | E.Op2(E.Div,e1,e2) =>(case (e1,e2)
136 :     of (E.B(E.Const 0),_) => (changed:=true;zero)
137 :     |(E.Op2(E.Div,a,b), E.Op2(E.Div,c,d)) => rewrite(setDiv(setProd[a,d],setProd[b,c]))
138 :     |(E.Op2(E.Div,a,b), c) => rewrite (setDiv(a, setProd[b,c]))
139 :     | (a,E.Op2(E.Div,b,c)) => rewrite (setDiv(setProd[a,c],b))
140 :     | _ => setDiv(rewrite e1, rewrite e2)
141 :     (*end case*))
142 :     (*************Algebraic Rewrites Opn **************)
143 :     | E.Opn(E.Add,es) => let
144 :     val (change,body')= mkAdd(List.map rewrite es)
145 :     in if (change=1) then ( changed:=true;body') else body' end
146 :    
147 : cchiw 2845 (*************Product**************)
148 : cchiw 3441 | E.Opn(E.Prod,[]) => err "missing elements in product"
149 :     | E.Opn(E.Prod,[e1]) => rewrite e1
150 :     | E.Opn(E.Prod,[E.Op1(E.Sqrt,s1),E.Op1(E.Sqrt,s2)])=>
151 : cchiw 3166 if(Eq.isBodyEq(s1,s2)) then (changed :=true;s1)
152 : cchiw 2903 else let
153 : cchiw 3441 val a=rewrite (E.Op1(E.Sqrt,s1))
154 :     val b=rewrite (E.Op1(E.Sqrt,s2))
155 : cchiw 2955 val (_,d)=mkProd ([a,b])
156 :     in d
157 :     end
158 : cchiw 2845 (*************Product EPS **************)
159 : cchiw 3441 | E.Opn(E.Prod,(E.G(E.Epsilon e1)::ps))=> let
160 :     val E.G(E.Epsilon(i,j,k))=E.G(E.Epsilon e1)
161 :     val eps1=E.G(E.Epsilon(i,j,k))
162 :     val p1=List.hd(ps)
163 :     in (case ps
164 :     of (E.Apply(E.Partial d,e)::es)=>let
165 :     val change= G.matchEps(0,d,[],[i,j,k])
166 :     in case (change,es)
167 :     of (1,_) => (changed:=true; zero)
168 :     | (_,[]) => setProd[eps1,rewrite p1]
169 :     | (_,_) => let
170 :     val a=rewrite(setProd([p1]@es))
171 :     val (_,b)=mkProd [eps1,a]
172 :     in b end
173 :     end
174 :     | (E.Conv(V,alpha, h, d)::es)=>let
175 :     val change= G.matchEps(0,d,[],[i,j,k])
176 :     in case (change,es)
177 :     of (1,_) => (changed:=true; E.Lift zero )
178 :     | (_,[]) => setProd[eps1,p1]
179 :     | (_,_) => let
180 :     val a=rewrite(setProd([p1]@es))
181 :     val (_,b) = mkProd [eps1,a]
182 :     in b end
183 :     end
184 :     | [E.Tensor(_,[E.V i1,E.V i2])] =>
185 :     if(j=i1 andalso k=i2) then (changed :=true;zero) else body
186 :     | _ => (case (G.epsToDels(eps1::ps))
187 :     of (1,e,[],_,_) => (changed:=true;e)(* Changed to Deltas*)
188 :     | (1,e,sx,_,_) => (changed:=true;E.Sum(sx,e))
189 :     | (_,_,_,_,[]) => body
190 :     | (_,_,_,epsAll,rest) => let
191 :     val p'=rewrite(setProd rest)
192 :     val(_,b)= mkProd(epsAll@[p'])
193 :     in b end
194 :     (*end case*))
195 :     (*end case*))
196 : cchiw 2845 end
197 : cchiw 3441 | E.Opn(E.Prod,E.Sum(c1,e1)::E.Sum(c2,e2)::es)=>(case (e1,e2,es)
198 :     of (E.Opn(E.Prod,E.G(E.Epsilon e1)::es1),E.Opn(E.Prod,E.G(E.Epsilon e2)::es2),_) =>
199 :     (case G.epsToDels([E.G(E.Epsilon e1), E.G(E.Epsilon e2)]@es1@es2@es)
200 :     of (1,e,sx,_,_)=> (changed:=true; E.Sum(c1@c2@sx,e))
201 :     | (_,_,_,_,_)=>let
202 :     val eA=rewrite(E.Sum(c1,setProd(E.G(E.Epsilon e1)::es1)))
203 :     val eB=rewrite(setProd(E.Sum(c2,setProd(E.G(E.Epsilon e2)::es2))::es))
204 :     val (_,e)=mkProd([eA,eB])
205 :     in e end
206 :     (*end case*))
207 :     | (_,_,[]) =>let
208 :     val (_,b)=mkProd[rewrite(E.Sum(c1,e1)), rewrite(E.Sum(c2,e2))]
209 : cchiw 2845 in b end
210 : cchiw 3441 | _ =>let
211 :     val e'=rewrite (E.Sum(c1,e1))
212 :     val e2=rewrite(E.Opn(E.Prod,E.Sum(c2,e2)::es))
213 :     val(_,b)=(case e2
214 :     of E.Opn(E.Prod, p')=> mkProd([e']@p')
215 :     | _ =>mkProd [e',e2])
216 :     in b end
217 : cchiw 2845 (*end case*))
218 : cchiw 3441 | E.Opn(E.Prod,E.G(E.Delta d)::es)=> (case es
219 :     of [E.Op1(E.Neg, e1)]=> (changed:=true;setNeg(setProd[E.G(E.Delta d), e1]))
220 :     | _=> let
221 :     val (pre',eps, dels,post)= filterGreek(E.G(E.Delta d)::es)
222 :     val _= testp["\n\n Reduce delta--",P.printbody(body)]
223 :     val (change,a)=G.reduceDelta(eps, dels, post)
224 :     val _= testp["\n\n ---delta moved--",P.printbody(a)]
225 :     in (case (change,a)
226 :     of (0, _)=> setProd [E.G(E.Delta d),rewrite(setProd es)]
227 :     | (_, E.Opn(E.Prod, p))=>let
228 :     val (_, p') = mkProd p
229 :     in (changed:=true;p') end
230 :     | _ => (changed:=true;a )
231 :     (*end case*))
232 : cchiw 2845 end
233 :     (*end case*))
234 : cchiw 3441 | E.Opn(E.Prod,[e1,e2])=> let
235 :     val (_,b)=mkProd[rewrite e1, rewrite e2]
236 : cchiw 2845 in b end
237 : cchiw 3441 | E.Opn(E.Prod,e1::es)=>let
238 :     val e'=rewrite e1
239 :     val e2=rewrite(setProd es)
240 : cchiw 2845 val(_,b)=(case e2
241 : cchiw 3441 of E.Opn(Prod, p')=> mkProd([e']@p')
242 : cchiw 2845 |_=>mkProd [e',e2])
243 : cchiw 3441 in b end
244 : cchiw 2845 (*end case*))
245 : cchiw 2584
246 : cchiw 3441
247 : cchiw 3267 val _=testp["\n******** Start Normalize: \n",P.printerE ee,"\n*****\n"]
248 : cchiw 2845 fun loop(body ,count) = let
249 : cchiw 3557 val _= (concat["\n N =>",Int.toString(count)])
250 : cchiw 3441 val body' = rewrite body
251 : cchiw 3557 val _=(EqualEin.boolToString(EqualEin.isBodyEq(body,body')))
252 : cchiw 2845 in
253 :     if !changed
254 :     then (changed := false ;loop(body',count+1))
255 :     else (body',count)
256 :     end
257 : cchiw 3267
258 : cchiw 2845 val (b,count) = loop(body,0)
259 : cchiw 3138 val _ =testp["\n Out of normalize \n",P.printbody(b),
260 : cchiw 2845 "\n Final CounterXX:",Int.toString(count),"\n\n"]
261 :     in
262 :     (Ein.EIN{params=params, index=index, body=b},count)
263 :     end
264 : cchiw 2463 end
265 :    
266 : cchiw 2397
267 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0