Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml
ViewVC logotype

Diff of /branches/charisee_dev/src/compiler/high-il/normalize-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2448, Tue Oct 1 00:57:08 2013 UTC revision 2449, Thu Oct 3 20:15:16 2013 UTC
# Line 4  Line 4 
4      local      local
5    
6      structure E = Ein      structure E = Ein
7       structure P=Printer
8    
9      in      in
10    
11        (*  
12  (*Flattens Add constructor: change, expression *)  (*Flattens Add constructor: change, expression *)
13  fun mkAdd [e]=(1,e)  fun mkAdd [e]=(1,e)
14      | mkAdd(e)=let      | mkAdd(e)=let
# Line 29  Line 29 
29                  (* end case *)                  (* end case *)
30       end       end
31    
32    (*
33  fun mkProd [e]=(1,e)  fun mkProd [e]=(1,e)
34      | mkProd(e)=let      | mkProd(e)=let
35      fun flatten(i,((E.Prod l)::l'))= flatten(1,l@l')      fun flatten(i,((E.Prod l)::l'))= flatten(1,l@l')
# Line 111  Line 112 
112      The epsToDels Function searches for Epsilons in the expression, checks for this identity in all adjacent Epsilons and if needed, does the transformation.      The epsToDels Function searches for Epsilons in the expression, checks for this identity in all adjacent Epsilons and if needed, does the transformation.
113       The Function returns two separate list, 1 is the remaining list of Epsilons that have not be changed to deltas, and the second is the Product of the remaining expression.       The Function returns two separate list, 1 is the remaining list of Epsilons that have not be changed to deltas, and the second is the Product of the remaining expression.
114    Ex:(Epsilon_ijk Epsilon_ilm) Epsilon_stu e =>([Epsilon_stu], [Delta_jl,Delta_km,e -Delta_jm Delta_kl, e] )    Ex:(Epsilon_ijk Epsilon_ilm) Epsilon_stu e =>([Epsilon_stu], [Delta_jl,Delta_km,e -Delta_jm Delta_kl, e] )
115     This is useful since we can normalize the second list without having to normalize the epsilons again.*)     This is useful since we can normalize the second list without having to normalize the epsilons again.
116            4(Eps Eps)
117           3( Delta_liDelta mj- Delta_mi Delta_lj)
118             Ai-
119            *)
120    
121    
122                       *)
123    
124    
125  fun epsToDels(E.Sum(count,E.Prod e))= let  fun epsToDels(E.Sum(count,E.Prod e))= let
126      fun doubleEps((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::es,e3)=      fun doubleEps((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::es,eps,e3)=
127          let          let
128          fun createDeltas(s,t,u,v, e3)=  
129              (1,  E.Sub(E.Sum(2,E.Prod([E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)] @e3)),          (*Function is called when eps are being changed to deltas*)
130                      E.Sum(2,E.Prod([E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]@e3))))          fun createDeltas(i,s,t,u,v, e3)= let
131          in if(a=d) then createDeltas(b,c,e,f, e3)  
132             else if(a=e) then createDeltas(b,c,f,d, e3)              (*remove index from original index list*)
133             else if(a=f) then createDeltas(b,c,d,e, e3)              (*currrent, left, sumIndex*)
134             else if(b=d) then createDeltas(c,a,e,f, e3)  
135             else if(b=e) then createDeltas(c,a,f,d,e3)              fun rmIndex(_,_,[])=[]
136             else if(b=f) then createDeltas(c,a,d,e,e3)              | rmIndex([],[],cs)=cs
137             else if(c=d) then createDeltas(a,b,e,f,e3)              | rmIndex([],m ,e1::cs)=[e1]@rmIndex(m,[],cs)
138             else if(c=e) then createDeltas(a,b,f,d,e3)              | rmIndex(i::ix,rest ,(E.V c)::cs)=
139             else if(c=f) then createDeltas(a,b,d,e,e3)                     if(i=c) then rmIndex(rest@ix,[],cs)
140             else (0,(E.Prod((E.Epsilon (a,b,c))::(E.Epsilon(d,e,f))::e3)))                     else rmIndex(ix,rest@[i],(E.V c)::cs)
141    
142                val s'= rmIndex([i,s,t,u,v],[],count)
143                val s''=[E.V s, E.V t ,E.V u, E.V v]
144                val deltas= E.Sub(
145                        E.Sum(s'',E.Prod([E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)] @e3)),
146                        E.Sum(s'',E.Prod([E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]@e3)))
147    
148                in (case (eps,s')
149                    of ([],[]) =>(1,deltas)
150                    |([],_)=>(1,E.Sum(s',deltas))
151                    |(_,[])=>(1,E.Prod(eps@[deltas]))
152                    |(_,_) =>(1, E.Sum(s', E.Prod(eps@[deltas])))
153                       )
154                 end
155    
156            in if(a=d) then createDeltas(a,b,c,e,f, e3)
157               else if(a=e) then createDeltas(a,b,c,f,d, e3)
158               else if(a=f) then createDeltas(a,b,c,d,e, e3)
159               else if(b=d) then createDeltas(b,c,a,e,f, e3)
160               else if(b=e) then createDeltas(b,c,a,f,d,e3)
161               else if(b=f) then createDeltas(b,c,a,d,e,e3)
162               else if(c=d) then createDeltas(c,a,b,e,f,e3)
163               else if(c=e) then createDeltas(c,a,b,f,d,e3)
164               else if(c=f) then createDeltas(c,a,b,d,e,e3)
165               else (0,E.Const 0.0)
166          end          end
167      fun findeps(e,[])= (e,[])      fun findeps(e,[])= (e,[])
168        | findeps(e,(E.Epsilon eps)::es)=  findeps(e@[E.Epsilon eps],es)        | findeps(e,(E.Epsilon eps)::es)=  findeps(e@[E.Epsilon eps],es)
169        | findeps(e,es)= (e, es)        | findeps(e,es)= (e, es)
170      fun distribute([], s)=(0, [],s)  
171        | distribute([e1], s)=(0, [e1], s)  
172        | distribute(e1::es, s)= let val(i, exp)=doubleEps(e1::es, s)      fun dist([],eps,rest)=(0,eps,rest)
173            in if(i=1) then (1, tl(es), [exp])       | dist([e],eps,rest)=(0,eps@[e],rest)
174               else let val(a,b,c)= distribute(es, s)       | dist(c1::current,eps,rest)=let
175                    in (a, [e1]@b, c) end              val(i, exp)= doubleEps(c1::current,eps,rest)
176            in  (case i of 1=>(i,[exp],[E.Const 2.0])
177                |_=> dist(current, eps@[c1],rest))
178              end              end
     val (change, eps,rest)= distribute(findeps([], e))  
     in (change, eps,rest) end  
179    
180    
181    
182        val (es,rest)=findeps([],e)
183    
184        in
185            dist(es,[],rest)
186        end
187    
188    (*
189    
190    
191    
# Line 211  Line 252 
252    
253    
254    
255                    *)
256    
257    fun reduceDelta(E.Sum(c,E.Prod p))=let
258    
259        fun findDeltas(dels,rest,E.Delta d::es)= findDeltas(dels@[E.Delta d], rest, es)
260        | findDeltas(dels,rest,E.Epsilon eps::es)=findDeltas(dels,rest@[E.Epsilon eps],es)
261        | findDeltas(dels,rest,es)=  (dels,rest,es)
262    
263        fun rmIndex(_,_,[])=[]
264            | rmIndex([],[],cs)=cs
265            | rmIndex([],m ,e1::cs)=[e1]@rmIndex(m,[],cs)
266            | rmIndex(i::ix,rest ,c::cs)=
267                if(i=c) then rmIndex(rest@ix,[],cs)
268                else rmIndex(ix,rest@[i],c::cs)
269    
270        fun distribute(change,d,dels,[],done)=(change,dels@d,done)
271            | distribute(change,[],[],e,done)=(change,[],done@e)
272            | distribute(change,E.Delta(i,j)::ds,dels,E.Tensor(id,[tx])::es,done)=
273                if(j=tx) then distribute(change@[j],dels@ds,[] ,es ,done@[E.Tensor(id,[i])])
274                else distribute(change,ds,dels@[E.Delta(i,j)],E.Tensor(id,[tx])::es,done)
275            | distribute(change,d,dels,e::es,done)=distribute(change,dels@d,[],es,done@[e])
276    
277        val (dels,eps,es)=findDeltas([],[],p)
278        val (change,dels',done)=distribute([],dels,[],es,[])
279        val index=rmIndex(change,[],c)
280    
281      in
282           (change, E.Sum(index,E.Prod (eps@dels'@done)))
283      end
284    
285    
286    
287  (*Apply normalize to each term in product list  (*Apply normalize to each term in product list
288  or Apply normalize to tail of each list*)  or Apply normalize to tail of each list*)
# Line 224  Line 296 
296                | E.Delta _ => body                | E.Delta _ => body
297                | E.Value _ =>body                | E.Value _ =>body
298                | E.Epsilon _=>body                | E.Epsilon _=>body
299    
300                | E.Neg e => E.Neg(rewriteBody e)                | E.Neg e => E.Neg(rewriteBody e)
301                | E.Add es => let val (b,a)= mkAdd(List.map rewriteBody es)                | E.Add es => let val (b,a)= mkAdd(List.map rewriteBody es)
302                     in if (b=1) then ( changed:=true;a) else a end                     in if (b=1) then ( changed:=true;a) else a end
# Line 234  Line 307 
307                | E.Probe(u,v)=>  E.Probe(rewriteBody u, rewriteBody v)                | E.Probe(u,v)=>  E.Probe(rewriteBody u, rewriteBody v)
308                | E.Image es => E.Image(List.map rewriteBody es)                | E.Image es => E.Image(List.map rewriteBody es)
309    
310                    (*Product*)
311  (************Summation *************)                | E.Prod [e1] => rewriteBody e1
   
               | E.Sum(0, e)=>e  
               | E.Sum(_, (E.Const c))=> E.Const c  
               | E.Sum(c,(E.Add l))=> E.Add(List.map (fn e => E.Sum(c,e)) l)  
   
               | E.Sum(c,E.Prod((E.Delta d)::es))=>(  
                 let val (i,dels, e)= mkDel((E.Delta d)::es)  
                     val rest=(case e of [e1]=> rewriteBody e1  
                             |_=> rewriteBody(E.Prod(e)))  
                     val soln= (case rest of E.Prod r=> E.Sum(c-i, E.Prod(dels@r))  
                         |_=>E.Sum(c-i, E.Prod(dels@[rest])))  
                     val q= checkDot(soln)  
                     in if (i=0) then q  
                    else (changed :=true;q)  
                    end )  
   
   
   
               | E.Sum(c,E.Prod((E.Epsilon e1 )::(E.Epsilon e2)::xs))=>  
                    let val (i,eps, e)= epsToDels(body)  
                    in  
                    if (i=0) then let val e'=rewriteBody(E.Prod(e)) in (case e'  
                         of E.Prod m=> let val (i2, p)= mkProd(eps @ m)  
                                     in E.Sum(c, p) end  
                         |_=>E.Sum(c, E.Prod(eps@ [e']))) end  
                    else(let val [list]=e  
                         val ans=rewriteBody(list)  
                         val soln=(case ans  
                             of E.Sub (E.Sum(c1,(E.Prod s1)),E.Sum(c2,(E.Prod s2))) =>  
                                 E.Sum(c-3+c1, E.Sub(E.Prod(eps@s1),E.Prod(eps@s2)))  
                             | E.Sub (E.Sum(c1,s1),E.Sum(c2,s2)) =>  
                                 E.Sum(c-3+c1, E.Prod(eps@ [E.Sub(s1,s2)]))  
                             |_=> E.Prod(eps@ [ans]))  
                         in (changed :=true;soln) end  
                    ) end  
   
   
             | E.Sum(c, E.Apply(E.Partial p,   E.Prod((E.Delta(i,j))::e3 )))=>  
   
                 let  
                    fun part([], e2, counter)=([], e2, counter)  
                       | part(p1::ps, [E.Delta(i,j)],counter)=  
                             if (p1=j) then ([i]@ps,[],counter-1)  
                             else (let  
                                     val (a,b,counter)=part(ps, [E.Delta(i,j)],counter)  
                                 in ([p1]@a, b,counter )  end)  
                    val (e1,e2,counter)= part(p, [E.Delta(i,j)],c)  
   
                    in  E.Sum(counter, E.Apply(E.Partial e1, E.Prod(e2@e3))) end  
   
             | E.Sum(c, E.Apply(p, e))=>let  
                    val e'= rewriteBody(E.Sum(c, e))  
                    val p'= rewriteBody p  
                    val (i, e2)= (case e'  
                         of E.Sum(c',exp)=> mkSumApply(E.Sum(c', E.Apply(p', exp)))  
                         |_=>mkApply( E.Apply(p', e')))  
                    in if(i=1) then (changed :=true;e2) else e2 end  
               | E.Sum(c, e)=> E.Sum(c, rewriteBody e)  
   
   
   
   
   
   
 (************Product**********)  
               | E.Prod([e1])=>(rewriteBody e1 )  
312                | E.Prod(e1::(E.Add(e2))::e3)=>                | E.Prod(e1::(E.Add(e2))::e3)=>
313                      (changed := true; E.Add(List.map (fn e=> E.Prod([e1, e]@e3)) e2))                      (changed := true; E.Add(List.map (fn e=> E.Prod([e1, e]@e3)) e2))
314                | E.Prod(e1::(E.Sub(e2,e3))::e4)=>                | E.Prod(e1::(E.Sub(e2,e3))::e4)=>
315                      (changed :=true; E.Sub(E.Prod([e1, e2]@e4), E.Prod([e1,e3]@e4 )))                      (changed :=true; E.Sub(E.Prod([e1, e2]@e4), E.Prod([e1,e3]@e4 )))
316                | E.Prod[E.Partial r1,E.Conv(f, deltas)]=>                | E.Prod[E.Partial r1,E.Conv(f, deltas)]=>
317                      (changed:=true; E.Conv(f,deltas@r1))                      (changed:=true; E.Conv(f,deltas@r1))
318                | E.Prod((E.Partial r1)::(E.Partial r2)::e) =>                | E.Prod (E.Partial r1::E.Conv(f,deltas)::ps)=>
319                      (changed := true; E.Prod([E.Partial (r1@r2)] @ e))                     (changed:=true; E.Prod([E.Conv(f,deltas@r1)]@ps))
   
   
               | E.Prod[(E.Epsilon(e1,e2,e3)), E.Tensor(_,[E.V i1,E.V i2])]=>  
                     if(e2=i1 andalso e3=i2) then (changed :=true;E.Const(0.0))  
                     else body  
   
               | E.Prod((E.Epsilon eps1)::es)=> (let  
                     val rest=(case es  
                         of [e1] => rewriteBody e1  
                          | _=> rewriteBody( E.Prod es))  
   
                     val (i, solution)=(case rest  
                         of E.Prod m=> mkProd ([E.Epsilon eps1] @m )  
                         |_=>  mkProd([E.Epsilon eps1]@ [rest]))  
                 in if (i=1) then (changed:=true;solution)  
                     else solution  
                 end)  
   
              | E.Prod (e::es) => (let  
                     val r=rewriteBody(E.Prod es)  
                     val (i,solution)= (case r  
                         of E.Prod m => mkProd([e]@m )  
                         |_=> mkProd([e]@ [r]))  
                 in if (i=1) then (changed:=true;solution)  
                         else solution  
                 end)  
   
 (**************Apply*******************)  
   
320    
               | E.Apply(E.Partial p, E.Prod((E.Delta(i,j))::e3))=>  
                     let fun part([], e2)=([], e2)  
                           | part(p1::ps, [E.Delta(i,j)])=  
                             if (p1=j) then ([i]@ps,[])  
                             else (let val (a,b)=part(ps, [E.Delta(i,j)])  
                                 in ([p1]@a, b )  end)  
                         val (e1,e2)= part(p, [E.Delta(i,j)])  
                     in   E.Apply(E.Partial e1, E.Prod(e2@e3)) end  
321    
322                | E.Apply(E.Partial d,e)=> ( let val (t1,t2)= mkApply(E.Apply(E.Partial d, rewriteBody e))                | E.Prod(e::es)=>let
323                      in if (t1=1) then (changed :=true;t2) else t2 end)                      val e'=rewriteBody e
324                        val e2=rewriteBody(E.Prod es)
325                        in(case e2 of E.Prod p'=> E.Prod([e']@p')
326                            |_=>E.Prod [e',e2])
327                       end
328    
329                | E.Apply(E.Prod d,e)=> ( let val (t1,t2)= mkApply(E.Apply(rewriteBody (E.Prod d), rewriteBody e))                (*Apply*)
330                     in if (t1=1) then (changed :=true;t2) else t2 end)                | E.Apply(e1,e2)=>E.Apply(rewriteBody e1, rewriteBody e2)
331    
               | E.Apply _ => (print "Err Apply ";body)  
332    
333    
334                  (* Sum *)
335                  | E.Sum([],e)=> rewriteBody e
336                       | E.Sum(_,E.Const c)=>(changed:=true;E.Const c)
337                  | E.Sum(c,(E.Add l))=> (changed:=true;E.Add(List.map (fn e => E.Sum(c,e)) l))
338                  | E.Sum(c,E.Prod(E.Epsilon eps1::E.Epsilon eps2::ps))=>
339                       let val (i,e,rest)=epsToDels(body)
340                       in (case (i, e,rest)
341                       of (1,[e1],_) =>(changed:=true;e1)
342                            |(0,eps,[])=>body
343                            |(0,eps,rest)=>(let
344                                val p'=rewriteBody(E.Prod rest)
345                                val p''= (case p' of E.Prod p=>p |e=>[e])
346                                in E.Sum(c, E.Prod (eps@p'')) end
347                                )
348                |_=> body                |_=> body
349                       ) end
350                  | E.Sum(c, E.Prod(E.Delta d::es))=>let
351                        val (change,body')=reduceDelta(body)
352                       in (case change of []=>body'|_=>(changed:=true;body')) end
353                  | E.Sum(c,e)=>E.Sum(c,rewriteBody e)
354    
355              (*end case*))              (*end case*))
356    
# Line 366  Line 358 
358              val body' = rewriteBody body              val body' = rewriteBody body
359              in              in
360                if !changed                if !changed
361                  then (changed := false; loop body')                     then (changed := false; print " \n \t => \n \t ";print( P.printbody body');print "\n";loop body')
362                  else body'                  else body'
363              end              end
364      val b = loop body      val b = loop body
365      in      in
366      ((Ein.EIN{params=params, index=index, body=b}))      ((Ein.EIN{params=params, index=index, body=b}))
367      end      end
   
   
   
                    4 end*)  
   
   
368  end  end
369    
370    

Legend:
Removed from v.2448  
changed lines
  Added in v.2449

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0