Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2585 - (view) (download)

1 : cchiw 2522 (* Shift Functions cleans up Params, and shifts down indices*)
2 :     structure shiftHtM = struct
3 :     local
4 :     structure E = Ein
5 : cchiw 2525 structure P=Printer
6 : cchiw 2522
7 :     in
8 :    
9 :    
10 :    
11 :     fun insert (key, value) d =fn s =>
12 :     if s = key then SOME value
13 :     else d s
14 :    
15 :     fun lookup k d = d k
16 :     val empty =fn key =>NONE
17 :    
18 :    
19 :    
20 :     fun flat xs = List.foldr op@ [] xs
21 :    
22 :     (*remap the tensor ids*)
23 :     fun cleanParams(body, params,args)=let
24 :     (*First step build a list of occurances*)
25 :     fun build(body,occur)=
26 :     (case body
27 :     of E.Tensor(id,ix)=> insert(id, 1) occur
28 :     | E.Sum(sx, e)=> build(e, occur)
29 :     | E.Neg e=> build(e,occur)
30 :     | E.Add e=> let
31 :    
32 :     fun add([],dict)=dict
33 :     | add(e1::es,dict)=let
34 :     val dict'=build(e1, dict)
35 :     in add(es, dict') end
36 :     in add (e, occur) end
37 :     | E.Sub(e1,e2) =>let
38 :     val dict'=build(e1, occur)
39 :     in build(e2, dict') end
40 :     | E.Div(e1,e2) =>let
41 :     val dict'=build(e1, occur)
42 :     in build(e2, dict') end
43 :     | E.Prod e => let
44 :     fun add([],dict)=dict
45 :     | add(e1::es,dict)=let
46 :     val dict'=build(e1, dict)
47 :     in add(es, dict') end
48 :     in add (e, occur) end
49 :     | E.Img(id,_,pos) => let
50 :     fun add([],dict)=dict
51 :     | add(e1::es,dict)=let
52 :     val dict'=build(e1, dict)
53 :     in add(es, dict') end
54 :     val d=insert(id, 1) occur
55 :     in add (pos, d) end
56 :     | E.Krn(id,_,pos) => let
57 :     val d=insert(id, 1) occur
58 :     in build(pos,d) end
59 :     | E.Conv(id,_,h,_) => let
60 :     val d= insert(id, 1) occur
61 :     in insert(h, 1) d end
62 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
63 :     | _ => occur
64 :     (*end case*))
65 :    
66 :     val occur=build(body, empty)
67 :    
68 :     (*remove params, args that are not used *)
69 :     fun removeP(_,_,newbie,lftp, [], lfta,_)=(newbie, lftp, lfta)
70 :     | removeP(_,_,newbie,lftp, _, lfta,[])=(newbie, lftp, lfta)
71 :     | removeP(pos,sumcount,newbie,lftp, p::pp, lfta,a::aa)=let
72 :     val c=lookup pos occur
73 :     in case c
74 :     of NONE => removeP(pos+1, sumcount,newbie@[0], lftp, pp,lfta,aa)
75 :     | SOME _=> removeP(pos+1,sumcount+1, newbie@[sumcount], lftp@[p], pp,lfta@[a],aa)
76 :     end
77 :    
78 :     val (newbie, params',args')=removeP(0,0,[],[],params,[],args)
79 :    
80 :     (*remap the tensor ids*)
81 :     fun remap body=(case body
82 :     of E.Tensor(id, ix)=> let val g=List.nth( newbie,id) in E.Tensor(g, ix) end
83 :     | E.Neg e=> E.Neg(remap e)
84 :     | E.Sum(sx, e)=> E.Sum(sx, remap e)
85 :     | E.Add e=> E.Add (List.map remap e)
86 :     | E.Prod e=> E.Prod (List.map remap e)
87 :     | E.Sub (e1,e2)=> E.Sub(remap e1, remap e2)
88 :     | E.Div(e1,e2)=> E.Div(remap e1, remap e2)
89 :     | E.Img(id,alpha,pos)=>let val g=List.nth( newbie,id)
90 :     in E.Img(g,alpha,(List.map remap pos)) end
91 :     | E.Krn(id,delta,pos)=>let val g=List.nth( newbie,id)
92 :     in E.Krn(g,delta,(remap pos)) end
93 :     | E.Conv(id,alpha,h,pos)=>let
94 :     val id'=List.nth( newbie,id)
95 :     val h'=List.nth( newbie,h)
96 :     in E.Conv(id',alpha,h',pos) end
97 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
98 :     | _=>body
99 :     (*end case*))
100 :    
101 :     val body'=remap body
102 :    
103 :     in (params',body',args')
104 :     end
105 :    
106 :    
107 :    
108 :     (*Remaps all the indices*)
109 :    
110 :     fun cleanIndex(e, intialn,index)=let
111 :     (*Each element in the list is unique*)
112 : cchiw 2584 (* val h=print "IN SHIFT"*)
113 : cchiw 2525
114 : cchiw 2555 fun uniq list1 =let
115 : cchiw 2522 fun m([],l)=l
116 :     | m(e1::es,l)= (case e1
117 : cchiw 2555 of E.V v=> let val a=List.find (fn x => x = e1) l
118 : cchiw 2522 in (case a of NONE=> m(es,l@[e1]) | _=> m(es,l)) end
119 :     |_ => m(es,l))
120 :    
121 :     in m(list1,[])
122 :     end
123 :    
124 : cchiw 2553 fun filterIndex(E.V v)= if(v>intialn) then [] else [E.V v]
125 : cchiw 2522
126 :    
127 : cchiw 2553 (*find all indices *)
128 : cchiw 2555 fun findOuterIndex body=(case body
129 : cchiw 2522 of E.Tensor(id,ix)=> ix
130 :     | E.Const _=> []
131 : cchiw 2555 | E.Add e=> (*findOuterIndex(e1,n)*) let
132 :     val e'=List.map (fn e1=>findOuterIndex e1) e
133 :     in uniq(flat e')
134 :     end
135 :     | E.Sub(e1,e2)=> (*findOuterIndex(e1,n)*)let
136 :     val e'=List.map (fn e1=>findOuterIndex e1) [e1,e2]
137 :     in uniq(flat e')
138 :     end
139 :     | E.Div(e1,e2)=> (*findOuterIndex(e1,n)*)let
140 :     val e'=List.map (fn e1=>findOuterIndex e1) [e1,e2]
141 :     in uniq(flat e')
142 :     end
143 : cchiw 2553 | E.Value e1=> [E.V e1]
144 : cchiw 2584 | E.Sum(sx,e)=> (findOuterIndex e)
145 : cchiw 2522 | E.Prod e=> let
146 : cchiw 2555 val e'=List.map (fn e1=>findOuterIndex e1) e
147 :     in uniq(flat e')
148 : cchiw 2522 end
149 :     | E.Delta(i,j)=>[i,j]
150 :     | E.Epsilon(i,j,k)=>[E.V i, E.V j, E.V k]
151 : cchiw 2555 | E.Neg e=> findOuterIndex e
152 :     | E.Img(v,alpha,pos)=>let
153 :     val e'=List.map (fn e1=>findOuterIndex e1) pos
154 :     in uniq(flat ([alpha]@e'))
155 :     end
156 :    
157 : cchiw 2584 | E.Krn(v,dels,pos)=> ((List.map (fn(e1,e2)=> e2) dels) @ (findOuterIndex pos))
158 : cchiw 2522 | E.Conv(_,alpha,_,dx)=> alpha@dx
159 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
160 :     | _=> []
161 :     (*end case*))
162 :    
163 :    
164 : cchiw 2576
165 : cchiw 2555 val ix=findOuterIndex e
166 : cchiw 2522
167 : cchiw 2576
168 : cchiw 2555 (*
169 : cchiw 2553 val g=print(String.concat["\n\n --",P.printbody(e),"length of binding-", Int.toString(length(index))," Outer", Int.toString(length(ix)),"\n"])
170 :    
171 :     fun q(E.V p,E.V n)=print(String.concat["\n", Int.toString(p),"===>>",Int.toString(n),"\n"])
172 : cchiw 2555 *)
173 : cchiw 2576
174 : cchiw 2555 (*Mapps just outer indices *)
175 : cchiw 2553
176 : cchiw 2555 (*Various ways to mapp summation indices
177 :     -one way is to subtract adjustments, but that assumes the summaiton indices are in order*)
178 : cchiw 2553
179 : cchiw 2555 (*
180 : cchiw 2553 fun g([],index',_,c,mapp,outer)=(index',c,mapp,outer)
181 :     | g(e::es,index',n,c, mapp,outer)= let
182 : cchiw 2522 val b=List.find (fn(E.V v)=>v=n) ix
183 :     in case b
184 : cchiw 2553 of NONE=>(g(es,index',n+1, c, mapp,outer))
185 : cchiw 2522 |_=> let val mapp'=insert(E.V n, E.V c) mapp
186 : cchiw 2555 (*val gg=print "Found, and inserting"*)
187 : cchiw 2553 in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)]) end
188 : cchiw 2555 end
189 : cchiw 2522
190 : cchiw 2555
191 : cchiw 2553 val (index',c,mapp,outer)=g(index,[],0,0,empty,[])
192 : cchiw 2555 val adjustment= intialn-c
193 :     fun addIndextoMapp([],mapp)=mapp
194 :     | addIndextoMapp((s,_,_)::es,mapp)= let
195 : cchiw 2525 val E.V p=s
196 : cchiw 2553 val n'=(p-adjustment)
197 :     val m=insert(s, E.V n') mapp
198 : cchiw 2555 in addIndextoMapp(es, m) end
199 :     *)
200 : cchiw 2522
201 : cchiw 2584 (*Problem Sigma_{ij} T_j..=> Sigma_{ii} T_i, first i is not removed from summation expression*)
202 : cchiw 2522
203 : cchiw 2555 (*Get Max*)
204 :     fun getMax max [] = max
205 :     | getMax max (E.V n::ns)= getMax (if n>max then n else max) ns
206 :     val max=getMax 0 ix
207 :    
208 : cchiw 2576
209 : cchiw 2555 fun g(_,index',_,c,mapp,outer,0)=(index',c,mapp,outer)
210 :     | g([],index',n,c, mapp,outer,maxx)= let
211 :     val b=List.find (fn(E.V v)=>v=n) ix
212 :     in case b
213 :     of NONE=>g([],index',n+1, c, mapp,outer,maxx-1)
214 :     |_=> let val mapp'=insert(E.V n, E.V c) mapp
215 :     in g([], index', n+1, c+1, mapp',outer,maxx-1) end
216 :     end
217 :     | g(e::es,index',n,c, mapp,outer,maxx)= let
218 :     val b=List.find (fn(E.V v)=>v=n) ix
219 :     in case b
220 :     of NONE=>g(es,index',n+1, c, mapp,outer,maxx-1)
221 :     |_=> let val mapp'=insert(E.V n, E.V c) mapp
222 : cchiw 2576
223 : cchiw 2555 in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)],maxx-1) end
224 :     end
225 : cchiw 2576
226 : cchiw 2555
227 :     val mapp'=insert(E.V 0, E.V 0) empty
228 : cchiw 2576
229 : cchiw 2555 val (index',c,mapp,outer)=g(index,[],0,0,mapp',[],max+1)
230 : cchiw 2576
231 : cchiw 2555
232 : cchiw 2522 fun rewriteIndex(e, smapp) =(case e
233 :     of E.V v =>let val l=lookup e smapp
234 : cchiw 2525 in case l of NONE=> raise Fail("error Could not find :"^Int.toString(v))
235 : cchiw 2522 | SOME s=> s end
236 :     | E.C _=> e
237 :     (*end case*))
238 :    
239 : cchiw 2585
240 :     fun rewriteSumIndex(es,smapp) =let
241 :     fun rewrite []=[]
242 :     |rewrite((e, lb,ub)::e2)=
243 :     (case e
244 :     of E.V v =>let
245 :     val l=lookup e smapp
246 :     in case l
247 :     of NONE=> []@rewrite(e2)
248 :     | SOME s=> [(s, lb, ub)]@rewrite(e2)
249 :     end
250 :     | E.C _=> [(e, lb, ub)]@rewrite(e2)
251 :     (*end case*))
252 :     in rewrite es
253 :     end
254 :    
255 : cchiw 2522 fun singleIndex(e,smapp)=let
256 :     val l=lookup (E.V e) smapp
257 : cchiw 2525 (* val g=print(String.concat["\n SingleIndex:", Int.toString(e)])*)
258 : cchiw 2522 in case l
259 :     of NONE=> (raise Fail" error could not find index" )
260 :     | SOME(E.V s)=> s
261 : cchiw 2585 end
262 : cchiw 2522
263 : cchiw 2585 (*Add in algebraic clean up *)
264 :     fun rewrite (body,smapp)=let
265 :    
266 :     in (case body
267 : cchiw 2522 of E.Tensor(id,ix)=> E.Tensor(id, (List.map (fn e=>rewriteIndex(e, smapp)) ix))
268 :     | E.Epsilon(i,j,k)=>E.Epsilon(singleIndex(i, smapp),singleIndex(j, smapp),singleIndex(k, smapp))
269 : cchiw 2525 | E.Value i=> (E.Value(singleIndex (i,smapp)))
270 : cchiw 2522 | E.Delta(i,j)=> E.Delta(rewriteIndex(i,smapp), rewriteIndex(j,smapp))
271 : cchiw 2585
272 :     | E.Add (E.Const 0::es)=> (print "ADD XXX zero"; rewrite(E.Add(es),smapp))
273 :     (*filters through zeros*)
274 :     (* | E.Add e=> E.Add( List.map (fn e1=>rewrite(e1,smapp)) e)
275 :     *)
276 :    
277 :    
278 :     | E.Add e=> let
279 :    
280 :     fun mkadd([],[e2])=e2
281 :     | mkadd([e1],[])= rewrite(e1, smapp)
282 :     | mkadd([],e2)= E.Add (e2)
283 :     | mkadd(E.Const 0::es,e2)=mkadd(es,e2)
284 :     | mkadd(e1::es,e2)= let
285 :     val a=rewrite(e1, smapp)
286 :     in (case a
287 :     of E.Const 0 => mkadd(es,e2)
288 :     | _ =>mkadd(es,e2@[a])
289 :     (* end case*))
290 :     end
291 :     in
292 :     mkadd (e,[])
293 :     end
294 :    
295 :     | E.Sub(E.Const 0, E.Const 0)=> E.Const 0
296 :     | E.Sub(E.Const 0, e2)=>E.Neg(rewrite(e2,smapp))
297 :     | E.Sub(e1, E.Const 0)=> rewrite(e1,smapp)
298 : cchiw 2555 | E.Sub(e1,e2)=> E.Sub(rewrite(e1,smapp),rewrite(e2,smapp))
299 :     | E.Div(e1,e2)=> E.Div(rewrite(e1,smapp),rewrite(e2,smapp))
300 : cchiw 2585 | E.Sum(_,E.Const c)=>E.Const c
301 :     | E.Sum([],e)=> rewrite(e, smapp)
302 :     | E.Sum(sx,e)=>let
303 :    
304 :     val sx'=rewriteSumIndex(sx,smapp)
305 :     in (case sx'
306 :     of []=> rewrite(e,smapp)
307 :     | _=> E.Sum( rewriteSumIndex(sx,smapp),rewrite(e,smapp ))
308 :     (*end case*))
309 :     end
310 :     | E.Prod (E.Const 0::es)=> (print "ProductXXX zero";E.Const 0)
311 : cchiw 2555 | E.Prod e=> E.Prod(List.map (fn(e1)=>rewrite(e1,smapp)) e)
312 : cchiw 2585 | E.Neg(E.Const 0)=> E.Const 0
313 : cchiw 2555 | E.Neg e=> E.Neg(rewrite(e, smapp))
314 : cchiw 2522 | E.Krn (h,dx, pos)=>
315 : cchiw 2555 E.Krn(h,(List.map (fn (e1,e2)=>(e1,rewriteIndex(e2, smapp))) dx), rewrite(pos,smapp))
316 : cchiw 2522 | E.Img (v,alpha, pos)=>
317 :     E.Img(v,(List.map (fn e=>rewriteIndex(e, smapp)) alpha),
318 : cchiw 2555 (List.map (fn e=>rewrite(e, smapp)) pos))
319 : cchiw 2584 | E.Conv(v,alpha,h, dx)=> (E.Conv(v, (List.map (fn e=>rewriteIndex(e, smapp)) alpha),h,(List.map (fn e=>rewriteIndex(e, smapp)) dx)))
320 : cchiw 2522 | E.Probe(e1,e2)=>raise Fail "Probe- Should have been expanded"
321 : cchiw 2585 | _=> (print "other";body)
322 : cchiw 2522 (*end case*))
323 : cchiw 2585 end
324 : cchiw 2576
325 : cchiw 2555 val e'=rewrite(e, mapp)
326 : cchiw 2576
327 : cchiw 2522 in
328 : cchiw 2553 (outer,index',e')
329 : cchiw 2522 end
330 :    
331 :    
332 :     fun clean(params, index, body, args)=let
333 :     val (p',body',args')= cleanParams(body, params, args)
334 :     val (_,i',b')=cleanIndex(body', length index, index)
335 : cchiw 2576
336 : cchiw 2555 (*val hh=print(String.concat["\n ~~~\n",P.printbody(body),"===>\n",P.printbody(b'),"\n ~~~\n"])*)
337 : cchiw 2522 in (p',i',b',args')
338 :     end
339 :    
340 :     end (* local *)
341 :    
342 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0