Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2525 - (view) (download)

1 : cchiw 2522 (* Shift Functions cleans up Params, and shifts down indices*)
2 :     structure shiftHtM = struct
3 :     local
4 :     structure E = Ein
5 : cchiw 2525 structure P=Printer
6 : cchiw 2522
7 :     in
8 :    
9 :    
10 :    
11 :     fun insert (key, value) d =fn s =>
12 :     if s = key then SOME value
13 :     else d s
14 :    
15 :     fun lookup k d = d k
16 :     val empty =fn key =>NONE
17 :    
18 :    
19 :    
20 :     fun flat xs = List.foldr op@ [] xs
21 :    
22 :     (*remap the tensor ids*)
23 :     fun cleanParams(body, params,args)=let
24 :     (*First step build a list of occurances*)
25 :     fun build(body,occur)=
26 :     (case body
27 :     of E.Tensor(id,ix)=> insert(id, 1) occur
28 :     | E.Sum(sx, e)=> build(e, occur)
29 :     | E.Neg e=> build(e,occur)
30 :     | E.Add e=> let
31 :    
32 :     fun add([],dict)=dict
33 :     | add(e1::es,dict)=let
34 :     val dict'=build(e1, dict)
35 :     in add(es, dict') end
36 :     in add (e, occur) end
37 :     | E.Sub(e1,e2) =>let
38 :     val dict'=build(e1, occur)
39 :     in build(e2, dict') end
40 :     | E.Div(e1,e2) =>let
41 :     val dict'=build(e1, occur)
42 :     in build(e2, dict') end
43 :     | E.Prod e => let
44 :     fun add([],dict)=dict
45 :     | add(e1::es,dict)=let
46 :     val dict'=build(e1, dict)
47 :     in add(es, dict') end
48 :     in add (e, occur) end
49 :     | E.Img(id,_,pos) => let
50 :     fun add([],dict)=dict
51 :     | add(e1::es,dict)=let
52 :     val dict'=build(e1, dict)
53 :     in add(es, dict') end
54 :     val d=insert(id, 1) occur
55 :     in add (pos, d) end
56 :     | E.Krn(id,_,pos) => let
57 :     val d=insert(id, 1) occur
58 :     in build(pos,d) end
59 :     | E.Conv(id,_,h,_) => let
60 :     val d= insert(id, 1) occur
61 :     in insert(h, 1) d end
62 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
63 :     | _ => occur
64 :     (*end case*))
65 :    
66 :     val occur=build(body, empty)
67 :    
68 :     (*remove params, args that are not used *)
69 :     fun removeP(_,_,newbie,lftp, [], lfta,_)=(newbie, lftp, lfta)
70 :     | removeP(_,_,newbie,lftp, _, lfta,[])=(newbie, lftp, lfta)
71 :     | removeP(pos,sumcount,newbie,lftp, p::pp, lfta,a::aa)=let
72 :     val c=lookup pos occur
73 :     in case c
74 :     of NONE => removeP(pos+1, sumcount,newbie@[0], lftp, pp,lfta,aa)
75 :     | SOME _=> removeP(pos+1,sumcount+1, newbie@[sumcount], lftp@[p], pp,lfta@[a],aa)
76 :     end
77 :    
78 :     val (newbie, params',args')=removeP(0,0,[],[],params,[],args)
79 :    
80 :     (*remap the tensor ids*)
81 :     fun remap body=(case body
82 :     of E.Tensor(id, ix)=> let val g=List.nth( newbie,id) in E.Tensor(g, ix) end
83 :     | E.Neg e=> E.Neg(remap e)
84 :     | E.Sum(sx, e)=> E.Sum(sx, remap e)
85 :     | E.Add e=> E.Add (List.map remap e)
86 :     | E.Prod e=> E.Prod (List.map remap e)
87 :     | E.Sub (e1,e2)=> E.Sub(remap e1, remap e2)
88 :     | E.Div(e1,e2)=> E.Div(remap e1, remap e2)
89 :     | E.Img(id,alpha,pos)=>let val g=List.nth( newbie,id)
90 :     in E.Img(g,alpha,(List.map remap pos)) end
91 :     | E.Krn(id,delta,pos)=>let val g=List.nth( newbie,id)
92 :     in E.Krn(g,delta,(remap pos)) end
93 :     | E.Conv(id,alpha,h,pos)=>let
94 :     val id'=List.nth( newbie,id)
95 :     val h'=List.nth( newbie,h)
96 :     in E.Conv(id',alpha,h',pos) end
97 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
98 :     | _=>body
99 :     (*end case*))
100 :    
101 :     val body'=remap body
102 :    
103 :     in (params',body',args')
104 :     end
105 :    
106 :    
107 :    
108 :     (*Remaps all the indices*)
109 :    
110 :     fun cleanIndex(e, intialn,index)=let
111 :     (*Each element in the list is unique*)
112 : cchiw 2525
113 :    
114 : cchiw 2522 fun uniq(list1,n)=let
115 :     fun m([],l)=l
116 :     | m(e1::es,l)= (case e1
117 :     of E.V v=> if (v>=n) then m(es,l)
118 :     else let val a=List.find (fn x => x = e1) l
119 :     in (case a of NONE=> m(es,l@[e1]) | _=> m(es,l)) end
120 :     |_ => m(es,l))
121 :    
122 :     in m(list1,[])
123 :     end
124 :    
125 :    
126 :    
127 :    
128 :     (*find outer index for expression, removes summation indices *)
129 :     fun findOuterIndex(body,n)=(case body
130 :     of E.Tensor(id,ix)=> ix
131 :     | E.Const _=> []
132 :     | E.Add (e1::es)=> findOuterIndex(e1,n)
133 :     | E.Sub(e1,e2)=> findOuterIndex(e1,n)
134 :     | E.Div(e1,e2)=> findOuterIndex(e1,n)
135 :     | E.Value e1=> [E.V e1]
136 :     | E.Sum(sx,e)=> findOuterIndex(e,n)
137 :     | E.Prod e=> let
138 :     val e'=List.map (fn e1=>findOuterIndex(e1,n)) e
139 :     in uniq((flat e'),n)
140 :     end
141 :     | E.Delta(i,j)=>[i,j]
142 :     | E.Epsilon(i,j,k)=>[E.V i, E.V j, E.V k]
143 :     | E.Neg e=> findOuterIndex(e,n)
144 :     | E.Img(v,alpha,pos)=>alpha
145 :     | E.Krn(v,dels,pos)=> List.map (fn(e1,e2)=> e2) dels
146 :     | E.Conv(_,alpha,_,dx)=> alpha@dx
147 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
148 :     | _=> []
149 :     (*end case*))
150 :    
151 :    
152 :    
153 :     val ix=findOuterIndex(e, intialn)
154 :    
155 : cchiw 2525 fun q(E.V p,E.V n)=""(*"print(String.concat["\n", Int.toString(p),"===>>",Int.toString(n),"\n"])*)
156 :    
157 : cchiw 2522 fun g([],index',_,_,mapp)=(index',mapp)
158 :     | g(e::es,index',n,c, mapp)= let
159 :     val b=List.find (fn(E.V v)=>v=n) ix
160 :     in case b
161 :     of NONE=>g(es,index',n+1, c, mapp)
162 :     |_=> let val mapp'=insert(E.V n, E.V c) mapp
163 : cchiw 2525 val y=q(E.V n,E.V c)
164 : cchiw 2522 in g(es, index'@[e], n+1, c+1, mapp') end
165 :     end
166 :    
167 :     val (index',mapp)=g(index,[],0,0,empty)
168 :    
169 :    
170 : cchiw 2525
171 :    
172 : cchiw 2522 fun createMapp([],n,mapp)=(mapp,n)
173 :     | createMapp((s,_,_)::es,n,mapp)= let
174 : cchiw 2525 val E.V p=s
175 :     (* val qq=print ("\n Inserting")*)
176 :     val y=q(s,E.V n)
177 : cchiw 2522 val m=insert(s, E.V n) mapp (*check here*)
178 :     in createMapp(es,n+1, m) end
179 :    
180 :    
181 :    
182 :     fun rewriteIndex(e, smapp) =(case e
183 :     of E.V v =>let val l=lookup e smapp
184 : cchiw 2525 in case l of NONE=> raise Fail("error Could not find :"^Int.toString(v))
185 : cchiw 2522 | SOME s=> s end
186 :     | E.C _=> e
187 :     (*end case*))
188 :    
189 :     fun singleIndex(e,smapp)=let
190 :     val l=lookup (E.V e) smapp
191 : cchiw 2525 (* val g=print(String.concat["\n SingleIndex:", Int.toString(e)])*)
192 : cchiw 2522 in case l
193 :     of NONE=> (raise Fail" error could not find index" )
194 :     | SOME(E.V s)=> s
195 :     end
196 :    
197 : cchiw 2525 fun rewrite (body,n,smapp,embed)=(case body
198 : cchiw 2522 of E.Tensor(id,ix)=> E.Tensor(id, (List.map (fn e=>rewriteIndex(e, smapp)) ix))
199 :     | E.Epsilon(i,j,k)=>E.Epsilon(singleIndex(i, smapp),singleIndex(j, smapp),singleIndex(k, smapp))
200 : cchiw 2525 | E.Value i=> (E.Value(singleIndex (i,smapp)))
201 : cchiw 2522 | E.Delta(i,j)=> E.Delta(rewriteIndex(i,smapp), rewriteIndex(j,smapp))
202 : cchiw 2525 | E.Add e=> E.Add(List.map (fn(e1)=>rewrite(e1,n,smapp,embed)) e)
203 :     | E.Sub(e1,e2)=> E.Sub(rewrite(e1,n,smapp,embed),rewrite(e2,n,smapp,embed))
204 :     | E.Div(e1,e2)=> E.Div(rewrite(e1,n,smapp,embed),rewrite(e2,n,smapp,embed))
205 :    
206 :     | E.Sum(sx,E.Prod e) =>let
207 :    
208 :     (* val level=(Int.toString(embed))
209 :     val q=print "\n START *************************************\n"
210 :     val qqqq=print level
211 :     *)
212 :     val (mm,nn)=createMapp(sx,n,smapp)
213 :    
214 :    
215 :    
216 :     val k=E.Prod(List.map (fn(e1)=> rewrite(e1,nn,mm,embed+1)) e)
217 :     in E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,mm),lb,ub)) sx),k) end
218 :    
219 : cchiw 2522 | E.Sum(sx,e)=> let
220 : cchiw 2525 (*
221 :     val q=print "\n START *************************************\n "
222 :     val level=(Int.toString(embed))
223 :     val qqqq=print level*)
224 :    
225 :     val (mm,nn)=createMapp(sx,n,smapp)
226 :    
227 :    
228 :    
229 :     val m=E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,mm),lb,ub)) sx),rewrite(e,nn,mm,embed+1))
230 :     (*
231 :     val qq=print level
232 :     val q=print "END *************************************\n "
233 :    
234 :     *)
235 :     in m end
236 :    
237 :     | E.Prod e=> E.Prod(List.map (fn(e1)=>rewrite(e1,n,smapp,embed)) e)
238 :     | E.Neg e=> E.Neg(rewrite(e, n, smapp,embed))
239 : cchiw 2522 | E.Krn (h,dx, pos)=>
240 : cchiw 2525 E.Krn(h,(List.map (fn (e1,e2)=>(e1,rewriteIndex(e2, smapp))) dx), rewrite(pos,n,smapp,embed))
241 : cchiw 2522 | E.Img (v,alpha, pos)=>
242 :     E.Img(v,(List.map (fn e=>rewriteIndex(e, smapp)) alpha),
243 : cchiw 2525 (List.map (fn e=>rewrite(e, n, smapp,embed)) pos))
244 : cchiw 2522 | E.Conv(v,alpha,h, dx)=> E.Conv(v, (List.map (fn e=>rewriteIndex(e, smapp)) alpha),h,(List.map (fn e=>rewriteIndex(e, smapp)) dx))
245 :     | E.Probe(e1,e2)=>raise Fail "Probe- Should have been expanded"
246 :     | _=> body
247 :     (*end case*))
248 :    
249 : cchiw 2525 val e'=rewrite(e, length ix, mapp,0)
250 : cchiw 2522
251 :     in
252 :     (ix,index',e')
253 :     end
254 :    
255 :    
256 :     fun clean(params, index, body, args)=let
257 :     val (p',body',args')= cleanParams(body, params, args)
258 :     val (_,i',b')=cleanIndex(body', length index, index)
259 :     in (p',i',b',args')
260 :     end
261 :    
262 :     end (* local *)
263 :    
264 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0