Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2867, Tue Feb 10 06:52:58 2015 UTC revision 2903, Sat Feb 28 04:09:29 2015 UTC
# Line 7  Line 7 
7      structure P=Printer      structure P=Printer
8      structure F=Filter      structure F=Filter
9      structure G=EpsHelpers      structure G=EpsHelpers
10        structure Eq=EqualEin
11        structure R=RationalEin
12    
13      in      in
14    
# Line 56  Line 58 
58          | E.Add e     => (1,E.Add (List.map (fn(a)=>E.Sum(c1,a)) e))          | E.Add e     => (1,E.Add (List.map (fn(a)=>E.Sum(c1,a)) e))
59          | E.Div (a,b) => (1,E.Div(E.Sum(c1,a),E.Sum(c1,b)))          | E.Div (a,b) => (1,E.Div(E.Sum(c1,a),E.Sum(c1,b)))
60          | E.Lift e    => (1,E.Lift(E.Sum(c1,e)))          | E.Lift e    => (1,E.Lift(E.Sum(c1,e)))
61            | E.PowReal(e,n1)=>(1,E.PowReal(E.Sum(c1,e),n1))
62          | E.Sqrt e    => (1,E.Sqrt(E.Sum(c1,e)))          | E.Sqrt e    => (1,E.Sqrt(E.Sum(c1,e)))
63          | E.Sum(c2,e2)=> (1,E.Sum(c1@c2,e2))          | E.Sum(c2,e2)=> (1,E.Sum(c1@c2,e2))
64          | E.Prod p    => filterSca(c1,p)          | E.Prod p    => filterSca(c1,p)
65          | E.Const 0   => (1,E.Const 0) (*expression could have been changed to 0*)          | E.Const 0   => (1,E.Const 0) (*expression could have been changed to 0*)
66        | Ein.ConstR _          => err("Sum of Const")
67          | E.Const _   => err("Sum of Const")          | E.Const _   => err("Sum of Const")
68          | E.Partial _ => err("Sum of Partial")          | E.Partial _ => err("Sum of Partial")
69          | E.Krn _     => err("Krn used before expand")          | E.Krn _     => err("Krn used before expand")
# Line 70  Line 74 
74      (* mkapply:mu list*ein_exp->int*ein_exp      (* mkapply:mu list*ein_exp->int*ein_exp
75      * rewrite Apply      * rewrite Apply
76      *)      *)
77      fun mkapply(d1,e1)=(case e1      fun mkapply(d1,e1)=let
78    
79            val (c,g) =(case e1
80          of E.Lift e   => (1,E.Const 0)          of E.Lift e   => (1,E.Const 0)
81          | E.Sqrt e   =>  (1,E.Prod[ E.Div(E.Const 1 ,E.Prod[E.Const 2, e1]), E.Apply( d1,e)])          | E.Sqrt e  => let
82                val half=E.Div(E.Const 1 ,E.Const 2)
83                val  E.Partial dels=d1
84                val del0=E.Partial([List.hd(dels)])
85                val deln=E.Partial( List.tl(dels))
86                val applydel0=E.Apply(del0,e)
87                (*distribute just one of the derivatives over the sqrt.*)
88                val g=(case deln
89                    of E.Partial []=>  E.Prod[half, E.Div(applydel0,e1)]
90                    | _  =>  E.Prod[half,E.Apply(deln, E.Div(applydel0,e1))]
91                    (*end case*))
92                val _ = testp["\n*****\n found sqrt \n",
93                        P.printbody(E.Apply(d1,e1)),"\n==>\n",P.printbody g,"\n ***\n\n"]
94                in
95                    (1,g)
96                end
97    (*
98            | E.Sqrt e=>let
99                val half=E.Div(E.Const 1 ,E.Const 2)
100                val  E.Partial dels=d1
101                val del0=E.Partial([List.hd(dels)])
102                val deln=E.Partial( List.tl(dels))
103                val applydel0=E.Apply(del0,e)
104                val e1'=E.PowReal(e,E.Sub(E.Const 1,half))
105                val g=(case deln
106                    of E.Partial []=>E.Prod[half,e1',applydel0]
107                    | _ =>E.Prod[half,E.Apply(deln,E.Prod[e1',applydel0])]
108                (*end case*))
109                val _ = print(String.concat["\n*****\n found sqrt \n",
110                    P.printbody(E.Apply(d1,e1)),"\n==>\n",P.printbody g,"\n ***\n\n"])
111                in
112                    (1,g)
113                end
114    *)
115            | E.PowReal(e2,n2)=> let
116                val  E.Partial dels=d1
117                val del0=E.Partial([List.hd(dels)])
118                val deln=E.Partial( List.tl(dels))
119                val applydel0=E.Apply(del0,e2)
120                in
121                    (1,E.Prod[E.ConstR n2,E.Apply(deln,E.Prod[E.PowReal(e2,R.-(R.fromInt 1 ,n2)),applydel0])])
122                end
123          | E.Prod []   => err("Apply of empty product")          | E.Prod []   => err("Apply of empty product")
124          | E.Add []    => err("Apply of empty Addition")          | E.Add []    => err("Apply of empty Addition")
125          | E.Conv(v, alpha, h, d2)    =>let          | E.Conv(v, alpha, h, d2)    =>let
# Line 92  Line 139 
139          | E.Neg e2    => (1,E.Neg(E.Apply(d1,e2)))          | E.Neg e2    => (1,E.Neg(E.Apply(d1,e2)))
140          | E.Add e     => (1,E.Add (List.map (fn(a)=>E.Apply(d1,a)) e))          | E.Add e     => (1,E.Add (List.map (fn(a)=>E.Apply(d1,a)) e))
141          | E.Sub (a,b) => (1,E.Sub(E.Apply(d1,a),E.Apply(d1,b)))          | E.Sub (a,b) => (1,E.Sub(E.Apply(d1,a),E.Apply(d1,b)))
142          | E.Div (g,b) => (case filterField[b]          | E.Div (E.Const g, b) =>(1,E.Div(E.Const g,E.Apply(d1,b)))
143            | E.Div (g,E.Const b) =>(1,E.Div(E.Apply(d1,g),E.Const b))
144            | E.Div (g,b) => let
145                val (c,EE)=(case filterField[b]
146              of (_,[]) => (1,E.Div(E.Apply(d1,g),b)) (*Division by a real*)              of (_,[]) => (1,E.Div(E.Apply(d1,g),b)) (*Division by a real*)
147              | (pre,h) => let              | (pre,h) => let
148                  (*quotient rule*)                  (*quotient rule*)
149                  val g'=E.Apply(d1,g)                  val  E.Partial dels=d1
150                  val h'=E.Apply(d1,flatProd(h))                  val del0=E.Partial([List.hd(dels)])
151                    val deln=E.Partial( List.tl(dels))
152                    val g'=E.Apply(del0,g)
153                    val h'=E.Apply(del0,flatProd(h))
154                  val num=E.Sub(E.Prod([g']@h),E.Prod[g,h'])                  val num=E.Sub(E.Prod([g']@h),E.Prod[g,h'])
155                  val denom=E.Prod(pre@h@h)                  val denom=E.Prod(pre@h@h)
156                  in (1,E.Div(num,denom))                      val e=(case deln
157                        of E.Partial []=>E.Div(num,denom)
158                        | _=>E.Apply(deln,E.Div(num,denom))
159                    (*end case*))
160                    in (1,e)
161                  end                  end
162              (*end case*))              (*end case*))
163                in
164                    (c,EE)
165                end
166          | E.Prod p =>let          | E.Prod p =>let
167    
168              val (pre, post)= filterField p              val (pre, post)= filterField p
169              val E.Partial d3=d1              val E.Partial d3=d1
170              in mkProd(pre@[prodAppPartial(post,d3)])              val (c,g)= mkProd(pre@[prodAppPartial(post,d3)])
171                val _ = testp["\n*****\n Product rule \n",
172                        P.printbody(E.Apply(d1,e1)),"\n==>\n",P.printbody g,"\n ***\n\n"]
173                in (c,g)
174              end              end
175          | E.Const _   => (1,E.Const 0)(*err("Const without Lift")*)          | E.Const _   => (1,E.Const 0)(*err("Const without Lift")*)
176            | Ein.ConstR _          =>(1,E.Const 0)
177          | E.Tensor _  => err("Tensor without Lift")          | E.Tensor _  => err("Tensor without Lift")
178          | E.Delta _   => err("Apply of Delta")          | E.Delta _   => err("Apply of Delta")
179          | E.Epsilon _ => err("Apply of Eps")          | E.Epsilon _ => err("Apply of Eps")
# Line 119  Line 184 
184          | E.Img _     => err("Probe used before expand")          | E.Img _     => err("Probe used before expand")
185          (*end case*))          (*end case*))
186    
187    
188        in
189            (c,g)
190        end
191    
192      (*mkprobe:ein_exp* ein_exp-> int ein_exp      (*mkprobe:ein_exp* ein_exp-> int ein_exp
193      *rewritten probe      *rewritten probe
194      *)      *)
195      fun mkprobe(e1,x)=(case e1      fun mkprobe(e1,x)=(case e1
196          of E.Lift e   => (1,e)          of E.Lift e   => (1,e)
197          | E.Sqrt a    => (1,E.Sqrt(E.Probe(a,x)))          | E.Sqrt a    => (1,E.Sqrt(E.Probe(a,x)))
198            | E.PowReal(a,n1)    => (1,E.PowReal(E.Probe(a,x),n1))
199          | E.Prod []   => err("Probe of empty product")          | E.Prod []   => err("Probe of empty product")
200          | E.Prod p    => (1,E.Prod (List.map (fn(a)=>E.Probe(a,x)) p))          | E.Prod p    => (1,E.Prod (List.map (fn(a)=>E.Probe(a,x)) p))
201          | E.Apply _   => (0,E.Probe(e1,x))          | E.Apply _   => (0,E.Probe(e1,x))
# Line 136  Line 207 
207          | E.Neg a    => (1,E.Neg(E.Probe(a,x)))          | E.Neg a    => (1,E.Neg(E.Probe(a,x)))
208          | E.Div (a,b) => (1,E.Div(E.Probe(a,x),E.Probe(b,x)))          | E.Div (a,b) => (1,E.Div(E.Probe(a,x),E.Probe(b,x)))
209          | E.Const _   => (1,e1)(*err("Const without Lift")*)          | E.Const _   => (1,e1)(*err("Const without Lift")*)
210     | Ein.ConstR _          =>(1,e1)
211          | E.Tensor _  => err("Tensor without Lift")          | E.Tensor _  => err("Tensor without Lift")
212          | E.Delta _   => (0,e1)          | E.Delta _   => (0,e1)
213          | E.Epsilon _ => (0,e1)          | E.Epsilon _ => (0,e1)
# Line 151  Line 223 
223      * rewrite body of EIN      * rewrite body of EIN
224      * note "c" keeps track if ein_exp is changed      * note "c" keeps track if ein_exp is changed
225      *)      *)
226      fun normalize (ee as Ein.EIN{params, index, body}) = let      fun normalize (ee as Ein.EIN{params, index, body},args) = let
227        val changed = ref false        val changed = ref false
228        fun rewriteBody body =(case body        fun rewriteBody body =(case body
229          of E.Const _    => body          of E.Const _    => body
230            | Ein.ConstR _          =>body
231          | E.Tensor _    => body          | E.Tensor _    => body
232          | E.Field _     => body          | E.Field _     => body
233          | E.Delta _     => body          | E.Delta _     => body
# Line 170  Line 243 
243          | E.Neg e           => E.Neg(rewriteBody e)          | E.Neg e           => E.Neg(rewriteBody e)
244          | E.Lift e          => E.Lift(rewriteBody e)          | E.Lift e          => E.Lift(rewriteBody e)
245          | E.Sqrt e          => E.Sqrt(rewriteBody e)          | E.Sqrt e          => E.Sqrt(rewriteBody e)
246            | E.PowInt(e,n1)        => E.PowInt(rewriteBody e,n1)
247            | E.PowReal(e,n1)       => E.PowReal(rewriteBody e,n1)
248          | E.Add es          => let          | E.Add es          => let
249              val (change,body')= mkAdd(List.map rewriteBody es)              val (change,body')= mkAdd(List.map rewriteBody es)
250              in if (change=1) then ( changed:=true;body') else body' end              in if (change=1) then ( changed:=true;body') else body' end
251          | E.Sub(a, E.Field f)=> (changed:=true;E.Add[a, E.Neg(E.Field(f))])          (*| E.Sub(a, E.Field f)=> (changed:=true;E.Add[a, E.Neg(E.Field(f))])
252          | E.Sub(E.Sub(a,b),E.Sub(c,d))  => rewriteBody(E.Sub(E.Add[a,d],E.Add[b,c]))          | E.Sub(E.Sub(a,b),E.Sub(c,d))  => rewriteBody(E.Sub(E.Add[a,d],E.Add[b,c]))
253          | E.Sub(E.Sub(a,b),e2)          => rewriteBody (E.Sub(a,E.Add[b,e2]))          | E.Sub(E.Sub(a,b),e2)          => rewriteBody (E.Sub(a,E.Add[b,e2]))
254          | E.Sub(e1,E.Sub(c,d))          => rewriteBody(E.Add([E.Sub(e1,c),d]))          | E.Sub(e1,E.Sub(c,d))          => rewriteBody(E.Add([E.Sub(e1,c),d]))*)
255          | E.Sub (a,b)                   => E.Sub(rewriteBody a, rewriteBody b)          | E.Sub (a,b)                   => E.Sub(rewriteBody a, rewriteBody b)
256          | E.Div(e1 as E.Tensor(_,[_]),e2 as E.Tensor(_,[]))=>          | E.Div(e1 as E.Tensor(_,[_]),e2 as E.Tensor(_,[]))=>
257                  rewriteBody (E.Prod[E.Div(E.Const 1, e2),e1])                  rewriteBody (E.Prod[E.Div(E.Const 1, e2),e1])
# Line 214  Line 289 
289          | E.Prod [e1] => rewriteBody e1          | E.Prod [e1] => rewriteBody e1
290          | E.Prod((E.Add(e2))::e3)=>          | E.Prod((E.Add(e2))::e3)=>
291             (changed := true; E.Add(List.map (fn e=> E.Prod([e]@e3)) e2))             (changed := true; E.Add(List.map (fn e=> E.Prod([e]@e3)) e2))
292    (*
293          | E.Prod((E.Sub(e2,e3))::e4)=>          | E.Prod((E.Sub(e2,e3))::e4)=>
294              (changed :=true; E.Sub(E.Prod([e2]@e4), E.Prod([e3]@e4 )))              (changed :=true; E.Sub(E.Prod([e2]@e4), E.Prod([e3]@e4 )))*)
295          | E.Prod((E.Div(e2,e3))::e4)=> (changed :=true; E.Div(E.Prod([e2]@e4), e3 ))          | E.Prod((E.Div(e2,e3))::e4)=> (changed :=true; E.Div(E.Prod([e2]@e4), e3 ))
296          | E.Prod(e1::E.Add(e2)::e3)=>          | E.Prod(e1::E.Add(e2)::e3)=>
297              (changed := true; E.Add(List.map (fn e=> E.Prod([e1,e]@e3)) e2))              (changed := true; E.Add(List.map (fn e=> E.Prod([e1,e]@e3)) e2))
298          | E.Prod(e1::E.Sub(e2,e3)::e4)=>          | E.Prod(e1::E.Sub(e2,e3)::e4)=>
299              (changed :=true; E.Sub(E.Prod([e1,e2]@e4), E.Prod([e1,e3]@e4 )))              (changed :=true; E.Sub(E.Prod([e1,e2]@e4), E.Prod([e1,e3]@e4 )))
300    
301            | E.Prod((e1 as E.Sqrt(s1))::(e2 as E.Sqrt(s2))::es)=>s1 (*
302                if(Eq.isEqual3(s1,s2,args)=0) then (print"prod sqrt";s1)
303                else let
304    val _ =print"prodsqrt:tried equal and did not find it"
305                    val (_,b)=mkProd([rewriteBody e1, rewriteBody e2]@es)
306                    in b end*)
307          (*************Product EPS **************)          (*************Product EPS **************)
308    
309          | E.Prod(E.Epsilon(i,j,k)::E.Apply(E.Partial d,e)::es)=>let          | E.Prod(E.Epsilon(i,j,k)::E.Apply(E.Partial d,e)::es)=>let
# Line 321  Line 404 
404    
405      (*end case*))      (*end case*))
406    
407        val _=testp["\n********Normalize",P.printerE ee,"\n*****\n"]
408      fun loop(body ,count) = let      fun loop(body ,count) = let
409          val _= testp["\n\n N =>",Int.toString(count),"--",P.printbody(body)]          val _= testp["\n\n N =>",Int.toString(count),"--",P.printbody(body)]
410          val body' = rewriteBody body          val body' = rewriteBody body

Legend:
Removed from v.2867  
changed lines
  Added in v.2903

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0