Home My Page Projects Code Snippets Project Openings diderot
 Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-il/normalize-ein.sml
 [diderot] / branches / charisee / src / compiler / high-il / normalize-ein.sml

Diff of /branches/charisee/src/compiler/high-il/normalize-ein.sml

revision 2452, Sat Oct 5 00:43:58 2013 UTC revision 2460, Wed Oct 9 19:09:26 2013 UTC
# Line 5  Line 5
5
6      structure E = Ein      structure E = Ein
7      structure P=Printer      structure P=Printer
8        structure O =OrderEin
9      in      in
10
11
# Line 47  Line 47
47          (* end case *)          (* end case *)
48           end           end
49
(*

fun mkEps(e)= (case e
of E.Apply(E.Partial [E.V a], E.Prod( e2::m ))=> (0,e)
| E.Apply(E.Partial [E.V a,E.V b], E.Prod( (E.Epsilon(i,j,k))::m ))=>
(if(a=i andalso b=j) then (1,E.Const(0.0))
else if(a=i andalso b=k) then (1,E.Const(0.0))
else if(a=j andalso b=i) then (1,E.Const(0.0))
else if(a=j andalso b=k) then (1,E.Const(0.0))
else if(a=k andalso b=j) then (1,E.Const(0.0))
else if(a=k andalso b=i) then (1,E.Const(0.0))
else (0,e))
|_=> (0,e)
(*end case*))

fun mkSumApply(E.Sum(c,E.Apply(d, e))) = (case e
of E.Tensor(a,[])=> (0,E.Const(0.0))
| E.Tensor _=> (0,E.Sum(c,E.Apply(d,e)))
| E.Field _ =>(0, E.Sum(c, E.Apply(d,e)))
| E.Const _=> (1,E.Const(0.0))
| E.Add l => (1,E.Add(List.map (fn e => E.Sum(c,E.Apply(d, e))) l))
| E.Sub(e2, e3) =>(1, E.Sub(E.Sum(c,E.Apply(d, e2)), E.Sum(c,E.Apply(d, e3))))
| E.Prod((E.Epsilon c)::e2)=> mkEps(E.Apply(d,e))
| E.Prod[E.Tensor(a,[]), e2]=>  (0, E.Prod[ E.Tensor(a,[]), E.Sum(c,E.Apply(d, e2))]  )
| E.Prod((E.Tensor(a,[]))::e2)=>  (0, E.Prod[E.Tensor(a,[]), E.Sum(c,E.Apply(d, E.Prod e2))] )
| E.Prod es =>   (let
fun prod [e] = (E.Apply(d, e))
| prod(e1::e2)=(let val l= prod(e2) val m= E.Prod[e1,l]
val lr=e2 @[E.Apply(d,e1)]   val(b,a) =mkProd lr
in ( E.Add[ a, m] ) end)
| prod _= (E.Const(1.0))
in (1, E.Sum(c,prod es))  end)
| _=> (0,E.Sum(c,E.Apply(d,e)))
(*end case*))
50
*)
51
52
53  fun rmEpsIndex(_,_,[])=[]  fun rmEpsIndex(_,_,[])=[]
# Line 112  Line 75
75                      E.Sum(s'',E.Prod([E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)] @e3)),                      E.Sum(s'',E.Prod([E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)] @e3)),
76                      E.Sum(s'',E.Prod([E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]@e3)))                      E.Sum(s'',E.Prod([E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]@e3)))
77
78              in (case (eps,s')                  in (case (eps,es,s')
79                  of ([],[]) =>(1,deltas)                  of ([],[],[]) =>(1,deltas)
80                  |([],_)=>(1,E.Sum(s',deltas))                  |([],_,[]) =>(1,E.Prod( es@[deltas]))
81                  |(_,[])=>(1,E.Prod(eps@[deltas]))                  |([],[],_)=>(1,E.Sum(s',deltas))
82                  |(_,_) =>(1, E.Sum(s', E.Prod(eps@[deltas])))                  |([],_,_)=>(1,E.Sum(s',E.Prod(es@[deltas])))
83                    |(_,_,[])=>(1,E.Prod(eps@es@[deltas]))
84                    |_ =>(1, E.Sum(s', E.Prod(eps@es@[deltas])))
85                     )                     )
86               end               end
87
# Line 184  Line 149
149         (change, E.Sum(index,E.Prod (eps@dels'@done)))         (change, E.Sum(index,E.Prod (eps@dels'@done)))
150    end    end
151
152  fun mkApply2(E.Apply(d,e))=(case e
153    fun mkApplySum(E.Apply(E.Partial d,E.Sum(c,e)))=(print "apply sum";case e
154      of E.Tensor(a,[])=>(1,E.Const 0.0)      of E.Tensor(a,[])=>(1,E.Const 0.0)
155      | E.Const _ =>(1,E.Const 0.0)      | E.Const _ =>(1,E.Const 0.0)
156      | E.Add l => (1,E.Add(List.map (fn e => E.Apply(d, e)) l))      | E.Add l => (1,E.Add(List.map (fn e => E.Apply(E.Partial d, E.Sum(c,e))) l))
157      | E.Sub(e2, e3) =>(1, E.Sub(E.Apply(d, e2), E.Apply(d, e3)))      | E.Sub(e2, e3) =>(1, E.Sub(E.Apply(E.Partial d, E.Sum(c,e2)), E.Apply(E.Partial d, E.Sum(c,e3))))
158      | E.Prod(E.Tensor(a,[])::e2)=>(1,E.Prod[E.Tensor(a,[]),E.Apply(d,e)])
159      | E.Prod [e1]=>(1,E.Apply(d,e1))      | E.Prod [e1]=>(1,E.Apply(E.Partial d,E.Sum(c,e1)))
160        | E.Prod(E.Tensor(a,[])::e1::[])=>(1,E.Prod[E.Tensor(a,[]),E.Apply(E.Partial d,E.Sum(c,e1))])
161
162        | E.Prod(E.Tensor(a,[])::e2)=>(1,E.Prod[E.Tensor(a,[]),E.Apply(E.Partial d,E.Sum(c,E.Prod e2))])
163
164      | E.Prod es=> (let      | E.Prod es=> (let
165          fun prod [e1] =E.Apply(d,e1)          fun prod [e1] =E.Apply(E.Partial d,e1)
166            | prod (E.Epsilon eps1::es) = (E.Apply(E.Partial d, E.Prod (E.Epsilon eps1::es)))
167            | prod (E.Delta e1::es) = (E.Apply(E.Partial d, E.Prod (E.Delta e1::es)))
168            | prod (E.Prod e1::es)=prod(e1@es)
169          | prod(e1::e2)=(let          | prod(e1::e2)=(let
170              val l= prod(e2) val m= E.Prod[e1,l]              val l= prod(e2)
171              val lr=e2 @[E.Apply(d,e1)] val(b,a) =mkProd lr              val (_, a)= mkProd[e1,l]
172              in  E.Add[a,m]              val lr=e2 @[E.Apply(E.Partial d,e1)]
173                val(_,b) =mkProd lr
174                in  E.Add[b,a]
175                end)
176            val chainrule=prod es
177            in (1,E.Sum(c, chainrule)) end)
178        |_=>(0,E.Apply(E.Partial d,E.Sum(c,e)))
179        (* end case*))
180
181    fun mkApply2(E.Apply(E.Partial d,e))=(print "aa";case e
182        of E.Tensor(a,[])=>(1,E.Const 0.0)
183        | E.Const _ =>(1,E.Const 0.0)
184        | E.Add l => (1,E.Add(List.map (fn e => E.Apply(E.Partial d, e)) l))
185        | E.Sub(e2, e3) =>(1, E.Sub(E.Apply(E.Partial d, e2), E.Apply(E.Partial d, e3)))
186        | E.Apply(E.Partial e1,e2)=>(1,E.Apply(E.Partial(d@e1), e2))
187        | E.Prod [e1]=>(1,E.Apply(E.Partial d,e1))
188        | E.Prod(E.Tensor(a,[])::e1::[])=>(1,E.Prod[E.Tensor(a,[]),E.Apply(E.Partial d,e1)])
189        | E.Prod(E.Tensor(a,[])::e2)=>(1,E.Prod[E.Tensor(a,[]),E.Apply(E.Partial d,E.Prod e2)])
190        | E.Prod es=> (let
191            fun prod [e1] =(0,E.Apply(E.Partial d,e1))
192            | prod (E.Epsilon eps1::es) = (0,E.Apply(E.Partial d, E.Prod (E.Epsilon eps1::es)))
193            | prod (E.Delta e1::es) = (0,E.Apply(E.Partial d, E.Prod (E.Delta e1::es)))
194             | prod (E.Prod e1::es)=prod(e1@es)
195            | prod(E.Tensor t::e2)=(let
196                val (_,l)= prod(e2) val m= E.Prod[E.Tensor t,l]
197                val lr=e2 @[E.Apply(E.Partial d,E.Tensor t)] val(b,a) =mkProd lr
198                in  (1,E.Add[a,m])
199                end)
200            | prod(E.Field f::e2)=(let
201                val (_,l)= prod(e2) val m= E.Prod[E.Field f,l]
202                val lr=e2 @[E.Apply(E.Partial d,E.Field f)] val(b,a) =mkProd lr
203                in  (1,E.Add[a,m])
204              end)              end)
205          in (1,prod es) end)          | prod e = (0,E.Apply(E.Partial d, E.Prod e))
206      |_=>(0,E.Apply(d,e))
207
208            val (a,b)= prod es
209
210            in (a, b) end)
211        |_=>(0,E.Apply(E.Partial d,e))
212      (* end case*))      (* end case*))
213
214  fun mkSumApply2(E.Sum(c,E.Apply(E.Partial d,e)))=(case e  fun mkSumApply2(E.Sum(c,E.Apply(E.Partial d,e)))=(print "in here ";case e
215      of E.Const _=>(1,E.Const 0.0)      of E.Const _=>(1,E.Const 0.0)
216        | E.Tensor(_,[])=> (1,E.Const 0.0)
217        | E.Field _=>(0,E.Sum(c,E.Apply(E.Partial d,e)))
218        | E.Apply(E.Partial e1,e2)=>(1,E.Sum(c,E.Apply(E.Partial(d@e1),e2)))
219
220      | E.Add l => (1,E.Add(List.map (fn e => E.Sum(c,E.Apply(E.Partial d, e))) l))      | E.Add l => (1,E.Add(List.map (fn e => E.Sum(c,E.Apply(E.Partial d, e))) l))
221      | E.Sub(e2, e3) =>(1, E.Sub(E.Sum(c,E.Apply(E.Partial d, e2)), E.Sum(c,E.Apply(E.Partial d, e3))))      | E.Sub(e2, e3) =>
222      | E.Prod(E.Tensor(a,[])::e2)=>(1, E.Prod[E.Tensor(a,[]),E.Sum(c,E.Apply(E.Partial d,E.Prod e2))])                  (*(0,E.Sub(e2,e3))
223      | E.Prod [e1]=>(1,E.Sum(c,E.Apply(E.Partial d,e1)))                  *)
224      | E.Prod es =>(let                  (print "sub";(1, E.Sub(E.Sum(c,E.Apply(E.Partial d, e2)), E.Sum(c,E.Apply(E.Partial d, e3)))))
225
226         | E.Prod [e1]=>(print "one";(1,E.Sum(c,E.Apply(E.Partial d,e1))))
227
228
229        | E.Prod(E.Tensor(a,[])::e2::[])=>("in scalar";(1, E.Prod[E.Tensor(a,[]),E.Sum(c,E.Apply(E.Partial d,e2))]))
230
231        | E.Prod(E.Tensor(a,[])::e2)=>("in scalar";(1, E.Prod[E.Tensor(a,[]),E.Sum(c,E.Apply(E.Partial d,E.Prod e2))]))
232
233        | E.Prod es =>(print "in prod";let
234          fun prod (change,rest, sum,partial,[]) = (change,E.Sum(sum,E.Apply(E.Partial partial,E.Prod rest)))          fun prod (change,rest, sum,partial,[]) = (change,E.Sum(sum,E.Apply(E.Partial partial,E.Prod rest)))
235          | prod (change,rest, sum,partial,E.Epsilon(i,j,k)::ps)= let          | prod (change,rest, sum,partial,E.Epsilon(i,j,k)::ps)= let
236              fun matchprod(2,_,_,_)= 1 (*matched 2*)              fun matchprod(2,_,_,_)= 1 (*matched 2*)
# Line 217  Line 239
239              | matchprod(num,[],rest,eps::epsx)=              | matchprod(num,[],rest,eps::epsx)=
240                  matchprod(num,rest,[],epsx)                  matchprod(num,rest,[],epsx)
241              | matchprod(num,E.V p::px,rest,eps::epsx)=              | matchprod(num,E.V p::px,rest,eps::epsx)=
242                  if(p=eps) then matchprod(num+1,px,rest,epsx)                  if(p=eps) then (matchprod(num+1,rest@px,[],epsx))
243                  else matchprod(num,px,rest@[E.V p], eps::epsx)                  else matchprod(num,px,rest@[E.V p], eps::epsx)
244              | matchprod(num,p::px,rest,eps)=              | matchprod(num,p::px,rest,eps)=
245                  matchprod(num,px,rest,eps)                  matchprod(num,px,rest,eps)
# Line 225  Line 247
247              val change'= matchprod(0,d,[],[i,j,k])              val change'= matchprod(0,d,[],[i,j,k])
248              in (case change'              in (case change'
249                  of 1 => (1,E.Const 0.0)                  of 1 => (1,E.Const 0.0)
250                  | _ =>prod(change,rest@[E.Epsilon(i,j,k)],sum,partial,ps))                  | _ =>prod(change,rest@[E.Epsilon(i,j,k)],sum,partial,ps)
251                    (*end case*))
252              end              end
253          | prod (change,rest, sum,partial,E.Delta(i,j)::ps)=let          | prod (change,rest, sum,partial,E.Delta(i,j)::ps)=let
254              fun applyDelPartial([],_)=(0,[])              fun applyDelPartial([],_)=(0,[])
# Line 248  Line 271
271          in          in
272              (change,exp)              (change,exp)
273          end)          end)
274      | _=>(0,E.Sum(c,E.Apply(E.Partial d,e)))          | _=>(print "nope";(0,E.Sum(c,E.Apply(E.Partial d,e))))
275          (* end case*))          (* end case*))
276
277  (*  (*
# Line 262  Line 285
285  (*Apply normalize to each term in product list  (*Apply normalize to each term in product list
286  or Apply normalize to tail of each list*)  or Apply normalize to tail of each list*)
287  fun normalize (Ein.EIN{params, index, body}) = let  fun normalize (Ein.EIN{params, index, body}) = let
288
289        val changed = ref false        val changed = ref false
290
291        fun rewriteBody body = (case body        fun rewriteBody body = (case body
292               of E.Const _=> body               of E.Const _=> body
293                | E.Tensor _ =>body                | E.Tensor _ =>body
# Line 297  Line 322
322                | E.Prod[(E.Epsilon(e1,e2,e3)), E.Tensor(_,[E.V i1,E.V i2])]=>                | E.Prod[(E.Epsilon(e1,e2,e3)), E.Tensor(_,[E.V i1,E.V i2])]=>
323                      if(e2=i1 andalso e3=i2) then (changed :=true;E.Const(0.0))                      if(e2=i1 andalso e3=i2) then (changed :=true;E.Const(0.0))
324                      else body                      else body
325                 | E.Prod [E.Partial r1, E.Tensor(_,[])]=> (changed:=true;E.Const(0.0))
326               | E.Prod(E.Partial r1::E.Partial r2::p)=>               | E.Prod(E.Partial r1::E.Partial r2::p)=>
327                     (changed:=true;E.Prod([E.Partial(r1@r2)]@p))                     (changed:=true;E.Prod([E.Partial(r1@r2)]@p))
328                 | E.Prod [E.Partial _, _] =>body
329
330                 | E.Prod (E.Partial p1::es)=> (let
331                    fun prod [e1] =E.Apply(E.Partial p1,e1)
332                    | prod(e1::e2)=(let
333                        val l= prod(e2) val m= E.Prod[e1,l]
334                        val lr=e2 @[E.Apply(E.Partial p1,e1)] val(b,a) =mkProd lr
335                        in  E.Add[a,m]
336                        end)
337                    in (changed:=true;prod es) end)
338
339                | E.Prod(e::es)=>let                | E.Prod(e::es)=>let
340                      val e'=rewriteBody e                      val e'=rewriteBody e
341                      val e2=rewriteBody(E.Prod es)                      val e2=rewriteBody(E.Prod es)
# Line 309  Line 345
345                     end                     end
346
347                (*Apply*)                (*Apply*)
348
349                  | E.Apply(E.Partial d,E.Sum(c,e))=>let
350                        val(c,e')=mkApplySum(E.Apply(E.Partial d,E.Sum(c, rewriteBody e)))
351                        val e''=(case e'
352                            of E.Apply(d,E.Sum s)=>E.Apply(d,rewriteBody(E.Sum s))
353                            |_=> e')
354                    in (print "bb";case c of 1=>(changed:=true;e'')
355                        |_=> e'')end
356                | E.Apply(E.Partial [],e)=> e                | E.Apply(E.Partial [],e)=> e
357
358                | E.Apply(E.Partial p, e)=>let                | E.Apply(E.Partial p, e)=>let
359                      val body'=E.Apply(E.Partial p, rewriteBody e)                      val body'=E.Apply(E.Partial p, rewriteBody e)
360                      val (c, e')=mkApply2(body')                      val (c, e')=mkApply2(body')
# Line 323  Line 368
368                | E.Sum([],e)=> (changed:=true;rewriteBody e)                | E.Sum([],e)=> (changed:=true;rewriteBody e)
369                | E.Sum(_,E.Const c)=>(changed:=true;E.Const c)                | E.Sum(_,E.Const c)=>(changed:=true;E.Const c)
370                | E.Sum(c,(E.Add l))=> (changed:=true;E.Add(List.map (fn e => E.Sum(c,e)) l))                | E.Sum(c,(E.Add l))=> (changed:=true;E.Add(List.map (fn e => E.Sum(c,e)) l))
371                  | E.Sum(c,E.Sub(e1,e2))=>(changed:=true; E.Sub(E.Sum(c,e1),E.Sum(c,e2)))
372                | E.Sum(c,E.Prod(E.Epsilon eps1::E.Epsilon eps2::ps))=>                | E.Sum(c,E.Prod(E.Epsilon eps1::E.Epsilon eps2::ps))=>
373                     let val (i,e,rest)=epsToDels(body)                     let val (i,e,rest)=epsToDels(body)
374                     in (case (i, e,rest)                  in (print "eps to dels \n ";case (i, e,rest)
375                          of (1,[e1],_) =>(changed:=true;e1)                  of (1,[e1],r) =>(print "changed\n";changed:=true;e1)
376                          |(0,eps,[])=>body                  |(0,eps,[])=>(print "non";body)
377                          |(0,eps,rest)=>(let                          |(0,eps,rest)=>(let
378                              val p'=rewriteBody(E.Prod rest)                              val p'=rewriteBody(E.Prod rest)
379                              val p''= (case p' of E.Prod p=>p |e=>[e])                              val p''= (case p' of E.Prod p=>p |e=>[e])
# Line 342  Line 388
388                          of E.Prod p=> mkProd p                          of E.Prod p=> mkProd p
389                          |_=> (0,a))                          |_=> (0,a))
390                     in (case change of []=>body'|_=>(changed:=true;body')) end                     in (case change of []=>body'|_=>(changed:=true;body')) end
391                | E.Sum(c,E.Apply(d,e))=>let
392                      val(c',e')=mkSumApply2(body)                | E.Sum(c,E.Apply(E.Partial _,e))=>let
393                  in (case c' of 1=>(changed:=true;e') |_=>e')                      val (change,exp)=mkSumApply2(body)
394                        val exp'=(case exp
395                            of  E.Const c => E.Const c
396                            | E.Sum(c',E.Apply(d',e'))  => (let
397                                val s'=rewriteBody(E.Sum(c',e'))
398                               in (case s'
399                                    of E.Sum([],e'')=> rewriteBody (E.Apply(d',e''))
400                                    | E.Sum(s'',e'') => E.Sum(s'',rewriteBody(E.Apply(d',e'')))
401                                    | _ => E.Apply(d',s'))
402
403                                end)
404
405
406                            | _ =>exp
407                            (* end case *))
408
409                    in (case change of 1=>(changed:=true;exp') |_=>exp')
410                  end                  end
411
412
413                | E.Sum(c,e)=>E.Sum(c,rewriteBody e)                | E.Sum(c,e)=>E.Sum(c,rewriteBody e)
414
415              (*end case*))              (*end case*))
# Line 354  Line 418
418              val body' = rewriteBody body              val body' = rewriteBody body
419              in              in
420                if !changed                if !changed
421                     then (changed := false; print " \n \t => \n \t ";print( P.printbody body');print "\n";loop body')                     then (changed := false; loop body')
422                  else body'                  else body'
423              end              end
424      val b = loop body      val b = loop body

Legend:
 Removed from v.2452 changed lines Added in v.2460

 root@smlnj-gforge.cs.uchicago.edu ViewVC Help Powered by ViewVC 1.0.0