Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2680 - (view) (download)

1 : cchiw 2522 (*hashs Ein Function after substitution*)
2 :     structure genKrn = struct
3 :     local
4 : cchiw 2612
5 :     structure DstOp = LowOps
6 :     structure DstTy = LowILTypes
7 :     structure DstIL = LowIL
8 : cchiw 2522 structure E = Ein
9 : cchiw 2612 structure evalKrn =evalKrn
10 :     structure S3=step3
11 : cchiw 2624 structure tS= toStringEin
12 : cchiw 2553
13 : cchiw 2522 in
14 :    
15 : cchiw 2612 val testing=0
16 :     val Sca=DstTy.TensorTy []
17 : cchiw 2676 val addR=DstOp.addSca
18 : cchiw 2522
19 : cchiw 2553 fun insert (key, value) d =fn s =>
20 : cchiw 2612 if s = key then SOME value
21 :     else d s
22 : cchiw 2553 fun lookup k d = d k
23 : cchiw 2624 fun find(v, mapp)=(case (lookup v mapp)
24 :     of NONE=> raise Fail "Outside Bound"
25 :     |SOME s => s)
26 :    
27 : cchiw 2553
28 : cchiw 2612 fun errS str=raise Fail(str)
29 : cchiw 2637 fun mkInt n= S3.mkInt n
30 : cchiw 2628
31 :    
32 : cchiw 2612 fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca)
33 : cchiw 2676 fun mkAddSca rest= S3.aaV(addR,rest,"addSca",Sca)
34 : cchiw 2624 fun mkCons(shape, rest)=let
35 :     val ty=DstTy.TensorTy shape
36 :     val a=DstIL.Var.new("Cons" ,ty)
37 :     val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
38 :     val _=print("\n****"^tS.toStringAll(ty,code))
39 :     in
40 :     (a, [code])
41 :     end
42 : cchiw 2553
43 : cchiw 2624
44 : cchiw 2522 (*Add,Subtract Scalars*)
45 :     fun mkSimpleOp(mapp,e,args)=let
46 :     fun subP e=(case e
47 : cchiw 2612 of E.Tensor(t1,ix1)=>S3.mkSca(mapp,(t1,ix1,args))
48 :     | E.Value v1=> mkInt (find(v1,mapp))
49 :     | _ => errS("ill-formed Kernel position")
50 : cchiw 2522 (*end case*))
51 :     in (case e
52 :     of E.Sub(e1,e2)=> let
53 :     val (vA,A)=subP e1
54 :     val (vB,B)=subP e2
55 : cchiw 2612 val (vD,D)=mkSubSca [vA,vB]
56 : cchiw 2522 in (vD,A@B@D) end
57 :     | E.Add[e1,e2]=> let
58 :     val (vA,A)=subP e1
59 :     val (vB,B)=subP e2
60 : cchiw 2612 val (vD,D)=mkAddSca [vA,vB]
61 : cchiw 2522 in (vD,A@B@D) end
62 : cchiw 2615 | _ => raise Fail"Probed position is not subtraction or addition"
63 :     (* end case*))
64 : cchiw 2522 end
65 :    
66 :    
67 : cchiw 2525 (*FIX TYPE ON CONS TYPE HERE *)
68 : cchiw 2522 (*con everything on the list, make a vectors*)
69 :    
70 : cchiw 2612 fun consfn([],rest, code,_)=(rest,code)
71 :     | consfn(e::es,rest,code,dim)=let
72 :     val (vA,A)= mkCons([length(e)],List.rev e)
73 :     in
74 :     consfn(es, [vA]@rest, A@code,dim)
75 :     end
76 : cchiw 2522
77 : cchiw 2612
78 : cchiw 2522 (*sort expression into kernels and images*)
79 :     fun sortK(a,b,[])=(a,b)
80 : cchiw 2612 | sortK(a,b,e::es)=(case e
81 :     of E.Krn k1=>sortK(a,b@[k1],es)
82 :     | E.Img img1=>sortK(a@[img1],b,es)
83 : cchiw 2615 | _ =>raise Fail"Non Image or Krn in summation expression"
84 : cchiw 2612 (*end case*))
85 : cchiw 2522
86 :    
87 : cchiw 2680 (*
88 : cchiw 2529 fun sumP(a,b,last)=let
89 : cchiw 2680 val (vD, D)=S3.aaV(DstOp.prodVec(last,1),[a, b],"prodV",DstTy.TensorTy([last]))
90 :     val (vE, E)=S3.aaV(DstOp.sumVec(last,1),[vD],"sumVec",DstTy.intTy)
91 : cchiw 2529 in (vE,D@E) end
92 : cchiw 2680 *)
93 :     (*added dot vec operator here *)
94 :     fun sumP(a,b,last)=let
95 :     val (vE, E)=S3.aaV(DstOp.dotVec(last,1),[a,b],"dotVec",DstTy.intTy)
96 :     in (vE,E) end
97 : cchiw 2525
98 : cchiw 2522
99 : cchiw 2680
100 :     (*Images, ivar:original argument *)
101 :     fun mkImg(mappOrig,sx,starter,v,vNew,info,sid,lb,ub,top, R)=let
102 : cchiw 2615 val [(fid,ix,px)]=(case starter
103 :     of [(fid,ix,px)]=>[(fid,ix,px)]
104 :     | _=> raise Fail"Non summation range")
105 :    
106 : cchiw 2529 val dim=length(px)
107 : cchiw 2612 val (vlb,BB)= mkInt lb
108 : cchiw 2529
109 : cchiw 2680 val (vBase,base)=S3.aaV(DstOp.baseAddr v,[vNew],"baseAddr",Sca)
110 :    
111 : cchiw 2612 fun createImgVar mapp=let
112 :     fun mkpos(e,rest,code)=(case e
113 :     of [E.Add[E.Tensor(t1,ix1),_]]=> let
114 :     val (vA,A)=S3.mkSca(mapp,(t1,ix1,info))
115 :     val (vC,C)=mkAddSca [vA,vlb]
116 :     in
117 :     (rest@[vC],code@A@C)
118 :     end
119 :     | pos1::es => let
120 :     val (vF,code1)=mkSimpleOp(mapp,pos1,info)
121 :     in
122 :     mkpos(es,rest@[vF],code@code1)
123 :     end
124 : cchiw 2615 | _=> raise Fail "Non-addition in Image Position"
125 : cchiw 2612 (*end case*))
126 : cchiw 2530
127 : cchiw 2529 val (vF,F)= mkpos(px,[],[])
128 : cchiw 2612 val imgType=DstTy.imgIndex(List.map (fn (e1)=> S3.mapIndex(e1,mapp)) ix)
129 : cchiw 2680 val (vA,A)=S3.aaV(DstOp.imgAddr(v,imgType,dim),[vBase]@vF,"Imageaddress",DstTy.intTy)
130 :     val (vB,B)=S3.aaV(DstOp.imgLoad(v,dim,R),[vA],"imgLoad",DstTy.tensorTy([R]))
131 : cchiw 2522 in
132 : cchiw 2529 (vB,F@A@B)
133 : cchiw 2522 end
134 : cchiw 2612 fun sumI1(lft,dict,0,0,code,n')=let
135 :     val mapp=insert (n', lb) dict
136 :     val (lft', code')= createImgVar mapp
137 :     in ([lft']@lft,code'@code) end
138 :     | sumI1(lft,dict,i,0,code,n')=let
139 :     val mapp=insert (n', i-1) dict
140 : cchiw 2529 val (lft', code')=createImgVar mapp
141 : cchiw 2612 in sumI1([lft']@lft,dict,i-1,0,code'@code,n') end
142 :     | sumI1(lft,dict,0,n,code,n')=let
143 :     val mapp=insert (n', lb) dict
144 :     in sumI1(lft,mapp,top,n-1,code,n'+1) end
145 :     | sumI1(lft,dict,i,n, code,n')=let
146 :     val mapp=insert (n', i+lb) dict
147 : cchiw 2553 val (lft',code')=sumI1(lft,mapp,top,n-1,code,n'+1)
148 : cchiw 2612 in sumI1(lft',dict,i-1,n,code',n') end
149 :     val(lft,code)=sumI1([],mappOrig,top,dim-2,[],sid)
150 : cchiw 2529 in
151 : cchiw 2680 (lft,base@BB@code)
152 : cchiw 2529 end
153 :    
154 :    
155 :     (* kernels*)
156 : cchiw 2612 fun mkkrns(mappOrig,sx,dels,h,args, sid,lb,ub,top,R)=let
157 :     val newdels= List.map (fn (id,d1,pos)=>(id,S3.evalDels(mappOrig,d1),pos)) dels
158 : cchiw 2605 val _ =(case testing
159 :     of 1=> let
160 : cchiw 2612 fun mm(e)=Int.toString e
161 :     val _ =String.concat(["Differentiation value of kernels:"]@
162 :     (List.map (fn(id,v, pos)=> mm(v)) newdels)@ ["\n ub:", mm ub, "lb:", mm lb, "Range", mm top])
163 : cchiw 2605 in 1 end
164 :     | _ => 1)
165 : cchiw 2612 fun mkpos(k,fin,l,dict, i,code,n)= (case (k,i)
166 :     of ([],_)=>(fin,code)
167 :     | ((id1,d,pos1)::ks,0)=>let
168 :     val mapp=insert (n, lb) dict
169 : cchiw 2529 val (l', code')=mkSimpleOp(mapp,pos1,args)
170 :     val e=l@[l']
171 : cchiw 2612 val mapp'=insert (n, 0) dict
172 :     in mkpos(ks,fin@[e],[],mapp',top,code@code',n+1)
173 : cchiw 2529 end
174 : cchiw 2612 | (e1::es,_) =>let
175 :     val (id1,d,pos1)=e1
176 :     val mapp= insert (n, lb+i) dict
177 : cchiw 2529 val (l', code')=mkSimpleOp(mapp,pos1,args)
178 : cchiw 2612 in mkpos(k,fin,l@[l'],dict,i-1,code@code',n)
179 : cchiw 2529 end
180 : cchiw 2612 (*end case*))
181 :     fun evalK([],[],_,code,newId)=(newId,code)
182 :     | evalK(kn::kns,x::xs,n,code,newId)=let
183 : cchiw 2536 val (_,dk,_) =kn
184 : cchiw 2612 val (id,kcode)= evalKrn.expandEvalKernel (R,h, dk, x,n)
185 :     in
186 :     evalK(kns,xs,n+1,code@kcode,newId@[id])
187 : cchiw 2536 end
188 : cchiw 2615 |evalK _ =raise Fail "Non-equal variable list, error in mkKrns"
189 :    
190 : cchiw 2612 val(lftkrn,code)=mkpos(newdels,[],[],mappOrig,top,[],sid)
191 :     val (lft,code')=consfn((lftkrn),[],[],top)
192 :     val (ids, evalKcode)=evalK(newdels,lft,0,[],[])
193 : cchiw 2529 in
194 : cchiw 2543 (ids, code@code'@evalKcode)
195 : cchiw 2529 end
196 : cchiw 2536
197 : cchiw 2612
198 :     (*Product of Image and Kernel*)
199 : cchiw 2529 fun prodImgKrn(imgArg,krnArg,R)=let
200 : cchiw 2525
201 : cchiw 2624 fun ConsInt(shape, rest)=let
202 :     val ty=DstTy.TensorTy [R]
203 :     val a=DstIL.Var.new("Cons" ,ty)
204 :     val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
205 :     in (a, [code])
206 :     end
207 : cchiw 2605
208 : cchiw 2624
209 : cchiw 2529 fun dhz([],conslist,rest,code,_,_)=(conslist,code)
210 : cchiw 2612 | dhz(e::es,conslist,rest,code,hz,0)=let
211 : cchiw 2529 val (vA,A)=sumP(e,hz,R)
212 : cchiw 2612 val (vD,D)=ConsInt(R,rest@[vA])
213 : cchiw 2529 in dhz(es,conslist@[vD],[],code@A@D,hz,R-1)
214 : cchiw 2522 end
215 : cchiw 2612 | dhz(e::es,conslist,rest,code,hz,r)=let
216 : cchiw 2529 val (vA,A)=sumP(e,hz,R)
217 :     in dhz(es,conslist,rest@[vA],code@A,hz,r-1)
218 :     end
219 : cchiw 2522
220 :    
221 : cchiw 2529 fun dhy([],rest,code,hy)= let
222 : cchiw 2612 val (vD,D)=ConsInt(length(rest),rest)
223 :     in (vD,code@D) end
224 :     | dhy(e::es,rest,code,hy)=let
225 :     val (vA,A)=sumP(e,hy,R)
226 :     in dhy(es,rest@[vA],code@A,hy)end
227 : cchiw 2522
228 : cchiw 2612 (*Create Product by doing case analysis of the dimension*)
229 : cchiw 2529 in (case krnArg
230 : cchiw 2536 of [hx]=>let
231 :     val [i]=imgArg
232 :     in sumP(i,hx,R)
233 :     end
234 : cchiw 2612 | [hy,hx]=>let
235 : cchiw 2529 val ty=DstTy.TensorTy[R]
236 : cchiw 2530 val (vD,code)=dhy(imgArg,[],[],hy)
237 : cchiw 2529 val (vE,E)=sumP(vD,hx,R)
238 :     in
239 :     (vE,code@E)
240 :     end
241 :     | [hz,hy,hx]=>let
242 :     val (vZ,codeZ)=dhz(imgArg,[],[],[],hz,R-1)
243 :     val (vY,codeY)=dhy(vZ,[],[],hy)
244 :     val (vE,E)=sumP(vY,hx,R)
245 :     in
246 :     (vE,codeZ@codeY@E)
247 :     end
248 : cchiw 2615 | _ => raise Fail "Kernel dimensions not between 1-3"
249 : cchiw 2612 (*end case*))
250 : cchiw 2529 end
251 :    
252 : cchiw 2680 fun evalField(mapp,(E.Sum(sx,E.Prod e),(v,vNew),h,info))=let
253 : cchiw 2605 val _=(case testing
254 :     of 0 => 1
255 : cchiw 2612 | _ => (print "\n\n ************** new direction **********\n\n";1)
256 : cchiw 2605 (*end test*))
257 : cchiw 2530
258 : cchiw 2529 val (img1,k1)=sortK([],[],e)
259 : cchiw 2612 val (E.V sid,lb,ub)=hd(sx)
260 :     val top=(ub-lb)
261 :     val R=top+1
262 : cchiw 2529
263 : cchiw 2680 val (imgArg,imgCode)= mkImg(mapp,sx,img1,v,vNew,info,sid,lb,ub,top,R)
264 : cchiw 2612 val (krnArg, krnCode)= mkkrns(mapp,sx,k1,h,info,sid,lb,ub,top,R)
265 : cchiw 2529 val (vA,A)=prodImgKrn(imgArg,krnArg,R)
266 : cchiw 2612 val _=(case testing
267 :     of 0=> 1
268 :     |_ =>(print ("Number of Assignments in prodImgArg returned"^Int.toString(length(imgArg)));1)
269 :     (*end case*))
270 :     in
271 :     (vA,imgCode@krnCode@A)
272 : cchiw 2525 end
273 : cchiw 2522
274 : cchiw 2615 |evalField _=raise Fail "Incorrect Field Expression"
275 : cchiw 2522
276 :     end (* local *)
277 :    
278 : cchiw 2529 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0