Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/mid-to-low/eval-kern.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/mid-to-low/eval-kern.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4317 - (view) (download)

1 : jhr 3737 (* eval-kern.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure EvalKern : sig
10 :    
11 :     (* `expand (result, d, h, k, [x])`
12 :     *
13 :     * expands the EvalKernel operations into vector operations. The parameters
14 :     * are
15 :     * result -- the lhs variable to store the result
16 :     * d -- the vector width of the operation, which should be equal
17 :     * to twice the support of the kernel
18 :     * h -- the kernel
19 :     * k -- the derivative of the kernel to evaluate
20 : jhr 4317 * x -- the d-wide vector that specifies the values at which the
21 :     * kernel is being evaluated.
22 : jhr 3737 *
23 :     * The generated code is computing
24 :     *
25 :     * result = a_0 + x*(a_1 + x*(a_2 + ... x*a_n) ... )
26 :     *
27 :     * as a d-wide vector operation, where n is the degree of the kth derivative
28 : jhr 4057 * of h and the a_i are d-wide coefficient vectors that have an element for
29 :     * each piece of h. The computation is implemented as follows
30 : jhr 3737 *
31 :     * m_n = x * a_n
32 :     * s_{n-1} = a_{n-1} + m_n
33 :     * m_{n-1} = x * s_{n-1}
34 :     * s_{n-2} = a_{n-2} + m_{n-1}
35 :     * m_{n-2} = x * s_{n-2}
36 :     * ...
37 :     * s_1 = a_1 + m_2
38 :     * m_1 = x * s_1
39 :     * result = a_0 + m_1
40 :     *
41 : jhr 4057 * Note that the coeffient vectors are flipped.
42 : jhr 3737 *)
43 : jhr 4043 val expand : LowIR.var * int * Kernel.t * int * LowIR.var list
44 : jhr 4317 -> (LowIR.var * LowIR.rhs) list
45 : jhr 3737
46 :     end = struct
47 :    
48 :     structure IR = LowIR
49 :     structure Ty = LowTypes
50 :     structure Op = LowOps
51 :    
52 :     (* convert a rational to a RealLit.t value. We do this by long division
53 :     * with a cutoff when we get to 12 digits.
54 :     *)
55 :     fun ratToFloat r = (case Rational.explode r
56 :     of {sign=0, ...} => RealLit.zero false
57 :     | {sign, num, denom=1} => RealLit.fromInt(IntInf.fromInt sign * num)
58 :     | {sign, num, denom} => let
59 :     (* normalize so that num <= denom *)
60 :     val (denom, exp) = let
61 :     fun lp (n, denom) = if (denom < num)
62 :     then lp(n+1, denom*10)
63 :     else (denom, n)
64 :     in
65 :     lp (1, denom)
66 :     end
67 :     (* normalize so that num <= denom < 10*num *)
68 :     val (num, exp) = let
69 :     fun lp (n, num) = if (10*num < denom)
70 :     then lp(n-1, 10*num)
71 :     else (num, n)
72 :     in
73 :     lp (exp, num)
74 :     end
75 :     (* divide num/denom, computing the resulting digits *)
76 :     fun divLp (n, a) = let
77 :     val (q, r) = IntInf.divMod(a, denom)
78 :     in
79 :     if (r = 0) then (q, [])
80 :     else if (n < 12) then let
81 :     val (d, dd) = divLp(n+1, 10*r)
82 :     in
83 :     if (d < 10)
84 :     then (q, (IntInf.toInt d)::dd)
85 :     else (q+1, 0::dd)
86 :     end
87 :     else if (IntInf.div(10*r, denom) < 5)
88 :     then (q, [])
89 :     else (q+1, []) (* round up *)
90 :     end
91 :     val digits = let
92 :     val (d, dd) = divLp (0, num)
93 :     in
94 :     (IntInf.toInt d)::dd
95 :     end
96 :     in
97 :     RealLit.fromDigits{isNeg=(sign < 0), digits=digits, exp=exp}
98 :     end
99 :     (* end case *))
100 :    
101 :     fun expand (result, d, h, k, [x]) = let
102 :     val {isCont, segs} = Kernel.curve (h, k)
103 :     (* degree of polynomial *)
104 :     val deg = List.length(hd segs) - 1
105 :     (* convert to a vector of vectors to give fast access *)
106 :     val segs = Vector.fromList (List.rev (List.map Vector.fromList segs))
107 :     (* get the kernel coefficient value for the d'th term of the i'th
108 :     * segment.
109 :     *)
110 :     fun coefficient d i =
111 :     Literal.Real(ratToFloat (Vector.sub (Vector.sub(segs, i), d)))
112 :     val ty = Ty.vecTy d
113 : jhr 4057 val coeffs = List.tabulate (deg+1, fn i => IR.Var.new("a"^Int.toString i, ty))
114 : jhr 3737 (* code to define the coefficient vectors *)
115 :     val coeffVecs = let
116 :     fun mk (x, (i, code)) = let
117 :     val lits = List.tabulate(d, coefficient i)
118 :     val vars = List.tabulate(d, fn _ => IR.Var.new("_f", Ty.realTy))
119 :     val code =
120 :     ListPair.map (fn (x, lit) => (x, IR.LIT lit)) (vars, lits) @
121 :     (x, IR.CONS(vars, IR.Var.ty x)) :: code
122 :     in
123 :     (i-1, code)
124 :     end
125 :     in
126 :     #2 (List.foldr mk (deg, []) coeffs)
127 :     end
128 :     (* build the evaluation of the polynomials in reverse order *)
129 :     fun pTmp i = IR.Var.new("prod" ^ Int.toString i, ty)
130 :     fun sTmp i = IR.Var.new("sum" ^ Int.toString i, ty)
131 :     fun eval (i, [coeff]) = let
132 :     val m = pTmp i
133 :     in
134 :     (m, [(m, IR.OP(Op.VMul d, [x, coeff]))])
135 :     end
136 :     | eval (i, coeff::r) = let
137 :     val (m, stms) = eval(i+1, r)
138 :     val s = sTmp i
139 :     val m' = pTmp i
140 :     val stms =
141 :     (m', IR.OP(Op.VMul d, [x, s])) ::
142 :     (s, IR.OP(Op.VAdd d, [coeff, m])) ::
143 :     stms
144 :     in
145 :     (m', stms)
146 :     end
147 :     val evalCode = (case coeffs
148 :     of [a0] => (* constant function *)
149 :     [(result, IR.VAR a0)]
150 :     | a0::r => let
151 :     val (m, stms) = eval (1, r)
152 :     in
153 :     List.rev ((result, IR.OP(Op.VAdd d, [a0, m]))::stms)
154 :     end
155 :     (* end case *))
156 :     in
157 :     coeffVecs @ evalCode
158 :     end
159 :    
160 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0