Home My Page Projects Code Snippets Project Openings diderot
 Summary Activity Tracker Tasks SCM

# SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml
 [diderot] / branches / charisee_dev / src / compiler / high-il / normalize-ein.sml

# Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml

revision 2448, Tue Oct 1 00:57:08 2013 UTC revision 2449, Thu Oct 3 20:15:16 2013 UTC
# Line 4  Line 4
4      local      local
5
6      structure E = Ein      structure E = Ein
7       structure P=Printer
8
9      in      in
10
11        (*
12  (*Flattens Add constructor: change, expression *)  (*Flattens Add constructor: change, expression *)
13  fun mkAdd [e]=(1,e)  fun mkAdd [e]=(1,e)
# Line 29  Line 29
29                  (* end case *)                  (* end case *)
30       end       end
31
32    (*
33  fun mkProd [e]=(1,e)  fun mkProd [e]=(1,e)
34      | mkProd(e)=let      | mkProd(e)=let
35      fun flatten(i,((E.Prod l)::l'))= flatten(1,l@l')      fun flatten(i,((E.Prod l)::l'))= flatten(1,l@l')
# Line 111  Line 112
112      The epsToDels Function searches for Epsilons in the expression, checks for this identity in all adjacent Epsilons and if needed, does the transformation.      The epsToDels Function searches for Epsilons in the expression, checks for this identity in all adjacent Epsilons and if needed, does the transformation.
113       The Function returns two separate list, 1 is the remaining list of Epsilons that have not be changed to deltas, and the second is the Product of the remaining expression.       The Function returns two separate list, 1 is the remaining list of Epsilons that have not be changed to deltas, and the second is the Product of the remaining expression.
114    Ex:(Epsilon_ijk Epsilon_ilm) Epsilon_stu e =>([Epsilon_stu], [Delta_jl,Delta_km,e -Delta_jm Delta_kl, e] )    Ex:(Epsilon_ijk Epsilon_ilm) Epsilon_stu e =>([Epsilon_stu], [Delta_jl,Delta_km,e -Delta_jm Delta_kl, e] )
115     This is useful since we can normalize the second list without having to normalize the epsilons again.*)     This is useful since we can normalize the second list without having to normalize the epsilons again.
116            4(Eps Eps)
117           3( Delta_liDelta mj- Delta_mi Delta_lj)
118             Ai-
119            *)
120
121
122                       *)
123
124
125  fun epsToDels(E.Sum(count,E.Prod e))= let  fun epsToDels(E.Sum(count,E.Prod e))= let
126      fun doubleEps((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::es,e3)=      fun doubleEps((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::es,eps,e3)=
127          let          let
128          fun createDeltas(s,t,u,v, e3)=
129              (1,  E.Sub(E.Sum(2,E.Prod([E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)] @e3)),          (*Function is called when eps are being changed to deltas*)
130                      E.Sum(2,E.Prod([E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]@e3))))          fun createDeltas(i,s,t,u,v, e3)= let
131          in if(a=d) then createDeltas(b,c,e,f, e3)
132             else if(a=e) then createDeltas(b,c,f,d, e3)              (*remove index from original index list*)
133             else if(a=f) then createDeltas(b,c,d,e, e3)              (*currrent, left, sumIndex*)
134             else if(b=d) then createDeltas(c,a,e,f, e3)
135             else if(b=e) then createDeltas(c,a,f,d,e3)              fun rmIndex(_,_,[])=[]
136             else if(b=f) then createDeltas(c,a,d,e,e3)              | rmIndex([],[],cs)=cs
137             else if(c=d) then createDeltas(a,b,e,f,e3)              | rmIndex([],m ,e1::cs)=[e1]@rmIndex(m,[],cs)
138             else if(c=e) then createDeltas(a,b,f,d,e3)              | rmIndex(i::ix,rest ,(E.V c)::cs)=
139             else if(c=f) then createDeltas(a,b,d,e,e3)                     if(i=c) then rmIndex(rest@ix,[],cs)
140             else (0,(E.Prod((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::e3)))                     else rmIndex(ix,rest@[i],(E.V c)::cs)
141
142                val s'= rmIndex([i,s,t,u,v],[],count)
143                val s''=[E.V s, E.V t ,E.V u, E.V v]
144                val deltas= E.Sub(
145                        E.Sum(s'',E.Prod([E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)] @e3)),
146                        E.Sum(s'',E.Prod([E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]@e3)))
147
148                in (case (eps,s')
149                    of ([],[]) =>(1,deltas)
150                    |([],_)=>(1,E.Sum(s',deltas))
151                    |(_,[])=>(1,E.Prod(eps@[deltas]))
152                    |(_,_) =>(1, E.Sum(s', E.Prod(eps@[deltas])))
153                       )
154                 end
155
156            in if(a=d) then createDeltas(a,b,c,e,f, e3)
157               else if(a=e) then createDeltas(a,b,c,f,d, e3)
158               else if(a=f) then createDeltas(a,b,c,d,e, e3)
159               else if(b=d) then createDeltas(b,c,a,e,f, e3)
160               else if(b=e) then createDeltas(b,c,a,f,d,e3)
161               else if(b=f) then createDeltas(b,c,a,d,e,e3)
162               else if(c=d) then createDeltas(c,a,b,e,f,e3)
163               else if(c=e) then createDeltas(c,a,b,f,d,e3)
164               else if(c=f) then createDeltas(c,a,b,d,e,e3)
165               else (0,E.Const 0.0)
166          end          end
167      fun findeps(e,[])= (e,[])      fun findeps(e,[])= (e,[])
168        | findeps(e,(E.Epsilon eps)::es)=  findeps(e@[E.Epsilon eps],es)        | findeps(e,(E.Epsilon eps)::es)=  findeps(e@[E.Epsilon eps],es)
169        | findeps(e,es)= (e, es)        | findeps(e,es)= (e, es)
170      fun distribute([], s)=(0, [],s)
171        | distribute([e1], s)=(0, [e1], s)
172        | distribute(e1::es, s)= let val(i, exp)=doubleEps(e1::es, s)      fun dist([],eps,rest)=(0,eps,rest)
173            in if(i=1) then (1, tl(es), [exp])       | dist([e],eps,rest)=(0,eps@[e],rest)
174               else let val(a,b,c)= distribute(es, s)       | dist(c1::current,eps,rest)=let
175                    in (a, [e1]@b, c) end              val(i, exp)= doubleEps(c1::current,eps,rest)
176            in  (case i of 1=>(i,[exp],[E.Const 2.0])
177                |_=> dist(current, eps@[c1],rest))
178              end              end
val (change, eps,rest)= distribute(findeps([], e))
in (change, eps,rest) end
179
180
181
182        val (es,rest)=findeps([],e)
183
184        in
185            dist(es,[],rest)
186        end
187
188    (*
189
190
191
# Line 211  Line 252
252
253
254
255                    *)
256
257    fun reduceDelta(E.Sum(c,E.Prod p))=let
258
259        fun findDeltas(dels,rest,E.Delta d::es)= findDeltas(dels@[E.Delta d], rest, es)
260        | findDeltas(dels,rest,E.Epsilon eps::es)=findDeltas(dels,rest@[E.Epsilon eps],es)
261        | findDeltas(dels,rest,es)=  (dels,rest,es)
262
263        fun rmIndex(_,_,[])=[]
264            | rmIndex([],[],cs)=cs
265            | rmIndex([],m ,e1::cs)=[e1]@rmIndex(m,[],cs)
266            | rmIndex(i::ix,rest ,c::cs)=
267                if(i=c) then rmIndex(rest@ix,[],cs)
268                else rmIndex(ix,rest@[i],c::cs)
269
270        fun distribute(change,d,dels,[],done)=(change,dels@d,done)
271            | distribute(change,[],[],e,done)=(change,[],done@e)
272            | distribute(change,E.Delta(i,j)::ds,dels,E.Tensor(id,[tx])::es,done)=
273                if(j=tx) then distribute(change@[j],dels@ds,[] ,es ,done@[E.Tensor(id,[i])])
274                else distribute(change,ds,dels@[E.Delta(i,j)],E.Tensor(id,[tx])::es,done)
275            | distribute(change,d,dels,e::es,done)=distribute(change,dels@d,[],es,done@[e])
276
277        val (dels,eps,es)=findDeltas([],[],p)
278        val (change,dels',done)=distribute([],dels,[],es,[])
279        val index=rmIndex(change,[],c)
280
281      in
282           (change, E.Sum(index,E.Prod (eps@dels'@done)))
283      end
284
285
286
287  (*Apply normalize to each term in product list  (*Apply normalize to each term in product list
288  or Apply normalize to tail of each list*)  or Apply normalize to tail of each list*)
# Line 224  Line 296
296                | E.Delta _ => body                | E.Delta _ => body
297                | E.Value _ =>body                | E.Value _ =>body
298                | E.Epsilon _=>body                | E.Epsilon _=>body
299
300                | E.Neg e => E.Neg(rewriteBody e)                | E.Neg e => E.Neg(rewriteBody e)
301                | E.Add es => let val (b,a)= mkAdd(List.map rewriteBody es)                | E.Add es => let val (b,a)= mkAdd(List.map rewriteBody es)
302                     in if (b=1) then ( changed:=true;a) else a end                     in if (b=1) then ( changed:=true;a) else a end
# Line 234  Line 307
307                | E.Probe(u,v)=>  E.Probe(rewriteBody u, rewriteBody v)                | E.Probe(u,v)=>  E.Probe(rewriteBody u, rewriteBody v)
308                | E.Image es => E.Image(List.map rewriteBody es)                | E.Image es => E.Image(List.map rewriteBody es)
309
310                    (*Product*)
311  (************Summation *************)                | E.Prod [e1] => rewriteBody e1

| E.Sum(0, e)=>e
| E.Sum(_, (E.Const c))=> E.Const c
| E.Sum(c,(E.Add l))=> E.Add(List.map (fn e => E.Sum(c,e)) l)

| E.Sum(c,E.Prod((E.Delta d)::es))=>(
let val (i,dels, e)= mkDel((E.Delta d)::es)
val rest=(case e of [e1]=> rewriteBody e1
|_=> rewriteBody(E.Prod(e)))
val soln= (case rest of E.Prod r=> E.Sum(c-i, E.Prod(dels@r))
|_=>E.Sum(c-i, E.Prod(dels@[rest])))
val q= checkDot(soln)
in if (i=0) then q
else (changed :=true;q)
end )

| E.Sum(c,E.Prod((E.Epsilon e1 )::(E.Epsilon e2)::xs))=>
let val (i,eps, e)= epsToDels(body)
in
if (i=0) then let val e'=rewriteBody(E.Prod(e)) in (case e'
of E.Prod m=> let val (i2, p)= mkProd(eps @ m)
in E.Sum(c, p) end
|_=>E.Sum(c, E.Prod(eps@ [e']))) end
else(let val [list]=e
val ans=rewriteBody(list)
val soln=(case ans
of E.Sub (E.Sum(c1,(E.Prod s1)),E.Sum(c2,(E.Prod s2))) =>
E.Sum(c-3+c1, E.Sub(E.Prod(eps@s1),E.Prod(eps@s2)))
| E.Sub (E.Sum(c1,s1),E.Sum(c2,s2)) =>
E.Sum(c-3+c1, E.Prod(eps@ [E.Sub(s1,s2)]))
|_=> E.Prod(eps@ [ans]))
in (changed :=true;soln) end
) end

| E.Sum(c, E.Apply(E.Partial p,   E.Prod((E.Delta(i,j))::e3 )))=>

let
fun part([], e2, counter)=([], e2, counter)
| part(p1::ps, [E.Delta(i,j)],counter)=
if (p1=j) then ([i]@ps,[],counter-1)
else (let
val (a,b,counter)=part(ps, [E.Delta(i,j)],counter)
in ([p1]@a, b,counter )  end)
val (e1,e2,counter)= part(p, [E.Delta(i,j)],c)

in  E.Sum(counter, E.Apply(E.Partial e1, E.Prod(e2@e3))) end

| E.Sum(c, E.Apply(p, e))=>let
val e'= rewriteBody(E.Sum(c, e))
val p'= rewriteBody p
val (i, e2)= (case e'
of E.Sum(c',exp)=> mkSumApply(E.Sum(c', E.Apply(p', exp)))
|_=>mkApply( E.Apply(p', e')))
in if(i=1) then (changed :=true;e2) else e2 end
| E.Sum(c, e)=> E.Sum(c, rewriteBody e)

(************Product**********)
| E.Prod([e1])=>(rewriteBody e1 )
313                      (changed := true; E.Add(List.map (fn e=> E.Prod([e1, e]@e3)) e2))                      (changed := true; E.Add(List.map (fn e=> E.Prod([e1, e]@e3)) e2))
314                | E.Prod(e1::(E.Sub(e2,e3))::e4)=>                | E.Prod(e1::(E.Sub(e2,e3))::e4)=>
315                      (changed :=true; E.Sub(E.Prod([e1, e2]@e4), E.Prod([e1,e3]@e4 )))                      (changed :=true; E.Sub(E.Prod([e1, e2]@e4), E.Prod([e1,e3]@e4 )))
316                | E.Prod[E.Partial r1,E.Conv(f, deltas)]=>                | E.Prod[E.Partial r1,E.Conv(f, deltas)]=>
317                      (changed:=true; E.Conv(f,deltas@r1))                      (changed:=true; E.Conv(f,deltas@r1))
318                | E.Prod((E.Partial r1)::(E.Partial r2)::e) =>                | E.Prod (E.Partial r1::E.Conv(f,deltas)::ps)=>
319                      (changed := true; E.Prod([E.Partial (r1@r2)] @ e))                     (changed:=true; E.Prod([E.Conv(f,deltas@r1)]@ps))

| E.Prod[(E.Epsilon(e1,e2,e3)), E.Tensor(_,[E.V i1,E.V i2])]=>
if(e2=i1 andalso e3=i2) then (changed :=true;E.Const(0.0))
else body

| E.Prod((E.Epsilon eps1)::es)=> (let
val rest=(case es
of [e1] => rewriteBody e1
| _=> rewriteBody( E.Prod es))

val (i, solution)=(case rest
of E.Prod m=> mkProd ([E.Epsilon eps1] @m )
|_=>  mkProd([E.Epsilon eps1]@ [rest]))
in if (i=1) then (changed:=true;solution)
else solution
end)

| E.Prod (e::es) => (let
val r=rewriteBody(E.Prod es)
val (i,solution)= (case r
of E.Prod m => mkProd([e]@m )
|_=> mkProd([e]@ [r]))
in if (i=1) then (changed:=true;solution)
else solution
end)

(**************Apply*******************)

320
| E.Apply(E.Partial p, E.Prod((E.Delta(i,j))::e3))=>
let fun part([], e2)=([], e2)
| part(p1::ps, [E.Delta(i,j)])=
if (p1=j) then ([i]@ps,[])
else (let val (a,b)=part(ps, [E.Delta(i,j)])
in ([p1]@a, b )  end)
val (e1,e2)= part(p, [E.Delta(i,j)])
in   E.Apply(E.Partial e1, E.Prod(e2@e3)) end
321
322                | E.Apply(E.Partial d,e)=> ( let val (t1,t2)= mkApply(E.Apply(E.Partial d, rewriteBody e))                | E.Prod(e::es)=>let
323                      in if (t1=1) then (changed :=true;t2) else t2 end)                      val e'=rewriteBody e
324                        val e2=rewriteBody(E.Prod es)
325                        in(case e2 of E.Prod p'=> E.Prod([e']@p')
326                            |_=>E.Prod [e',e2])
327                       end
328
329                | E.Apply(E.Prod d,e)=> ( let val (t1,t2)= mkApply(E.Apply(rewriteBody (E.Prod d), rewriteBody e))                (*Apply*)
330                     in if (t1=1) then (changed :=true;t2) else t2 end)                | E.Apply(e1,e2)=>E.Apply(rewriteBody e1, rewriteBody e2)
331
| E.Apply _ => (print "Err Apply ";body)
332
333
334                  (* Sum *)
335                  | E.Sum([],e)=> rewriteBody e
336                       | E.Sum(_,E.Const c)=>(changed:=true;E.Const c)
337                  | E.Sum(c,(E.Add l))=> (changed:=true;E.Add(List.map (fn e => E.Sum(c,e)) l))
338                  | E.Sum(c,E.Prod(E.Epsilon eps1::E.Epsilon eps2::ps))=>
339                       let val (i,e,rest)=epsToDels(body)
340                       in (case (i, e,rest)
341                       of (1,[e1],_) =>(changed:=true;e1)
342                            |(0,eps,[])=>body
343                            |(0,eps,rest)=>(let
344                                val p'=rewriteBody(E.Prod rest)
345                                val p''= (case p' of E.Prod p=>p |e=>[e])
346                                in E.Sum(c, E.Prod (eps@p'')) end
347                                )
348                |_=> body                |_=> body
349                       ) end
350                  | E.Sum(c, E.Prod(E.Delta d::es))=>let
351                        val (change,body')=reduceDelta(body)
352                       in (case change of []=>body'|_=>(changed:=true;body')) end
353                  | E.Sum(c,e)=>E.Sum(c,rewriteBody e)
354
355              (*end case*))              (*end case*))
356
# Line 366  Line 358
358              val body' = rewriteBody body              val body' = rewriteBody body
359              in              in
360                if !changed                if !changed
361                  then (changed := false; loop body')                     then (changed := false; print " \n \t => \n \t ";print( P.printbody body');print "\n";loop body')
362                  else body'                  else body'
363              end              end
364      val b = loop body      val b = loop body
365      in      in
366      ((Ein.EIN{params=params, index=index, body=b}))      ((Ein.EIN{params=params, index=index, body=b}))
367      end      end

4 end*)

368  end  end
369
370

Legend:
 Removed from v.2448 changed lines Added in v.2449

 root@smlnj-gforge.cs.uchicago.edu ViewVC Help Powered by ViewVC 1.0.0