Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3553 - (view) (download)

1 : cchiw 3553 (*evalKrn
2 : cchiw 3444 *Evaluate an EIN kernel expression to low-IL ops
3 :     *iterate over the range of the support, determine differentiation, and calls segements
4 : cchiw 3553 *)
5 : cchiw 3444 structure EvalKrnSet = struct
6 :     local
7 :     structure DstIL = LowIL
8 :     structure DstTy = LowILTypes
9 :     structure Var = LowIL.Var
10 :     structure E=Ein
11 :     structure H=HelperSet
12 :    
13 :     in
14 :    
15 : cchiw 3542
16 : cchiw 3444 fun lookup e =H.lookup e
17 :     fun insert e=H.insert e
18 :     fun find e=H.find e
19 :     fun intToReal n=H.intToReal n
20 :     fun assgnCons e=H.assgnCons e
21 :     fun indexTensor e=H.indexTensor e
22 :     fun mkAddVec e=H.mkAddVec e
23 :     fun mkSubSca e= H.mkSubSca e
24 :     fun mkProdVec e =H.mkProdVec e
25 : cchiw 3553 fun mkDotVec (setO,a,b,last) =H.mkDotVec (setO,last,[a,b])
26 : cchiw 3444 fun iTos n=Int.toString n
27 : cchiw 3553 fun err str=raise Fail (str)
28 : cchiw 3444 val realTy=DstTy.TensorTy []
29 :     val intTy=DstTy.IntTy
30 : cchiw 3542 val testing=false
31 : cchiw 3553 fun testp n =if (testing) then (print (String.concat n) ;1) else 1
32 : cchiw 3444
33 : cchiw 3553 (* convert a rational to a FloatLit.float value. We do this by long division
34 : cchiw 3444 * with a cutoff when we get to 12 digits.
35 : cchiw 3553 *)
36 :     fun ratToFloat r = (case Rational.explode r
37 : cchiw 3444 of {sign=0, ...} => FloatLit.zero false
38 : cchiw 3553 | {sign, num, denom=1} => FloatLit.fromInt (IntInf.fromInt sign * num)
39 : cchiw 3444 | {sign, num, denom} => let
40 : cchiw 3553 (* normalize so that num <= denom *)
41 :     val (denom, exp) = let
42 :     fun lp (n, denom) = if (denom < num)
43 :     then lp (n+1, denom*10)
44 :     else (denom, n)
45 : cchiw 3444 in
46 : cchiw 3553 lp (1, denom)
47 : cchiw 3444 end
48 : cchiw 3553 (* normalize so that num <= denom < 10*num *)
49 :     val (num, exp) = let
50 :     fun lp (n, num) = if (10*num < denom)
51 :     then lp (n-1, 10*num)
52 :     else (num, n)
53 : cchiw 3444 in
54 : cchiw 3553 lp (exp, num)
55 : cchiw 3444 end
56 : cchiw 3553 (* divide num/denom, computing the resulting digits *)
57 :     fun divLp (n, a) = let
58 :     val (q, r) = IntInf.divMod (a, denom)
59 : cchiw 3444 in
60 : cchiw 3553 if (r = 0) then (q, [])
61 :     else if (n < 12) then let
62 :     val (d, dd) = divLp (n+1, 10*r)
63 : cchiw 3444 in
64 : cchiw 3553 if (d < 10)
65 :     then (q, (IntInf.toInt d) ::dd)
66 :     else (q+1, 0::dd)
67 : cchiw 3444 end
68 : cchiw 3553 else if (IntInf.div (10*r, denom) < 5)
69 :     then (q, [])
70 :     else (q+1, []) (* round up *)
71 : cchiw 3444 end
72 :     val digits = let
73 : cchiw 3553 val (d, dd) = divLp (0, num)
74 : cchiw 3444 in
75 : cchiw 3553 (IntInf.toInt d) ::dd
76 : cchiw 3444 end
77 :     in
78 : cchiw 3553 FloatLit.fromDigits{isNeg= (sign < 0) , digits=digits, exp=exp}
79 : cchiw 3444 end
80 : cchiw 3553 (* end case *) )
81 : cchiw 3444
82 :    
83 : cchiw 3553 (* expand the EvalKernel operations into vector operations. The parameters
84 : cchiw 3444 * are
85 :     * result -- the lhs variable to store the result
86 :     * d -- the vector width of the operation, which should be equal
87 :     * to twice the support of the kernel
88 :     * h -- the kernel
89 :     * k -- the derivative of the kernel to evaluate
90 :     *
91 :     * The generated code is computing
92 :     *
93 : cchiw 3553 * result = a_0 + x* (a_1 + x* (a_2 + ... x*a_n) ... )
94 : cchiw 3444 *
95 :     * as a d-wide vector operation, where n is the degree of the kth derivative
96 :     * of h and the a_i are coefficient vectors that have an element for each
97 :     * piece of h. The computation is implemented as follows
98 :     *
99 :     * m_n = x * a_n
100 :     * s_{n-1} = a_{n-1} + m_n
101 :     * m_{n-1} = x * s_{n-1}
102 :     * s_{n-2} = a_{n-2} + m_{n-1}
103 :     * m_{n-2} = x * s_{n-2}
104 :     * ...
105 :     * s_1 = a_1 + m_2
106 :     * m_1 = x * s_1
107 :     * result = a_0 + m_1
108 :     *
109 : cchiw 3553 * Note that the coeffient vectors are flipped (cf high-to-low/probe.sml) .
110 :     *)
111 :     fun expandEvalKernel (setOrig,pre,d, h, k, x) = let
112 :     val {isCont, segs} = Kernel.curve (h, k)
113 :     (* degree of polynomial *)
114 :     val deg = List.length (hd segs) - 1
115 :     (*segs is length 2*support, inner list is listof poynomial*)
116 :     val segs = Vector.fromList (List.rev (List.map Vector.fromList segs))
117 : cchiw 3444 fun coefficient d i =
118 : cchiw 3553 Literal.Float (ratToFloat (Vector.sub (Vector.sub (segs, i) , d)))
119 : cchiw 3444 val ty = DstTy.vecTy d
120 : cchiw 3553 val coeffs = List.tabulate (deg+1,fn i => Var.new ("P"^Int.toString i, ty))
121 :     fun filterLit ([], vars, code, opset) = (vars, code, opset)
122 :     | filterLit ((lhs, rhs) ::es, vars, code, opset) = let
123 :     val (opset, var) = lowSet.filter (opset, (lhs, rhs))
124 :     in (case var
125 :     of NONE=> filterLit (es, vars@[lhs], code@[(lhs, rhs)], opset)
126 :     | SOME v =>filterLit (es, vars@[v], code, opset)
127 :     (*end case*) )
128 : cchiw 3542 end
129 : cchiw 3553 val (coeffVecs,setOrig) = let
130 :     fun mk (x, (i::xs, code,opset) ) = let
131 :     val lits = List.tabulate (d, coefficient i)
132 :     val vars = List.tabulate (d, fn _ => Var.new ("_f", DstTy.realTy))
133 :     val code0 = ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits)
134 :     val (vars, code0, opset) =filterLit (code0,[],[],opset)
135 :     val code = code@code0@[ (x, DstIL.CONS (Var.ty x, vars))]
136 : cchiw 3542 in
137 : cchiw 3553 (xs, code, opset)
138 : cchiw 3542 end
139 : cchiw 3553 val n= List.tabulate (deg+1, fn e=>e )
140 :     val (a,b,c) = (List.foldr mk (n, [],setOrig) (List.rev coeffs))
141 : cchiw 3444 in
142 : cchiw 3553 (b,c)
143 : cchiw 3444 end
144 : cchiw 3553 val coeffVecsCode= List.map (fn (x,y) =>DstIL.ASSGN (x,y) ) coeffVecs
145 :     fun getSet ([], done, opset, cnt) = (done, opset, cnt)
146 :     | getSet (DstIL.ASSGN (lhs, rhs) ::es, done, opset, cnt) = let
147 :     val (opset,var) = lowSet.filter (opset, (lhs, rhs) )
148 :     in (case var
149 :     of NONE => getSet (es,done@[DstIL.ASSGN (lhs, rhs) ], opset, cnt)
150 :     | SOME v=> getSet (es,done@[DstIL.ASSGN (lhs,DstIL.VAR v) ], opset, cnt+1)
151 :     (*end case*))
152 : cchiw 3444 end
153 : cchiw 3553 | getSet (e1::es, done, opset,cnt) =getSet (es,done@[e1],opset,cnt)
154 :     (*get dot product and addition of list of coeffs*)
155 :     fun mkdot (setO, [e2,e1], code) = let
156 :     val (setA, vA, A) = mkProdVec (setO, d, [x,e2])
157 :     val (setB, vB, B) = mkAddVec (setA, d, [e1,vA])
158 : cchiw 3444 in
159 : cchiw 3553 (setB,vB,code@A@B)
160 : cchiw 3444 end
161 : cchiw 3553 | mkdot (setO,e2::e1::es,code) = let
162 :     val (setA, vA, A) = mkProdVec (setO, d, [x,e2])
163 :     val (setB, vB, B) = mkAddVec (setA, d, [e1,vA])
164 : cchiw 3444 in
165 : cchiw 3553 mkdot (setB,vB::es,code@A@B)
166 : cchiw 3444 end
167 : cchiw 3553 | mkdot (setO, [e1], []) = mkProdVec (setO,d,[x,e1])
168 : cchiw 3444 | mkdot _ = raise Fail "0 or 1 item in Kernel coeffs"
169 : cchiw 3553 val (setC, vC, code) = mkdot (setOrig, List.rev coeffs, [])
170 :     val _ = (String.concat["\n coeffVecs code :",Int.toString (length (coeffVecsCode) ) ," other code code: ",Int.toString (length (code))])
171 : cchiw 3444 in
172 : cchiw 3553 (setC,vC,coeffVecsCode@code)
173 : cchiw 3444 end
174 :    
175 :    
176 : cchiw 3553 (*mkkrns:dict*string*E.params*Var List*sum_id list* (E.mu*E.mu) list*Kernel*int*int*int*int
177 : cchiw 3444 * kernels
178 :     * comments on functions
179 :     *
180 : cchiw 3553 * evalDels:dictionary* (E.mu*E.mu) list->int
181 : cchiw 3444 * evaluate each delta and therefore each differentiation level for each kernel
182 :     *
183 :     * mkSimpleOp:dict*string*E.params*Var list*E.body
184 :     * -> Var * LowIL.assign list
185 :     * turn position into low-IL op
186 :     *
187 : cchiw 3553 *mkpos: (E.kernel_id*int*E.pos) list* Var list* Var list *dict*int* LowIL.assign list* int
188 : cchiw 3444 * -> Var * LowIL.assign list
189 :     * bind summation indices by creating mapp and evaluate position
190 :     *
191 :     * consfn: Var list list*Var list*LowIL.assign list
192 :     * ->Var * LowIL.assign list
193 :     * con everything on the list, makes vectors
194 :     *
195 : cchiw 3553 * evalK: (E.kernel_id*int*E.pos) list* var list*int*int*param_id*LowIL.assign list
196 : cchiw 3444 * ->Var * LowIL.assign list
197 :     * evaluate kernel with segments
198 : cchiw 3553 *)
199 :     fun mkkrns (setOrig, mappOrig, lhs, params, args, krns, h, sid, lb, range0, range1) = let
200 : cchiw 3444
201 : cchiw 3553 fun evalDels (mapp, dels) = List.foldl (fn (x, y) =>x+y) 0 (List.map (fn (i, j) =>H.deltaToInt (mapp, i, j) ) dels)
202 : cchiw 3444
203 : cchiw 3553 fun mkSimpleOp (setO, mapp, lhs, params, args, E.Op2 (E.Sub,E.Tensor (t1,ix1) ,E.Value v1) ) = let
204 :     val (setA, vA, A) = indexTensor (setO, mapp, (lhs, params, args, t1, ix1, realTy) )
205 :     in (case (find (v1,mapp) )
206 :     of 0=> (setA, vA, A)
207 : cchiw 3444 | j=>let
208 : cchiw 3553 val (setB, vB, B) = intToReal (setA, j)
209 :     val (setC, vC, C) = mkSubSca (setB, [vA,vB])
210 :     in (setC, vC, A@B@C) end
211 :     (*end case*) )
212 : cchiw 3444 end
213 :    
214 : cchiw 3553 fun mkpos (setO, k, fin, l, dict, i, code, n) = (case (k, i)
215 :     of ([], _) => (setO,fin,code)
216 :     | ( (_, _, pos1) ::ks,0) => let
217 : cchiw 3444 val _=testp ["\n insert",iTos n, "->",iTos lb]
218 : cchiw 3553 val mapp=insert (n, lb) dict
219 :     val (opset',l', code') = mkSimpleOp (setO, mapp, lhs, params, args, pos1)
220 : cchiw 3444 val e=l@[l']
221 : cchiw 3553 val mapp'= insert (n, 0) dict
222 : cchiw 3444 in
223 : cchiw 3553 mkpos (opset', ks, fin@[e], [], mapp', range0, code@code', n+1)
224 : cchiw 3444 end
225 : cchiw 3553 | ( (_,_,pos1) ::es,_) => let
226 :     val _=testp ["\n insert",iTos n, "->",iTos (lb+i) ]
227 :     val mapp= insert (n, lb+i) dict
228 :     val (opset', l', code') = mkSimpleOp (setO, mapp, lhs, params, args, pos1)
229 : cchiw 3444 in
230 : cchiw 3553 mkpos (opset', k, fin, l@[l'], dict, i-1, code@code', n)
231 : cchiw 3444 end
232 : cchiw 3553 (*end case*) )
233 : cchiw 3444
234 : cchiw 3553 fun consfn (setO, [], _, rest, code) = (setO, rest, code)
235 :     | consfn (setO, e1::es, n, rest, code) = let
236 :     val (setA, vA, A) = assgnCons (setO, "h"^iTos n, [length (e1)], List.rev e1)
237 : cchiw 3444 in
238 : cchiw 3553 consfn (setA, es, n+1, rest@[vA], code@A)
239 : cchiw 3444 end
240 :    
241 : cchiw 3553 fun evalK (setO, [], [], newId, code) = (setO, newId, code)
242 :     | evalK (setO, kn::kns, x::xs, newId, code) = let
243 :     val (_, dk, pos) = kn
244 :     val directionX= (case pos
245 :     of E.Op2 (E.Sub,E.Tensor (_,[E.C directionX]) ,_) => directionX
246 : cchiw 3444 | _ => 0
247 : cchiw 3553 (*end case*) )
248 : cchiw 3444 val name=String.concat["h",iTos directionX,"_",iTos dk]
249 :     val _ =testp["\n",Var.toString x," = ",name]
250 : cchiw 3553 val (opsetK, id, kcode) = expandEvalKernel (setO, name, range1, h, dk, x)
251 : cchiw 3444 in
252 : cchiw 3553 evalK (opsetK, kns, xs, newId@[id], code@kcode)
253 : cchiw 3444 end
254 :     | evalK _ =raise Fail "Non-equal variable list, error in mkKrns"
255 :    
256 : cchiw 3553 val newkrns = List.map (fn (id, d1, pos) => (id, evalDels (mappOrig,d1) , pos) ) krns
257 : cchiw 3444 val _ = testp["\n\n ***** Differentiation value of kernels ****\n "]
258 : cchiw 3553 val (set2, lftkrn, poscode) = mkpos (setOrig, newkrns, [], [], mappOrig, range0, [], sid)
259 :     val (set3, lft, conscode) = consfn (set2, lftkrn, 0, [], [])
260 :     val (set4, ids, evalKcode) = evalK (set3, newkrns, lft, [], [])
261 :     val _ = List.map (fn e=>testp["\n IDS",Var.toString e,","]) ids
262 :     (*returns list in order h0, h1, h2*)
263 : cchiw 3444 in
264 : cchiw 3553 (set4, ids, poscode@conscode@evalKcode)
265 : cchiw 3444 end
266 :    
267 :    
268 : cchiw 3553 end (* local *)
269 : cchiw 3444
270 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0