Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2843, Mon Dec 8 01:27:25 2014 UTC revision 2844, Tue Dec 9 18:05:29 2014 UTC
# Line 12  Line 12 
12    
13  fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])  fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])
14  val testing=1  val testing=1
15    fun flatProd e =F.rewriteProd e
16    fun mkProd e= F.mkProd e
17    fun filterSca e=F.filterSca e
18    fun filterField e=F.filterField e
19    fun mkAdd e=F.mkAdd e
20    fun filterGreek e=F.filterGreek e
21    
 fun flatProd [e]=e  
 | flatProd e=E.Prod e  
22    
23    fun testp n=(case testing
24    of 0=> 1
25    | _ =>(print(String.concat n);1)
26    (*end case*))
27    
28    (*prodAppPartia:ein_exp list * mu list ->ein_exp
29    * chain rule
30    *)
31  fun prodAppPartial(es,p1)=(case es  fun prodAppPartial(es,p1)=(case es
32      of []      => raise Fail "Empty App Partial"      of []      => err "Empty App Partial"
33      | [e1]     => E.Apply(E.Partial p1,e1)      | [e1]     => E.Apply(E.Partial p1,e1)
34      | (e1::e2) => let      | (e1::e2) => let
35          val l= prodAppPartial(e2,p1)          val l= prodAppPartial(e2,p1)
36          val (_,e2')= F.mkProd[e1,l]          val (_,e2')= mkProd[e1,l]
37          val (_,e1')=F.mkProd(e2@ [E.Apply(E.Partial p1, e1)])          val (_,e1')=mkProd(e2@ [E.Apply(E.Partial p1, e1)])
38          in          in
39              E.Add[e1',e2']              E.Add[e1',e2']
40          end          end
41      (* end case *))      (* end case *))
42    
43  (*rewritten Sum*)  (*mkSum:sum_indexid list * ein_exp->int *ein_exp
44    *distribute summation expression
45    *)
46  fun mkSum(c1,e1)=(case e1  fun mkSum(c1,e1)=(case e1
47      of E.Conv _   => (0,E.Sum(c1,e1))      of E.Conv _   => (0,E.Sum(c1,e1))
48      | E.Field _   => (0,E.Sum(c1,e1))      | E.Field _   => (0,E.Sum(c1,e1))
# Line 38  Line 51 
51      | E.Delta _   => (0,E.Sum(c1,e1))      | E.Delta _   => (0,E.Sum(c1,e1))
52      | E.Epsilon _ => (0,E.Sum(c1,e1))      | E.Epsilon _ => (0,E.Sum(c1,e1))
53      | E.Eps2 _    => (0,E.Sum(c1,e1))      | E.Eps2 _    => (0,E.Sum(c1,e1))
54      (*| E.Tensor []  => (1,e1)*)      | E.Tensor(_,[]) => (1,e1)
55      | E.Tensor _  => (0,E.Sum(c1,e1))      | E.Tensor _  => (0,E.Sum(c1,e1))
56      | E.Neg e2    => (1,E.Neg(E.Sum(c1,e2)))      | E.Neg e2    => (1,E.Neg(E.Sum(c1,e2)))
57      | E.Sub (a,b) => (1,E.Sub(E.Sum(c1,a),E.Sum(c1,b)))      | E.Sub (a,b) => (1,E.Sub(E.Sum(c1,a),E.Sum(c1,b)))
# Line 46  Line 59 
59      | E.Div (a,b) => (1,E.Div(E.Sum(c1,a),E.Sum(c1,b)))      | E.Div (a,b) => (1,E.Div(E.Sum(c1,a),E.Sum(c1,b)))
60      | E.Lift e    => (1,E.Lift(E.Sum(c1,e)))      | E.Lift e    => (1,E.Lift(E.Sum(c1,e)))
61      | E.Sum(c2,e2)=> (1,E.Sum(c1@c2,e2))      | E.Sum(c2,e2)=> (1,E.Sum(c1@c2,e2))
62      | E.Prod p     =>F.filterSca(c1,p)      | E.Prod p    => filterSca(c1,p)
63      | E.Const _   => err("Sum of Const")      | E.Const _   => err("Sum of Const")
64      | E.Partial _ => err("Sum of Partial")      | E.Partial _ => err("Sum of Partial")
65      | E.Krn _     => err("Krn used before expand")      | E.Krn _     => err("Krn used before expand")
# Line 54  Line 67 
67      | E.Img _     => err("Probe used before expand")      | E.Img _     => err("Probe used before expand")
68      (*end case*))      (*end case*))
69    
70  (*rewritten Apply*)  (* mkapply:mu list*ein_exp->int*ein_exp
71    * rewrite Apply
72    *)
73  fun mkapply(d1,e1)=(case e1  fun mkapply(d1,e1)=(case e1
74      of E.Lift e   => (1,E.Const 0)      of E.Lift e   => (1,E.Const 0)
75      | E.Prod []   => err("Apply of empty product")      | E.Prod []   => err("Apply of empty product")
76      | E.Add []    => err("Apply of empty Addition")      | E.Add []    => err("Apply of empty Addition")
77      | E.Conv(v, alpha, h, d2)    =>let      | E.Conv(v, alpha, h, d2)    =>let
78                          val E.Partial d3=d1                          val E.Partial d3=d1
79                          in (1,E.Conv(v,alpha,h,d2@d3)) end              in
80                    (1,E.Conv(v,alpha,h,d2@d3))
81                end
82      | E.Field _   => (0,E.Apply(d1,e1))      | E.Field _   => (0,E.Apply(d1,e1))
83      | E.Probe _   => (0,E.Apply(d1,e1)) (*FIX ME, Should be error actually apply of a tensor result*)      | E.Probe _   => (0,E.Apply(d1,e1))
84      | E.Apply(E.Partial d2,e2)  => let      | E.Apply(E.Partial d2,e2)  => let
85                          val E.Partial d3=d1                          val E.Partial d3=d1
86                          in (1,E.Apply(E.Partial(d3@d2),e2)) end              in
87                    (1,E.Apply(E.Partial(d3@d2),e2))
88                end
89      | E.Apply _   => err" Apply of non-Partial expression"      | E.Apply _   => err" Apply of non-Partial expression"
90      | E.Sum(c2,e2)=> (1,E.Sum(c2,E.Apply(d1,e2)))      | E.Sum(c2,e2)=> (1,E.Sum(c2,E.Apply(d1,e2)))
91      | E.Neg e2    => (1,E.Neg(E.Apply(d1,e2)))      | E.Neg e2    => (1,E.Neg(E.Apply(d1,e2)))
92      | E.Add e     => (1,E.Add (List.map (fn(a)=>E.Apply(d1,a)) e))      | E.Add e     => (1,E.Add (List.map (fn(a)=>E.Apply(d1,a)) e))
93      | E.Sub (a,b) => (1,E.Sub(E.Apply(d1,a),E.Apply(d1,b)))      | E.Sub (a,b) => (1,E.Sub(E.Apply(d1,a),E.Apply(d1,b)))
94      | E.Div (g,b) => let      | E.Div (g,b) => (case filterField[b]
         in  
         (case F.filterField[b]  
95          of (_,[]) => (1,E.Div(E.Apply(d1,g),b)) (*Division by a real*)          of (_,[]) => (1,E.Div(E.Apply(d1,g),b)) (*Division by a real*)
96          | (pre,h) => let          | (pre,h) => let
97                (*quotient rule*)
98              val g'=E.Apply(d1,g)              val g'=E.Apply(d1,g)
99              val h'=E.Apply(d1,flatProd(h))              val h'=E.Apply(d1,flatProd(h))
100              val num=E.Sub(E.Prod([g']@h),E.Prod[g,h'])              val num=E.Sub(E.Prod([g']@h),E.Prod[g,h'])
# Line 84  Line 102 
102              in (1,E.Div(num,denom))              in (1,E.Div(num,denom))
103              end              end
104          (*end case*))          (*end case*))
         end  
   
105      | E.Prod p =>let      | E.Prod p =>let
106          val (pre, post)= F.filterField p          val (pre, post)= filterField p
107          val E.Partial d3=d1          val E.Partial d3=d1
108          in F.mkProd(pre@[prodAppPartial(post,d3)])          in mkProd(pre@[prodAppPartial(post,d3)])
109          end          end
110      | E.Const _   => err("Const without Lift")      | E.Const _   => err("Const without Lift")
111      | E.Tensor _  => err("Tensor without Lift")      | E.Tensor _  => err("Tensor without Lift")
# Line 103  Line 119 
119      (*end case*))      (*end case*))
120    
121    
122  (*rewritten probe*)  (*mkprobe:ein_exp* ein_exp-> int ein_exp
123    *rewritten probe
124    *)
125  fun mkprobe(e1,x)=(case e1  fun mkprobe(e1,x)=(case e1
126      of E.Lift e   => (1,e)      of E.Lift e   => (1,e)
127      | E.Prod []   => err("Probe of empty product")      | E.Prod []   => err("Probe of empty product")
# Line 129  Line 147 
147  (*end case*))  (*end case*))
148    
149    
150    (*normalize: EIN->EIN
151    *rewrite body of EIN
152  (*Apply normalize to each term in product list  *)
 or Apply normalize to tail of each list*)  
153  fun normalize (ee as Ein.EIN{params, index, body}) = let  fun normalize (ee as Ein.EIN{params, index, body}) = let
154        val changed = ref false        val changed = ref false
155    
# Line 153  Line 170 
170              | E.Neg(E.Neg e)    => rewriteBody e              | E.Neg(E.Neg e)    => rewriteBody e
171              | E.Neg e           => E.Neg(rewriteBody e)              | E.Neg e           => E.Neg(rewriteBody e)
172              | E.Lift e          => E.Lift(rewriteBody e)              | E.Lift e          => E.Lift(rewriteBody e)
173              | E.Add es          => let val (change,body')= F.mkAdd(List.map rewriteBody es)              | E.Add es          => let val (change,body')= mkAdd(List.map rewriteBody es)
174                     in if (change=1) then ( changed:=true;body') else body' end                     in if (change=1) then ( changed:=true;body') else body' end
175              | E.Sub(a, E.Field f)=> (changed:=true;E.Add[a, E.Neg(E.Field(f))])              | E.Sub(a, E.Field f)=> (changed:=true;E.Add[a, E.Neg(E.Field(f))])
176              | E.Sub(E.Sub(a,b),E.Sub(c,d))  => rewriteBody(E.Sub(E.Add[a,d],E.Add[b,c]))              | E.Sub(E.Sub(a,b),E.Sub(c,d))  => rewriteBody(E.Sub(E.Add[a,d],E.Add[b,c]))
# Line 212  Line 229 
229                      | (_,[]) =>E.Prod[E.Epsilon(i,j,k),rewriteBody (E.Apply(E.Partial d,e))]                      | (_,[]) =>E.Prod[E.Epsilon(i,j,k),rewriteBody (E.Apply(E.Partial d,e))]
230                      |(_,_)=> let                      |(_,_)=> let
231                          val a=rewriteBody(E.Prod([E.Apply(E.Partial d,e)]@ es))                          val a=rewriteBody(E.Prod([E.Apply(E.Partial d,e)]@ es))
232                          val (_,b)=F.mkProd [E.Epsilon(i,j,k),a]                          val (_,b)=mkProd [E.Epsilon(i,j,k),a]
233                          in b end                          in b end
234                  end                  end
235                | E.Prod(E.Epsilon(i,j,k)::E.Conv(V,alpha, h, d)::es)=>let                | E.Prod(E.Epsilon(i,j,k)::E.Conv(V,alpha, h, d)::es)=>let
# Line 222  Line 239 
239                          | (_,[]) =>E.Prod[E.Epsilon(i,j,k),E.Conv(V,alpha, h, d)]                          | (_,[]) =>E.Prod[E.Epsilon(i,j,k),E.Conv(V,alpha, h, d)]
240                          | (_,_) =>let                          | (_,_) =>let
241                              val a=rewriteBody(E.Prod([E.Conv(V,alpha, h, d)]@ es))                              val a=rewriteBody(E.Prod([E.Conv(V,alpha, h, d)]@ es))
242                              val (_,b) = F.mkProd [E.Epsilon(i,j,k),a]                              val (_,b) = mkProd [E.Epsilon(i,j,k),a]
243                              in b end                              in b end
244                      end                      end
245    
# Line 236  Line 253 
253                  | (_,_,_,_,[])   =>  body                  | (_,_,_,_,[])   =>  body
254                  | (_,_,_,epsAll,rest) => let                  | (_,_,_,epsAll,rest) => let
255                          val p'=rewriteBody(E.Prod rest)                          val p'=rewriteBody(E.Prod rest)
256                          val(_,b)= F.mkProd(epsAll@[p'])                          val(_,b)= mkProd(epsAll@[p'])
257                          in b end                          in b end
258                  (*end case*))                  (*end case*))
259    
# Line 246  Line 263 
263                  | (_,_,_,_,_)=>let                  | (_,_,_,_,_)=>let
264                      val eA=rewriteBody(E.Sum(c1,E.Prod(E.Epsilon e1::es1)))                      val eA=rewriteBody(E.Sum(c1,E.Prod(E.Epsilon e1::es1)))
265                      val eB=rewriteBody(E.Prod(E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es))                      val eB=rewriteBody(E.Prod(E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es))
266                      val (_,e)=F.mkProd([eA,eB])                      val (_,e)=mkProd([eA,eB])
267                      in e                      in e
268                      end                      end
269                  (*end case*))                  (*end case*))
270    
271              | E.Prod[E.Delta d, E.Neg e]=> (changed:=true;E.Neg(E.Prod[E.Delta d, e]))              | E.Prod[E.Delta d, E.Neg e]=> (changed:=true;E.Neg(E.Prod[E.Delta d, e]))
272              | E.Prod(E.Delta d::es)=>let              | E.Prod(E.Delta d::es)=>let
273                  val (pre',eps, dels,post)= F.filterGreek(E.Delta d::es)                  val (pre',eps, dels,post)= filterGreek(E.Delta d::es)
274                  val (change,a)=G.reduceDelta(eps, dels, post)                  val (change,a)=G.reduceDelta(eps, dels, post)
275                  in (case (change,a)                  in (case (change,a)
276                      of (0, _)=> E.Prod [E.Delta d,rewriteBody(E.Prod es)]                      of (0, _)=> E.Prod [E.Delta d,rewriteBody(E.Prod es)]
277                      | (_, E.Prod p)=>let                      | (_, E.Prod p)=>let
278                          val (_, p') = F.mkProd p                          val (_, p') = mkProd p
279                          in (changed:=true;p') end                          in (changed:=true;p') end
280                      | _ => (changed:=true;a )                      | _ => (changed:=true;a )
281                      (*end case*))                      (*end case*))
282                      end                      end
283    
284                | E.Prod[e1,e2]=> let val (_,b)=F.mkProd[rewriteBody e1, rewriteBody e2] in b end                | E.Prod[e1,e2]=> let val (_,b)=mkProd[rewriteBody e1, rewriteBody e2] in b end
285                | E.Prod(e::es)=>let                | E.Prod(e::es)=>let
286                      val e'=rewriteBody e                      val e'=rewriteBody e
287                      val e2=rewriteBody(E.Prod es)                      val e2=rewriteBody(E.Prod es)
288                      val(_,b)=(case e2                      val(_,b)=(case e2
289                          of E.Prod p'=> F.mkProd([e']@p')                          of E.Prod p'=> mkProd([e']@p')
290                          |_=>F.mkProd [e',e2])                          |_=>mkProd [e',e2])
291                  in b                  in b
292                     end                     end
293    
294              (*end case*))              (*end case*))
295    
296              fun loop(body ,count) = let              fun loop(body ,count) = let
297                  val _ =print(String.concat["\n\n N =>",Int.toString(count),"--",P.printbody(body)])                  val _= testp["\n\n N =>",Int.toString(count),"--",P.printbody(body)]
298                  val body' = rewriteBody body                  val body' = rewriteBody body
299    
300                             in                             in
# Line 285  Line 302 
302                  then  (changed := false ;loop(body',count+1))                  then  (changed := false ;loop(body',count+1))
303                  else (body',count)                  else (body',count)
304              end              end
305      val _ =print(String.concat["\n ******************* \n Start Normalize \n\n "])              val _ =testp["\n ******************* \n Start Normalize \n\n "]
306      val (b,count) = loop(body,0)      val (b,count) = loop(body,0)
307      val _ =print(String.concat["\n Out of normalize \n",P.printbody(b),"\n    Final CounterXX:",Int.toString(count),"\n\n"])              val _ =testp["\n Out of normalize \n",P.printbody(b),
308                    "\n    Final CounterXX:",Int.toString(count),"\n\n"]
309          in          in
310                  (Ein.EIN{params=params, index=index, body=b},count)                  (Ein.EIN{params=params, index=index, body=b},count)
311      end      end

Legend:
Removed from v.2843  
changed lines
  Added in v.2844

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0