SCM Repository
Annotation of /trunk/src/compiler/simplify/eval.sml
Parent Directory
|
Revision Log
Revision 1910 - (view) (download)
1 : | jhr | 231 | (* eval.sml |
2 : | * | ||
3 : | jhr | 435 | * COPYRIGHT (c) 2010 The Diderot Project (http://diderot-language.cs.uchicago.edu) |
4 : | jhr | 231 | * All rights reserved. |
5 : | * | ||
6 : | * Evaluation of "static" expressions. | ||
7 : | *) | ||
8 : | |||
9 : | jhr | 236 | structure Eval : sig |
10 : | jhr | 231 | |
11 : | jhr | 1140 | (* raised if there is an error due to faulty code or input values (e.g., loading an |
12 : | * image of the wrong shape. | ||
13 : | *) | ||
14 : | exception Error of string list | ||
15 : | |||
16 : | jhr | 236 | datatype value |
17 : | = BV of bool | ||
18 : | | SV of string | ||
19 : | | IV of IntInf.int | ||
20 : | jhr | 1116 | | RV of real |
21 : | | TV of value list (* tensors: values will either be RV or TV *) | ||
22 : | | ImgV of ImageInfo.info * Var.var | ||
23 : | jhr | 236 | |
24 : | jhr | 269 | val evalStatics : Var.Set.set * Simple.block -> value Var.Map.map |
25 : | jhr | 236 | |
26 : | end = struct | ||
27 : | |||
28 : | jhr | 234 | structure Ty = Types |
29 : | structure BV = BasisVars | ||
30 : | jhr | 231 | structure S = Simple |
31 : | structure VMap = Var.Map | ||
32 : | jhr | 269 | structure VSet = Var.Set |
33 : | jhr | 231 | structure VTbl = Var.Tbl |
34 : | |||
35 : | jhr | 1140 | exception Error of string list |
36 : | |||
37 : | jhr | 231 | datatype value |
38 : | = BV of bool | ||
39 : | jhr | 234 | | SV of string |
40 : | jhr | 231 | | IV of IntInf.int |
41 : | jhr | 1116 | (* FIXME: we probably should use FloatLit.float values instead of reals! *) |
42 : | | RV of real | ||
43 : | | TV of value list (* tensors: values will either be RV or TV *) | ||
44 : | | ImgV of ImageInfo.info * Var.var | ||
45 : | jhr | 231 | |
46 : | jhr | 234 | fun toString (BV b) = Bool.toString b |
47 : | | toString (IV i) = IntInf.toString i | ||
48 : | | toString (SV s) = concat["\"", String.toString s, "\""] | ||
49 : | jhr | 1116 | | toString (RV r) = Real.toString r |
50 : | | toString (TV _) = "tensor" | ||
51 : | (* FIXME: should include x in output *) | ||
52 : | | toString (ImgV(info, x)) = ImageInfo.toString info | ||
53 : | jhr | 231 | |
54 : | val tbl : (value list -> value) VTbl.hash_table = let | ||
55 : | val tbl = VTbl.mkTable (128, Fail "Eval table") | ||
56 : | fun intBinOp rator [IV a, IV b] = IV(rator(a, b)) | ||
57 : | jhr | 1116 | fun tensorBinOp rator [v1, v2] = let |
58 : | fun f (TV v1, TV v2) = TV(ListPair.mapEq f (v1, v2)) | ||
59 : | | f (RV r1, RV r2) = RV(rator(r1, r2)) | ||
60 : | in | ||
61 : | f (v1, v2) | ||
62 : | end | ||
63 : | fun realBinOp rator [RV a, RV b] = RV(rator(a, b)) | ||
64 : | fun realUnOp rator [RV a] = RV(rator a) | ||
65 : | jhr | 231 | fun intCmp rator [IV a, IV b] = BV(rator(a, b)) |
66 : | jhr | 1116 | fun realCmp rator [RV a, RV b] = BV(rator(a, b)) |
67 : | jhr | 234 | fun boolCmp rator [BV a, BV b] = BV(rator(a, b)) |
68 : | fun stringCmp rator [SV a, SV b] = BV(rator(a, b)) | ||
69 : | jhr | 231 | in |
70 : | List.app (VTbl.insert tbl) [ | ||
71 : | (BV.add_ii, intBinOp (op +)), | ||
72 : | (BV.add_tt, tensorBinOp (op +)), | ||
73 : | (BV.sub_ii, intBinOp (op -)), | ||
74 : | (BV.sub_tt, tensorBinOp (op -)), | ||
75 : | jhr | 234 | (BV.mul_ii, intBinOp (op * )), |
76 : | (BV.mul_rr, realBinOp (op * )), | ||
77 : | (* | ||
78 : | jhr | 231 | (BV.mul_rt, tensorOp Op.Scale), |
79 : | (BV.mul_tr, fn (y, sv, [t, r]) => tensorOp Op.Scale (y, sv, [r, t])), | ||
80 : | jhr | 234 | *) |
81 : | jhr | 231 | (BV.div_ii, intBinOp IntInf.quot), |
82 : | jhr | 234 | (BV.div_rr, realBinOp (op /)), |
83 : | (* | ||
84 : | jhr | 231 | (BV.div_tr, tensorOp Op.InvScale), |
85 : | jhr | 234 | *) |
86 : | jhr | 231 | (BV.lt_ii, intCmp (op <)), |
87 : | (BV.lt_rr, realCmp (op <)), | ||
88 : | (BV.lte_ii, intCmp (op <=)), | ||
89 : | (BV.lte_rr, realCmp (op <=)), | ||
90 : | (BV.gte_ii, intCmp (op >=)), | ||
91 : | (BV.gte_rr, realCmp (op >=)), | ||
92 : | (BV.gt_ii, intCmp (op >)), | ||
93 : | (BV.gt_rr, realCmp (op >)), | ||
94 : | jhr | 234 | (BV.equ_bb, boolCmp (op =)), |
95 : | jhr | 231 | (BV.equ_ii, intCmp (op =)), |
96 : | jhr | 234 | (BV.equ_ss, stringCmp (op =)), |
97 : | jhr | 231 | (BV.equ_rr, realCmp Real.==), |
98 : | jhr | 234 | (BV.neq_bb, boolCmp (op <>)), |
99 : | jhr | 231 | (BV.neq_ii, intCmp (op <>)), |
100 : | jhr | 234 | (BV.neq_ss, stringCmp (op <>)), |
101 : | jhr | 231 | (BV.neq_rr, realCmp Real.!=), |
102 : | jhr | 245 | (BV.neg_i, fn [IV i] => IV(~i)), |
103 : | jhr | 234 | (* |
104 : | jhr | 231 | (BV.neg_t, tensorOp Op.Neg), |
105 : | jhr | 245 | (BV.neg_f, fn [FV fld] => FV(FieldDef.neg fld)), |
106 : | jhr | 231 | (BV.op_at, fn (y, _, xs) => assign(y, Op.Probe, xs)), |
107 : | jhr | 245 | (BV.op_D, fn [FV fld] => FV(FieldDef.diff fld)), |
108 : | jhr | 231 | (BV.op_norm, tensorOp Op.Norm), |
109 : | jhr | 245 | *) |
110 : | (BV.op_not, fn [BV b] => BV(not b)), | ||
111 : | jhr | 231 | (* |
112 : | jhr | 245 | (BV.fn_CL, fn (y, _, xs) => assign(y, Op.CL, xs)), |
113 : | jhr | 269 | (BV.op_convolve, fn [Img info, KV h] => FV(FieldDef.CONV(0, info, h))), |
114 : | jhr | 1116 | *) |
115 : | jhr | 245 | (BV.fn_cos, realUnOp Math.cos), |
116 : | (* | ||
117 : | jhr | 231 | (BV.fn_inside, fn (y, _, xs) => assign(y, Op.Inside, xs)), |
118 : | jhr | 234 | *) |
119 : | jhr | 1910 | (BV.fn_log10, realUnOp Math.log10), |
120 : | (BV.fn_ln, realUnOp Math.ln), | ||
121 : | jhr | 234 | (BV.fn_max, realBinOp Real.min), |
122 : | (BV.fn_min, realBinOp Real.max), | ||
123 : | (BV.fn_modulate, tensorBinOp (op * )), | ||
124 : | jhr | 231 | (* |
125 : | (BV.fn_principleEvec, vectorOp Op.PrincipleEvec), | ||
126 : | jhr | 234 | *) |
127 : | jhr | 245 | (BV.fn_sin, realUnOp Math.sin), |
128 : | jhr | 1116 | (* |
129 : | jhr | 231 | (BV.kn_bspln3, kernel Kernel.bspln3), |
130 : | (BV.kn_bspln5, kernel Kernel.bspln5), | ||
131 : | (BV.kn_ctmr, kernel Kernel.ctmr), | ||
132 : | (BV.kn_tent, kernel Kernel.tent), | ||
133 : | jhr | 1116 | (BV.kn_c1tent, kernel Kernel.c1tent), |
134 : | (BV.kn_c2ctmr, kernel Kernel.c2ctmr), | ||
135 : | *) | ||
136 : | jhr | 245 | (BV.i2r, fn [IV i] => RV(real(IntInf.toInt i))) |
137 : | jhr | 231 | ]; |
138 : | tbl | ||
139 : | end | ||
140 : | |||
141 : | jhr | 1116 | fun loadImage ([Ty.DIM dim, Ty.SHAPE shp], SV filename) = let |
142 : | jhr | 245 | val Ty.DimConst d = TypeUtil.resolveDim dim |
143 : | jhr | 1140 | val dd = let |
144 : | val Ty.Shape dd = TypeUtil.resolveShape shp | ||
145 : | fun doDim (Ty.DimConst d) = d | ||
146 : | | doDim (Ty.DimVar d) = let val Ty.DimConst d = TypeUtil.resolveDim d in d end | ||
147 : | in | ||
148 : | List.map doDim dd | ||
149 : | end | ||
150 : | val info as ImageInfo.ImgInfo{dim, ty=(rng, _), ...} = ImageInfo.getInfo filename | ||
151 : | fun rngToS [] = "real" | ||
152 : | | rngToS dd = concat["tensor[", String.concatWith "," (List.map Int.toString dd), "]"] | ||
153 : | fun error msg = raise Error("image file \"" :: filename :: "\" " :: msg) | ||
154 : | jhr | 234 | in |
155 : | jhr | 245 | (* check that the expected dimension and actual dimension match *) |
156 : | if (d <> dim) | ||
157 : | jhr | 1140 | then error ["has dimension ", Int.toString dim, ", expected ", Int.toString d] |
158 : | (* check that the expected shape and actual shape match *) | ||
159 : | else if not(ListPair.allEq (op =) (dd, rng)) | ||
160 : | then error ["has range ", rngToS rng, ", expected ", rngToS dd] | ||
161 : | jhr | 245 | else (); |
162 : | jhr | 1116 | info |
163 : | jhr | 234 | end |
164 : | |||
165 : | jhr | 231 | fun evalVar env x = (case VMap.find (env, x) |
166 : | of SOME v => v | ||
167 : | | NONE => raise Fail("undefined variable " ^ Var.uniqueNameOf x) | ||
168 : | (* end case *)) | ||
169 : | |||
170 : | jhr | 269 | fun apply (env, f, mvs, xs) = |
171 : | if List.all (fn x => VMap.inDomain(env, x)) xs | ||
172 : | then (* try *)( | ||
173 : | jhr | 231 | if Var.same(f, BV.fn_load) |
174 : | jhr | 1116 | then let |
175 : | val [imgName] = xs | ||
176 : | in | ||
177 : | SOME(ImgV(loadImage(mvs, evalVar env imgName), imgName)) | ||
178 : | end | ||
179 : | jhr | 231 | else (case VTbl.find tbl f |
180 : | jhr | 269 | of SOME evalFn => SOME(evalFn (List.map (evalVar env) xs)) |
181 : | | NONE => NONE | ||
182 : | jhr | 231 | (* end case *)) |
183 : | jhr | 1140 | ) handle ex as Error msg => raise ex |
184 : | | ex => ( | ||
185 : | TextIO.output (TextIO.stdErr, concat [ | ||
186 : | Var.uniqueNameOf f, "(", | ||
187 : | String.concatWith "," (List.map Var.uniqueNameOf xs), | ||
188 : | ") fails with exception ", exnName ex, "\n" | ||
189 : | ]); | ||
190 : | jhr | 269 | raise ex) |
191 : | else NONE | ||
192 : | |||
193 : | fun evalExp (env, e) = (case e | ||
194 : | of S.E_Var x => VMap.find (env, x) | ||
195 : | | S.E_Lit(Literal.Int i) => SOME(IV i) | ||
196 : | | S.E_Lit(Literal.Float f) => SOME(RV(FloatLit.toReal f)) | ||
197 : | | S.E_Lit(Literal.String s) => SOME(SV s) | ||
198 : | | S.E_Lit(Literal.Bool b) => SOME(BV b) | ||
199 : | | S.E_Tuple _ => raise Fail "E_Tuple" | ||
200 : | | S.E_Apply(f, mvs, xs, _) => apply(env, f, mvs, xs) | ||
201 : | jhr | 1116 | | S.E_Cons xs => (case evalArgs(env, xs) |
202 : | of NONE => NONE | ||
203 : | | SOME vs => SOME(TV vs) | ||
204 : | (* end case *)) | ||
205 : | | S.E_Slice(x, indices, _) => (case VMap.find (env, x) | ||
206 : | of SOME v => let | ||
207 : | fun slice (TV vs, SOME ix :: ixs) = (case VMap.find (env, ix) | ||
208 : | of SOME(IV i) => slice (List.nth(vs, IntInf.toInt i), ixs) | ||
209 : | | NONE => raise Subscript | ||
210 : | (* end case *)) | ||
211 : | | slice (TV vs, NONE :: ixs) = | ||
212 : | TV(List.map (fn v => slice(v, ixs)) vs) | ||
213 : | | slice (v, []) = v | ||
214 : | in | ||
215 : | SOME(slice(v, indices)) handle Subscript => NONE | ||
216 : | end | ||
217 : | | _ => NONE | ||
218 : | (* end case *)) | ||
219 : | jhr | 1301 | | S.E_Input(ty, name, desc, optDefault) => raise Fail "impossible" |
220 : | jhr | 1116 | | S.E_LoadImage info => SOME(ImgV info) |
221 : | jhr | 269 | (* end case *)) |
222 : | |||
223 : | jhr | 1116 | and evalArgs (env, args) = let |
224 : | fun eval ([], vs) = SOME(List.rev vs) | ||
225 : | | eval (x::xs, vs) = (case VMap.find(env, x) | ||
226 : | of SOME v => eval(xs, v::vs) | ||
227 : | | NONE => NONE | ||
228 : | (* end case *)) | ||
229 : | in | ||
230 : | eval (args, []) | ||
231 : | end | ||
232 : | |||
233 : | fun getInput (ty, name, optDefault) = (case ty | ||
234 : | of Ty.T_Bool => | ||
235 : | Inputs.getInput(name, (Option.map BV) o Bool.fromString, optDefault) | ||
236 : | | Ty.T_Int => | ||
237 : | Inputs.getInput(name, (Option.map IV) o IntInf.fromString, optDefault) | ||
238 : | | Ty.T_String => Inputs.getInput(name, fn s => SOME(SV s), optDefault) | ||
239 : | | Ty.T_Tensor(Ty.Shape[]) => | ||
240 : | Inputs.getInput(name, (Option.map RV) o Real.fromString, optDefault) | ||
241 : | | Ty.T_Tensor(Ty.Shape[Ty.DimConst d]) => let | ||
242 : | fun fromString s = let | ||
243 : | (* first split into fields by "," *) | ||
244 : | val flds = String.fields (fn #"," => true | _ => false) s | ||
245 : | (* then tokenize by white space and flatten *) | ||
246 : | val toks = List.concat(List.map (String.tokens Char.isSpace) flds) | ||
247 : | (* then convert to reals *) | ||
248 : | val vals = List.map (RV o valOf o Real.fromString) toks | ||
249 : | in | ||
250 : | if (List.length vals = d) | ||
251 : | then SOME(TV(vals)) | ||
252 : | else NONE | ||
253 : | end | ||
254 : | handle _ => NONE | ||
255 : | in | ||
256 : | Inputs.getInput(name, fromString, optDefault) | ||
257 : | end | ||
258 : | | Ty.T_Tensor shp => raise Fail "TODO: general tensor inputs" | ||
259 : | | _ => raise Fail(concat[ | ||
260 : | "input ", name, " has invalid type ", TypeUtil.toString ty | ||
261 : | ]) | ||
262 : | (* end case *)) | ||
263 : | |||
264 : | jhr | 269 | fun evalStatics (statics, blk) = let |
265 : | fun evalBlock (env, S.Block stms) = let | ||
266 : | exception Done of value VMap.map | ||
267 : | fun evalStm (stm, env) = (case stm | ||
268 : | jhr | 1116 | of S.S_Var _ => raise Fail "unexpected variable decl" |
269 : | jhr | 1301 | | S.S_Assign(x, S.E_Input(ty, name, desc, optDefault)) => |
270 : | jhr | 269 | if VSet.member(statics, x) |
271 : | then let | ||
272 : | val optDefault = Option.map (evalVar env) optDefault | ||
273 : | jhr | 1116 | val input = getInput (ty, name, optDefault) |
274 : | jhr | 269 | in |
275 : | case input | ||
276 : | of SOME v => VMap.insert(env, x, v) | ||
277 : | | NONE => raise Fail("error getting required input " ^ name) | ||
278 : | (* end case *) | ||
279 : | end | ||
280 : | else env | ||
281 : | | S.S_Assign(x, e) => (case evalExp(env, e) | ||
282 : | of SOME v => | ||
283 : | jhr | 340 | (Log.msg(concat["eval assignment: ", Var.uniqueNameOf x, " = ", toString v, "\n"]); |
284 : | jhr | 269 | VMap.insert(env, x, v) |
285 : | ) | ||
286 : | | NONE => env | ||
287 : | (* end case *)) | ||
288 : | | S.S_IfThenElse(x, b1, b2) => (case VMap.find(env, x) | ||
289 : | of SOME(BV true) => evalBlock(env, b1) | ||
290 : | | SOME(BV false) => evalBlock(env, b2) | ||
291 : | | SOME _ => raise Fail "type error" | ||
292 : | | NONE => raise (Done env) | ||
293 : | (* end case *)) | ||
294 : | jhr | 511 | | S.S_New _ => raise Fail "unexpected new strand" |
295 : | jhr | 269 | | S.S_Die => raise Fail "unexpected die" |
296 : | | S.S_Stabilize => raise Fail "unexpected stabilize" | ||
297 : | jhr | 1640 | | S.S_Print _ => raise Fail "unexpected print" |
298 : | jhr | 234 | (* end case *)) |
299 : | jhr | 231 | in |
300 : | jhr | 269 | (List.foldl evalStm env stms) handle Done env => env |
301 : | jhr | 231 | end |
302 : | jhr | 236 | in |
303 : | jhr | 269 | evalBlock (VMap.empty, blk) |
304 : | jhr | 236 | end |
305 : | |||
306 : | jhr | 231 | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |