Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3542 - (view) (download)

1 : cchiw 3444 (*evalKrn
2 :     *Evaluate an EIN kernel expression to low-IL ops
3 :     *iterate over the range of the support, determine differentiation, and calls segements
4 :     *)
5 :     structure EvalKrnSet = struct
6 :     local
7 :     structure DstIL = LowIL
8 :     structure DstTy = LowILTypes
9 :     structure Var = LowIL.Var
10 :     structure E=Ein
11 :     structure H=HelperSet
12 :    
13 :     in
14 :    
15 : cchiw 3542
16 : cchiw 3444 fun lookup e =H.lookup e
17 :     fun insert e=H.insert e
18 :     fun find e=H.find e
19 :     fun intToReal n=H.intToReal n
20 :     fun assgnCons e=H.assgnCons e
21 :     fun indexTensor e=H.indexTensor e
22 :     fun mkAddVec e=H.mkAddVec e
23 :     fun mkSubSca e= H.mkSubSca e
24 :     fun mkProdVec e =H.mkProdVec e
25 :     fun mkDotVec(setO,a,b,last)=H.mkDotVec(setO,last,[a,b])
26 :     fun iTos n=Int.toString n
27 :     fun err str=raise Fail(str)
28 :     val realTy=DstTy.TensorTy []
29 :     val intTy=DstTy.IntTy
30 : cchiw 3542 val testing=false
31 :     fun testp n =if (testing) then (print(String.concat n);1) else 1
32 : cchiw 3444
33 :     (* convert a rational to a FloatLit.float value. We do this by long division
34 :     * with a cutoff when we get to 12 digits.
35 :     *)
36 :     fun ratToFloat r = (case Rational.explode r
37 :     of {sign=0, ...} => FloatLit.zero false
38 :     | {sign, num, denom=1} => FloatLit.fromInt(IntInf.fromInt sign * num)
39 :     | {sign, num, denom} => let
40 :     (* normalize so that num <= denom *)
41 :     val (denom, exp) = let
42 :     fun lp (n, denom) = if (denom < num)
43 :     then lp(n+1, denom*10)
44 :     else (denom, n)
45 :     in
46 :     lp (1, denom)
47 :     end
48 :     (* normalize so that num <= denom < 10*num *)
49 :     val (num, exp) = let
50 :     fun lp (n, num) = if (10*num < denom)
51 :     then lp(n-1, 10*num)
52 :     else (num, n)
53 :     in
54 :     lp (exp, num)
55 :     end
56 :     (* divide num/denom, computing the resulting digits *)
57 :     fun divLp (n, a) = let
58 :     val (q, r) = IntInf.divMod(a, denom)
59 :     in
60 :     if (r = 0) then (q, [])
61 :     else if (n < 12) then let
62 :     val (d, dd) = divLp(n+1, 10*r)
63 :     in
64 :     if (d < 10)
65 :     then (q, (IntInf.toInt d)::dd)
66 :     else (q+1, 0::dd)
67 :     end
68 :     else if (IntInf.div(10*r, denom) < 5)
69 :     then (q, [])
70 :     else (q+1, []) (* round up *)
71 :     end
72 :     val digits = let
73 :     val (d, dd) = divLp (0, num)
74 :     in
75 :     (IntInf.toInt d)::dd
76 :     end
77 :     in
78 :     FloatLit.fromDigits{isNeg=(sign < 0), digits=digits, exp=exp}
79 :     end
80 :     (* end case *))
81 :    
82 :    
83 :     (* expand the EvalKernel operations into vector operations. The parameters
84 :     * are
85 :     * result -- the lhs variable to store the result
86 :     * d -- the vector width of the operation, which should be equal
87 :     * to twice the support of the kernel
88 :     * h -- the kernel
89 :     * k -- the derivative of the kernel to evaluate
90 :     *
91 :     * The generated code is computing
92 :     *
93 :     * result = a_0 + x*(a_1 + x*(a_2 + ... x*a_n) ... )
94 :     *
95 :     * as a d-wide vector operation, where n is the degree of the kth derivative
96 :     * of h and the a_i are coefficient vectors that have an element for each
97 :     * piece of h. The computation is implemented as follows
98 :     *
99 :     * m_n = x * a_n
100 :     * s_{n-1} = a_{n-1} + m_n
101 :     * m_{n-1} = x * s_{n-1}
102 :     * s_{n-2} = a_{n-2} + m_{n-1}
103 :     * m_{n-2} = x * s_{n-2}
104 :     * ...
105 :     * s_1 = a_1 + m_2
106 :     * m_1 = x * s_1
107 :     * result = a_0 + m_1
108 :     *
109 :     * Note that the coeffient vectors are flipped (cf high-to-low/probe.sml).
110 :     *)
111 :     fun expandEvalKernel (setOrig,pre,d, h, k, x) = let
112 :     val {isCont, segs} = Kernel.curve (h, k)
113 :     (* degree of polynomial *)
114 :     val deg = List.length(hd segs) - 1
115 :     (*segs is length 2*support, inner list is listof poynomial*)
116 :     val segs = Vector.fromList (List.rev (List.map Vector.fromList segs))
117 :     fun coefficient d i =
118 :     Literal.Float(ratToFloat (Vector.sub (Vector.sub(segs, i), d)))
119 :     val ty = DstTy.vecTy d
120 :     val coeffs = List.tabulate (deg+1,fn i => Var.new("P"^Int.toString i, ty))
121 :     fun filterLit([],vars,code,opset)=(vars,code,opset)
122 : cchiw 3542 | filterLit((lhs,rhs)::es,vars,code,opset)=let
123 :     val (opset,var) = lowSet.filter(opset,(lhs,rhs))
124 :     in (case var
125 :     of NONE=> filterLit(es,vars@[lhs], code@[(lhs, rhs)],opset)
126 :     | SOME v =>filterLit(es,vars@[v],code,opset)
127 :     (*end case*))
128 :     end
129 : cchiw 3444 val (coeffVecs,setOrig) = let
130 : cchiw 3542 fun mk (x, (i::xs, code,opset)) = let
131 :     val lits = List.tabulate(d, coefficient i)
132 :     val vars = List.tabulate(d, fn _ => Var.new("_f", DstTy.realTy))
133 :     val code0 =ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits)
134 :     val (vars,code0,opset)=filterLit(code0,[],[],opset)
135 :     val code =code@code0@[(x, DstIL.CONS(Var.ty x, vars))]
136 :     in
137 :     (xs, code,opset)
138 :     end
139 :     val n=List.tabulate(deg+1, fn e=>e )
140 :     val (a,b,c)=(List.foldr mk (n, [],setOrig) (List.rev coeffs))
141 : cchiw 3444 in
142 : cchiw 3542 (b,c)
143 : cchiw 3444 end
144 :     val coeffVecsCode=List.map (fn(x,y)=>DstIL.ASSGN (x,y)) coeffVecs
145 :     fun getSet([],done,opset,cnt)=(done,opset,cnt)
146 :     | getSet( DstIL.ASSGN(lhs,rhs)::es,done,opset,cnt)=let
147 :     val (opset,var) = lowSet.filter(opset,(lhs,rhs))
148 :     in (case var
149 :     of NONE => getSet(es,done@[DstIL.ASSGN(lhs,rhs)], opset,cnt)
150 :     | SOME v=> (("TASH:replacing"^DstIL.Var.toString(lhs));getSet(es,done@[DstIL.ASSGN(lhs,DstIL.VAR v)], opset,cnt+1))
151 :     (*end case*))
152 :     end
153 :     | getSet (e1::es, done, opset,cnt)=getSet(es,done@[e1],opset,cnt)
154 :     (*get dot product and addition of list of coeffs*)
155 :     fun mkdot(setO,[e2,e1],code)=let
156 :     val (setA,vA,A)= mkProdVec(setO,d,[x,e2])
157 :     val (setB,vB,B)=mkAddVec(setA,d,[e1,vA])
158 :     in
159 :     (setB,vB,code@A@B)
160 :     end
161 :     | mkdot(setO,e2::e1::es,code)= let
162 :     val (setA,vA,A)= mkProdVec(setO,d,[x,e2])
163 :     val (setB,vB,B)=mkAddVec(setA,d,[e1,vA])
164 :     in
165 :     mkdot(setB,vB::es,code@A@B)
166 :     end
167 :     | mkdot (setO,[e1],[])= mkProdVec(setO,d,[x,e1])
168 :     | mkdot _ = raise Fail "0 or 1 item in Kernel coeffs"
169 :     val (setC,vC,code)= mkdot(setOrig,List.rev coeffs, [])
170 :     val _ =(String.concat["\n coeffVecs code :",Int.toString(length(coeffVecsCode))," other code code: ",Int.toString(length(code))])
171 :     in
172 :     (setC,vC,coeffVecsCode@code)
173 :     end
174 :    
175 :    
176 :     (*mkkrns:dict*string*E.params*Var List*sum_id list*(E.mu*E.mu) list*Kernel*int*int*int*int
177 :     * kernels
178 :     * comments on functions
179 :     *
180 :     * evalDels:dictionary*(E.mu*E.mu)list->int
181 :     * evaluate each delta and therefore each differentiation level for each kernel
182 :     *
183 :     * mkSimpleOp:dict*string*E.params*Var list*E.body
184 :     * -> Var * LowIL.assign list
185 :     * turn position into low-IL op
186 :     *
187 :     *mkpos:(E.kernel_id*int*E.pos)list* Var list* Var list *dict*int* LowIL.assign list* int
188 :     * -> Var * LowIL.assign list
189 :     * bind summation indices by creating mapp and evaluate position
190 :     *
191 :     * consfn: Var list list*Var list*LowIL.assign list
192 :     * ->Var * LowIL.assign list
193 :     * con everything on the list, makes vectors
194 :     *
195 :     * evalK:(E.kernel_id*int*E.pos)list* var list*int*int*param_id*LowIL.assign list
196 :     * ->Var * LowIL.assign list
197 :     * evaluate kernel with segments
198 :     *)
199 :     fun mkkrns(setOrig,mappOrig,lhs,params,args, krns,h,sid,lb,range0,range1)=let
200 :    
201 :     fun evalDels(mapp,dels)=List.foldl(fn(x,y)=>x+y) 0 (List.map (fn(i,j)=>H.deltaToInt(mapp,i,j)) dels)
202 :    
203 : cchiw 3448 fun mkSimpleOp(setO,mapp,lhs,params,args, E.Op2(E.Sub,E.Tensor(t1,ix1),E.Value v1))=let
204 : cchiw 3444 val (setA,vA,A)=indexTensor(setO,mapp,(lhs,params,args, t1,ix1,realTy))
205 :     in (case (find(v1,mapp))
206 :     of 0=>(setA,vA,A)
207 :     | j=>let
208 :     val (setB,vB,B)= intToReal(setA, j)
209 :     val (setC,vC,C)=mkSubSca(setB, [vA,vB])
210 :     in (setC,vC,A@B@C) end
211 :     (*end case*))
212 :     end
213 :    
214 :     fun mkpos(setO,k,fin,l,dict, i,code,n)= (case (k,i)
215 :     of ([],_)=>(setO,fin,code)
216 :     | ((_,_,pos1)::ks,0)=>let
217 :     val _=testp ["\n insert",iTos n, "->",iTos lb]
218 :     val mapp=insert (n, lb) dict
219 :     val (opset',l', code')=mkSimpleOp(setO,mapp,lhs,params,args, pos1)
220 :     val e=l@[l']
221 :     val mapp'=insert (n, 0) dict
222 :     in
223 :     mkpos(opset',ks,fin@[e],[],mapp',range0,code@code',n+1)
224 :     end
225 :     | ((_,_,pos1)::es,_) =>let
226 :     val _=testp ["\n insert",iTos n, "->",iTos (lb+i)]
227 :     val mapp= insert (n, lb+i) dict
228 :     val (opset',l', code')=mkSimpleOp(setO,mapp,lhs,params,args, pos1)
229 :     in
230 :     mkpos(opset',k,fin,l@[l'],dict,i-1,code@code',n)
231 :     end
232 :     (*end case*))
233 :    
234 :     fun consfn(setO,[],_,rest, code)=(setO,rest,code)
235 :     | consfn(setO,e1::es,n,rest,code)=let
236 :     val (setA, vA,A)= assgnCons(setO, "h"^iTos n,[length(e1)],List.rev e1)
237 :     in
238 :     consfn(setA,es, n+1,rest@[vA], code@A)
239 :     end
240 :    
241 :     fun evalK(setO,[],[],newId,code)=(setO,newId,code)
242 :     | evalK(setO,kn::kns,x::xs,newId,code)=let
243 :     val (_,dk,pos) =kn
244 :     val directionX= (case pos
245 : cchiw 3448 of E.Op2(E.Sub,E.Tensor (_,[E.C directionX]),_)=>directionX
246 : cchiw 3444 | _ => 0
247 :     (*end case*))
248 :     val name=String.concat["h",iTos directionX,"_",iTos dk]
249 :     val _ =testp["\n",Var.toString x," = ",name]
250 :     val (opsetK,id,kcode)= expandEvalKernel (setO,name,range1,h,dk, x)
251 :     in
252 :     evalK(opsetK,kns,xs,newId@[id],code@kcode)
253 :     end
254 :     | evalK _ =raise Fail "Non-equal variable list, error in mkKrns"
255 :    
256 :     val newkrns= List.map (fn (id,d1,pos)=>(id,evalDels(mappOrig,d1),pos)) krns
257 :     val _ = testp["\n\n ***** Differentiation value of kernels ****\n "]
258 :     val (set2,lftkrn,poscode)=mkpos(setOrig,newkrns,[],[],mappOrig,range0,[],sid)
259 :     val (set3,lft,conscode)=consfn(set2,lftkrn,0,[],[])
260 :     val (set4,ids, evalKcode)=evalK(set3,newkrns, lft,[],[])
261 :     val _ = List.map (fn e=>testp["\n IDS",Var.toString e,","]) ids
262 :     (*returns list in order h0, h1, h2*)
263 :     in
264 :     (set4, ids, poscode@conscode@evalKcode)
265 :     end
266 :    
267 :    
268 :     end (* local *)
269 :    
270 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0