Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2553, Sun Mar 2 19:53:33 2014 UTC revision 2555, Mon Mar 3 19:14:57 2014 UTC
# Line 109  Line 109 
109    
110  fun cleanIndex(e, intialn,index)=let  fun cleanIndex(e, intialn,index)=let
111      (*Each element in the list is unique*)      (*Each element in the list is unique*)
112       val h=print "IN SHIFT"
113    
114        fun uniq list1 =let
     fun uniq(list1,n)=let  
115          fun m([],l)=l          fun m([],l)=l
116              | m(e1::es,l)= (case e1              | m(e1::es,l)= (case e1
117                  of E.V v=>  if (v>=n) then m(es,l)                  of E.V v=> let val a=List.find (fn x => x = e1) l
                     else let val a=List.find (fn x => x = e1) l  
118                      in (case a of NONE=> m(es,l@[e1]) | _=> m(es,l)) end                      in (case a of NONE=> m(es,l@[e1]) | _=> m(es,l)) end
119                  |_ => m(es,l))                  |_ => m(es,l))
120    
121          in m(list1,[])          in m(list1,[])
122          end          end
123        val m=print "A"
124    
125    
126      fun filterIndex(E.V v)= if(v>intialn) then [] else [E.V v]      fun filterIndex(E.V v)= if(v>intialn) then [] else [E.V v]
127    
128    
129      (*find all indices *)      (*find all indices *)
130      fun findOuterIndex(body,n)=(case body      fun findOuterIndex body=(case body
131          of  E.Tensor(id,ix)=> ix          of  E.Tensor(id,ix)=> ix
132              | E.Const _=> []              | E.Const _=> []
133              | E.Add (e1::es)=> findOuterIndex(e1,n)              | E.Add e=> (*findOuterIndex(e1,n)*) let
134              | E.Sub(e1,e2)=> findOuterIndex(e1,n)                  val e'=List.map (fn e1=>findOuterIndex e1) e
135              | E.Div(e1,e2)=> findOuterIndex(e1,n)                  in  uniq(flat e')
136                    end
137                | E.Sub(e1,e2)=> (*findOuterIndex(e1,n)*)let
138                    val e'=List.map (fn e1=>findOuterIndex e1) [e1,e2]
139                    in  uniq(flat e')
140                    end
141                | E.Div(e1,e2)=> (*findOuterIndex(e1,n)*)let
142                    val e'=List.map (fn e1=>findOuterIndex e1) [e1,e2]
143                    in  uniq(flat e')
144                    end
145              | E.Value e1=> [E.V e1]              | E.Value e1=> [E.V e1]
146              | E.Sum(sx,e)=> findOuterIndex(e,n)              | E.Sum(sx,e)=> (print "in summ";findOuterIndex e)
147              | E.Prod e=> let              | E.Prod e=> let
148                  val e'=List.map (fn e1=>findOuterIndex(e1,n)) e                  val e'=List.map (fn e1=>findOuterIndex e1) e
149                  in  uniq((flat e'),n)                  in  uniq(flat e')
150                  end                  end
151              | E.Delta(i,j)=>[i,j]              | E.Delta(i,j)=>[i,j]
152              | E.Epsilon(i,j,k)=>[E.V i, E.V j, E.V k]              | E.Epsilon(i,j,k)=>[E.V i, E.V j, E.V k]
153              | E.Neg e=> findOuterIndex(e,n)              | E.Neg e=> findOuterIndex e
154              | E.Img(v,alpha,pos)=>alpha              | E.Img(v,alpha,pos)=>let
155              | E.Krn(v,dels,pos)=> List.map (fn(e1,e2)=> e2) dels                  val e'=List.map (fn e1=>findOuterIndex e1) pos
156                        in  uniq(flat ([alpha]@e'))
157                        end
158    
159                | E.Krn(v,dels,pos)=> (print "krn";(List.map (fn(e1,e2)=> e2) dels) @ (findOuterIndex pos))
160              | E.Conv(_,alpha,_,dx)=> alpha@dx              | E.Conv(_,alpha,_,dx)=> alpha@dx
161              | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"              | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
162              | _=> []              | _=> []
163              (*end case*))              (*end case*))
164    
165    
166      val m=print "B"
167        val ix=findOuterIndex e
168    
169      val ix=findOuterIndex(e, intialn)    val m=print "C"
170    (*
   
   
171      val g=print(String.concat["\n\n --",P.printbody(e),"length of binding-", Int.toString(length(index))," Outer", Int.toString(length(ix)),"\n"])      val g=print(String.concat["\n\n --",P.printbody(e),"length of binding-", Int.toString(length(index))," Outer", Int.toString(length(ix)),"\n"])
172    
173      fun q(E.V p,E.V n)=print(String.concat["\n", Int.toString(p),"===>>",Int.toString(n),"\n"])      fun q(E.V p,E.V n)=print(String.concat["\n", Int.toString(p),"===>>",Int.toString(n),"\n"])
174    *)
175          val m=print "D"
176        (*Mapps just outer indices *)
177    
178        (*Various ways to mapp summation indices
179            -one way is to subtract adjustments, but that assumes the summaiton indices are in order*)
180    
181      (*Mapps just outer indices *)  (*
182      fun g([],index',_,c,mapp,outer)=(index',c,mapp,outer)      fun g([],index',_,c,mapp,outer)=(index',c,mapp,outer)
183          | g(e::es,index',n,c, mapp,outer)= let          | g(e::es,index',n,c, mapp,outer)= let
184              val b=List.find (fn(E.V v)=>v=n) ix              val b=List.find (fn(E.V v)=>v=n) ix
185                  in case b                  in case b
186                      of NONE=>(g(es,index',n+1, c, mapp,outer))                      of NONE=>(g(es,index',n+1, c, mapp,outer))
187                      |_=> let val mapp'=insert(E.V n, E.V c) mapp                      |_=> let val mapp'=insert(E.V n, E.V c) mapp
188                      val gg=print "Found, and inserting"                      (*val gg=print "Found, and inserting"*)
189                      in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)]) end                      in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)]) end
190                  end                  end
191    
     val (index',c,mapp,outer)=g(index,[],0,0,empty,[])  
192    
193      val mm=print ("FInal count"^Int.toString(c))      val (index',c,mapp,outer)=g(index,[],0,0,empty,[])
194      val adjustment= intialn-c      val adjustment= intialn-c
195        fun addIndextoMapp([],mapp)=mapp
196      fun addIndextoMapp([],n,mapp)=(mapp,n)          | addIndextoMapp((s,_,_)::es,mapp)= let
         | addIndextoMapp((s,_,_)::es,n,mapp)= let  
197              val E.V p=s              val E.V p=s
198              val n'=(p-adjustment)              val n'=(p-adjustment)
             val y=q( s,E.V n')  
199              val m=insert(s, E.V n') mapp              val m=insert(s, E.V n') mapp
200              in addIndextoMapp(es,n+1, m) end              in addIndextoMapp(es, m) end
201            *)
202    
203    
204    
205            (*Get Max*)
206            fun getMax max [] = max
207            | getMax max (E.V n::ns)= getMax  (if n>max then n else max) ns
208            val max=getMax 0 ix
209            val gg=print("MAX:"^Int.toString(max))
210          val m=print "YY"
211    
212            fun g(_,index',_,c,mapp,outer,0)=(index',c,mapp,outer)
213            | g([],index',n,c, mapp,outer,maxx)= let
214                val b=List.find (fn(E.V v)=>v=n) ix
215                in case b
216                    of NONE=>g([],index',n+1, c, mapp,outer,maxx-1)
217                    |_=> let val mapp'=insert(E.V n, E.V c) mapp
218                        in g([], index', n+1, c+1, mapp',outer,maxx-1) end
219                    end
220            | g(e::es,index',n,c, mapp,outer,maxx)= let
221                val b=List.find (fn(E.V v)=>v=n) ix
222                in case b
223                    of NONE=>g(es,index',n+1, c, mapp,outer,maxx-1)
224                    |_=> let val mapp'=insert(E.V n, E.V c) mapp
225                        val y=print("Addign to outerIndex"^Int.toString(n))
226                        in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)],maxx-1) end
227                end
228          val m=print "LL"
229    
230            val mapp'=insert(E.V 0, E.V 0) empty
231              val m=print "UU"
232            val (index',c,mapp,outer)=g(index,[],0,0,mapp',[],max+1)
233              val m=print "OO"
234    
235      fun rewriteIndex(e, smapp) =(case e      fun rewriteIndex(e, smapp) =(case e
236          of E.V v =>let val l=lookup e smapp          of E.V v =>let val l=lookup e smapp
# Line 200  Line 247 
247              | SOME(E.V s)=> s              | SOME(E.V s)=> s
248          end          end
249    
250      fun rewrite (body,n,smapp,embed)=(case body      fun rewrite (body,smapp)=(case body
251          of  E.Tensor(id,ix)=> E.Tensor(id, (List.map (fn e=>rewriteIndex(e, smapp)) ix))          of  E.Tensor(id,ix)=> E.Tensor(id, (List.map (fn e=>rewriteIndex(e, smapp)) ix))
252          | E.Epsilon(i,j,k)=>E.Epsilon(singleIndex(i, smapp),singleIndex(j, smapp),singleIndex(k, smapp))          | E.Epsilon(i,j,k)=>E.Epsilon(singleIndex(i, smapp),singleIndex(j, smapp),singleIndex(k, smapp))
253          | E.Value i=> (E.Value(singleIndex (i,smapp)))          | E.Value i=> (E.Value(singleIndex (i,smapp)))
254          | E.Delta(i,j)=> E.Delta(rewriteIndex(i,smapp), rewriteIndex(j,smapp))          | E.Delta(i,j)=> E.Delta(rewriteIndex(i,smapp), rewriteIndex(j,smapp))
255          | E.Add e=> E.Add(List.map (fn(e1)=>rewrite(e1,n,smapp,embed)) e)          | E.Add e=> E.Add(List.map (fn(e1)=>rewrite(e1,smapp)) e)
256          | E.Sub(e1,e2)=>  E.Sub(rewrite(e1,n,smapp,embed),rewrite(e2,n,smapp,embed))          | E.Sub(e1,e2)=>  E.Sub(rewrite(e1,smapp),rewrite(e2,smapp))
257          | E.Div(e1,e2)=>  E.Div(rewrite(e1,n,smapp,embed),rewrite(e2,n,smapp,embed))          | E.Div(e1,e2)=>  E.Div(rewrite(e1,smapp),rewrite(e2,smapp))
258          | E.Sum(sx,e)=> let          | E.Sum(sx,e)=> (*let
259              val (mm,nn)=addIndextoMapp(sx,n,smapp)              val mm=addIndextoMapp(sx,smapp)
260              val m=E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,mm),lb,ub)) sx),rewrite(e,nn,mm,embed+1))              val m=E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,mm),lb,ub)) sx),rewrite(e,mm))
261              in m end              in m end
262          | E.Prod e=> E.Prod(List.map (fn(e1)=>rewrite(e1,n,smapp,embed)) e)  
263          | E.Neg e=> E.Neg(rewrite(e, n, smapp,embed))          *)(print "pre sum";E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,smapp),lb,ub)) sx),rewrite(e,smapp)))
264            | E.Prod e=> E.Prod(List.map (fn(e1)=>rewrite(e1,smapp)) e)
265            | E.Neg e=> E.Neg(rewrite(e, smapp))
266          |  E.Krn (h,dx, pos)=>          |  E.Krn (h,dx, pos)=>
267              E.Krn(h,(List.map (fn (e1,e2)=>(e1,rewriteIndex(e2, smapp))) dx), rewrite(pos,n,smapp,embed))              E.Krn(h,(List.map (fn (e1,e2)=>(e1,rewriteIndex(e2, smapp))) dx), rewrite(pos,smapp))
268          |  E.Img (v,alpha, pos)=>          |  E.Img (v,alpha, pos)=>
269              E.Img(v,(List.map (fn e=>rewriteIndex(e, smapp)) alpha),              E.Img(v,(List.map (fn e=>rewriteIndex(e, smapp)) alpha),
270                  (List.map (fn e=>rewrite(e, n, smapp,embed)) pos))                  (List.map (fn e=>rewrite(e, smapp)) pos))
271          | E.Conv(v,alpha,h, dx)=> E.Conv(v, (List.map (fn e=>rewriteIndex(e, smapp)) alpha),h,(List.map (fn e=>rewriteIndex(e, smapp)) dx))          | E.Conv(v,alpha,h, dx)=> (print "conv";E.Conv(v, (List.map (fn e=>rewriteIndex(e, smapp)) alpha),h,(List.map (fn e=>rewriteIndex(e, smapp)) dx)))
272          | E.Probe(e1,e2)=>raise Fail "Probe- Should have been expanded"          | E.Probe(e1,e2)=>raise Fail "Probe- Should have been expanded"
273          | _=> body          | _=> body
274          (*end case*))          (*end case*))
275          val m=print "E"
276      val e'=rewrite(e, length ix, mapp,0)      val e'=rewrite(e, mapp)
277          val m=print "F"
278      in      in
279          (outer,index',e')          (outer,index',e')
280      end      end
# Line 233  Line 282 
282    
283  fun clean(params, index, body, args)=let  fun clean(params, index, body, args)=let
284      val (p',body',args')= cleanParams(body, params, args)      val (p',body',args')= cleanParams(body, params, args)
285          val m=print "H"
286      val (_,i',b')=cleanIndex(body', length index, index)      val (_,i',b')=cleanIndex(body', length index, index)
287          val m=print "I"
288        (*val hh=print(String.concat["\n ~~~\n",P.printbody(body),"===>\n",P.printbody(b'),"\n ~~~\n"])*)
289      in  (p',i',b',args')      in  (p',i',b',args')
290      end      end
291    

Legend:
Removed from v.2553  
changed lines
  Added in v.2555

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0