Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/shiftHtM.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2553 - (view) (download)

1 : cchiw 2522 (* Shift Functions cleans up Params, and shifts down indices*)
2 :     structure shiftHtM = struct
3 :     local
4 :     structure E = Ein
5 : cchiw 2525 structure P=Printer
6 : cchiw 2522
7 :     in
8 :    
9 :    
10 :    
11 :     fun insert (key, value) d =fn s =>
12 :     if s = key then SOME value
13 :     else d s
14 :    
15 :     fun lookup k d = d k
16 :     val empty =fn key =>NONE
17 :    
18 :    
19 :    
20 :     fun flat xs = List.foldr op@ [] xs
21 :    
22 :     (*remap the tensor ids*)
23 :     fun cleanParams(body, params,args)=let
24 :     (*First step build a list of occurances*)
25 :     fun build(body,occur)=
26 :     (case body
27 :     of E.Tensor(id,ix)=> insert(id, 1) occur
28 :     | E.Sum(sx, e)=> build(e, occur)
29 :     | E.Neg e=> build(e,occur)
30 :     | E.Add e=> let
31 :    
32 :     fun add([],dict)=dict
33 :     | add(e1::es,dict)=let
34 :     val dict'=build(e1, dict)
35 :     in add(es, dict') end
36 :     in add (e, occur) end
37 :     | E.Sub(e1,e2) =>let
38 :     val dict'=build(e1, occur)
39 :     in build(e2, dict') end
40 :     | E.Div(e1,e2) =>let
41 :     val dict'=build(e1, occur)
42 :     in build(e2, dict') end
43 :     | E.Prod e => let
44 :     fun add([],dict)=dict
45 :     | add(e1::es,dict)=let
46 :     val dict'=build(e1, dict)
47 :     in add(es, dict') end
48 :     in add (e, occur) end
49 :     | E.Img(id,_,pos) => let
50 :     fun add([],dict)=dict
51 :     | add(e1::es,dict)=let
52 :     val dict'=build(e1, dict)
53 :     in add(es, dict') end
54 :     val d=insert(id, 1) occur
55 :     in add (pos, d) end
56 :     | E.Krn(id,_,pos) => let
57 :     val d=insert(id, 1) occur
58 :     in build(pos,d) end
59 :     | E.Conv(id,_,h,_) => let
60 :     val d= insert(id, 1) occur
61 :     in insert(h, 1) d end
62 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
63 :     | _ => occur
64 :     (*end case*))
65 :    
66 :     val occur=build(body, empty)
67 :    
68 :     (*remove params, args that are not used *)
69 :     fun removeP(_,_,newbie,lftp, [], lfta,_)=(newbie, lftp, lfta)
70 :     | removeP(_,_,newbie,lftp, _, lfta,[])=(newbie, lftp, lfta)
71 :     | removeP(pos,sumcount,newbie,lftp, p::pp, lfta,a::aa)=let
72 :     val c=lookup pos occur
73 :     in case c
74 :     of NONE => removeP(pos+1, sumcount,newbie@[0], lftp, pp,lfta,aa)
75 :     | SOME _=> removeP(pos+1,sumcount+1, newbie@[sumcount], lftp@[p], pp,lfta@[a],aa)
76 :     end
77 :    
78 :     val (newbie, params',args')=removeP(0,0,[],[],params,[],args)
79 :    
80 :     (*remap the tensor ids*)
81 :     fun remap body=(case body
82 :     of E.Tensor(id, ix)=> let val g=List.nth( newbie,id) in E.Tensor(g, ix) end
83 :     | E.Neg e=> E.Neg(remap e)
84 :     | E.Sum(sx, e)=> E.Sum(sx, remap e)
85 :     | E.Add e=> E.Add (List.map remap e)
86 :     | E.Prod e=> E.Prod (List.map remap e)
87 :     | E.Sub (e1,e2)=> E.Sub(remap e1, remap e2)
88 :     | E.Div(e1,e2)=> E.Div(remap e1, remap e2)
89 :     | E.Img(id,alpha,pos)=>let val g=List.nth( newbie,id)
90 :     in E.Img(g,alpha,(List.map remap pos)) end
91 :     | E.Krn(id,delta,pos)=>let val g=List.nth( newbie,id)
92 :     in E.Krn(g,delta,(remap pos)) end
93 :     | E.Conv(id,alpha,h,pos)=>let
94 :     val id'=List.nth( newbie,id)
95 :     val h'=List.nth( newbie,h)
96 :     in E.Conv(id',alpha,h',pos) end
97 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
98 :     | _=>body
99 :     (*end case*))
100 :    
101 :     val body'=remap body
102 :    
103 :     in (params',body',args')
104 :     end
105 :    
106 :    
107 :    
108 :     (*Remaps all the indices*)
109 :    
110 :     fun cleanIndex(e, intialn,index)=let
111 :     (*Each element in the list is unique*)
112 : cchiw 2553
113 : cchiw 2525
114 : cchiw 2522 fun uniq(list1,n)=let
115 :     fun m([],l)=l
116 :     | m(e1::es,l)= (case e1
117 :     of E.V v=> if (v>=n) then m(es,l)
118 :     else let val a=List.find (fn x => x = e1) l
119 :     in (case a of NONE=> m(es,l@[e1]) | _=> m(es,l)) end
120 :     |_ => m(es,l))
121 :    
122 :     in m(list1,[])
123 :     end
124 :    
125 : cchiw 2553 fun filterIndex(E.V v)= if(v>intialn) then [] else [E.V v]
126 : cchiw 2522
127 :    
128 : cchiw 2553 (*find all indices *)
129 : cchiw 2522 fun findOuterIndex(body,n)=(case body
130 :     of E.Tensor(id,ix)=> ix
131 :     | E.Const _=> []
132 :     | E.Add (e1::es)=> findOuterIndex(e1,n)
133 :     | E.Sub(e1,e2)=> findOuterIndex(e1,n)
134 :     | E.Div(e1,e2)=> findOuterIndex(e1,n)
135 : cchiw 2553 | E.Value e1=> [E.V e1]
136 : cchiw 2522 | E.Sum(sx,e)=> findOuterIndex(e,n)
137 :     | E.Prod e=> let
138 :     val e'=List.map (fn e1=>findOuterIndex(e1,n)) e
139 :     in uniq((flat e'),n)
140 :     end
141 :     | E.Delta(i,j)=>[i,j]
142 :     | E.Epsilon(i,j,k)=>[E.V i, E.V j, E.V k]
143 :     | E.Neg e=> findOuterIndex(e,n)
144 :     | E.Img(v,alpha,pos)=>alpha
145 :     | E.Krn(v,dels,pos)=> List.map (fn(e1,e2)=> e2) dels
146 :     | E.Conv(_,alpha,_,dx)=> alpha@dx
147 :     | E.Probe(e,x)=> raise Fail "Probe- Should have been expanded"
148 :     | _=> []
149 :     (*end case*))
150 :    
151 :    
152 :    
153 :     val ix=findOuterIndex(e, intialn)
154 :    
155 : cchiw 2525
156 : cchiw 2553
157 :     val g=print(String.concat["\n\n --",P.printbody(e),"length of binding-", Int.toString(length(index))," Outer", Int.toString(length(ix)),"\n"])
158 :    
159 :     fun q(E.V p,E.V n)=print(String.concat["\n", Int.toString(p),"===>>",Int.toString(n),"\n"])
160 :    
161 :    
162 :     (*Mapps just outer indices *)
163 :     fun g([],index',_,c,mapp,outer)=(index',c,mapp,outer)
164 :     | g(e::es,index',n,c, mapp,outer)= let
165 : cchiw 2522 val b=List.find (fn(E.V v)=>v=n) ix
166 :     in case b
167 : cchiw 2553 of NONE=>(g(es,index',n+1, c, mapp,outer))
168 : cchiw 2522 |_=> let val mapp'=insert(E.V n, E.V c) mapp
169 : cchiw 2553 val gg=print "Found, and inserting"
170 :     in g(es, index'@[e], n+1, c+1, mapp',outer@[(E.V n)]) end
171 : cchiw 2522 end
172 :    
173 : cchiw 2553 val (index',c,mapp,outer)=g(index,[],0,0,empty,[])
174 : cchiw 2522
175 : cchiw 2553 val mm=print ("FInal count"^Int.toString(c))
176 :     val adjustment= intialn-c
177 : cchiw 2522
178 : cchiw 2553 fun addIndextoMapp([],n,mapp)=(mapp,n)
179 :     | addIndextoMapp((s,_,_)::es,n,mapp)= let
180 : cchiw 2525 val E.V p=s
181 : cchiw 2553 val n'=(p-adjustment)
182 :     val y=q( s,E.V n')
183 :     val m=insert(s, E.V n') mapp
184 :     in addIndextoMapp(es,n+1, m) end
185 : cchiw 2522
186 :    
187 :    
188 :     fun rewriteIndex(e, smapp) =(case e
189 :     of E.V v =>let val l=lookup e smapp
190 : cchiw 2525 in case l of NONE=> raise Fail("error Could not find :"^Int.toString(v))
191 : cchiw 2522 | SOME s=> s end
192 :     | E.C _=> e
193 :     (*end case*))
194 :    
195 :     fun singleIndex(e,smapp)=let
196 :     val l=lookup (E.V e) smapp
197 : cchiw 2525 (* val g=print(String.concat["\n SingleIndex:", Int.toString(e)])*)
198 : cchiw 2522 in case l
199 :     of NONE=> (raise Fail" error could not find index" )
200 :     | SOME(E.V s)=> s
201 :     end
202 :    
203 : cchiw 2525 fun rewrite (body,n,smapp,embed)=(case body
204 : cchiw 2522 of E.Tensor(id,ix)=> E.Tensor(id, (List.map (fn e=>rewriteIndex(e, smapp)) ix))
205 :     | E.Epsilon(i,j,k)=>E.Epsilon(singleIndex(i, smapp),singleIndex(j, smapp),singleIndex(k, smapp))
206 : cchiw 2525 | E.Value i=> (E.Value(singleIndex (i,smapp)))
207 : cchiw 2522 | E.Delta(i,j)=> E.Delta(rewriteIndex(i,smapp), rewriteIndex(j,smapp))
208 : cchiw 2525 | E.Add e=> E.Add(List.map (fn(e1)=>rewrite(e1,n,smapp,embed)) e)
209 :     | E.Sub(e1,e2)=> E.Sub(rewrite(e1,n,smapp,embed),rewrite(e2,n,smapp,embed))
210 :     | E.Div(e1,e2)=> E.Div(rewrite(e1,n,smapp,embed),rewrite(e2,n,smapp,embed))
211 : cchiw 2522 | E.Sum(sx,e)=> let
212 : cchiw 2553 val (mm,nn)=addIndextoMapp(sx,n,smapp)
213 : cchiw 2525 val m=E.Sum((List.map (fn(e1, lb,ub)=>(rewriteIndex(e1,mm),lb,ub)) sx),rewrite(e,nn,mm,embed+1))
214 :     in m end
215 :     | E.Prod e=> E.Prod(List.map (fn(e1)=>rewrite(e1,n,smapp,embed)) e)
216 :     | E.Neg e=> E.Neg(rewrite(e, n, smapp,embed))
217 : cchiw 2522 | E.Krn (h,dx, pos)=>
218 : cchiw 2525 E.Krn(h,(List.map (fn (e1,e2)=>(e1,rewriteIndex(e2, smapp))) dx), rewrite(pos,n,smapp,embed))
219 : cchiw 2522 | E.Img (v,alpha, pos)=>
220 :     E.Img(v,(List.map (fn e=>rewriteIndex(e, smapp)) alpha),
221 : cchiw 2525 (List.map (fn e=>rewrite(e, n, smapp,embed)) pos))
222 : cchiw 2522 | E.Conv(v,alpha,h, dx)=> E.Conv(v, (List.map (fn e=>rewriteIndex(e, smapp)) alpha),h,(List.map (fn e=>rewriteIndex(e, smapp)) dx))
223 :     | E.Probe(e1,e2)=>raise Fail "Probe- Should have been expanded"
224 :     | _=> body
225 :     (*end case*))
226 :    
227 : cchiw 2525 val e'=rewrite(e, length ix, mapp,0)
228 : cchiw 2522
229 :     in
230 : cchiw 2553 (outer,index',e')
231 : cchiw 2522 end
232 :    
233 :    
234 :     fun clean(params, index, body, args)=let
235 :     val (p',body',args')= cleanParams(body, params, args)
236 :     val (_,i',b')=cleanIndex(body', length index, index)
237 :     in (p',i',b',args')
238 :     end
239 :    
240 :     end (* local *)
241 :    
242 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0