SCM Repository
Annotation of /trunk/src/compiler/translate/translate-basis.sml
Parent Directory
|
Revision Log
Revision 3349 - (view) (download)
1 : | jhr | 180 | (* translate-basis.sml |
2 : | * | ||
3 : | jhr | 3349 | * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) |
4 : | * | ||
5 : | * COPYRIGHT (c) 2015 The University of Chicago | ||
6 : | jhr | 180 | * All rights reserved. |
7 : | * | ||
8 : | * Translation for basis operations in Simple AST to HighIL code | ||
9 : | *) | ||
10 : | |||
11 : | structure TranslateBasis : sig | ||
12 : | |||
13 : | jhr | 189 | (* translate(lhs, f, mvs, args) translates the application of f (specialized |
14 : | * to the instantiated meta variables mvs) to a list of SSA assignments in | ||
15 : | * reverse order. | ||
16 : | *) | ||
17 : | jhr | 2476 | val translate : (HighIL.var * Var.var * SimpleTypes.meta_arg list * HighIL.var list) |
18 : | jhr | 2356 | -> HighIL.assignment list |
19 : | jhr | 180 | |
20 : | end = struct | ||
21 : | |||
22 : | jhr | 194 | structure BV = BasisVars |
23 : | jhr | 180 | structure IL = HighIL |
24 : | jhr | 392 | structure DstTy = HighILTypes |
25 : | jhr | 180 | structure Op = HighOps |
26 : | jhr | 2476 | structure Ty = SimpleTypes |
27 : | jhr | 194 | structure VTbl = Var.Tbl |
28 : | jhr | 180 | |
29 : | jhr | 2483 | fun trType (Ty.TY ty) = TranslateTy.tr ty |
30 : | | trType _ = raise Fail "expected type" | ||
31 : | jhr | 2476 | fun dimVarToInt (Ty.DIM d) = d |
32 : | | dimVarToInt _ = raise Fail "expected dim" | ||
33 : | jhr | 2358 | fun dimVarToTensor dv = DstTy.tensorTy[dimVarToInt dv] |
34 : | jhr | 1116 | fun dimVarToMatrix dv = let |
35 : | jhr | 2358 | val d = dimVarToInt dv |
36 : | jhr | 2356 | in |
37 : | DstTy.tensorTy[d, d] (* square matrix type *) | ||
38 : | end | ||
39 : | jhr | 2476 | fun shapeVarToTensor (Ty.SHAPE shp) = DstTy.tensorTy shp |
40 : | | shapeVarToTensor _ = raise Fail "expected shape" | ||
41 : | jhr | 197 | |
42 : | jhr | 1640 | fun assign (y, rator, xs) = [IL.ASSGN(y, IL.OP(rator, xs))] |
43 : | jhr | 407 | |
44 : | jhr | 1640 | fun basisFn name (y, [], xs) = [IL.ASSGN(y, IL.APPLY(name, xs))] |
45 : | jhr | 1116 | |
46 : | jhr | 407 | fun simpleOp rator (y, [], xs) = assign (y, rator, xs) |
47 : | |||
48 : | fun tensorOp rator (y, [sv], xs) = assign (y, rator(shapeVarToTensor sv), xs) | ||
49 : | |||
50 : | fun vectorOp rator (y, [dv], xs) = assign (y, rator(dimVarToTensor dv), xs) | ||
51 : | |||
52 : | jhr | 1116 | fun kernel h (y, [], []) = assign(y, Op.Kernel(h, 0), []) |
53 : | jhr | 197 | |
54 : | jhr | 1640 | (* utility functions for synthesizing eigenvector/eigenvalue code *) |
55 : | fun eigenVec (rator, dim) = let | ||
56 : | val ty = DstTy.SeqTy(DstTy.realTy, dim) | ||
57 : | in | ||
58 : | fn (y, _, [m]) => let | ||
59 : | val v = IL.Var.new("evals", ty) | ||
60 : | in | ||
61 : | [IL.MASSGN([v, y], rator, [m])] | ||
62 : | end | ||
63 : | end | ||
64 : | fun eigenVal (rator, dim) = let | ||
65 : | val ty = DstTy.SeqTy(DstTy.vecTy dim, dim) | ||
66 : | in | ||
67 : | fn (y, _, [m]) => let | ||
68 : | val v = IL.Var.new("evecs", ty) | ||
69 : | in | ||
70 : | [IL.MASSGN([y, v], rator, [m])] | ||
71 : | end | ||
72 : | end | ||
73 : | |||
74 : | jhr | 184 | (* build a table that maps Basis variables to their translation functions *) |
75 : | jhr | 2476 | val tbl : ((IL.var * Ty.meta_arg list * IL.var list) -> IL.assignment list) VTbl.hash_table = let |
76 : | jhr | 2356 | val tbl = VTbl.mkTable (128, Fail "Translate table") |
77 : | val insert = VTbl.insert tbl | ||
78 : | in | ||
79 : | List.app insert [ | ||
80 : | (BV.lt_ii, simpleOp(Op.LT DstTy.IntTy)), | ||
81 : | (BV.lt_rr, simpleOp(Op.LT DstTy.realTy)), | ||
82 : | (BV.lte_ii, simpleOp(Op.LTE DstTy.IntTy)), | ||
83 : | (BV.lte_rr, simpleOp(Op.LTE DstTy.realTy)), | ||
84 : | (BV.gte_ii, simpleOp(Op.GTE DstTy.IntTy)), | ||
85 : | (BV.gte_rr, simpleOp(Op.GTE(DstTy.realTy))), | ||
86 : | (BV.gt_ii, simpleOp(Op.GT DstTy.IntTy)), | ||
87 : | (BV.gt_rr, simpleOp(Op.GT(DstTy.realTy))), | ||
88 : | (BV.equ_bb, simpleOp(Op.EQ DstTy.BoolTy)), | ||
89 : | (BV.equ_ii, simpleOp(Op.EQ DstTy.IntTy)), | ||
90 : | (BV.equ_ss, simpleOp(Op.EQ DstTy.StringTy)), | ||
91 : | (BV.equ_rr, simpleOp(Op.EQ(DstTy.realTy))), | ||
92 : | (BV.neq_bb, simpleOp(Op.NEQ DstTy.BoolTy)), | ||
93 : | (BV.neq_ii, simpleOp(Op.NEQ DstTy.IntTy)), | ||
94 : | (BV.neq_ss, simpleOp(Op.NEQ DstTy.StringTy)), | ||
95 : | (BV.neq_rr, simpleOp(Op.NEQ(DstTy.realTy))), | ||
96 : | (BV.add_ii, simpleOp(Op.Add DstTy.IntTy)), | ||
97 : | (BV.add_tt, tensorOp Op.Add), | ||
98 : | (BV.add_ff, fn (y, _, [f, g]) => assign(y, Op.AddField, [f, g])), | ||
99 : | jhr | 2358 | (BV.add_fr, fn (y, _, [f, s]) => assign(y, Op.OffsetField, [f, s])), |
100 : | (BV.add_rf, fn (y, _, [s, f]) => assign(y, Op.OffsetField, [f, s])), | ||
101 : | jhr | 2356 | (BV.sub_ii, simpleOp(Op.Sub DstTy.IntTy)), |
102 : | (BV.sub_tt, tensorOp Op.Sub), | ||
103 : | (BV.sub_ff, fn (y, _, [f, g]) => assign(y, Op.SubField, [f, g])), | ||
104 : | jhr | 2358 | (BV.sub_fr, fn (y, _, [f, s]) => let |
105 : | val s' = IL.Var.copy s | ||
106 : | in [ | ||
107 : | IL.ASSGN(s', IL.OP(Op.Neg DstTy.realTy, [s])), | ||
108 : | IL.ASSGN(y, IL.OP(Op.OffsetField, [f, s'])) | ||
109 : | ] end), | ||
110 : | (BV.sub_rf, fn (y, _, [s, f]) => let | ||
111 : | val f' = IL.Var.copy f | ||
112 : | in [ | ||
113 : | IL.ASSGN(f', IL.OP(Op.NegField, [f])), | ||
114 : | IL.ASSGN(y, IL.OP(Op.OffsetField, [f', s])) | ||
115 : | ] end), | ||
116 : | jhr | 2356 | (BV.mul_ii, simpleOp(Op.Mul DstTy.IntTy)), |
117 : | (BV.mul_rr, simpleOp(Op.Mul(DstTy.realTy))), | ||
118 : | (BV.mul_rt, tensorOp Op.Scale), | ||
119 : | (BV.mul_tr, fn (y, sv, [t, r]) => tensorOp Op.Scale (y, sv, [r, t])), | ||
120 : | (BV.mul_rf, fn (y, _, [s, f]) => assign(y, Op.ScaleField, [s, f])), | ||
121 : | (BV.mul_fr, fn (y, _, [f, s]) => assign(y, Op.ScaleField, [s, f])), | ||
122 : | (BV.div_ii, simpleOp(Op.Div DstTy.IntTy)), | ||
123 : | (BV.div_rr, simpleOp(Op.Div DstTy.realTy)), | ||
124 : | (BV.div_tr, fn (y, [sv], [x, s]) => let | ||
125 : | val one = IL.Var.new("one", DstTy.realTy) | ||
126 : | val s' = IL.Var.new("sInv", DstTy.realTy) | ||
127 : | in [ | ||
128 : | IL.ASSGN(one, IL.LIT(Literal.Float(FloatLit.one))), | ||
129 : | IL.ASSGN(s', IL.OP(Op.Div DstTy.realTy, [one, s])), | ||
130 : | IL.ASSGN(y, IL.OP(Op.Scale(shapeVarToTensor sv), [s', x])) | ||
131 : | ] end), | ||
132 : | (BV.div_fr, fn (y, _, [f, s]) => let | ||
133 : | val one = IL.Var.new("one", DstTy.realTy) | ||
134 : | val s' = IL.Var.new("sInv", DstTy.realTy) | ||
135 : | in [ | ||
136 : | IL.ASSGN(one, IL.LIT(Literal.Float(FloatLit.one))), | ||
137 : | IL.ASSGN(s', IL.OP(Op.Div DstTy.realTy, [one, s])), | ||
138 : | IL.ASSGN(y, IL.OP(Op.ScaleField, [s', f])) | ||
139 : | ] end), | ||
140 : | (BV.exp_ri, simpleOp(Op.Power)), | ||
141 : | (BV.exp_rr, basisFn MathFuns.pow), | ||
142 : | jhr | 2358 | (BV.curl2D, fn (y, _, xs) => assign(y, Op.CurlField 2, xs)), |
143 : | (BV.curl3D, fn (y, _, xs) => assign(y, Op.CurlField 3, xs)), | ||
144 : | jhr | 2476 | (BV.convolve_vk, fn (y, [_, Ty.DIM d, _], xs) => assign(y, Op.Field d, xs)), |
145 : | (BV.convolve_kv, fn (y, [_, Ty.DIM d, _], [k, v]) => assign(y, Op.Field d, [v, k])), | ||
146 : | jhr | 2356 | (BV.neg_i, simpleOp(Op.Neg DstTy.IntTy)), |
147 : | (BV.neg_t, tensorOp Op.Neg), | ||
148 : | (BV.neg_f, fn (y, _, xs) => assign(y, Op.NegField, xs)), | ||
149 : | (BV.op_probe, fn (y, [_, dv, sv], xs) => | ||
150 : | assign(y, Op.Probe(dimVarToTensor dv, shapeVarToTensor sv), xs)), | ||
151 : | (BV.op_D, fn (y, _, xs) => assign(y, Op.DiffField, xs)), | ||
152 : | (BV.op_Dotimes, fn (y, _, xs) => assign(y, Op.DiffField, xs)), | ||
153 : | (BV.op_norm, fn (y, [sv], xs) => (case shapeVarToTensor sv | ||
154 : | of DstTy.TensorTy[] => assign(y, Op.Abs DstTy.realTy, xs) | ||
155 : | | ty => assign(y, Op.Norm ty, xs) | ||
156 : | (* end case *))), | ||
157 : | (BV.op_not, simpleOp Op.Not), | ||
158 : | (BV.op_cross, simpleOp Op.Cross), | ||
159 : | jhr | 2476 | (BV.op_outer, fn (y, [Ty.DIM d1, Ty.DIM d2], xs) => |
160 : | jhr | 2480 | assign (y, Op.Outer(DstTy.tensorTy[d1, d2]), xs)), |
161 : | jhr | 2476 | (BV.op_inner, fn (y, [Ty.SHAPE dd1, Ty.SHAPE dd2, _], xs) => let |
162 : | val ty1 = DstTy.TensorTy dd1 | ||
163 : | val ty2 = DstTy.TensorTy dd2 | ||
164 : | jhr | 2356 | val rator = (case (dd1, dd2) |
165 : | of ([d], [d']) => Op.Dot ty1 | ||
166 : | | ([d1], [d1', d2]) => Op.MulVecMat ty2 | ||
167 : | | ([d1, d2], [d2']) => Op.MulMatVec ty1 | ||
168 : | | ([d1, d2], [d2', d3]) => Op.MulMatMat(ty1, ty2) | ||
169 : | jhr | 2358 | | ([d1], [d1', d2, d3]) => Op.MulVecTen3 ty2 |
170 : | | ([d1, d2, d3], [d3']) => Op.MulTen3Vec ty1 | ||
171 : | | _ => raise Fail(concat[ | ||
172 : | "unsupported inner-product: ", | ||
173 : | DstTy.toString ty1, " * ", DstTy.toString ty2 | ||
174 : | ]) | ||
175 : | jhr | 2356 | (* end case *)) |
176 : | in | ||
177 : | assign (y, rator, xs) | ||
178 : | end), | ||
179 : | jhr | 2358 | (BV.op_colon, fn (y, [sh1, sh2, _], xs) => let |
180 : | jhr | 2476 | val ty1 = shapeVarToTensor sh1 |
181 : | val ty2 = shapeVarToTensor sh2 | ||
182 : | jhr | 2358 | in |
183 : | assign (y, Op.ColonMul(ty1, ty2), xs) | ||
184 : | end), | ||
185 : | jhr | 2476 | (BV.fn_inside, fn (y, [_, Ty.DIM d, _], xs) => |
186 : | assign(y, Op.Inside d, xs)), | ||
187 : | jhr | 2356 | (BV.clamp_rrr, simpleOp (Op.Clamp DstTy.realTy)), |
188 : | (BV.clamp_vvv, vectorOp Op.Clamp), | ||
189 : | (BV.lerp3, tensorOp Op.Lerp), | ||
190 : | (BV.lerp5, fn (y, [sv], [a, b, x0, x, x1]) => let | ||
191 : | val t1 = IL.Var.new("t1", DstTy.realTy) | ||
192 : | val t2 = IL.Var.new("t2", DstTy.realTy) | ||
193 : | val t3 = IL.Var.new("t3", DstTy.realTy) | ||
194 : | in [ | ||
195 : | IL.ASSGN(t1, IL.OP(Op.Sub DstTy.realTy, [x, x0])), | ||
196 : | IL.ASSGN(t2, IL.OP(Op.Sub DstTy.realTy, [x1, x0])), | ||
197 : | IL.ASSGN(t3, IL.OP(Op.Div DstTy.realTy, [t1, t2])), | ||
198 : | IL.ASSGN(y, IL.OP(Op.Lerp(shapeVarToTensor sv), [a, b, t3])) | ||
199 : | ] end), | ||
200 : | (BV.evals2x2, eigenVal (Op.Eigen2x2, 2)), | ||
201 : | (BV.evals3x3, eigenVal (Op.Eigen3x3, 3)), | ||
202 : | (BV.evecs2x2, eigenVec (Op.Eigen2x2, 2)), | ||
203 : | (BV.evecs3x3, eigenVec (Op.Eigen3x3, 3)), | ||
204 : | (BV.fn_max, simpleOp Op.Max), | ||
205 : | (BV.fn_min, simpleOp Op.Min), | ||
206 : | (BV.fn_modulate, vectorOp Op.Mul), | ||
207 : | (BV.fn_normalize, vectorOp Op.Normalize), | ||
208 : | (BV.fn_principleEvec, vectorOp Op.PrincipleEvec), | ||
209 : | (BV.fn_trace, fn (y, [dv], xs) => assign(y, Op.Trace(dimVarToMatrix dv), xs)), | ||
210 : | jhr | 2480 | (BV.fn_transpose, fn (y, [Ty.DIM d1, Ty.DIM d2], xs) => |
211 : | assign(y, Op.Transpose(d1, d2), xs)), | ||
212 : | jhr | 2356 | (BV.kn_bspln3, kernel Kernel.bspln3), |
213 : | (BV.kn_bspln5, kernel Kernel.bspln5), | ||
214 : | (BV.kn_ctmr, kernel Kernel.ctmr), | ||
215 : | (BV.kn_c2ctmr, kernel Kernel.ctmr), | ||
216 : | jhr | 2358 | (BV.kn_c4hexic, kernel Kernel.c4hexic), |
217 : | jhr | 2356 | (BV.kn_tent, kernel Kernel.tent), |
218 : | (BV.kn_c1tent, kernel Kernel.tent), | ||
219 : | (BV.i2r, simpleOp Op.IntToReal), | ||
220 : | jhr | 2476 | (BV.identity, fn (y, [Ty.DIM d], []) => |
221 : | assign(y, Op.Identity d, [])), | ||
222 : | jhr | 2356 | (BV.zero, fn (y, [sv], []) => |
223 : | assign(y, Op.Zero(shapeVarToTensor sv), [])), | ||
224 : | jhr | 2483 | (BV.subscript, fn (y, [tv, Ty.DIM d], args) => |
225 : | jhr | 1640 | assign(y, |
226 : | jhr | 2483 | Op.SeqSub(DstTy.SeqTy(trType tv, d)), |
227 : | jhr | 1640 | args)) |
228 : | jhr | 2356 | ]; |
229 : | (* add C math functions *) | ||
230 : | List.app (fn (n, x) => insert(x, basisFn n)) BV.mathFns; | ||
231 : | tbl | ||
232 : | end | ||
233 : | jhr | 180 | |
234 : | jhr | 194 | fun translate (y, f, mvs, xs) = (case VTbl.find tbl f |
235 : | jhr | 2356 | of SOME transFn => transFn(y, mvs, xs) |
236 : | | NONE => raise Fail("TranslateBasis.translate: unknown basis function " ^ Var.uniqueNameOf f) | ||
237 : | (* end case *)) | ||
238 : | jhr | 1116 | handle ex => (print(concat["translate (", IL.Var.toString y, ", ", |
239 : | Var.uniqueNameOf f, ", ...)\n"]); raise ex) | ||
240 : | jhr | 180 | |
241 : | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |