Home My Page Projects Code Snippets Project Openings diderot
 Summary Activity Tracker Tasks SCM

# SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml
 [diderot] / branches / charisee_dev / src / compiler / high-il / normalize-ein.sml

# Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml

revision 2613, Wed May 7 04:35:38 2014 UTC revision 2843, Mon Dec 8 01:27:25 2014 UTC
# Line 11  Line 11
11      in      in
12
13  fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])  fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])
14  val testing=0  val testing=1
15
16  fun flatProd [e]=e  fun flatProd [e]=e
17  | flatProd e=E.Prod e  | flatProd e=E.Prod e
# Line 37  Line 37
37      | E.Apply _   => (0,E.Sum(c1,e1))      | E.Apply _   => (0,E.Sum(c1,e1))
38      | E.Delta _   => (0,E.Sum(c1,e1))      | E.Delta _   => (0,E.Sum(c1,e1))
39      | E.Epsilon _ => (0,E.Sum(c1,e1))      | E.Epsilon _ => (0,E.Sum(c1,e1))
40        | E.Eps2 _    => (0,E.Sum(c1,e1))
41        (*| E.Tensor []  => (1,e1)*)
42      | E.Tensor _  => (0,E.Sum(c1,e1))      | E.Tensor _  => (0,E.Sum(c1,e1))
43      | E.Neg e2    => (1,E.Neg(E.Sum(c1,e2)))      | E.Neg e2    => (1,E.Neg(E.Sum(c1,e2)))
44      | E.Sub (a,b) => (1,E.Sub(E.Sum(c1,a),E.Sum(c1,b)))      | E.Sub (a,b) => (1,E.Sub(E.Sum(c1,a),E.Sum(c1,b)))
# Line 61  Line 63
63                          val E.Partial d3=d1                          val E.Partial d3=d1
64                          in (1,E.Conv(v,alpha,h,d2@d3)) end                          in (1,E.Conv(v,alpha,h,d2@d3)) end
65      | E.Field _   => (0,E.Apply(d1,e1))      | E.Field _   => (0,E.Apply(d1,e1))
66      | E.Probe _   => (0,E.Apply(d1,e1))      | E.Probe _   => (0,E.Apply(d1,e1)) (*FIX ME, Should be error actually apply of a tensor result*)
67      | E.Apply(E.Partial d2,e2)  => let      | E.Apply(E.Partial d2,e2)  => let
68                          val E.Partial d3=d1                          val E.Partial d3=d1
69                          in (1,E.Apply(E.Partial(d3@d2),e2)) end                          in (1,E.Apply(E.Partial(d3@d2),e2)) end
# Line 93  Line 95
95      | E.Tensor _  => err("Tensor without Lift")      | E.Tensor _  => err("Tensor without Lift")
96      | E.Delta _   => err("Apply of Delta")      | E.Delta _   => err("Apply of Delta")
97      | E.Epsilon _ => err("Apply of Eps")      | E.Epsilon _ => err("Apply of Eps")
98        | E.Eps2 _ => err("Apply of Eps")
99      | E.Partial _ => err("Apply of Partial")      | E.Partial _ => err("Apply of Partial")
100      | E.Krn _     => err("Krn used before expand")      | E.Krn _     => err("Krn used before expand")
101      | E.Value _   => err("Value used before expand")      | E.Value _   => err("Value used before expand")
# Line 115  Line 118
118      | E.Div (a,b) => (1,E.Div(E.Probe(a,x),E.Probe(b,x)))      | E.Div (a,b) => (1,E.Div(E.Probe(a,x),E.Probe(b,x)))
119      | E.Const _   => err("Const without Lift")      | E.Const _   => err("Const without Lift")
120      | E.Tensor _  => err("Tensor without Lift")      | E.Tensor _  => err("Tensor without Lift")
121      | E.Delta _   => err("Probe of Delta")      | E.Delta _   => (0,e1)
122      | E.Epsilon _ => err("Probe of Eps")      | E.Epsilon _ => (0,e1)
123        | E.Eps2 _    => (0,e1)
124      | E.Partial _ => err("Probe Partial")      | E.Partial _ => err("Probe Partial")
125      | E.Probe _   => err("Probe of a Probe")      | E.Probe _   => err("Probe of a Probe")
126      | E.Krn _     => err("Krn used before expand")      | E.Krn _     => err("Krn used before expand")
# Line 129  Line 133
133
134  (*Apply normalize to each term in product list  (*Apply normalize to each term in product list
135  or Apply normalize to tail of each list*)  or Apply normalize to tail of each list*)
136  fun normalize (Ein.EIN{params, index, body}) = let  fun normalize (ee as Ein.EIN{params, index, body}) = let
137        val changed = ref false        val changed = ref false
138
139        fun rewriteBody body =(case body        fun rewriteBody body =(case body
# Line 138  Line 142
142              | E.Field _     => body              | E.Field _     => body
143              | E.Delta _     => body              | E.Delta _     => body
144              | E.Epsilon _   => body              | E.Epsilon _   => body
145                | E.Eps2 _   => body
146              | E.Conv _      => body              | E.Conv _      => body
147              | E.Partial _   => body              | E.Partial _   => body
148              | E.Krn _       => raise Fail"Krn before Expand"              | E.Krn _       => raise Fail"Krn before Expand"
# Line 155  Line 160
160              | E.Sub(E.Sub(a,b),e2)          => rewriteBody (E.Sub(a,E.Add[b,e2]))              | E.Sub(E.Sub(a,b),e2)          => rewriteBody (E.Sub(a,E.Add[b,e2]))
161              | E.Sub(e1,E.Sub(c,d))          => rewriteBody(E.Add([E.Sub(e1,c),d]))              | E.Sub(e1,E.Sub(c,d))          => rewriteBody(E.Add([E.Sub(e1,c),d]))
162              | E.Sub (a,b)                   => E.Sub(rewriteBody a, rewriteBody b)              | E.Sub (a,b)                   => E.Sub(rewriteBody a, rewriteBody b)
163                | E.Div(e1 as E.Tensor(_,[_]),e2 as E.Tensor(_,[]))=>
164                        rewriteBody (E.Prod[E.Div(E.Const 1, e2),e1])
165              | E.Div(E.Div(a,b),E.Div(c,d))  => rewriteBody(E.Div(E.Prod[a,d],E.Prod[b,c]))              | E.Div(E.Div(a,b),E.Div(c,d))  => rewriteBody(E.Div(E.Prod[a,d],E.Prod[b,c]))
166              | E.Div(E.Div(a,b),c)           => rewriteBody (E.Div(a, E.Prod[b,c]))              | E.Div(E.Div(a,b),c)           => rewriteBody (E.Div(a, E.Prod[b,c]))
167              | E.Div(a,E.Div(b,c))           => rewriteBody (E.Div(E.Prod[a,c],b))              | E.Div(a,E.Div(b,c))           => rewriteBody (E.Div(E.Prod[a,c],b))
# Line 226  Line 233
233              | E.Prod(E.Epsilon eps1::ps)=> (case (G.epsToDels(E.Epsilon eps1::ps))              | E.Prod(E.Epsilon eps1::ps)=> (case (G.epsToDels(E.Epsilon eps1::ps))
234                  of (1,e,[],_,_)      =>(changed:=true;e)(* Changed to Deltas *)                  of (1,e,[],_,_)      =>(changed:=true;e)(* Changed to Deltas *)
235                  | (1,e,sx,_,_)      =>(changed:=true;E.Sum(sx,e))(* Changed to Deltas *)                  | (1,e,sx,_,_)      =>(changed:=true;E.Sum(sx,e))(* Changed to Deltas *)
236                  | (0,_,_,_,[])   =>  body                  | (_,_,_,_,[])   =>  body
237                  | (0,_,_,epsAll,rest) => let                  | (_,_,_,epsAll,rest) => let
238                          val p'=rewriteBody(E.Prod rest)                          val p'=rewriteBody(E.Prod rest)
239                          val(_,b)= F.mkProd(epsAll@[p'])                          val(_,b)= F.mkProd(epsAll@[p'])
240                          in b end                          in b end
# Line 236  Line 243
243              | E.Prod(E.Sum(c1,E.Prod(E.Epsilon e1::es1))::E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es) =>              | E.Prod(E.Sum(c1,E.Prod(E.Epsilon e1::es1))::E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es) =>
244                  (case G.epsToDels([E.Epsilon e1, E.Epsilon e2]@es1@es2@es)                  (case G.epsToDels([E.Epsilon e1, E.Epsilon e2]@es1@es2@es)
245                  of (1,e,sx,_,_)=> (changed:=true; E.Sum(c1@c2@sx,e))                  of (1,e,sx,_,_)=> (changed:=true; E.Sum(c1@c2@sx,e))
246                  | (0,_,_,_,_)=>let                  | (_,_,_,_,_)=>let
247                      val eA=rewriteBody(E.Sum(c1,E.Prod(E.Epsilon e1::es1)))                      val eA=rewriteBody(E.Sum(c1,E.Prod(E.Epsilon e1::es1)))
248                      val eB=rewriteBody(E.Prod(E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es))                      val eB=rewriteBody(E.Prod(E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es))
249                      val (_,e)=F.mkProd([eA,eB])                      val (_,e)=F.mkProd([eA,eB])
# Line 244  Line 251
251                      end                      end
252                  (*end case*))                  (*end case*))
253
254                | E.Prod[E.Delta d, E.Neg e]=> (changed:=true;E.Neg(E.Prod[E.Delta d, e]))
255              | E.Prod(E.Delta d::es)=>let              | E.Prod(E.Delta d::es)=>let
256                  val (pre',eps, dels,post)= F.filterGreek(E.Delta d::es)                  val (pre',eps, dels,post)= F.filterGreek(E.Delta d::es)
257                  val (change,a)=G.reduceDelta(eps, dels, post)                  val (change,a)=G.reduceDelta(eps, dels, post)
# Line 269  Line 277
277              (*end case*))              (*end case*))
278
279              fun loop(body ,count) = let              fun loop(body ,count) = let
280                    val _ =print(String.concat["\n\n N =>",Int.toString(count),"--",P.printbody(body)])
281                  val body' = rewriteBody body                  val body' = rewriteBody body
282
283                             in                             in
284                if !changed                if !changed
285                  then let                  then  (changed := false ;loop(body',count+1))
val _= (case testing
of 1=> (print(String.concat["\nN =>",Int.toString(count),"--",P.printbody(body')]);1)
| _=> 1)
in
(changed := false ;loop(body',count+1))
end
286                  else (body',count)                  else (body',count)
287              end              end
288        val _ =print(String.concat["\n ******************* \n Start Normalize \n\n "])
289      val (b,count) = loop(body,0)      val (b,count) = loop(body,0)
290      val _ =(case testing      val _ =print(String.concat["\n Out of normalize \n",P.printbody(b),"\n    Final CounterXX:",Int.toString(count),"\n\n"])
of 1 => (print(String.concat["\n out of normalize \n",P.printbody(b),"\n    Final CounterXX:",Int.toString(count),"\n\n"]);1)
| _=> 1
(*end case*))
291      in      in
292                  (Ein.EIN{params=params, index=index, body=b},count)                  (Ein.EIN{params=params, index=index, body=b},count)
293      end      end

Legend:
 Removed from v.2613 changed lines Added in v.2843

 root@smlnj-gforge.cs.uchicago.edu ViewVC Help Powered by ViewVC 1.0.0