Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/evalKrn.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2615 - (view) (download)

1 : cchiw 2522 (* Split Functions before code generation process*)
2 :     structure evalKrn = struct
3 :     local
4 :     structure DstIL = LowIL
5 :     structure DstTy = LowILTypes
6 :     structure DstOp = LowOps
7 : cchiw 2612 structure Var = LowIL.Var
8 :     structure E = Ein
9 : cchiw 2522 structure P=Printer
10 : cchiw 2612 structure S3=step3
11 : cchiw 2522
12 :     in
13 :    
14 : cchiw 2612 fun createProdVec(args,dim)=S3.aaV( DstOp.prodVec(dim),args,"prodVec",DstTy.TensorTy([dim]))
15 :     fun createAddVec(args,dim)=S3.aaV(DstOp.addVec(dim),args,"AddVec",DstTy.TensorTy([dim]))
16 : cchiw 2522
17 :    
18 :    
19 :     fun iadd (r : DstIL.var, a, b) = (r, DstIL.OP(DstOp.IAdd, [a, b]))
20 :     fun ilit (r : DstIL.var, n) = (r, DstIL.LIT(Literal.Int(IntInf.fromInt n)))
21 :     fun ilit (r : DstIL.var, n) = (r, DstIL.LIT(Literal.Int(IntInf.fromInt n)))
22 :     fun imul (r : DstIL.var, a, b) = (r, DstIL.OP(DstOp.IMul, [a, b]))
23 :    
24 :     (* convert a rational to a FloatLit.float value. We do this by long division
25 :     * with a cutoff when we get to 12 digits.
26 :     *)
27 :    
28 :     fun ratToFloat r = (case Rational.explode r
29 :     of {sign=0, ...} => FloatLit.zero false
30 :     | {sign, num, denom=1} => FloatLit.fromInt(IntInf.fromInt sign * num)
31 :     | {sign, num, denom} => let
32 :     (* normalize so that num <= denom *)
33 :     val (denom, exp) = let
34 :     fun lp (n, denom) = if (denom < num)
35 :     then lp(n+1, denom*10)
36 :     else (denom, n)
37 :     in
38 :     lp (1, denom)
39 :     end
40 :     (* normalize so that num <= denom < 10*num *)
41 :     val (num, exp) = let
42 :     fun lp (n, num) = if (10*num < denom)
43 :     then lp(n-1, 10*num)
44 :     else (num, n)
45 :     in
46 :     lp (exp, num)
47 :     end
48 :     (* divide num/denom, computing the resulting digits *)
49 :     fun divLp (n, a) = let
50 :     val (q, r) = IntInf.divMod(a, denom)
51 :     in
52 :     if (r = 0) then (q, [])
53 :     else if (n < 12) then let
54 :     val (d, dd) = divLp(n+1, 10*r)
55 :     in
56 :     if (d < 10)
57 :     then (q, (IntInf.toInt d)::dd)
58 :     else (q+1, 0::dd)
59 :     end
60 :     else if (IntInf.div(10*r, denom) < 5)
61 :     then (q, [])
62 :     else (q+1, []) (* round up *)
63 :     end
64 :     val digits = let
65 :     val (d, dd) = divLp (0, num)
66 :     in
67 :     (IntInf.toInt d)::dd
68 :     end
69 :     in
70 :     FloatLit.fromDigits{isNeg=(sign < 0), digits=digits, exp=exp}
71 :     end
72 :     (* end case *))
73 :    
74 :    
75 :     (* expand the EvalKernel operations into vector operations. The parameters
76 :     * are
77 :     * result -- the lhs variable to store the result
78 :     * d -- the vector width of the operation, which should be equal
79 :     * to twice the support of the kernel
80 :     * h -- the kernel
81 :     * k -- the derivative of the kernel to evaluate
82 :     *
83 :     * The generated code is computing
84 :     *
85 :     * result = a_0 + x*(a_1 + x*(a_2 + ... x*a_n) ... )
86 :     *
87 :     * as a d-wide vector operation, where n is the degree of the kth derivative
88 :     * of h and the a_i are coefficient vectors that have an element for each
89 :     * piece of h. The computation is implemented as follows
90 :     *
91 :     * m_n = x * a_n
92 :     * s_{n-1} = a_{n-1} + m_n
93 :     * m_{n-1} = x * s_{n-1}
94 :     * s_{n-2} = a_{n-2} + m_{n-1}
95 :     * m_{n-2} = x * s_{n-2}
96 :     * ...
97 :     * s_1 = a_1 + m_2
98 :     * m_1 = x * s_1
99 :     * result = a_0 + m_1
100 :     *
101 :     * Note that the coeffient vectors are flipped (cf high-to-low/probe.sml).
102 :     *)
103 :    
104 :     (* >*)
105 :    
106 :     (* Polynomial:3, length of position is 2.
107 :    
108 :     a_0=[d0,e0]
109 :     a_1=[d_1,e_1]
110 :     a_2=[d_1,e_1]
111 :     -----------------------
112 :    
113 :     [d0e0],[d1e1],[d2e2]
114 :     result=a_0 +x(a_1+x a_2)
115 :    
116 :    
117 :     --OtherwiseWise need to flip
118 :    
119 :     [d0,d1,d2],[e0,e1,e2]=> [d0e0],[d1e1],[d2e2]
120 :     result=a_0 +x(a_1+ x a_2)
121 :    
122 :    
123 :     *)
124 :    
125 :    
126 :    
127 : cchiw 2525
128 : cchiw 2522
129 : cchiw 2525 fun expandEvalKernel (d, h, k, x,axis) = let
130 :    
131 : cchiw 2576 (* val tetser=print "\n\n ###### In eval kernel ######## \n\n "*)
132 : cchiw 2525
133 : cchiw 2522 val {isCont, segs} = Kernel.curve (h, k)
134 :     (* degree of polynomial *)
135 :     val deg = List.length(hd segs) - 1
136 : cchiw 2525 val dd=List.length(hd segs)
137 : cchiw 2522
138 : cchiw 2525 val ss=Kernel.support(h)
139 :     (*segs is length 2*support, inner list is listof poynomial*)
140 : cchiw 2576 (*val m=print(String.concat["Axis: ", Int.toString(axis)," Length of segs:",Int.toString(length(segs))," segs next",Int.toString(dd)," k:", Int.toString(k),"SUpport:",Int.toString(ss)])*)
141 : cchiw 2522
142 :     val segs = Vector.fromList (List.rev (List.map Vector.fromList segs))
143 :     fun coefficient d i =
144 :     Literal.Float(ratToFloat (Vector.sub (Vector.sub(segs, i), d)))
145 :    
146 :    
147 :     val ty = DstTy.vecTy d
148 : cchiw 2612 val coeffs = List.tabulate (deg+1,fn i => Var.new("P"^Int.toString i, ty))
149 : cchiw 2522
150 :    
151 :     (* code to define the coefficient vectors *)
152 :     val coeffVecs = let
153 :     fun mk (x, (i, code)) = let
154 :     val lits = List.tabulate(d, coefficient i)
155 : cchiw 2612 val vars = List.tabulate(d, fn _ => Var.new("_f", DstTy.realTy))
156 : cchiw 2522 val code =
157 : cchiw 2612 ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits) @(x, DstIL.CONS(Var.ty x, vars)) :: code
158 : cchiw 2522 in
159 :     (i-1, code)
160 :     end
161 :     in
162 :     #2 (List.foldr mk (deg, []) coeffs)
163 :     end
164 :    
165 : cchiw 2525 val q=List.map (fn(x,y)=>DstIL.ASSGN (x,y)) coeffVecs
166 : cchiw 2576 (*val tester= List.map gHelper.printX q*)
167 : cchiw 2522
168 : cchiw 2525
169 : cchiw 2522 fun m([e2,e1],code)=let
170 : cchiw 2612 val (vA,A)= createProdVec([x,e2],d)
171 :     val (vB,B)=createAddVec([e1,vA],d)
172 : cchiw 2522 in
173 : cchiw 2525 (vB,code@A@B)
174 : cchiw 2522 end
175 :     | m (e2::e1::es,code)= let
176 : cchiw 2612 val (vA,A)= createProdVec([x,e2],d)
177 :     val (vB,B)=createAddVec([e1,vA],d)
178 : cchiw 2522 in
179 : cchiw 2525 m(vB::es,code@A@B)
180 : cchiw 2522 end
181 : cchiw 2615 | m _ = raise Fail "0 or 1 item in Kernel coeffs"
182 : cchiw 2522 val (vC,code')= m(List.rev coeffs, [])
183 : cchiw 2615
184 : cchiw 2522 in
185 : cchiw 2525 (vC,q@code')
186 : cchiw 2522 end
187 :    
188 :    
189 :    
190 :     end (* local *)
191 :    
192 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0