Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2584 - (view) (download)

1 : cchiw 2522 (* Shift Functions cleans up Params, and shifts down indices*)
2 :     structure shiftHtM = struct
3 :     local
4 :     structure E = Ein
5 : cchiw 2525 structure P=Printer
6 : cchiw 2522
7 :     in
8 :    
9 :    
10 :    
11 :     fun insert (key, value) d =fn s =>
12 :     if s = key then SOME value
13 :     else d s
14 :    
15 :     fun lookup k d = d k
16 :     val empty =fn key =>NONE
17 :    
18 :    
19 :    
20 :     fun flat xs = List.foldr op@ [] xs
21 :    
22 :     (*remap the tensor ids*)
23 :     fun cleanParams(body, params,args)=let
24 :     (*First step build a list of occurances*)
25 :     fun build(body,occur)=
26 :     (case body
27 :     of E.Tensor(id,ix)=> insert(id, 1) occur
28 :     | E.Sum(sx, e)=> build(e, occur)
29 :     | E.Neg e=> build(e,occur)
30 :     | E.Add e=> let
31 :    
32 :     fun add([],dict)=dict
33 :     | add(e1::es,dict)=let
34 :     val dict'=build(e1, dict)
35 :     in add(es, dict') end
36 :     in add (e, occur) end
37 :     | E.Sub(e1,e2) =>let
38 :     val dict'=build(e1, occur)
39 :     in build(e2, dict') end
40 :     | E.Div(e1,e2) =>let
41 :     val dict'=build(e1, occur)
42 :     in build(e2, dict') end
43 :     | E.Prod e => let
44 :     fun add([],dict)=dict
45 :     | add(e1::es,dict)=let
46 :     val dict'=build(e1, dict)
47 :     in add(es, dict') end
48 :     in add (e, occur) end
49 :     | E.Img(id,_,pos) => let
50 :     fun add([],dict)=dict
51 :     | add(e1::es,dict)=let
52 :     val dict'=build(e1, dict)
53 :     in add(es, dict') end
54 :     val d=insert(id, 1) occur
55 :     in add (pos, d) end
56 :     | E.Krn(id,_,pos) => let
57 :     val d=insert(id, 1) occur
58 :     in build(pos,d) end
59 :     | E.Conv(id,_,h,_) => let
60 :     val d= insert(id, 1) occur
61 :     in insert(h, 1) d end
62 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
63 :     | _ => occur
64 :     (*end case*))
65 :    
66 :     val occur=build(body, empty)
67 :    
68 :     (*remove params, args that are not used *)
69 :     fun removeP(_,_,newbie,lftp, [], lfta,_)=(newbie, lftp, lfta)
70 :     | removeP(_,_,newbie,lftp, _, lfta,[])=(newbie, lftp, lfta)
71 :     | removeP(pos,sumcount,newbie,lftp, p::pp, lfta,a::aa)=let
72 :     val c=lookup pos occur
73 :     in case c
74 :     of NONE => removeP(pos+1, sumcount,newbie@[0], lftp, pp,lfta,aa)
75 :     | SOME _=> removeP(pos+1,sumcount+1, newbie@[sumcount], lftp@[p], pp,lfta@[a],aa)
76 :     end
77 :    
78 :     val (newbie, params',args')=removeP(0,0,[],[],params,[],args)
79 :    
80 :     (*remap the tensor ids*)
81 :     fun remap body=(case body
82 :     of E.Tensor(id, ix)=> let val g=List.nth( newbie,id) in E.Tensor(g, ix) end
83 :     | E.Neg e=> E.Neg(remap e)
84 :     | E.Sum(sx, e)=> E.Sum(sx, remap e)
85 :     | E.Add e=> E.Add (List.map remap e)
86 :     | E.Prod e=> E.Prod (List.map remap e)
87 :     | E.Sub (e1,e2)=> E.Sub(remap e1, remap e2)
88 :     | E.Div(e1,e2)=> E.Div(remap e1, remap e2)
89 :     | E.Img(id,alpha,pos)=>let val g=List.nth( newbie,id)
90 :     in E.Img(g,alpha,(List.map remap pos)) end
91 :     | E.Krn(id,delta,pos)=>let val g=List.nth( newbie,id)
92 :     in E.Krn(g,delta,(remap pos)) end
93 :     | E.Conv(id,alpha,h,pos)=>let
94 :     val id'=List.nth( newbie,id)
95 :     val h'=List.nth( newbie,h)
96 :     in E.Conv(id',alpha,h',pos) end
97 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
98 :     | _=>body
99 :     (*end case*))
100 :    
101 :     val body'=remap body
102 :    
103 :     in (params',body',args')
104 :     end
105 :    
106 :    
107 :    
108 :     (*Remaps all the indices*)
109 :    
110 :     fun cleanIndex(e, intialn,index)=let
111 :     (*Each element in the list is unique*)
112 : cchiw 2584 (* val h=print "IN SHIFT"*)
113 : cchiw 2525
114 : cchiw 2555 fun uniq list1 =let
115 : cchiw 2522 fun m([],l)=l
116 :     | m(e1::es,l)= (case e1
117 : cchiw 2555 of E.V v=> let val a=List.find (fn x => x = e1) l
118 : cchiw 2522 in (case a of NONE=> m(es,l@[e1]) | _=> m(es,l)) end
119 :     |_ => m(es,l))
120 :    
121 :     in m(list1,[])
122 :     end
123 :    
124 : cchiw 2553 fun filterIndex(E.V v)= if(v>intialn) then [] else [E.V v]
125 : cchiw 2522
126 :    
127 : cchiw 2553 (*find all indices *)
128 : cchiw 2555 fun findOuterIndex body=(case body
129 : cchiw 2522 of E.Tensor(id,ix)=> ix
130 :     | E.Const _=> []
131 : cchiw 2555 | E.Add e=> (*findOuterIndex(e1,n)*) let
132 :     val e'=List.map (fn e1=>findOuterIndex e1) e
133 :     in uniq(flat e')
134 :     end
135 :     | E.Sub(e1,e2)=> (*findOuterIndex(e1,n)*)let
136 :     val e'=List.map (fn e1=>findOuterIndex e1) [e1,e2]
137 :     in uniq(flat e')
138 :     end
139 :     | E.Div(e1,e2)=> (*findOuterIndex(e1,n)*)let
140 :     val e'=List.map (fn e1=>findOuterIndex e1) [e1,e2]
141 :     in uniq(flat e')
142 :     end
143 : cchiw 2553 | E.Value e1=> [E.V e1]
144 : cchiw 2584 | E.Sum(sx,e)=> (findOuterIndex e)
145 : cchiw 2522 | E.Prod e=> let
146 : cchiw 2555 val e'=List.map (fn e1=>findOuterIndex e1) e
147 :     in uniq(flat e')
148 : cchiw 2522 end
149 :     | E.Delta(i,j)=>[i,j]
150 :     | E.Epsilon(i,j,k)=>[E.V i, E.V j, E.V k]
151 : cchiw 2555 | E.Neg e=> findOuterIndex e
152 :     | E.Img(v,alpha,pos)=>let
153 :     val e'=List.map (fn e1=>findOuterIndex e1) pos
154 :     in uniq(flat ([alpha]@e'))
155 :     end
156 :    
157 : cchiw 2584 | E.Krn(v,dels,pos)=> ((List.map (fn(e1,e2)=> e2) dels) @ (findOuterIndex pos))
158 : cchiw 2522 | E.Conv(_,alpha,_,dx)=> alpha@dx
159 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
160 :     | _=> []
161 :     (*end case*))
162 :    
163 :    
164 : cchiw 2576
165 : cchiw 2555 val ix=findOuterIndex e
166 : cchiw 2522
167 : cchiw 2576
168 : cchiw 2555 (*
169 : cchiw 2553 val g=print(String.concat["\n\n --",P.printbody(e),"length of binding-", Int.toString(length(index))," Outer", Int.toString(length(ix)),"\n"])
170 :    
171 :     fun q(E.V p,E.V n)=print(String.concat["\n", Int.toString(p),"===>>",Int.toString(n),"\n"])
172 : cchiw 2555 *)
173 : cchiw 2576
174 : cchiw 2555 (*Mapps just outer indices *)
175 : cchiw 2553
176 : cchiw 2555 (*Various ways to mapp summation indices
177 :     -one way is to subtract adjustments, but that assumes the summaiton indices are in order*)
178 : cchiw 2553
179 : cchiw 2555 (*
180 : cchiw 2553 fun g([],index',_,c,mapp,outer)=(index',c,mapp,outer)
181 :     | g(e::es,index',n,c, mapp,outer)= let
182 : cchiw 2522 val b=List.find (fn(E.V v)=>v=n) ix
183 :     in case b
184 : cchiw 2553 of NONE=>(g(es,index',n+1, c, mapp,outer))
185 : cchiw 2522 |_=> let val mapp'=insert(E.V n, E.V c) mapp
186 : cchiw 2555 (*val gg=print "Found, and inserting"*)
187 : cchiw 2553 in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)]) end
188 : cchiw 2555 end
189 : cchiw 2522
190 : cchiw 2555
191 : cchiw 2553 val (index',c,mapp,outer)=g(index,[],0,0,empty,[])
192 : cchiw 2555 val adjustment= intialn-c
193 :     fun addIndextoMapp([],mapp)=mapp
194 :     | addIndextoMapp((s,_,_)::es,mapp)= let
195 : cchiw 2525 val E.V p=s
196 : cchiw 2553 val n'=(p-adjustment)
197 :     val m=insert(s, E.V n') mapp
198 : cchiw 2555 in addIndextoMapp(es, m) end
199 :     *)
200 : cchiw 2522
201 : cchiw 2584 (*Problem Sigma_{ij} T_j..=> Sigma_{ii} T_i, first i is not removed from summation expression*)
202 : cchiw 2522
203 : cchiw 2555 (*Get Max*)
204 :     fun getMax max [] = max
205 :     | getMax max (E.V n::ns)= getMax (if n>max then n else max) ns
206 :     val max=getMax 0 ix
207 :    
208 : cchiw 2576
209 : cchiw 2555 fun g(_,index',_,c,mapp,outer,0)=(index',c,mapp,outer)
210 :     | g([],index',n,c, mapp,outer,maxx)= let
211 :     val b=List.find (fn(E.V v)=>v=n) ix
212 :     in case b
213 :     of NONE=>g([],index',n+1, c, mapp,outer,maxx-1)
214 :     |_=> let val mapp'=insert(E.V n, E.V c) mapp
215 :     in g([], index', n+1, c+1, mapp',outer,maxx-1) end
216 :     end
217 :     | g(e::es,index',n,c, mapp,outer,maxx)= let
218 :     val b=List.find (fn(E.V v)=>v=n) ix
219 :     in case b
220 :     of NONE=>g(es,index',n+1, c, mapp,outer,maxx-1)
221 :     |_=> let val mapp'=insert(E.V n, E.V c) mapp
222 : cchiw 2576
223 : cchiw 2555 in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)],maxx-1) end
224 :     end
225 : cchiw 2576
226 : cchiw 2555
227 :     val mapp'=insert(E.V 0, E.V 0) empty
228 : cchiw 2576
229 : cchiw 2555 val (index',c,mapp,outer)=g(index,[],0,0,mapp',[],max+1)
230 : cchiw 2576
231 : cchiw 2555
232 : cchiw 2522 fun rewriteIndex(e, smapp) =(case e
233 :     of E.V v =>let val l=lookup e smapp
234 : cchiw 2525 in case l of NONE=> raise Fail("error Could not find :"^Int.toString(v))
235 : cchiw 2522 | SOME s=> s end
236 :     | E.C _=> e
237 :     (*end case*))
238 :    
239 :     fun singleIndex(e,smapp)=let
240 :     val l=lookup (E.V e) smapp
241 : cchiw 2525 (* val g=print(String.concat["\n SingleIndex:", Int.toString(e)])*)
242 : cchiw 2522 in case l
243 :     of NONE=> (raise Fail" error could not find index" )
244 :     | SOME(E.V s)=> s
245 :     end
246 :    
247 : cchiw 2555 fun rewrite (body,smapp)=(case body
248 : cchiw 2522 of E.Tensor(id,ix)=> E.Tensor(id, (List.map (fn e=>rewriteIndex(e, smapp)) ix))
249 :     | E.Epsilon(i,j,k)=>E.Epsilon(singleIndex(i, smapp),singleIndex(j, smapp),singleIndex(k, smapp))
250 : cchiw 2525 | E.Value i=> (E.Value(singleIndex (i,smapp)))
251 : cchiw 2522 | E.Delta(i,j)=> E.Delta(rewriteIndex(i,smapp), rewriteIndex(j,smapp))
252 : cchiw 2555 | E.Add e=> E.Add(List.map (fn(e1)=>rewrite(e1,smapp)) e)
253 :     | E.Sub(e1,e2)=> E.Sub(rewrite(e1,smapp),rewrite(e2,smapp))
254 :     | E.Div(e1,e2)=> E.Div(rewrite(e1,smapp),rewrite(e2,smapp))
255 :     | E.Sum(sx,e)=> (*let
256 :     val mm=addIndextoMapp(sx,smapp)
257 :     val m=E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,mm),lb,ub)) sx),rewrite(e,mm))
258 : cchiw 2525 in m end
259 : cchiw 2555
260 : cchiw 2584 *)(E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,smapp),lb,ub)) sx),rewrite(e,smapp)))
261 : cchiw 2555 | E.Prod e=> E.Prod(List.map (fn(e1)=>rewrite(e1,smapp)) e)
262 :     | E.Neg e=> E.Neg(rewrite(e, smapp))
263 : cchiw 2522 | E.Krn (h,dx, pos)=>
264 : cchiw 2555 E.Krn(h,(List.map (fn (e1,e2)=>(e1,rewriteIndex(e2, smapp))) dx), rewrite(pos,smapp))
265 : cchiw 2522 | E.Img (v,alpha, pos)=>
266 :     E.Img(v,(List.map (fn e=>rewriteIndex(e, smapp)) alpha),
267 : cchiw 2555 (List.map (fn e=>rewrite(e, smapp)) pos))
268 : cchiw 2584 | E.Conv(v,alpha,h, dx)=> (E.Conv(v, (List.map (fn e=>rewriteIndex(e, smapp)) alpha),h,(List.map (fn e=>rewriteIndex(e, smapp)) dx)))
269 : cchiw 2522 | E.Probe(e1,e2)=>raise Fail "Probe- Should have been expanded"
270 :     | _=> body
271 :     (*end case*))
272 : cchiw 2576
273 : cchiw 2555 val e'=rewrite(e, mapp)
274 : cchiw 2576
275 : cchiw 2522 in
276 : cchiw 2553 (outer,index',e')
277 : cchiw 2522 end
278 :    
279 :    
280 :     fun clean(params, index, body, args)=let
281 :     val (p',body',args')= cleanParams(body, params, args)
282 :     val (_,i',b')=cleanIndex(body', length index, index)
283 : cchiw 2576
284 : cchiw 2555 (*val hh=print(String.concat["\n ~~~\n",P.printbody(body),"===>\n",P.printbody(b'),"\n ~~~\n"])*)
285 : cchiw 2522 in (p',i',b',args')
286 :     end
287 :    
288 :     end (* local *)
289 :    
290 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0