Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2605 - (view) (download)

1 : cchiw 2522 (*hashs Ein Function after substitution*)
2 :     structure genKrn = struct
3 :     local
4 :     structure E = Ein
5 :     structure DstOp = LowOps
6 : cchiw 2553
7 : cchiw 2522 structure evalKrn =evalKrn
8 :     structure SrcIL = MidIL
9 :     structure SrcOp = MidOps
10 :     structure SrcSV = SrcIL.StateVar
11 :     structure SrcTy = MidILTypes
12 :     structure VTbl = SrcIL.Var.Tbl
13 :     structure DstIL = LowIL
14 :     structure DstTy = LowILTypes
15 : cchiw 2525 structure Var = LowIL.Var
16 : cchiw 2522 in
17 :    
18 :    
19 : cchiw 2553 fun insert (key, value) d =fn s =>
20 :     if s = key then SOME value
21 :     else d s
22 : cchiw 2522
23 : cchiw 2553 fun lookup k d = d k
24 :    
25 :     fun find(v, mapp)=let
26 :     val a=lookup v mapp
27 :     in (case a of NONE=> raise Fail "Outside Bound"
28 :     |SOME s => s)
29 :     end
30 :    
31 :     val empty =fn key =>NONE
32 :    
33 : cchiw 2605 val testing=1
34 : cchiw 2522 (*Add,Subtract Scalars*)
35 :     fun mkSimpleOp(mapp,e,args)=let
36 :    
37 :     fun subP e=(case e
38 : cchiw 2553 of E.Tensor(t1,ix1)=>gHelper.mkSca(mapp,(t1,ix1,args))
39 :     (*| E.Value v1=> gHelper.aaV(DstOp.C (List.nth(mapp,v1)),[],"Const",DstTy.TensorTy([]))*)
40 :     | E.Value v1=> gHelper.aaV(DstOp.C (find(v1,mapp)),[],"Const",DstTy.TensorTy([]))
41 :     | E.Const c=> gHelper.aaV(DstOp.C 9,[],"Const",DstTy.TensorTy([]))
42 : cchiw 2522 (*end case*))
43 :    
44 :     in (case e
45 :     of E.Sub(e1,e2)=> let
46 :     val (vA,A)=subP e1
47 :     val (vB,B)=subP e2
48 : cchiw 2553 val (vD,D)=gHelper.aaV(DstOp.subSca,[vA,vB],"Subsca",DstTy.TensorTy([]))
49 : cchiw 2522 in (vD,A@B@D) end
50 :     | E.Add[e1,e2]=> let
51 :     val (vA,A)=subP e1
52 :     val (vB,B)=subP e2
53 : cchiw 2553 val (vD,D)=gHelper.aaV(DstOp.addSca, [vA,vB],"addsca",DstTy.TensorTy([]))
54 : cchiw 2522 in (vD,A@B@D) end
55 :     (* ebd case*))
56 :     end
57 :    
58 :    
59 : cchiw 2525 (*FIX TYPE ON CONS TYPE HERE *)
60 : cchiw 2522 (*con everything on the list, make a vectors*)
61 : cchiw 2529 fun consfn([],rest, code,dim,n)=(rest,code)
62 :     | consfn(e::es,rest,code,dim,n)=let
63 : cchiw 2525 val gg=length(e)
64 : cchiw 2553 val (vA,A)=gHelper.aaV(DstOp.cons(DstTy.TensorTy [gg],0),List.rev e,"Cons "^Int.toString(n)^":--",DstTy.TensorTy([gg]))
65 : cchiw 2529 in consfn(es, [vA]@rest, A@code,dim,n+1)
66 : cchiw 2522 end
67 :    
68 :    
69 :     (*sort expression into kernels and images*)
70 :     fun sortK(a,b,[])=(a,b)
71 :     | sortK(a,b,e::es)=(case e
72 :     of E.Krn k1=>sortK(a,b@[k1],es)
73 :     | E.Img img1=>sortK(a@[img1],b,es)
74 :     (*end case*))
75 :    
76 :    
77 :    
78 : cchiw 2529 val bV= ref 0
79 : cchiw 2522
80 : cchiw 2525
81 :    
82 :    
83 : cchiw 2529 fun sumP(a,b,last)=let
84 : cchiw 2553 val (vD, D)=gHelper.aaV(DstOp.prodVec(last),[a, b],"prodV",DstTy.TensorTy([last]))
85 :     val (vE, E)=gHelper.aaV(DstOp.sumVec(last),[vD],"sumVec",DstTy.intTy)
86 : cchiw 2529 in (vE,D@E) end
87 : cchiw 2525
88 : cchiw 2529 fun ccons(rest,shape)= let
89 : cchiw 2553 val(vE,E)=gHelper.aaV(DstOp.cons(DstTy.TensorTy shape,0),rest,"Cons",DstTy.TensorTy(shape))
90 : cchiw 2529 in (vE,E) end
91 : cchiw 2522
92 : cchiw 2525
93 : cchiw 2553
94 : cchiw 2529 (*Images*)
95 : cchiw 2533 fun mkImg(mappOrig,sx,[(fid,ix,px)],v,args)=let
96 : cchiw 2529
97 : cchiw 2553 val (E.V vid,lb,ub)=hd(sx)
98 : cchiw 2529 val top=ub-lb
99 :     val R=top+1
100 :     val dim=length(px)
101 :     val sx'=List.tabulate(dim, fn _ =>top)
102 :     val sx''=List.map (fn n=>n+1) sx'
103 :     val argType=DstTy.tensorTy (List.tabulate(dim, fn _ =>R))
104 :    
105 : cchiw 2555 val (vlb,BB)= gHelper.mkC lb
106 : cchiw 2530
107 : cchiw 2529 fun createImgVar mapp=let
108 :     fun mkpos([E.Add[E.Tensor(t1,ix1),_]],rest,code)= let
109 : cchiw 2553 val (vA,A)=gHelper.mkSca(mapp,(t1,ix1,args))
110 : cchiw 2529
111 : cchiw 2555 val (vC,C)=gHelper.aaV(DstOp.addSca, [vA,vlb],"addsca",DstTy.TensorTy([]))
112 :    
113 : cchiw 2522 in
114 : cchiw 2555 (rest@[vC],code@A@C)
115 : cchiw 2522 end
116 :    
117 : cchiw 2529 | mkpos(pos1::es,rest,code)= let
118 :     val (vF,code1)=mkSimpleOp(mapp,pos1,args)
119 :     in mkpos(es,rest@[vF],code@code1)
120 :     end
121 : cchiw 2553 val ix1=List.map (fn (e1)=> gHelper.mapIndex(e1,mapp)) ix
122 : cchiw 2529 val (vF,F)= mkpos(px,[],[])
123 : cchiw 2530 val imgType=DstTy.imgIndex ix1
124 : cchiw 2553 val (vA,A)=gHelper.aaV(DstOp.imgAddr(v,imgType,dim),vF,"Imageaddress",DstTy.intTy)
125 :     val (vB,B)=gHelper.aaV(DstOp.imgLoad(v,dim,R),[vA],"imgLoad---",DstTy.tensorTy([R]))
126 : cchiw 2522 in
127 : cchiw 2529 (vB,F@A@B)
128 : cchiw 2522 end
129 :    
130 :    
131 : cchiw 2553 fun sumI1(lft,ix,0,0,code,n')=let
132 :     val mapp=insert (n', lb) ix
133 : cchiw 2529 val (lft', code')= createImgVar mapp
134 :     in ([lft']@lft,code'@code)
135 :     end
136 : cchiw 2553 | sumI1(lft,ix,i,0,code,n')=let
137 :     val mapp=insert (n', i-1) ix
138 : cchiw 2529 val (lft', code')=createImgVar mapp
139 : cchiw 2553 in sumI1([lft']@lft,ix,i-1,0,code'@code,n')
140 : cchiw 2522 end
141 : cchiw 2553 | sumI1(lft,ix,0,n,code,n')=let
142 :     val mapp=insert (n', lb) ix
143 :     in
144 :     sumI1(lft,mapp,top,n-1,code,n'+1)
145 :     end
146 :     | sumI1(lft,ix,i,n, code,n')=let
147 :     val mapp=insert (n', i+lb) ix
148 :     val (lft',code')=sumI1(lft,mapp,top,n-1,code,n'+1)
149 :     in sumI1(lft',ix,i-1,n,code',n') end
150 :    
151 :    
152 :     val(lft,code)=sumI1([],mappOrig,top,dim-2,[],vid)
153 : cchiw 2605
154 : cchiw 2529 in
155 : cchiw 2555 (lft,BB@code)
156 : cchiw 2522
157 : cchiw 2529 end
158 :    
159 :    
160 :     (* kernels*)
161 :    
162 : cchiw 2553 fun mkkrns2(mappOrig,sx,k1,h,args)=let
163 :    
164 : cchiw 2605
165 : cchiw 2553 val k= List.map (fn (id,d1,pos)=>(id,gHelper.evalDelta(d1,mappOrig),pos)) k1
166 : cchiw 2529
167 : cchiw 2605
168 : cchiw 2553 val (E.V sid,lb,ub)=hd(sx)
169 : cchiw 2533 val R=(ub-lb)
170 : cchiw 2536 val R'=R+1
171 : cchiw 2533
172 :     fun mm(e)=Int.toString e
173 :    
174 : cchiw 2605 val _ =(case testing
175 :     of 1=> let
176 :     val _ =print "Differentiation value of kernels:"
177 :     val _= List.map (fn(id,v, pos)=> print(Int.toString(v))) k
178 :     val _ =print(String.concat["\n ub:", mm ub, "lb:", mm lb, "Range", mm R ])
179 :     in 1 end
180 :     | _ => 1)
181 :    
182 :    
183 : cchiw 2553 fun q([],fin,l,ix, i,code,n')=(fin,code)
184 :     | q((id1,d,pos1)::ks,fin,l,ix,0,code,n')=let
185 :     val mapp=insert (n', lb) ix
186 : cchiw 2529 val (l', code')=mkSimpleOp(mapp,pos1,args)
187 :     val e=l@[l']
188 : cchiw 2553 val mapp'=insert (n', 0) ix
189 :     in q(ks,fin@[e],[],mapp',R,code@code',n'+1)
190 : cchiw 2529 end
191 : cchiw 2553 | q(k::ks,fin,l,ix, i,code,n')=let
192 : cchiw 2533 val (id1,d,pos1)=k
193 : cchiw 2553 val mapp= insert (n', lb+i) ix
194 : cchiw 2529 val (l', code')=mkSimpleOp(mapp,pos1,args)
195 : cchiw 2553 in q(k::ks,fin,l@[l'],ix,i-1,code@code',n')
196 : cchiw 2529 end
197 :    
198 : cchiw 2533
199 : cchiw 2553 val(lftkrn,code)=q(k,[],[],mappOrig,R,[],sid)
200 : cchiw 2533 val (lft,code')=consfn((lftkrn),[],[],R,0)
201 : cchiw 2536
202 : cchiw 2543
203 : cchiw 2605
204 : cchiw 2543 fun evalK([],[],n,code,newId)=(newId,code)
205 : cchiw 2536 | evalK(kn::kns,x::xs,n,code,newId)=let
206 :     val (_,dk,_) =kn
207 :     val (id,kcode)= evalKrn.expandEvalKernel (R', h, dk, x,n)
208 :     in evalK(kns,xs,n+1,code@kcode,newId@[id])
209 :     end
210 :    
211 : cchiw 2543 val (ids, evalKcode)=evalK(k,lft,0,[],[])
212 : cchiw 2536
213 : cchiw 2529 in
214 : cchiw 2543 (* (lft,code@code')*)
215 :     (ids, code@code'@evalKcode)
216 : cchiw 2529 end
217 : cchiw 2536
218 : cchiw 2529 (*Written for 2-d and 3-d*)
219 :     fun prodImgKrn(imgArg,krnArg,R)=let
220 : cchiw 2525
221 : cchiw 2605
222 : cchiw 2529 val tyM=DstTy.TensorTy[R,R]
223 :     val tyV=DstTy.TensorTy[R]
224 : cchiw 2605 val _=(case testing of 0=> 1
225 :     | _ =>(print ("Number of Assignments in prodImgArg returned"^Int.toString(length(imgArg)));1))
226 : cchiw 2525
227 : cchiw 2529 fun dhz([],conslist,rest,code,_,_)=(conslist,code)
228 :     | dhz(e::es,conslist,rest,code,hz,0)=let
229 :     val (vA,A)=sumP(e,hz,R)
230 : cchiw 2553 val (vD,D)=gHelper.aaV(DstOp.cons(DstTy.intTy,R),rest@[vA],"Cons",tyV)
231 : cchiw 2529 in dhz(es,conslist@[vD],[],code@A@D,hz,R-1)
232 : cchiw 2522 end
233 : cchiw 2529 | dhz(e::es,conslist,rest,code,hz,r)=let
234 :     val (vA,A)=sumP(e,hz,R)
235 :     in dhz(es,conslist,rest@[vA],code@A,hz,r-1)
236 :     end
237 : cchiw 2522
238 :    
239 : cchiw 2529 fun dhy([],rest,code,hy)= let
240 : cchiw 2533 val n=length(rest)
241 : cchiw 2553 val (vD,D)=gHelper.aaV(DstOp.cons(DstTy.intTy,n),rest,"Cons",tyV)
242 : cchiw 2529 in
243 :     (vD,code@D) end
244 :     | dhy(e::es,rest,code,hy)=let
245 :     val (vA,A)=sumP(e,hy,R)
246 :     in dhy(es,rest@[vA],code@A,hy)
247 :     end
248 : cchiw 2522
249 : cchiw 2529 in (case krnArg
250 : cchiw 2536 of [hx]=>let
251 :     val [i]=imgArg
252 :     in sumP(i,hx,R)
253 :     end
254 :    
255 :     | [hy,hx]=>let
256 : cchiw 2529 val ty=DstTy.TensorTy[R]
257 : cchiw 2530 val (vD,code)=dhy(imgArg,[],[],hy)
258 : cchiw 2529 val (vE,E)=sumP(vD,hx,R)
259 :     in
260 :     (vE,code@E)
261 :     end
262 :     | [hz,hy,hx]=>let
263 :    
264 :    
265 :     val (vZ,codeZ)=dhz(imgArg,[],[],[],hz,R-1)
266 :     val (vY,codeY)=dhy(vZ,[],[],hy)
267 :     val (vE,E)=sumP(vY,hx,R)
268 :     in
269 :     (vE,codeZ@codeY@E)
270 :     end
271 :    
272 :     (*end case*))
273 :     end
274 :    
275 : cchiw 2553 fun evalField(mapp,(E.Sum(sx,E.Prod e),v,h,args))=let
276 : cchiw 2605 val _=(case testing
277 :     of 0 => 1
278 :     | _ => (print "\n\n ************** new direction **********\n\n Outer Bound:";1)
279 :     (*end test*))
280 : cchiw 2530
281 : cchiw 2553
282 : cchiw 2529 val (img1,k1)=sortK([],[],e)
283 :     val (_,lb,ub)=hd(sx)
284 :     val R=(ub-lb)+1
285 :    
286 : cchiw 2553 val (imgArg,imgCode)= mkImg(mapp,sx,img1,v,args)
287 :     val (krnArg, krnCode)= mkkrns2(mapp,sx,k1,h,args)
288 : cchiw 2529 val (vA,A)=prodImgKrn(imgArg,krnArg,R)
289 :     in (vA,imgCode@krnCode@A)
290 :    
291 : cchiw 2533
292 : cchiw 2525 end
293 : cchiw 2522
294 :    
295 :    
296 :     end (* local *)
297 :    
298 : cchiw 2529 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0