14 |
structure Ty = Types |
structure Ty = Types |
15 |
structure U = Util |
structure U = Util |
16 |
|
|
17 |
|
val realZero = AST.E_Lit(Literal.Float(FloatLit.zero true)) |
18 |
|
|
19 |
(* check a differentiation level, which muse be >= 0 *) |
(* check a differentiation level, which muse be >= 0 *) |
20 |
fun checkDiff (cxt, k) = |
fun checkDiff (cxt, k) = |
21 |
if (k < 0) |
if (k < 0) |
111 |
then (AST.E_Apply(rator, tyArgs, [e1', e2'], rngTy), rngTy) |
then (AST.E_Apply(rator, tyArgs, [e1', e2'], rngTy), rngTy) |
112 |
else raise Fail "type error for binary operator" |
else raise Fail "type error for binary operator" |
113 |
end |
end |
114 |
| ovldList => raise Fail "unimplemented" (* FIXME *) |
| ovldList => raise Fail "overloaded binops unimplemented" (* FIXME *) |
115 |
(* end case *) |
(* end case *) |
116 |
end |
end |
117 |
| PT.E_UnaryOp(rator, e) => let |
| PT.E_UnaryOp(rator, e) => let |
151 |
val ty = checkTy(cxt, ty) |
val ty = checkTy(cxt, ty) |
152 |
val (args, tys) = checkExprList (env, cxt, args) |
val (args, tys) = checkExprList (env, cxt, args) |
153 |
in |
in |
154 |
raise Fail "E_Cons unimplemented" (* FIXME *) |
case (ty, tys) |
155 |
|
of (Ty.T_Tensor(Ty.Shape[]), [Ty.T_Int]) => (* int to real conversion *) |
156 |
|
(AST.E_Apply(BasisVars.i2r, [], args, ty), ty) |
157 |
|
| (Ty.T_Tensor(Ty.Shape[]), _) => raise Fail "invalid \"real\" conversion" |
158 |
|
| (Ty.T_Tensor(Ty.Shape dims), _) => let |
159 |
|
fun getDim (Ty.DimConst k) = k |
160 |
|
| getDim _ = raise Fail "unexpected dimension variable" |
161 |
|
val resultArity = List.foldl (fn (dim, a) => getDim dim * a) 1 dims |
162 |
|
val argArity = List.length args |
163 |
|
in |
164 |
|
if (resultArity = argArity) |
165 |
|
then (AST.E_Cons(ty, args), ty) |
166 |
|
else if (resultArity > argArity) |
167 |
|
then let |
168 |
|
val xArgs = List.tabulate (resultArity-argArity, fn _ => realZero) |
169 |
|
in |
170 |
|
(AST.E_Cons(ty, args@xArgs), ty) |
171 |
|
end |
172 |
|
else raise Fail "arity mismatch in tensor construction" |
173 |
|
end |
174 |
|
(* end case *) |
175 |
end |
end |
176 |
(* end case *)) |
(* end case *)) |
177 |
|
|