Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3624 - (view) (download)

1 : jhr 3624 (* evalKrn-set.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 : cchiw 3553 (*evalKrn
10 : cchiw 3444 *Evaluate an EIN kernel expression to low-IL ops
11 :     *iterate over the range of the support, determine differentiation, and calls segements
12 : cchiw 3553 *)
13 : cchiw 3602 structure EvalKrn = struct
14 : cchiw 3444 local
15 :     structure DstIL = LowIL
16 :     structure DstTy = LowILTypes
17 :     structure Var = LowIL.Var
18 :     structure E=Ein
19 : cchiw 3602 structure H=Helper
20 :     structure IMap = IntRedBlackMap
21 : cchiw 3444 in
22 : cchiw 3542
23 : cchiw 3444 fun find e=H.find e
24 :     fun intToReal n=H.intToReal n
25 :     fun assgnCons e=H.assgnCons e
26 :     fun indexTensor e=H.indexTensor e
27 :     fun mkAddVec e=H.mkAddVec e
28 :     fun mkSubSca e= H.mkSubSca e
29 :     fun mkProdVec e =H.mkProdVec e
30 : cchiw 3602 fun mkDotVec (avail,a,b,last) =H.mkDotVec (avail,last,[a,b])
31 : cchiw 3553 fun err str=raise Fail (str)
32 : cchiw 3602 fun lookup k d = IMap.find (d, k)
33 :     fun insert (k, v) d = IMap.insert (d, k, v)
34 : cchiw 3553 (* convert a rational to a FloatLit.float value. We do this by long division
35 : cchiw 3444 * with a cutoff when we get to 12 digits.
36 : cchiw 3553 *)
37 :     fun ratToFloat r = (case Rational.explode r
38 : cchiw 3444 of {sign=0, ...} => FloatLit.zero false
39 : cchiw 3553 | {sign, num, denom=1} => FloatLit.fromInt (IntInf.fromInt sign * num)
40 : cchiw 3444 | {sign, num, denom} => let
41 : cchiw 3553 (* normalize so that num <= denom *)
42 :     val (denom, exp) = let
43 :     fun lp (n, denom) = if (denom < num)
44 :     then lp (n+1, denom*10)
45 :     else (denom, n)
46 : cchiw 3444 in
47 : cchiw 3553 lp (1, denom)
48 : cchiw 3444 end
49 : cchiw 3553 (* normalize so that num <= denom < 10*num *)
50 :     val (num, exp) = let
51 :     fun lp (n, num) = if (10*num < denom)
52 :     then lp (n-1, 10*num)
53 :     else (num, n)
54 : cchiw 3444 in
55 : cchiw 3553 lp (exp, num)
56 : cchiw 3444 end
57 : cchiw 3553 (* divide num/denom, computing the resulting digits *)
58 :     fun divLp (n, a) = let
59 :     val (q, r) = IntInf.divMod (a, denom)
60 : cchiw 3444 in
61 : cchiw 3553 if (r = 0) then (q, [])
62 :     else if (n < 12) then let
63 :     val (d, dd) = divLp (n+1, 10*r)
64 : cchiw 3444 in
65 : cchiw 3553 if (d < 10)
66 :     then (q, (IntInf.toInt d) ::dd)
67 :     else (q+1, 0::dd)
68 : cchiw 3444 end
69 : cchiw 3553 else if (IntInf.div (10*r, denom) < 5)
70 :     then (q, [])
71 :     else (q+1, []) (* round up *)
72 : cchiw 3444 end
73 :     val digits = let
74 : cchiw 3553 val (d, dd) = divLp (0, num)
75 : cchiw 3444 in
76 : cchiw 3553 (IntInf.toInt d) ::dd
77 : cchiw 3444 end
78 :     in
79 : cchiw 3553 FloatLit.fromDigits{isNeg= (sign < 0) , digits=digits, exp=exp}
80 : cchiw 3444 end
81 : cchiw 3553 (* end case *) )
82 : cchiw 3444
83 :    
84 : cchiw 3553 (* expand the EvalKernel operations into vector operations. The parameters
85 : cchiw 3444 * are
86 :     * result -- the lhs variable to store the result
87 :     * d -- the vector width of the operation, which should be equal
88 :     * to twice the support of the kernel
89 :     * h -- the kernel
90 :     * k -- the derivative of the kernel to evaluate
91 :     *
92 :     * The generated code is computing
93 :     *
94 : cchiw 3553 * result = a_0 + x* (a_1 + x* (a_2 + ... x*a_n) ... )
95 : cchiw 3444 *
96 :     * as a d-wide vector operation, where n is the degree of the kth derivative
97 :     * of h and the a_i are coefficient vectors that have an element for each
98 :     * piece of h. The computation is implemented as follows
99 :     *
100 :     * m_n = x * a_n
101 :     * s_{n-1} = a_{n-1} + m_n
102 :     * m_{n-1} = x * s_{n-1}
103 :     * s_{n-2} = a_{n-2} + m_{n-1}
104 :     * m_{n-2} = x * s_{n-2}
105 :     * ...
106 :     * s_1 = a_1 + m_2
107 :     * m_1 = x * s_1
108 :     * result = a_0 + m_1
109 :     *
110 : cchiw 3553 * Note that the coeffient vectors are flipped (cf high-to-low/probe.sml) .
111 :     *)
112 : cchiw 3602 fun expandEvalKernel (avail,pre,d, h, k, x) = let
113 :    
114 :     val {isCont, segs} = Kernel.curve (h, k)
115 :     (* degree of polynomial *)
116 :     val deg = List.length (hd segs) - 1
117 :     (*segs is length 2*support, inner list is listof poynomial*)
118 :     val segs = Vector.fromList (List.rev (List.map Vector.fromList segs))
119 :     fun coefficient d i =
120 :     Literal.Float (ratToFloat (Vector.sub (Vector.sub (segs, i) , d)))
121 :     val ty = DstTy.vecTy d
122 :     val coeffs = List.tabulate (deg+1,fn i => Var.new ("P"^Int.toString i, ty))
123 :     val (_, avail, coeffs) = let
124 :     fun mk (x, (i::xs, avail,current) ) = let
125 :     val lits = List.tabulate (d, coefficient i)
126 :     val vars = List.tabulate (d, fn _ => Var.new ("_f", DstTy.TensorTy []))
127 :     val vars= ListPair.map (fn (x, lit) => (AvailRHS.addAssign avail (x, DstIL.LIT lit))) (vars, lits)
128 :     val var = AvailRHS.addAssign avail (x, DstIL.CONS (Var.ty x, vars))
129 :     in
130 :     (xs, avail,var::current)
131 : cchiw 3542 end
132 : cchiw 3444 in
133 : cchiw 3602 (List.foldr mk (List.tabulate (deg+1, fn e=>e ), avail,[]) (List.rev coeffs))
134 : cchiw 3444 end
135 : cchiw 3602
136 :     (*get dot product and addition of list of coeffs*)
137 :     fun mkdot (avail, [e2,e1]) = let
138 :     val (avail, vA) = mkProdVec (avail, d, [x,e2])
139 :     in mkAddVec (avail, d, [e1,vA]) end
140 :     | mkdot (avail,e2::e1::es) = let
141 :     val (avail, vA) = mkProdVec (avail, d, [x,e2])
142 :     val (avail, vB) = mkAddVec (avail, d, [e1,vA])
143 : cchiw 3444 in
144 : cchiw 3602 mkdot (avail,vB::es)
145 : cchiw 3444 end
146 : cchiw 3602 | mkdot (avail, [e1]) = mkProdVec (avail,d,[x,e1])
147 : cchiw 3444 | mkdot _ = raise Fail "0 or 1 item in Kernel coeffs"
148 : cchiw 3602
149 : cchiw 3444 in
150 : cchiw 3602 mkdot (avail, coeffs)
151 : cchiw 3444 end
152 :    
153 :    
154 : cchiw 3553 (*mkkrns:dict*string*E.params*Var List*sum_id list* (E.mu*E.mu) list*Kernel*int*int*int*int
155 : cchiw 3444 * kernels
156 :     * comments on functions
157 :     *
158 : cchiw 3553 * evalDels:dictionary* (E.mu*E.mu) list->int
159 : cchiw 3444 * evaluate each delta and therefore each differentiation level for each kernel
160 :     *
161 :     * mkSimpleOp:dict*string*E.params*Var list*E.body
162 :     * -> Var * LowIL.assign list
163 :     * turn position into low-IL op
164 :     *
165 : cchiw 3553 *mkpos: (E.kernel_id*int*E.pos) list* Var list* Var list *dict*int* LowIL.assign list* int
166 : cchiw 3444 * -> Var * LowIL.assign list
167 :     * bind summation indices by creating mapp and evaluate position
168 :     *
169 :     * consfn: Var list list*Var list*LowIL.assign list
170 :     * ->Var * LowIL.assign list
171 :     * con everything on the list, makes vectors
172 :     *
173 : cchiw 3553 * evalK: (E.kernel_id*int*E.pos) list* var list*int*int*param_id*LowIL.assign list
174 : cchiw 3444 * ->Var * LowIL.assign list
175 :     * evaluate kernel with segments
176 : cchiw 3553 *)
177 : cchiw 3602 fun mkkrns (avail, mappOrig,params, args, krns, h, sid, lb, range0, range1) = let
178 : cchiw 3444
179 : cchiw 3553 fun evalDels (mapp, dels) = List.foldl (fn (x, y) =>x+y) 0 (List.map (fn (i, j) =>H.deltaToInt (mapp, i, j) ) dels)
180 : cchiw 3602 fun mkSimpleOp (avail, mapp, params, args, E.Op2 (E.Sub,E.Tensor (t1,ix1) ,E.Value v1) ) = let
181 :     val (avail, vA) = indexTensor (avail, mapp, ( params, args, t1, ix1, DstTy.TensorTy []) )
182 : cchiw 3553 in (case (find (v1,mapp) )
183 : cchiw 3602 of 0=> (avail, vA)
184 : cchiw 3444 | j=>let
185 : cchiw 3602 val (avail, vB) = intToReal (avail, j)
186 :     in mkSubSca (avail, [vA,vB]) end
187 : cchiw 3553 (*end case*) )
188 : cchiw 3444 end
189 : cchiw 3602 fun mkpos (avail, k, fin, rest, dict, i, n) = (case (k, i)
190 :     of ([], _) => (avail, List.rev fin)
191 : cchiw 3553 | ( (_, _, pos1) ::ks,0) => let
192 :     val mapp=insert (n, lb) dict
193 : cchiw 3602 val (avail,rest') = mkSimpleOp (avail, mapp, params, args, pos1)
194 :     val e=rest@[rest']
195 : cchiw 3553 val mapp'= insert (n, 0) dict
196 : cchiw 3444 in
197 : cchiw 3602 mkpos (avail, ks, e::fin, [], mapp', range0, n+1)
198 : cchiw 3444 end
199 : cchiw 3553 | ( (_,_,pos1) ::es,_) => let
200 :     val mapp= insert (n, lb+i) dict
201 : cchiw 3602 val (avail, rest') = mkSimpleOp (avail, mapp, params, args, pos1)
202 : cchiw 3444 in
203 : cchiw 3602 mkpos (avail, k, fin, rest@[rest'], dict, i-1, n)
204 : cchiw 3444 end
205 : cchiw 3553 (*end case*) )
206 : cchiw 3444
207 : cchiw 3602 fun consfn (avail, [], _, rest) = (avail, List.rev rest)
208 :     | consfn (avail, e1::es, n, rest) = let
209 :     val (avail, vA) = assgnCons (avail, "h"^Int.toString(n), [length (e1)], List.rev e1)
210 : cchiw 3444 in
211 : cchiw 3602 consfn (avail, es, n+1, vA::rest)
212 : cchiw 3444 end
213 :    
214 : cchiw 3602 fun evalK (avail, [], [], newId) = (avail, List.rev newId)
215 :     | evalK (avail, kn::kns, x::xs, newId) = let
216 : cchiw 3553 val (_, dk, pos) = kn
217 :     val directionX= (case pos
218 :     of E.Op2 (E.Sub,E.Tensor (_,[E.C directionX]) ,_) => directionX
219 : cchiw 3444 | _ => 0
220 : cchiw 3553 (*end case*) )
221 : cchiw 3602 val name=String.concat["h",Int.toString(directionX),"_",Int.toString dk]
222 :     val (avail, id) = expandEvalKernel (avail, name, range1, h, dk, x)
223 : cchiw 3444 in
224 : cchiw 3602 evalK (avail, kns, xs, id::newId)
225 : cchiw 3444 end
226 :     | evalK _ =raise Fail "Non-equal variable list, error in mkKrns"
227 : cchiw 3602
228 : cchiw 3553 val newkrns = List.map (fn (id, d1, pos) => (id, evalDels (mappOrig,d1) , pos) ) krns
229 : cchiw 3602 val (avail, lftkrn) = mkpos (avail, newkrns, [], [], mappOrig, range0, sid)
230 :     val (avail, lft) = consfn (avail, lftkrn, 0, [])
231 : cchiw 3444 in
232 : cchiw 3602 evalK (avail, newkrns, lft, [])
233 : cchiw 3444 end
234 : cchiw 3602
235 : cchiw 3553 end (* local *)
236 : cchiw 3444
237 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0