Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3444 - (view) (download)

1 : cchiw 3444 (*evalKrn
2 :     *Evaluate an EIN kernel expression to low-IL ops
3 :     *iterate over the range of the support, determine differentiation, and calls segements
4 :     *)
5 :     structure EvalKrnSet = struct
6 :     local
7 :     structure DstIL = LowIL
8 :     structure DstTy = LowILTypes
9 :     structure Var = LowIL.Var
10 :     structure E=Ein
11 :     structure H=HelperSet
12 :    
13 :     in
14 :    
15 :     val testing=0
16 :     fun lookup e =H.lookup e
17 :     fun insert e=H.insert e
18 :     fun find e=H.find e
19 :     fun intToReal n=H.intToReal n
20 :     fun assgnCons e=H.assgnCons e
21 :     fun indexTensor e=H.indexTensor e
22 :     fun mkAddVec e=H.mkAddVec e
23 :     fun mkSubSca e= H.mkSubSca e
24 :     fun mkProdVec e =H.mkProdVec e
25 :     fun mkDotVec(setO,a,b,last)=H.mkDotVec(setO,last,[a,b])
26 :    
27 :     fun iTos n=Int.toString n
28 :     fun err str=raise Fail(str)
29 :     val realTy=DstTy.TensorTy []
30 :     val intTy=DstTy.IntTy
31 :     fun testp n =(case testing
32 :     of 0 => 1
33 :     | _ => (print(String.concat n);1)
34 :     (*end case *))
35 :    
36 :     (* convert a rational to a FloatLit.float value. We do this by long division
37 :     * with a cutoff when we get to 12 digits.
38 :     *)
39 :     fun ratToFloat r = (case Rational.explode r
40 :     of {sign=0, ...} => FloatLit.zero false
41 :     | {sign, num, denom=1} => FloatLit.fromInt(IntInf.fromInt sign * num)
42 :     | {sign, num, denom} => let
43 :     (* normalize so that num <= denom *)
44 :     val (denom, exp) = let
45 :     fun lp (n, denom) = if (denom < num)
46 :     then lp(n+1, denom*10)
47 :     else (denom, n)
48 :     in
49 :     lp (1, denom)
50 :     end
51 :     (* normalize so that num <= denom < 10*num *)
52 :     val (num, exp) = let
53 :     fun lp (n, num) = if (10*num < denom)
54 :     then lp(n-1, 10*num)
55 :     else (num, n)
56 :     in
57 :     lp (exp, num)
58 :     end
59 :     (* divide num/denom, computing the resulting digits *)
60 :     fun divLp (n, a) = let
61 :     val (q, r) = IntInf.divMod(a, denom)
62 :     in
63 :     if (r = 0) then (q, [])
64 :     else if (n < 12) then let
65 :     val (d, dd) = divLp(n+1, 10*r)
66 :     in
67 :     if (d < 10)
68 :     then (q, (IntInf.toInt d)::dd)
69 :     else (q+1, 0::dd)
70 :     end
71 :     else if (IntInf.div(10*r, denom) < 5)
72 :     then (q, [])
73 :     else (q+1, []) (* round up *)
74 :     end
75 :     val digits = let
76 :     val (d, dd) = divLp (0, num)
77 :     in
78 :     (IntInf.toInt d)::dd
79 :     end
80 :     in
81 :     FloatLit.fromDigits{isNeg=(sign < 0), digits=digits, exp=exp}
82 :     end
83 :     (* end case *))
84 :    
85 :    
86 :     (* expand the EvalKernel operations into vector operations. The parameters
87 :     * are
88 :     * result -- the lhs variable to store the result
89 :     * d -- the vector width of the operation, which should be equal
90 :     * to twice the support of the kernel
91 :     * h -- the kernel
92 :     * k -- the derivative of the kernel to evaluate
93 :     *
94 :     * The generated code is computing
95 :     *
96 :     * result = a_0 + x*(a_1 + x*(a_2 + ... x*a_n) ... )
97 :     *
98 :     * as a d-wide vector operation, where n is the degree of the kth derivative
99 :     * of h and the a_i are coefficient vectors that have an element for each
100 :     * piece of h. The computation is implemented as follows
101 :     *
102 :     * m_n = x * a_n
103 :     * s_{n-1} = a_{n-1} + m_n
104 :     * m_{n-1} = x * s_{n-1}
105 :     * s_{n-2} = a_{n-2} + m_{n-1}
106 :     * m_{n-2} = x * s_{n-2}
107 :     * ...
108 :     * s_1 = a_1 + m_2
109 :     * m_1 = x * s_1
110 :     * result = a_0 + m_1
111 :     *
112 :     * Note that the coeffient vectors are flipped (cf high-to-low/probe.sml).
113 :     *)
114 :     fun expandEvalKernel (setOrig,pre,d, h, k, x) = let
115 :     val _="\n*********************\n"
116 :     val {isCont, segs} = Kernel.curve (h, k)
117 :     (* degree of polynomial *)
118 :     val deg = List.length(hd segs) - 1
119 :     (*segs is length 2*support, inner list is listof poynomial*)
120 :     val segs = Vector.fromList (List.rev (List.map Vector.fromList segs))
121 :     fun coefficient d i =
122 :     Literal.Float(ratToFloat (Vector.sub (Vector.sub(segs, i), d)))
123 :     val ty = DstTy.vecTy d
124 :     val coeffs = List.tabulate (deg+1,fn i => Var.new("P"^Int.toString i, ty))
125 :    
126 :     (* code to define the coefficient vectors *)
127 :     (*
128 :     val coeffVecs = let
129 :     fun mk (x, (i, code)) = let
130 :     val lits = List.tabulate(d, coefficient i)
131 :     val vars = List.tabulate(d, fn _ => Var.new("_f", DstTy.realTy))
132 :     (* val code =
133 :     ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits) @(x, DstIL.CONS(Var.ty x, vars)) :: code
134 :     *)
135 :    
136 :     val code0 =ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits)
137 :     val code =code0@(x, DstIL.CONS(Var.ty x, vars)) :: code
138 :     val _ = List.map (fn e=>print(LowToString.toStringAll(LowILTypes.realTy, DstIL.ASSGN e))) [(x,DstIL.CONS(Var.ty x, vars))]
139 :     in
140 :     (i-1, code)
141 :     end
142 :     in
143 :     #2 (List.foldr mk (deg, []) coeffs)
144 :     end
145 :    
146 :    
147 :     *)
148 :    
149 :     fun filterLit([],vars,code,opset)=(vars,code,opset)
150 :     | filterLit((lhs,rhs)::es,vars,code,opset)=let
151 :    
152 :     val (opset,var) = lowSet.filter(opset,(lhs,rhs))
153 :     in (case var
154 :     of NONE=> filterLit(es,vars@[lhs], code@[(lhs, rhs)],opset)
155 :     | SOME v =>filterLit(es,vars@[v],code,opset)
156 :     (*end case*))
157 :     end
158 :    
159 :     val (coeffVecs,setOrig) = let
160 :     fun mk (x, (i::xs, code,opset)) = let
161 :     val lits = List.tabulate(d, coefficient i)
162 :     val vars = List.tabulate(d, fn _ => Var.new("_f", DstTy.realTy))
163 :    
164 :    
165 :     val code0 =ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits)
166 :     val (vars,code0,opset)=filterLit(code0,[],[],opset)
167 :    
168 :     val code =code@code0@[(x, DstIL.CONS(Var.ty x, vars))]
169 :    
170 :    
171 :    
172 :     val _ = List.map (fn e=>(LowToString.toStringAll(LowILTypes.realTy, DstIL.ASSGN e))) [(x,DstIL.CONS(Var.ty x, vars))]
173 :     in
174 :     (xs, code,opset)
175 :     end
176 :     val n=List.tabulate(deg+1, fn e=>e )
177 :     val (a,b,c)=(List.foldr mk (n, [],setOrig) (List.rev coeffs))
178 :     in
179 :     (b,c)
180 :     end
181 :    
182 :    
183 :    
184 :    
185 :     val coeffVecsCode=List.map (fn(x,y)=>DstIL.ASSGN (x,y)) coeffVecs
186 :    
187 :    
188 :     fun getSet([],done,opset,cnt)=(done,opset,cnt)
189 :     | getSet( DstIL.ASSGN(lhs,rhs)::es,done,opset,cnt)=let
190 :     val (opset,var) = lowSet.filter(opset,(lhs,rhs))
191 :     in (case var
192 :     of NONE => getSet(es,done@[DstIL.ASSGN(lhs,rhs)], opset,cnt)
193 :     | SOME v=> (("TASH:replacing"^DstIL.Var.toString(lhs));getSet(es,done@[DstIL.ASSGN(lhs,DstIL.VAR v)], opset,cnt+1))
194 :     (*end case*))
195 :     end
196 :     | getSet (e1::es, done, opset,cnt)=getSet(es,done@[e1],opset,cnt)
197 :     (*
198 :     val (coeffVecsCode,setOrig,cnt)=getSet(coeffVecsCode, [],setOrig,0)
199 :    
200 :     val _ =("\nTASHreplaced"^Int.toString(cnt))
201 :     *)
202 :    
203 :    
204 :    
205 :     val _ = List.map (fn e=>(LowToString.toStringAll(LowILTypes.realTy,e))) coeffVecsCode
206 :    
207 :    
208 :     (*get dot product and addition of list of coeffs*)
209 :     fun mkdot(setO,[e2,e1],code)=let
210 :     val (setA,vA,A)= mkProdVec(setO,d,[x,e2])
211 :     val (setB,vB,B)=mkAddVec(setA,d,[e1,vA])
212 :     in
213 :     (setB,vB,code@A@B)
214 :     end
215 :     | mkdot(setO,e2::e1::es,code)= let
216 :     val (setA,vA,A)= mkProdVec(setO,d,[x,e2])
217 :     val (setB,vB,B)=mkAddVec(setA,d,[e1,vA])
218 :     in
219 :     mkdot(setB,vB::es,code@A@B)
220 :     end
221 :     | mkdot (setO,[e1],[])= mkProdVec(setO,d,[x,e1])
222 :     | mkdot _ = raise Fail "0 or 1 item in Kernel coeffs"
223 :     val (setC,vC,code)= mkdot(setOrig,List.rev coeffs, [])
224 :     val _ =(String.concat["\n coeffVecs code :",Int.toString(length(coeffVecsCode))," other code code: ",Int.toString(length(code))])
225 :     in
226 :     (setC,vC,coeffVecsCode@code)
227 :     end
228 :    
229 :    
230 :     (*mkkrns:dict*string*E.params*Var List*sum_id list*(E.mu*E.mu) list*Kernel*int*int*int*int
231 :     * kernels
232 :     * comments on functions
233 :     *
234 :     * evalDels:dictionary*(E.mu*E.mu)list->int
235 :     * evaluate each delta and therefore each differentiation level for each kernel
236 :     *
237 :     * mkSimpleOp:dict*string*E.params*Var list*E.body
238 :     * -> Var * LowIL.assign list
239 :     * turn position into low-IL op
240 :     *
241 :     *mkpos:(E.kernel_id*int*E.pos)list* Var list* Var list *dict*int* LowIL.assign list* int
242 :     * -> Var * LowIL.assign list
243 :     * bind summation indices by creating mapp and evaluate position
244 :     *
245 :     * consfn: Var list list*Var list*LowIL.assign list
246 :     * ->Var * LowIL.assign list
247 :     * con everything on the list, makes vectors
248 :     *
249 :     * evalK:(E.kernel_id*int*E.pos)list* var list*int*int*param_id*LowIL.assign list
250 :     * ->Var * LowIL.assign list
251 :     * evaluate kernel with segments
252 :     *)
253 :     fun mkkrns(setOrig,mappOrig,lhs,params,args, krns,h,sid,lb,range0,range1)=let
254 :    
255 :     fun evalDels(mapp,dels)=List.foldl(fn(x,y)=>x+y) 0 (List.map (fn(i,j)=>H.deltaToInt(mapp,i,j)) dels)
256 :    
257 :     fun mkSimpleOp(setO,mapp,lhs,params,args, E.Sub(E.Tensor(t1,ix1),E.Value v1))=let
258 :     val (setA,vA,A)=indexTensor(setO,mapp,(lhs,params,args, t1,ix1,realTy))
259 :     in (case (find(v1,mapp))
260 :     of 0=>(setA,vA,A)
261 :     | j=>let
262 :     val (setB,vB,B)= intToReal(setA, j)
263 :     val (setC,vC,C)=mkSubSca(setB, [vA,vB])
264 :     in (setC,vC,A@B@C) end
265 :     (*end case*))
266 :     end
267 :    
268 :     fun mkpos(setO,k,fin,l,dict, i,code,n)= (case (k,i)
269 :     of ([],_)=>(setO,fin,code)
270 :     | ((_,_,pos1)::ks,0)=>let
271 :     val _=testp ["\n insert",iTos n, "->",iTos lb]
272 :     val mapp=insert (n, lb) dict
273 :     val (opset',l', code')=mkSimpleOp(setO,mapp,lhs,params,args, pos1)
274 :     val e=l@[l']
275 :     val mapp'=insert (n, 0) dict
276 :     in
277 :     mkpos(opset',ks,fin@[e],[],mapp',range0,code@code',n+1)
278 :     end
279 :     | ((_,_,pos1)::es,_) =>let
280 :     val _=testp ["\n insert",iTos n, "->",iTos (lb+i)]
281 :     val mapp= insert (n, lb+i) dict
282 :     val (opset',l', code')=mkSimpleOp(setO,mapp,lhs,params,args, pos1)
283 :     in
284 :     mkpos(opset',k,fin,l@[l'],dict,i-1,code@code',n)
285 :     end
286 :     (*end case*))
287 :    
288 :     fun consfn(setO,[],_,rest, code)=(setO,rest,code)
289 :     | consfn(setO,e1::es,n,rest,code)=let
290 :     val (setA, vA,A)= assgnCons(setO, "h"^iTos n,[length(e1)],List.rev e1)
291 :     in
292 :     consfn(setA,es, n+1,rest@[vA], code@A)
293 :     end
294 :    
295 :     fun evalK(setO,[],[],newId,code)=(setO,newId,code)
296 :     | evalK(setO,kn::kns,x::xs,newId,code)=let
297 :     val (_,dk,pos) =kn
298 :     val directionX= (case pos
299 :     of E.Sub(E.Tensor (_,[E.C directionX]),_)=>directionX
300 :     | _ => 0
301 :     (*end case*))
302 :     val name=String.concat["h",iTos directionX,"_",iTos dk]
303 :     val _ =testp["\n",Var.toString x," = ",name]
304 :     val (opsetK,id,kcode)= expandEvalKernel (setO,name,range1,h,dk, x)
305 :     in
306 :     evalK(opsetK,kns,xs,newId@[id],code@kcode)
307 :     end
308 :     | evalK _ =raise Fail "Non-equal variable list, error in mkKrns"
309 :    
310 :     val newkrns= List.map (fn (id,d1,pos)=>(id,evalDels(mappOrig,d1),pos)) krns
311 :     val _ = testp["\n\n ***** Differentiation value of kernels ****\n "]
312 :     val (set2,lftkrn,poscode)=mkpos(setOrig,newkrns,[],[],mappOrig,range0,[],sid)
313 :     val (set3,lft,conscode)=consfn(set2,lftkrn,0,[],[])
314 :     val (set4,ids, evalKcode)=evalK(set3,newkrns, lft,[],[])
315 :     val _ = List.map (fn e=>testp["\n IDS",Var.toString e,","]) ids
316 :     (*returns list in order h0, h1, h2*)
317 :     in
318 :     (set4, ids, poscode@conscode@evalKcode)
319 :     end
320 :    
321 :    
322 :     end (* local *)
323 :    
324 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0