1 : |
cchiw |
2522 |
(*hashs Ein Function after substitution*)
|
2 : |
|
|
structure genKrn = struct
|
3 : |
|
|
local
|
4 : |
|
|
structure E = Ein
|
5 : |
|
|
structure DstOp = LowOps
|
6 : |
cchiw |
2553 |
|
7 : |
cchiw |
2522 |
structure evalKrn =evalKrn
|
8 : |
|
|
structure SrcIL = MidIL
|
9 : |
|
|
structure SrcOp = MidOps
|
10 : |
|
|
structure SrcSV = SrcIL.StateVar
|
11 : |
|
|
structure SrcTy = MidILTypes
|
12 : |
|
|
structure VTbl = SrcIL.Var.Tbl
|
13 : |
|
|
structure DstIL = LowIL
|
14 : |
|
|
structure DstTy = LowILTypes
|
15 : |
cchiw |
2525 |
structure Var = LowIL.Var
|
16 : |
cchiw |
2522 |
in
|
17 : |
|
|
|
18 : |
|
|
|
19 : |
cchiw |
2553 |
fun insert (key, value) d =fn s =>
|
20 : |
|
|
if s = key then SOME value
|
21 : |
|
|
else d s
|
22 : |
cchiw |
2522 |
|
23 : |
cchiw |
2553 |
fun lookup k d = d k
|
24 : |
|
|
|
25 : |
|
|
fun find(v, mapp)=let
|
26 : |
|
|
val a=lookup v mapp
|
27 : |
|
|
in (case a of NONE=> raise Fail "Outside Bound"
|
28 : |
|
|
|SOME s => s)
|
29 : |
|
|
end
|
30 : |
|
|
|
31 : |
|
|
val empty =fn key =>NONE
|
32 : |
|
|
|
33 : |
cchiw |
2605 |
val testing=1
|
34 : |
cchiw |
2522 |
(*Add,Subtract Scalars*)
|
35 : |
|
|
fun mkSimpleOp(mapp,e,args)=let
|
36 : |
|
|
|
37 : |
|
|
fun subP e=(case e
|
38 : |
cchiw |
2553 |
of E.Tensor(t1,ix1)=>gHelper.mkSca(mapp,(t1,ix1,args))
|
39 : |
|
|
(*| E.Value v1=> gHelper.aaV(DstOp.C (List.nth(mapp,v1)),[],"Const",DstTy.TensorTy([]))*)
|
40 : |
|
|
| E.Value v1=> gHelper.aaV(DstOp.C (find(v1,mapp)),[],"Const",DstTy.TensorTy([]))
|
41 : |
|
|
| E.Const c=> gHelper.aaV(DstOp.C 9,[],"Const",DstTy.TensorTy([]))
|
42 : |
cchiw |
2522 |
(*end case*))
|
43 : |
|
|
|
44 : |
|
|
in (case e
|
45 : |
|
|
of E.Sub(e1,e2)=> let
|
46 : |
|
|
val (vA,A)=subP e1
|
47 : |
|
|
val (vB,B)=subP e2
|
48 : |
cchiw |
2553 |
val (vD,D)=gHelper.aaV(DstOp.subSca,[vA,vB],"Subsca",DstTy.TensorTy([]))
|
49 : |
cchiw |
2522 |
in (vD,A@B@D) end
|
50 : |
|
|
| E.Add[e1,e2]=> let
|
51 : |
|
|
val (vA,A)=subP e1
|
52 : |
|
|
val (vB,B)=subP e2
|
53 : |
cchiw |
2553 |
val (vD,D)=gHelper.aaV(DstOp.addSca, [vA,vB],"addsca",DstTy.TensorTy([]))
|
54 : |
cchiw |
2522 |
in (vD,A@B@D) end
|
55 : |
|
|
(* ebd case*))
|
56 : |
|
|
end
|
57 : |
|
|
|
58 : |
|
|
|
59 : |
cchiw |
2525 |
(*FIX TYPE ON CONS TYPE HERE *)
|
60 : |
cchiw |
2522 |
(*con everything on the list, make a vectors*)
|
61 : |
cchiw |
2529 |
fun consfn([],rest, code,dim,n)=(rest,code)
|
62 : |
|
|
| consfn(e::es,rest,code,dim,n)=let
|
63 : |
cchiw |
2525 |
val gg=length(e)
|
64 : |
cchiw |
2553 |
val (vA,A)=gHelper.aaV(DstOp.cons(DstTy.TensorTy [gg],0),List.rev e,"Cons "^Int.toString(n)^":--",DstTy.TensorTy([gg]))
|
65 : |
cchiw |
2529 |
in consfn(es, [vA]@rest, A@code,dim,n+1)
|
66 : |
cchiw |
2522 |
end
|
67 : |
|
|
|
68 : |
|
|
|
69 : |
|
|
(*sort expression into kernels and images*)
|
70 : |
|
|
fun sortK(a,b,[])=(a,b)
|
71 : |
|
|
| sortK(a,b,e::es)=(case e
|
72 : |
|
|
of E.Krn k1=>sortK(a,b@[k1],es)
|
73 : |
|
|
| E.Img img1=>sortK(a@[img1],b,es)
|
74 : |
|
|
(*end case*))
|
75 : |
|
|
|
76 : |
|
|
|
77 : |
|
|
|
78 : |
cchiw |
2529 |
val bV= ref 0
|
79 : |
cchiw |
2522 |
|
80 : |
cchiw |
2525 |
|
81 : |
|
|
|
82 : |
|
|
|
83 : |
cchiw |
2529 |
fun sumP(a,b,last)=let
|
84 : |
cchiw |
2553 |
val (vD, D)=gHelper.aaV(DstOp.prodVec(last),[a, b],"prodV",DstTy.TensorTy([last]))
|
85 : |
|
|
val (vE, E)=gHelper.aaV(DstOp.sumVec(last),[vD],"sumVec",DstTy.intTy)
|
86 : |
cchiw |
2529 |
in (vE,D@E) end
|
87 : |
cchiw |
2525 |
|
88 : |
cchiw |
2529 |
fun ccons(rest,shape)= let
|
89 : |
cchiw |
2553 |
val(vE,E)=gHelper.aaV(DstOp.cons(DstTy.TensorTy shape,0),rest,"Cons",DstTy.TensorTy(shape))
|
90 : |
cchiw |
2529 |
in (vE,E) end
|
91 : |
cchiw |
2522 |
|
92 : |
cchiw |
2525 |
|
93 : |
cchiw |
2553 |
|
94 : |
cchiw |
2529 |
(*Images*)
|
95 : |
cchiw |
2533 |
fun mkImg(mappOrig,sx,[(fid,ix,px)],v,args)=let
|
96 : |
cchiw |
2529 |
|
97 : |
cchiw |
2553 |
val (E.V vid,lb,ub)=hd(sx)
|
98 : |
cchiw |
2529 |
val top=ub-lb
|
99 : |
|
|
val R=top+1
|
100 : |
|
|
val dim=length(px)
|
101 : |
|
|
val sx'=List.tabulate(dim, fn _ =>top)
|
102 : |
|
|
val sx''=List.map (fn n=>n+1) sx'
|
103 : |
|
|
val argType=DstTy.tensorTy (List.tabulate(dim, fn _ =>R))
|
104 : |
|
|
|
105 : |
cchiw |
2555 |
val (vlb,BB)= gHelper.mkC lb
|
106 : |
cchiw |
2530 |
|
107 : |
cchiw |
2529 |
fun createImgVar mapp=let
|
108 : |
|
|
fun mkpos([E.Add[E.Tensor(t1,ix1),_]],rest,code)= let
|
109 : |
cchiw |
2553 |
val (vA,A)=gHelper.mkSca(mapp,(t1,ix1,args))
|
110 : |
cchiw |
2529 |
|
111 : |
cchiw |
2555 |
val (vC,C)=gHelper.aaV(DstOp.addSca, [vA,vlb],"addsca",DstTy.TensorTy([]))
|
112 : |
|
|
|
113 : |
cchiw |
2522 |
in
|
114 : |
cchiw |
2555 |
(rest@[vC],code@A@C)
|
115 : |
cchiw |
2522 |
end
|
116 : |
|
|
|
117 : |
cchiw |
2529 |
| mkpos(pos1::es,rest,code)= let
|
118 : |
|
|
val (vF,code1)=mkSimpleOp(mapp,pos1,args)
|
119 : |
|
|
in mkpos(es,rest@[vF],code@code1)
|
120 : |
|
|
end
|
121 : |
cchiw |
2553 |
val ix1=List.map (fn (e1)=> gHelper.mapIndex(e1,mapp)) ix
|
122 : |
cchiw |
2529 |
val (vF,F)= mkpos(px,[],[])
|
123 : |
cchiw |
2530 |
val imgType=DstTy.imgIndex ix1
|
124 : |
cchiw |
2553 |
val (vA,A)=gHelper.aaV(DstOp.imgAddr(v,imgType,dim),vF,"Imageaddress",DstTy.intTy)
|
125 : |
|
|
val (vB,B)=gHelper.aaV(DstOp.imgLoad(v,dim,R),[vA],"imgLoad---",DstTy.tensorTy([R]))
|
126 : |
cchiw |
2522 |
in
|
127 : |
cchiw |
2529 |
(vB,F@A@B)
|
128 : |
cchiw |
2522 |
end
|
129 : |
|
|
|
130 : |
|
|
|
131 : |
cchiw |
2553 |
fun sumI1(lft,ix,0,0,code,n')=let
|
132 : |
|
|
val mapp=insert (n', lb) ix
|
133 : |
cchiw |
2529 |
val (lft', code')= createImgVar mapp
|
134 : |
|
|
in ([lft']@lft,code'@code)
|
135 : |
|
|
end
|
136 : |
cchiw |
2553 |
| sumI1(lft,ix,i,0,code,n')=let
|
137 : |
|
|
val mapp=insert (n', i-1) ix
|
138 : |
cchiw |
2529 |
val (lft', code')=createImgVar mapp
|
139 : |
cchiw |
2553 |
in sumI1([lft']@lft,ix,i-1,0,code'@code,n')
|
140 : |
cchiw |
2522 |
end
|
141 : |
cchiw |
2553 |
| sumI1(lft,ix,0,n,code,n')=let
|
142 : |
|
|
val mapp=insert (n', lb) ix
|
143 : |
|
|
in
|
144 : |
|
|
sumI1(lft,mapp,top,n-1,code,n'+1)
|
145 : |
|
|
end
|
146 : |
|
|
| sumI1(lft,ix,i,n, code,n')=let
|
147 : |
|
|
val mapp=insert (n', i+lb) ix
|
148 : |
|
|
val (lft',code')=sumI1(lft,mapp,top,n-1,code,n'+1)
|
149 : |
|
|
in sumI1(lft',ix,i-1,n,code',n') end
|
150 : |
|
|
|
151 : |
|
|
|
152 : |
|
|
val(lft,code)=sumI1([],mappOrig,top,dim-2,[],vid)
|
153 : |
cchiw |
2605 |
|
154 : |
cchiw |
2529 |
in
|
155 : |
cchiw |
2555 |
(lft,BB@code)
|
156 : |
cchiw |
2522 |
|
157 : |
cchiw |
2529 |
end
|
158 : |
|
|
|
159 : |
|
|
|
160 : |
|
|
(* kernels*)
|
161 : |
|
|
|
162 : |
cchiw |
2553 |
fun mkkrns2(mappOrig,sx,k1,h,args)=let
|
163 : |
|
|
|
164 : |
cchiw |
2605 |
|
165 : |
cchiw |
2553 |
val k= List.map (fn (id,d1,pos)=>(id,gHelper.evalDelta(d1,mappOrig),pos)) k1
|
166 : |
cchiw |
2529 |
|
167 : |
cchiw |
2605 |
|
168 : |
cchiw |
2553 |
val (E.V sid,lb,ub)=hd(sx)
|
169 : |
cchiw |
2533 |
val R=(ub-lb)
|
170 : |
cchiw |
2536 |
val R'=R+1
|
171 : |
cchiw |
2533 |
|
172 : |
|
|
fun mm(e)=Int.toString e
|
173 : |
|
|
|
174 : |
cchiw |
2605 |
val _ =(case testing
|
175 : |
|
|
of 1=> let
|
176 : |
|
|
val _ =print "Differentiation value of kernels:"
|
177 : |
|
|
val _= List.map (fn(id,v, pos)=> print(Int.toString(v))) k
|
178 : |
|
|
val _ =print(String.concat["\n ub:", mm ub, "lb:", mm lb, "Range", mm R ])
|
179 : |
|
|
in 1 end
|
180 : |
|
|
| _ => 1)
|
181 : |
|
|
|
182 : |
|
|
|
183 : |
cchiw |
2553 |
fun q([],fin,l,ix, i,code,n')=(fin,code)
|
184 : |
|
|
| q((id1,d,pos1)::ks,fin,l,ix,0,code,n')=let
|
185 : |
|
|
val mapp=insert (n', lb) ix
|
186 : |
cchiw |
2529 |
val (l', code')=mkSimpleOp(mapp,pos1,args)
|
187 : |
|
|
val e=l@[l']
|
188 : |
cchiw |
2553 |
val mapp'=insert (n', 0) ix
|
189 : |
|
|
in q(ks,fin@[e],[],mapp',R,code@code',n'+1)
|
190 : |
cchiw |
2529 |
end
|
191 : |
cchiw |
2553 |
| q(k::ks,fin,l,ix, i,code,n')=let
|
192 : |
cchiw |
2533 |
val (id1,d,pos1)=k
|
193 : |
cchiw |
2553 |
val mapp= insert (n', lb+i) ix
|
194 : |
cchiw |
2529 |
val (l', code')=mkSimpleOp(mapp,pos1,args)
|
195 : |
cchiw |
2553 |
in q(k::ks,fin,l@[l'],ix,i-1,code@code',n')
|
196 : |
cchiw |
2529 |
end
|
197 : |
|
|
|
198 : |
cchiw |
2533 |
|
199 : |
cchiw |
2553 |
val(lftkrn,code)=q(k,[],[],mappOrig,R,[],sid)
|
200 : |
cchiw |
2533 |
val (lft,code')=consfn((lftkrn),[],[],R,0)
|
201 : |
cchiw |
2536 |
|
202 : |
cchiw |
2543 |
|
203 : |
cchiw |
2605 |
|
204 : |
cchiw |
2543 |
fun evalK([],[],n,code,newId)=(newId,code)
|
205 : |
cchiw |
2536 |
| evalK(kn::kns,x::xs,n,code,newId)=let
|
206 : |
|
|
val (_,dk,_) =kn
|
207 : |
|
|
val (id,kcode)= evalKrn.expandEvalKernel (R', h, dk, x,n)
|
208 : |
|
|
in evalK(kns,xs,n+1,code@kcode,newId@[id])
|
209 : |
|
|
end
|
210 : |
|
|
|
211 : |
cchiw |
2543 |
val (ids, evalKcode)=evalK(k,lft,0,[],[])
|
212 : |
cchiw |
2536 |
|
213 : |
cchiw |
2529 |
in
|
214 : |
cchiw |
2543 |
(* (lft,code@code')*)
|
215 : |
|
|
(ids, code@code'@evalKcode)
|
216 : |
cchiw |
2529 |
end
|
217 : |
cchiw |
2536 |
|
218 : |
cchiw |
2529 |
(*Written for 2-d and 3-d*)
|
219 : |
|
|
fun prodImgKrn(imgArg,krnArg,R)=let
|
220 : |
cchiw |
2525 |
|
221 : |
cchiw |
2605 |
|
222 : |
cchiw |
2529 |
val tyM=DstTy.TensorTy[R,R]
|
223 : |
|
|
val tyV=DstTy.TensorTy[R]
|
224 : |
cchiw |
2605 |
val _=(case testing of 0=> 1
|
225 : |
|
|
| _ =>(print ("Number of Assignments in prodImgArg returned"^Int.toString(length(imgArg)));1))
|
226 : |
cchiw |
2525 |
|
227 : |
cchiw |
2529 |
fun dhz([],conslist,rest,code,_,_)=(conslist,code)
|
228 : |
|
|
| dhz(e::es,conslist,rest,code,hz,0)=let
|
229 : |
|
|
val (vA,A)=sumP(e,hz,R)
|
230 : |
cchiw |
2553 |
val (vD,D)=gHelper.aaV(DstOp.cons(DstTy.intTy,R),rest@[vA],"Cons",tyV)
|
231 : |
cchiw |
2529 |
in dhz(es,conslist@[vD],[],code@A@D,hz,R-1)
|
232 : |
cchiw |
2522 |
end
|
233 : |
cchiw |
2529 |
| dhz(e::es,conslist,rest,code,hz,r)=let
|
234 : |
|
|
val (vA,A)=sumP(e,hz,R)
|
235 : |
|
|
in dhz(es,conslist,rest@[vA],code@A,hz,r-1)
|
236 : |
|
|
end
|
237 : |
cchiw |
2522 |
|
238 : |
|
|
|
239 : |
cchiw |
2529 |
fun dhy([],rest,code,hy)= let
|
240 : |
cchiw |
2533 |
val n=length(rest)
|
241 : |
cchiw |
2553 |
val (vD,D)=gHelper.aaV(DstOp.cons(DstTy.intTy,n),rest,"Cons",tyV)
|
242 : |
cchiw |
2529 |
in
|
243 : |
|
|
(vD,code@D) end
|
244 : |
|
|
| dhy(e::es,rest,code,hy)=let
|
245 : |
|
|
val (vA,A)=sumP(e,hy,R)
|
246 : |
|
|
in dhy(es,rest@[vA],code@A,hy)
|
247 : |
|
|
end
|
248 : |
cchiw |
2522 |
|
249 : |
cchiw |
2529 |
in (case krnArg
|
250 : |
cchiw |
2536 |
of [hx]=>let
|
251 : |
|
|
val [i]=imgArg
|
252 : |
|
|
in sumP(i,hx,R)
|
253 : |
|
|
end
|
254 : |
|
|
|
255 : |
|
|
| [hy,hx]=>let
|
256 : |
cchiw |
2529 |
val ty=DstTy.TensorTy[R]
|
257 : |
cchiw |
2530 |
val (vD,code)=dhy(imgArg,[],[],hy)
|
258 : |
cchiw |
2529 |
val (vE,E)=sumP(vD,hx,R)
|
259 : |
|
|
in
|
260 : |
|
|
(vE,code@E)
|
261 : |
|
|
end
|
262 : |
|
|
| [hz,hy,hx]=>let
|
263 : |
|
|
|
264 : |
|
|
|
265 : |
|
|
val (vZ,codeZ)=dhz(imgArg,[],[],[],hz,R-1)
|
266 : |
|
|
val (vY,codeY)=dhy(vZ,[],[],hy)
|
267 : |
|
|
val (vE,E)=sumP(vY,hx,R)
|
268 : |
|
|
in
|
269 : |
|
|
(vE,codeZ@codeY@E)
|
270 : |
|
|
end
|
271 : |
|
|
|
272 : |
|
|
(*end case*))
|
273 : |
|
|
end
|
274 : |
|
|
|
275 : |
cchiw |
2553 |
fun evalField(mapp,(E.Sum(sx,E.Prod e),v,h,args))=let
|
276 : |
cchiw |
2605 |
val _=(case testing
|
277 : |
|
|
of 0 => 1
|
278 : |
|
|
| _ => (print "\n\n ************** new direction **********\n\n Outer Bound:";1)
|
279 : |
|
|
(*end test*))
|
280 : |
cchiw |
2530 |
|
281 : |
cchiw |
2553 |
|
282 : |
cchiw |
2529 |
val (img1,k1)=sortK([],[],e)
|
283 : |
|
|
val (_,lb,ub)=hd(sx)
|
284 : |
|
|
val R=(ub-lb)+1
|
285 : |
|
|
|
286 : |
cchiw |
2553 |
val (imgArg,imgCode)= mkImg(mapp,sx,img1,v,args)
|
287 : |
|
|
val (krnArg, krnCode)= mkkrns2(mapp,sx,k1,h,args)
|
288 : |
cchiw |
2529 |
val (vA,A)=prodImgKrn(imgArg,krnArg,R)
|
289 : |
|
|
in (vA,imgCode@krnCode@A)
|
290 : |
|
|
|
291 : |
cchiw |
2533 |
|
292 : |
cchiw |
2525 |
end
|
293 : |
cchiw |
2522 |
|
294 : |
|
|
|
295 : |
|
|
|
296 : |
|
|
end (* local *)
|
297 : |
|
|
|
298 : |
cchiw |
2529 |
end
|