Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-il/normalize-ein.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/high-il/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2397, Sat Jul 6 20:50:46 2013 UTC revision 2450, Thu Oct 3 20:17:08 2013 UTC
# Line 2  Line 2 
2  structure NormalizeEin = struct  structure NormalizeEin = struct
3    
4      local      local
     structure G = GenericEin  
     structure E = Ein  
     structure S = Specialize  
     structure R = Rewrite  
   
5    
6        structure E = Ein
7      (* structure P=Printer*)
8    
9      in      in
10    
11    
   
 (*  
 If changed is true then I know the expression will run through the funciton again.  
 However, if not, then I want to make sure that every expression in the Product is examined, and not just individually but as a group.  
 Prod[t1,t2,(t3+t4)] indivually=> same  
 Prod[t1] @ Prod[t2,(t3+t4)]=> Notice rule here  
 Prod[t1] @ Add(Prod (t2, t3), Prod (t2, t4))  
 => Add( Prod[t1, Prod(t2,t3)]..)  
 => Add (Prod[t1,t2,t3]) Flattened  
   
 *)  
   
   
   
   
   
12  (*Flattens Add constructor: change, expression *)  (*Flattens Add constructor: change, expression *)
13  fun mkAdd [e]=(1,e)  fun mkAdd [e]=(1,e)
14      | mkAdd(e)=let      | mkAdd(e)=let
# Line 48  Line 29 
29                  (* end case *)                  (* end case *)
30       end       end
31    
32    (*
33  fun mkProd [e]=(1,e)  fun mkProd [e]=(1,e)
34      | mkProd(e)=let      | mkProd(e)=let
35      fun flatten(i,((E.Prod l)::l'))= flatten(1,l@l')      fun flatten(i,((E.Prod l)::l'))= flatten(1,l@l')
# Line 71  Line 53 
53    
54    
55  fun mkEps(e)= (case e  fun mkEps(e)= (case e
56      of E.Apply(E.Partial [a], E.Prod( e2::m ))=> (0,e)      of E.Apply(E.Partial [E.V a], E.Prod( e2::m ))=> (0,e)
57       | E.Apply(E.Partial [a,b], E.Prod( (E.Epsilon(i,j,k))::m ))=>       | E.Apply(E.Partial [E.V a,E.V b], E.Prod( (E.Epsilon(i,j,k))::m ))=>
58          (if(a=i andalso b=j) then (1,E.Const(0.0))          (if(a=i andalso b=j) then (1,E.Const(0.0))
59          else if(a=i andalso b=k) then (1,E.Const(0.0))          else if(a=i andalso b=k) then (1,E.Const(0.0))
60          else if(a=j andalso b=i) then (1,E.Const(0.0))          else if(a=j andalso b=i) then (1,E.Const(0.0))
# Line 130  Line 112 
112      The epsToDels Function searches for Epsilons in the expression, checks for this identity in all adjacent Epsilons and if needed, does the transformation.      The epsToDels Function searches for Epsilons in the expression, checks for this identity in all adjacent Epsilons and if needed, does the transformation.
113       The Function returns two separate list, 1 is the remaining list of Epsilons that have not be changed to deltas, and the second is the Product of the remaining expression.       The Function returns two separate list, 1 is the remaining list of Epsilons that have not be changed to deltas, and the second is the Product of the remaining expression.
114    Ex:(Epsilon_ijk Epsilon_ilm) Epsilon_stu e =>([Epsilon_stu], [Delta_jl,Delta_km,e -Delta_jm Delta_kl, e] )    Ex:(Epsilon_ijk Epsilon_ilm) Epsilon_stu e =>([Epsilon_stu], [Delta_jl,Delta_km,e -Delta_jm Delta_kl, e] )
115     This is useful since we can normalize the second list without having to normalize the epsilons again.*)     This is useful since we can normalize the second list without having to normalize the epsilons again.
116            4(Eps Eps)
117           3( Delta_liDelta mj- Delta_mi Delta_lj)
118             Ai-
119            *)
120    
121    
122                       *)
123    
124    
125  fun epsToDels(E.Sum(count,E.Prod e))= let  fun epsToDels(E.Sum(count,E.Prod e))= let
126      fun doubleEps((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::es,e3)=      fun doubleEps((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::es,eps,e3)=
127          let          let
128          fun createDeltas(s,t,u,v, e3)=  
129              (1,  E.Sub(E.Sum(2,E.Prod([E.Delta(s,u), E.Delta(t,v)] @e3)),          (*Function is called when eps are being changed to deltas*)
130                      E.Sum(2,E.Prod([E.Delta(s,v), E.Delta(t,u)]@e3))))          fun createDeltas(i,s,t,u,v, e3)= let
131          in if(a=d) then createDeltas(b,c,e,f, e3)  
132             else if(a=e) then createDeltas(b,c,f,d, e3)              (*remove index from original index list*)
133             else if(a=f) then createDeltas(b,c,d,e, e3)              (*currrent, left, sumIndex*)
134             else if(b=d) then createDeltas(c,a,e,f, e3)  
135             else if(b=e) then createDeltas(c,a,f,d,e3)              fun rmIndex(_,_,[])=[]
136             else if(b=f) then createDeltas(c,a,d,e,e3)              | rmIndex([],[],cs)=cs
137             else if(c=d) then createDeltas(a,b,e,f,e3)              | rmIndex([],m ,e1::cs)=[e1]@rmIndex(m,[],cs)
138             else if(c=e) then createDeltas(a,b,f,d,e3)              | rmIndex(i::ix,rest ,(E.V c)::cs)=
139             else if(c=f) then createDeltas(a,b,d,e,e3)                     if(i=c) then rmIndex(rest@ix,[],cs)
140             else (0,(E.Prod((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::e3)))                     else rmIndex(ix,rest@[i],(E.V c)::cs)
141    
142                val s'= rmIndex([i,s,t,u,v],[],count)
143                val s''=[E.V s, E.V t ,E.V u, E.V v]
144                val deltas= E.Sub(
145                        E.Sum(s'',E.Prod([E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)] @e3)),
146                        E.Sum(s'',E.Prod([E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]@e3)))
147    
148                in (case (eps,s')
149                    of ([],[]) =>(1,deltas)
150                    |([],_)=>(1,E.Sum(s',deltas))
151                    |(_,[])=>(1,E.Prod(eps@[deltas]))
152                    |(_,_) =>(1, E.Sum(s', E.Prod(eps@[deltas])))
153                       )
154                 end
155    
156            in if(a=d) then createDeltas(a,b,c,e,f, e3)
157               else if(a=e) then createDeltas(a,b,c,f,d, e3)
158               else if(a=f) then createDeltas(a,b,c,d,e, e3)
159               else if(b=d) then createDeltas(b,c,a,e,f, e3)
160               else if(b=e) then createDeltas(b,c,a,f,d,e3)
161               else if(b=f) then createDeltas(b,c,a,d,e,e3)
162               else if(c=d) then createDeltas(c,a,b,e,f,e3)
163               else if(c=e) then createDeltas(c,a,b,f,d,e3)
164               else if(c=f) then createDeltas(c,a,b,d,e,e3)
165               else (0,E.Const 0.0)
166          end          end
167      fun findeps(e,[])= (e,[])      fun findeps(e,[])= (e,[])
168        | findeps(e,(E.Epsilon eps)::es)=  findeps(e@[E.Epsilon eps],es)        | findeps(e,(E.Epsilon eps)::es)=  findeps(e@[E.Epsilon eps],es)
169        | findeps(e,es)= (e, es)        | findeps(e,es)= (e, es)
170      fun distribute([], s)=(0, [],s)  
171        | distribute([e1], s)=(0, [e1], s)  
172        | distribute(e1::es, s)= let val(i, exp)=doubleEps(e1::es, s)      fun dist([],eps,rest)=(0,eps,rest)
173            in if(i=1) then (1, tl(es), [exp])       | dist([e],eps,rest)=(0,eps@[e],rest)
174               else let val(a,b,c)= distribute(es, s)       | dist(c1::current,eps,rest)=let
175                    in (a, [e1]@b, c) end              val(i, exp)= doubleEps(c1::current,eps,rest)
176            in  (case i of 1=>(i,[exp],[E.Const 2.0])
177                |_=> dist(current, eps@[c1],rest))
178              end              end
     val (change, eps,rest)= distribute(findeps([], e))  
     in (change, eps,rest) end  
179    
180    
181    
182        val (es,rest)=findeps([],e)
183    
184        in
185            dist(es,[],rest)
186        end
187    
188    (*
189    
190    
191    
# Line 230  Line 252 
252    
253    
254    
255                    *)
256    
257    fun reduceDelta(E.Sum(c,E.Prod p))=let
258    
259        fun findDeltas(dels,rest,E.Delta d::es)= findDeltas(dels@[E.Delta d], rest, es)
260        | findDeltas(dels,rest,E.Epsilon eps::es)=findDeltas(dels,rest@[E.Epsilon eps],es)
261        | findDeltas(dels,rest,es)=  (dels,rest,es)
262    
263        fun rmIndex(_,_,[])=[]
264            | rmIndex([],[],cs)=cs
265            | rmIndex([],m ,e1::cs)=[e1]@rmIndex(m,[],cs)
266            | rmIndex(i::ix,rest ,c::cs)=
267                if(i=c) then rmIndex(rest@ix,[],cs)
268                else rmIndex(ix,rest@[i],c::cs)
269    
270        fun distribute(change,d,dels,[],done)=(change,dels@d,done)
271            | distribute(change,[],[],e,done)=(change,[],done@e)
272            | distribute(change,E.Delta(i,j)::ds,dels,E.Tensor(id,[tx])::es,done)=
273                if(j=tx) then distribute(change@[j],dels@ds,[] ,es ,done@[E.Tensor(id,[i])])
274                else distribute(change,ds,dels@[E.Delta(i,j)],E.Tensor(id,[tx])::es,done)
275            | distribute(change,d,dels,e::es,done)=distribute(change,dels@d,[],es,done@[e])
276    
277        val (dels,eps,es)=findDeltas([],[],p)
278        val (change,dels',done)=distribute([],dels,[],es,[])
279        val index=rmIndex(change,[],c)
280    
281      in
282           (change, E.Sum(index,E.Prod (eps@dels'@done)))
283      end
284    
285    
286    
287  (*Apply normalize to each term in product list  (*Apply normalize to each term in product list
288  or Apply normalize to tail of each list*)  or Apply normalize to tail of each list*)
# Line 239  Line 292 
292               of E.Const _=> body               of E.Const _=> body
293                | E.Tensor _ =>body                | E.Tensor _ =>body
294                | E.Field _=> body                | E.Field _=> body
295                  | E.Kernel _ =>body
296                | E.Delta _ => body                | E.Delta _ => body
297                  | E.Value _ =>body
298                | E.Epsilon _=>body                | E.Epsilon _=>body
299                | E.Conv _=> body  
300                | E.Partial _=>body                | E.Neg e => E.Neg(rewriteBody e)
301                            | E.Add es => let val (b,a)= mkAdd(List.map rewriteBody es)                            | E.Add es => let val (b,a)= mkAdd(List.map rewriteBody es)
302                      in if (b=1) then ( changed:=true;a) else a end                      in if (b=1) then ( changed:=true;a) else a end
303                | E.Sub (a,b)=>  E.Sub(rewriteBody a, rewriteBody b)                | E.Sub (a,b)=>  E.Sub(rewriteBody a, rewriteBody b)
304                | E.Div (a, b) => E.Div(rewriteBody a, rewriteBody b)                | E.Div (a, b) => E.Div(rewriteBody a, rewriteBody b)
305                | E.Probe(u,v)=> (  E.Probe(rewriteBody u, v))                | E.Partial _=>body
306                | E.Sum(0, e)=>e                | E.Conv (V, alpha)=> E.Conv(rewriteBody V, alpha)
307                | E.Sum(_, (E.Const c))=> E.Const c                | E.Probe(u,v)=>  E.Probe(rewriteBody u, rewriteBody v)
308                | E.Sum(c,(E.Add l))=> E.Add(List.map (fn e => E.Sum(c,e)) l)                | E.Image es => E.Image(List.map rewriteBody es)
   
               | E.Sum(c,E.Prod((E.Delta d)::es))=>(  
                 let val (i,dels, e)= mkDel((E.Delta d)::es)  
                     val rest=(case e of [e1]=> rewriteBody e1  
                             |_=> rewriteBody(E.Prod(e)))  
                     val soln= (case rest of E.Prod r=> E.Sum(c-i, E.Prod(dels@r))  
                         |_=>E.Sum(c-i, E.Prod(dels@[rest])))  
                     val q= checkDot(soln)  
                     in if (i=0) then q  
                    else (changed :=true;q)  
                    end )  
   
               | E.Sum(c,E.Prod((E.Epsilon e1 )::(E.Epsilon e2)::xs))=>  
                    let val (i,eps, e)= epsToDels(body)  
                    in  
                    if (i=0) then let val e'=rewriteBody(E.Prod(e)) in (case e'  
                         of E.Prod m=> let val (i2, p)= mkProd(eps @ m)  
                                     in E.Sum(c, p) end  
                         |_=>E.Sum(c, E.Prod(eps@ [e']))) end  
                    else(let val [list]=e  
                         val ans=rewriteBody(list)  
                         val soln=(case ans  
                             of E.Sub (E.Sum(c1,(E.Prod s1)),E.Sum(c2,(E.Prod s2))) =>  
                                 E.Sum(c-3+c1, E.Sub(E.Prod(eps@s1),E.Prod(eps@s2)))  
                             | E.Sub (E.Sum(c1,s1),E.Sum(c2,s2)) =>  
                                 E.Sum(c-3+c1, E.Prod(eps@ [E.Sub(s1,s2)]))  
                             |_=> E.Prod(eps@ [ans]))  
                         in (changed :=true;soln) end  
                    ) end  
309    
310                    (*Product*)
311                  | E.Prod [e1] => rewriteBody e1
312                  | E.Prod(e1::(E.Add(e2))::e3)=>
313                       (changed := true; E.Add(List.map (fn e=> E.Prod([e1, e]@e3)) e2))
314                  | E.Prod(e1::(E.Sub(e2,e3))::e4)=>
315                       (changed :=true; E.Sub(E.Prod([e1, e2]@e4), E.Prod([e1,e3]@e4 )))
316                  | E.Prod [E.Partial r1,E.Conv(f,deltas)]=>
317                       (changed :=true;E.Conv(f,deltas@r1))
318                  | E.Prod (E.Partial r1::E.Conv(f,deltas)::ps)=>
319                       (changed:=true; E.Prod([E.Conv(f,deltas@r1)]@ps))
320    
             | E.Sum(c, E.Apply(E.Partial p,   E.Prod((E.Delta(i,j))::e3 )))=>  
321    
322                  let fun part([], e2, counter)=([], e2, counter)                | E.Prod(e::es)=>let
323                     | part(p1::ps, [E.Delta(i,j)],counter)=if (p1=j) then ([i]@ps,[],counter-1)                      val e'=rewriteBody e
324                          else (let val (a,b,counter)=part(ps, [E.Delta(i,j)],counter)                      val e2=rewriteBody(E.Prod es)
325                          in ([p1]@a, b,counter )  end)                      in(case e2 of E.Prod p'=> E.Prod([e']@p')
326                  val (e1,e2,counter)= part(p, [E.Delta(i,j)],c)                          |_=>E.Prod [e',e2])
327                       end
                    in   (E.Sum(counter, E.Apply(E.Partial e1, E.Prod(e2@e3)))) end  
   
             | E.Sum(c, E.Apply(p, e))=>let  
                    val e'= rewriteBody(E.Sum(c, e))  
                    val p'= rewriteBody p  
                    val (i, e2)= (case e'  
                         of E.Sum(c',exp)=> mkSumApply(E.Sum(c', E.Apply(p', exp)))  
                         |_=>mkApply( E.Apply(p', e')))  
                    in if(i=1) then (changed :=true;e2) else e2 end  
328    
329                  (*Apply*)
330                  | E.Apply(e1,e2)=>E.Apply(rewriteBody e1, rewriteBody e2)
331    
               | E.Sum(c, e)=> E.Sum(c, rewriteBody e)  
332    
               | E.Prod([e1])=>(rewriteBody e1 )  
               | E.Prod(e1::(E.Add(e2))::e3)=>  
                     (changed := true;  
                     E.Add(List.map (fn e=> E.Prod([e1, e]@e3)) e2))  
               | E.Prod(e1::(E.Sub(e2,e3))::e4)=>  
                     ( changed :=true; E.Sub(E.Prod([e1, e2]@e4), E.Prod([e1,e3]@e4 )))  
               | E.Prod[E.Partial r1,E. Conv(i, j, k, l)]=>  
                     (changed:=true; ( let val j1=  
                                         List.map (fn(x)=> (l,x))  r1 in E.Conv(i, j1@j, k, l) end ))  
               | E.Prod((E.Partial r1)::(E.Partial r2)::e) =>  
                     (changed := true; E.Prod([E.Partial (r1@r2)] @ e)  )  
               | E.Prod[(E.Epsilon(e1,e2,e3)), E.Tensor(_,[i1,i2])]=>  
                     if(e2=i1 andalso e3=i2) then (changed :=true;E.Const(0.0))  
                     else body  
               | E.Prod((E.Epsilon eps1)::es)=> (let  
                    val rest=(case es of [e1] => rewriteBody e1  
                    |_=> rewriteBody(E.Prod(es)))  
                    val (i, solution)=(case rest  
                         of E.Prod m=> mkProd ([E.Epsilon eps1] @m )  
                             |_=>  mkProd([E.Epsilon eps1]@ [rest]))  
                    in if (i=1) then (changed:=true;solution)  
                    else solution end)  
   
              | E.Prod (e::es) => (let val r=rewriteBody(E.Prod es)  
                    val (i,solution)= (case r of E.Prod m => mkProd([e]@m )  
                                     |_=> mkProd([e]@ [r]))  
                    in if (i=1) then (changed:=true;solution)  
                    else solution end)  
               | E.Apply(E.Const _,_) => (E.Const(0.0))  
               | E.Apply(E.Partial p, E.Prod((E.Delta(i,j))::e3))=>  
                     let fun part([], e2)=([], e2)  
                         | part(p1::ps, [E.Delta(i,j)])=if (p1=j) then ([i]@ps,[])  
                                 else (let val (a,b)=part(ps, [E.Delta(i,j)])  
                                 in ([p1]@a, b )  end)  
                     val (e1,e2)= part(p, [E.Delta(i,j)])  
                     in   E.Apply(E.Partial e1, E.Prod(e2@e3)) end  
333    
334                | E.Apply(d,e)=> ( let val (t1,t2)= mkApply(E.Apply(rewriteBody d, rewriteBody e))                (* Sum *)
335                      in if (t1=1) then (changed :=true;t2) else t2 end )                | E.Sum([],e)=> rewriteBody e
336                       | E.Sum(_,E.Const c)=>(changed:=true;E.Const c)
337                  | E.Sum(c,(E.Add l))=> (changed:=true;E.Add(List.map (fn e => E.Sum(c,e)) l))
338                  | E.Sum(c,E.Prod(E.Epsilon eps1::E.Epsilon eps2::ps))=>
339                       let val (i,e,rest)=epsToDels(body)
340                       in (case (i, e,rest)
341                       of (1,[e1],_) =>(changed:=true;e1)
342                            |(0,eps,[])=>body
343                            |(0,eps,rest)=>(let
344                                val p'=rewriteBody(E.Prod rest)
345                                val p''= (case p' of E.Prod p=>p |e=>[e])
346                                in E.Sum(c, E.Prod (eps@p'')) end
347                                )
348                |_=> body                |_=> body
349                       ) end
350                  | E.Sum(c, E.Prod(E.Delta d::es))=>let
351                        val (change,body')=reduceDelta(body)
352                       in (case change of []=>body'|_=>(changed:=true;body')) end
353                  | E.Sum(c,e)=>E.Sum(c,rewriteBody e)
354    
355              (*end case*))              (*end case*))
356    
# Line 350  Line 358 
358              val body' = rewriteBody body              val body' = rewriteBody body
359              in              in
360                if !changed                if !changed
361                  then (changed := false; loop body')                     then (changed := false; (*print " \n \t => \n \t ";print( P.printbody body');print "\n";*)loop body')
362                  else body'                  else body'
363              end              end
364      val b = loop body      val b = loop body
# Line 360  Line 368 
368    end    end
369    
370    
   
   
   
371  end (* local *)  end (* local *)

Legend:
Removed from v.2397  
changed lines
  Added in v.2450

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0