Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/typechecker/check-expr.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/typechecker/check-expr.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3402, Wed Nov 11 02:54:23 2015 UTC revision 3405, Wed Nov 11 14:46:13 2015 UTC
# Line 19  Line 19 
19      structure E = Env      structure E = Env
20      structure Ty = Types      structure Ty = Types
21      structure BV = BasisVars      structure BV = BasisVars
22        structure TU = TypeUtil
23    
24    (* an expression to return when there is a type error *)    (* an expression to return when there is a type error *)
25      val bogusExp = (AST.E_Lit(L.Int 0), Ty.T_Error)      val bogusExp = AST.E_Lit(L.Int 0)
26        val bogusExpTy = (bogusExp, Ty.T_Error)
27    
28      fun err arg = (TypeError.error arg; bogusExp)      fun err arg = (TypeError.error arg; bogusExpTy)
29      val warn = TypeError.warning      val warn = TypeError.warning
30    
31      datatype token = datatype TypeError.token      datatype token = datatype TypeError.token
# Line 36  Line 38 
38              | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)              | (L.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
39            (* end case *))            (* end case *))
40    
41      (* check a tensor shape *)
42        fun checkShape (cxt, shape) =  let
43              fun checkDim d = (
44                    if (d <= 1)
45                      then TypeError.error (cxt, [S "invalid tensor-shape dimension; must be > 1"])
46                      else ();
47                    Ty.DimConst(IntInf.toInt d))
48              in
49                Ty.Shape(List.map checkDim shape)
50              end
51    
52      (* type check a dot product, which has the constraint:
53       *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]
54       * and similarly for fields.
55       *)
56        fun chkInnerProduct (cxt, e1, ty1, e2, ty2) = let
57            (* check the shape of the two arguments to verify that the inner constraint matches *)
58              fun chkShape (Ty.Shape(dd1 as _::_), Ty.Shape(d2::dd2)) = let
59                    val (dd1, d1) = let
60                          fun splitLast (prefix, [d]) = (List.rev prefix, d)
61                            | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)
62                            | splitLast (_, []) = raise Fail "impossible"
63                          in
64                            splitLast ([], dd1)
65                          end
66                    in
67                      if Unify.equalDim(d1, d2)
68                        then SOME(Ty.Shape(dd1@dd2))
69                        else NONE
70                    end
71                | chkShape _ = NONE
72              fun error () = err (cxt, [
73                      S "type error for arguments of binary operator '•'\n",
74                      S "  found: ", TYS[ty1, ty2], S "\n"
75                    ])
76              in
77                case (TU.prune ty1, TU.prune ty2)
78                (* tensor * tensor inner product *)
79                 of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
80                       of SOME shp => let
81                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tt)
82                            val resTy = Ty.T_Tensor shp
83                            in
84                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
85                                then (AST.E_Apply(BV.op_inner_tt, tyArgs, [e1, e2], rngTy), rngTy)
86                                else error()
87                            end
88                        | NONE => error()
89                      (* end case *))
90                (* tensor * field inner product *)
91                  | (Ty.T_Tensor s1, Ty.T_Field{diff, dim, shape=s2}) => (case chkShape(s1, s2)
92                       of SOME shp => let
93                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_tf)
94                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
95                            in
96                              if Unify.equalTypes(domTy, [ty1, ty2])
97                              andalso Unify.equalType(rngTy, resTy)
98                                then (AST.E_Apply(BV.op_inner_tf, tyArgs, [e1, e2], rngTy), rngTy)
99                                else error()
100                            end
101                        | NONE => error()
102                      (* end case *))
103                (* field * tensor inner product *)
104                  | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
105                       of SOME shp => let
106                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ft)
107                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
108                            in
109                              if Unify.equalTypes(domTy, [ty1, ty2])
110                              andalso Unify.equalType(rngTy, resTy)
111                                then (AST.E_Apply(BV.op_inner_ft, tyArgs, [e1, e2], rngTy), rngTy)
112                                else error()
113                            end
114                        | NONE => error()
115                      (* end case *))
116                (* field * field inner product *)
117                  | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
118                      case chkShape(s1, s2)
119                       of SOME shp => let
120                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_inner_ff)
121                            val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
122                            in
123    (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
124                              if Unify.equalDim(dim1, dim2)
125                              andalso Unify.equalTypes(domTy, [ty1, ty2])
126                              andalso Unify.equalType(rngTy, resTy)
127                                then (AST.E_Apply(BV.op_inner_ff, tyArgs, [e1, e2], rngTy), rngTy)
128                                else error()
129                            end
130                        | NONE => error()
131                      (* end case *))
132                  | (ty1, ty2) => error()
133                (* end case *)
134              end
135    
136      (* type check a colon product, which has the constraint:
137       *     ALL[sigma1, d1, d2, sigma2] . tensor[sigma1, d1, d2] * tensor[d2, d1, sigma2] -> tensor[sigma1, sigma2]
138       * and similarly for fields.
139       *)
140        fun chkColonProduct (cxt, e1, ty1, e2, ty2) = let
141            (* check the shape of the two arguments to verify that the inner constraint matches *)
142              fun chkShape (Ty.Shape(dd1 as _::_::_), Ty.Shape(d21::d22::dd2)) = let
143                    val (dd1, d11, d12) = let
144                          fun splitLast2 (prefix, [d1, d2]) = (List.rev prefix, d1, d2)
145                            | splitLast2 (prefix, d::dd) = splitLast2 (d::prefix, dd)
146                            | splitLast2 (_, []) = raise Fail "impossible"
147                          in
148                            splitLast2 ([], dd1)
149                          end
150                    in
151                      if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)
152                        then SOME(Ty.Shape(dd1@dd2))
153                        else NONE
154                    end
155                | chkShape _ = NONE
156              fun error () = err (cxt, [
157                      S "type error for arguments of binary operator \":\"\n",
158                      S "  found: ", TYS[ty1, ty2], S "\n"
159                    ])
160              in
161                case (TU.prune ty1, TU.prune ty2)
162                (* tensor * tensor colon product *)
163                 of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case chkShape(s1, s2)
164                       of SOME shp => let
165                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tt)
166                            val resTy = Ty.T_Tensor shp
167                            in
168                              if Unify.equalTypes(domTy, [ty1, ty2])
169                              andalso Unify.equalType(rngTy, resTy)
170                                then (AST.E_Apply(BV.op_colon_tt, tyArgs, [e1, e2], rngTy), rngTy)
171                                else error()
172                            end
173                        | NONE => error()
174                      (* end case *))
175                (* field * tensor colon product *)
176                  | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case chkShape(s1, s2)
177                       of SOME shp => let
178                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ft)
179                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
180                            in
181                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
182                                then (AST.E_Apply(BV.op_colon_ft, tyArgs, [e1, e2], rngTy), rngTy)
183                                else error()
184                            end
185                        | NONE => error()
186                      (* end case *))
187                (* tensor * field colon product *)
188                  | (Ty.T_Tensor s1, Ty.T_Field{diff=diff, dim=dim, shape=s2}) => (case chkShape(s1, s2)
189                       of SOME shp => let
190                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_tf)
191                            val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
192                            in
193                              if Unify.equalTypes(domTy, [ty1, ty2]) andalso Unify.equalType(rngTy, resTy)
194                                then (AST.E_Apply(BV.op_colon_tf, tyArgs, [e1, e2], rngTy), rngTy)
195                                else error()
196                            end
197                        | NONE => error()
198                      (* end case *))
199                (* field * field colon product *)
200                  | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
201                      case chkShape(s1, s2)
202                       of SOME shp => let
203                            val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf BV.op_colon_ff)
204                            val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
205                            in
206    (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
207                              if Unify.equalDim(dim1, dim2)
208                              andalso Unify.equalTypes(domTy, [ty1, ty2])
209                              andalso Unify.equalType(rngTy, resTy)
210                                then (AST.E_Apply(BV.op_colon_ff, tyArgs, [e1, e2], rngTy), rngTy)
211                                else error()
212                            end
213                        | NONE => error()
214                      (* end case *))
215                  | (ty1, ty2) => error()
216                (* end case *)
217              end
218    
219    (* check the type of an expression *)    (* check the type of an expression *)
220      fun check (env, cxt, e) = (case e      fun check (env, cxt, e) = (case e
221             of PT.E_Mark m => check (withEnvAndContext (env, cxt, m))             of PT.E_Mark m => check (E.withEnvAndContext (env, cxt, m))
222              | PT.E_Cond(e1, cond, e2) => let              | PT.E_Cond(e1, cond, e2) => let
223                  val eTy1 = check (env, cxt, e1)                  val eTy1 = check (env, cxt, e1)
224                  val eTy2 = check (env, cxt, e2)                  val eTy2 = check (env, cxt, e2)
225                  in                  in
226                    case check(env, cxt, cond)                    case check(env, cxt, cond)
227                     of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)                     of (cond', Ty.T_Bool) => (case Util.coerceType2(eTy1, eTy2)
228                           of SOME(e1, e2, ty) => (AST.E_Cond(cond', e1', e2', ty), ty)                           of SOME(e1', e2', ty) => (AST.E_Cond(cond', e1', e2', ty), ty)
229                            | NONE => err (cxt, [                            | NONE => err (cxt, [
230                                S "types do not match in conditional expression\n",                                S "types do not match in conditional expression\n",
231                                S "  true branch:  ", TY(#2 eTy1), S "\n",                                S "  true branch:  ", TY(#2 eTy1), S "\n",
# Line 81  Line 261 
261                  val (e2', ty2) = check (env, cxt, e2)                  val (e2', ty2) = check (env, cxt, e2)
262                  in                  in
263                    if Atom.same(rator, BasisNames.op_dot)                    if Atom.same(rator, BasisNames.op_dot)
264                    (* we have to handle inner product as a special case, because our type                      then chkInnerProduct (cxt, e1', ty1, e2', ty2)
                    * system cannot express the constraint that the type is  
                    *     ALL[sigma1, d1, sigma2] . tensor[sigma1, d1] * tensor[d1, sigma2] -> tensor[sigma1, sigma2]  
                    *)  
                     then (case (TU.prune ty1, TU.prune ty2)  
                        of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_)), Ty.T_Tensor(s2 as Ty.Shape(d2::dd2))) => let  
                             val (dd1, d1) = let  
                                   fun splitLast (prefix, [d]) = (List.rev prefix, d)  
                                     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)  
                                     | splitLast (_, []) = raise Fail "impossible"  
                                   in  
                                     splitLast ([], dd1)  
                                   end  
                             val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_inner)  
                             val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))  
                             in  
                               if Unify.equalDim(d1, d2)  
                               andalso Unify.equalTypes(domTy, [ty1, ty2])  
                               andalso Unify.equalType(rngTy, resTy)  
                                 then (AST.E_Apply(BV.op_inner, tyArgs, [e1', e2'], rngTy), rngTy)  
                                 else err (cxt, [  
                                     S "type error for arguments of binary operator '•'\n",  
                                     S "  found: ", TYS[ty1, ty2], S "\n"  
                                   ])  
                             end  
                        | (ty1, ty2) => err (cxt, [  
                               S "type error for arguments of binary operator '•'\n",  
                               S "  found: ", TYS[ty1, ty2], S "\n"  
                             ])  
                       (* end case *))  
265                    else if Atom.same(rator, BasisNames.op_colon)                    else if Atom.same(rator, BasisNames.op_colon)
266                      then (case (TU.prune ty1, TU.prune ty2)                      then chkColonProduct (cxt, e1', ty1, e2', ty2)
267                         of (Ty.T_Tensor(s1 as Ty.Shape(dd1 as _::_::_)), Ty.T_Tensor(s2 as Ty.Shape(d21::d22::dd2))) => let                      else (case Env.findFunc (env, rator)
                             val (dd1, d11, d12) = let  
                                   fun splitLast (prefix, [d1, d2]) = (List.rev prefix, d1, d2)  
                                     | splitLast (prefix, d::dd) = splitLast (d::prefix, dd)  
                                     | splitLast (_, []) = raise Fail "impossible"  
                                   in  
                                     splitLast ([], dd1)  
                                   end  
                             val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_colon)  
                             val resTy = Ty.T_Tensor(Ty.Shape(dd1@dd2))  
                             in  
                               if Unify.equalDim(d11, d21) andalso Unify.equalDim(d12, d22)  
                               andalso Unify.equalTypes(domTy, [ty1, ty2])  
                               andalso Unify.equalType(rngTy, resTy)  
                                 then (AST.E_Apply(BV.op_colon, tyArgs, [e1', e2'], rngTy), rngTy)  
                                 else err (cxt, [  
                                     S "type error for arguments of binary operator ':'\n",  
                                     S "  found: ", TYS[ty1, ty2], S "\n"  
                                   ])  
                             end  
                        | (ty1, ty2) => err (cxt, [  
                               S "type error for arguments of binary operator ':'\n",  
                               S "  found: ", TYS[ty1, ty2], S "\n"  
                             ])  
                       (* end case *))  
                     else (case Env.findFunc (#env env, rator)  
268                         of Env.PrimFun[rator] => let                         of Env.PrimFun[rator] => let
269                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf rator)                              val (tyArgs, Ty.T_Fun(domTy, rngTy)) = TU.instantiate(Var.typeOf rator)
270                              in                              in
271                                case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])                                case Unify.matchArgs(domTy, [e1', e2'], [ty1, ty2])
272                                 of SOME args => (AST.E_Apply(rator, tyArgs, args, rngTy), rngTy)                                 of SOME args => (AST.E_Apply(rator, tyArgs, args, rngTy), rngTy)
# Line 157  Line 283 
283                        (* end case *))                        (* end case *))
284                  end                  end
285              | PT.E_UnaryOp(rator, e) => let              | PT.E_UnaryOp(rator, e) => let
286                  val (e', ty) = check(env, cxt, e)                  val eTy = check(env, cxt, e)
287                  in                  in
288                    case Env.findFunc (#env env, rator)                    case Env.findFunc (env, rator)
289                     of Env.PrimFun[rator] => let                     of Env.PrimFun[rator] => let
290                          val (tyArgs, Ty.T_Fun([domTy], rngTy)) = U.instantiate(Var.typeOf rator)                          val (tyArgs, Ty.T_Fun([domTy], rngTy)) = TU.instantiate(Var.typeOf rator)
291                          in                          in
292                            case coerceType (domTy, ty, e')                            case Util.coerceType (domTy, eTy)
293                             of SOME e' => (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy)                             of SOME(e', ty) => (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy)
294                              | NONE => err (cxt, [                              | NONE => err (cxt, [
295                                    S "type error for unary operator \"", V rator, S "\"\n",                                    S "type error for unary operator \"", V rator, S "\"\n",
296                                    S "  expected:  ", TY domTy, S "\n",                                    S "  expected:  ", TY domTy, S "\n",
297                                    S "  but found: ", TY ty                                    S "  but found: ", TY (#2 eTy)
298                                  ])                                  ])
299                            (* end case *)                            (* end case *)
300                          end                          end
301                      | Env.PrimFun ovldList => resolveOverload (cxt, rator, [ty], [e'], ovldList)                      | Env.PrimFun ovldList => resolveOverload (cxt, rator, [#2 eTy], [#1 eTy], ovldList)
302                      | _ => raise Fail "impossible"                      | _ => raise Fail "impossible"
303                    (* end case *)                    (* end case *)
304                  end                  end
# Line 186  Line 312 
312                        ])                        ])
313                  (* end case *))                  (* end case *))
314              | PT.E_Select(e, field) => (case check(env, cxt, e)              | PT.E_Select(e, field) => (case check(env, cxt, e)
315                   of (e', Ty.T_Strand strand) => (case Env.findStrand(#env env, strand)                   of (e', Ty.T_Named strand) => (case Env.findStrand(env, strand)
316                         of SOME(AST.Strand{name, state, ...}) => let                         of SOME sEnv => (case StrandEnv.findStateVar(sEnv, field)
317                              fun isField (AST.VD_Decl(AST.V{name, ...}, _)) = Atom.same(name, field)                               of SOME x' => let
                             in  
                               case List.find isField state  
                                of SOME(AST.VD_Decl(x', _)) => let  
318                                      val ty = Var.monoTypeOf x'                                      val ty = Var.monoTypeOf x'
319                                      in                                      in
320                                        (AST.E_Selector(e', field, ty), ty)                                      (AST.E_Select(e', x'), ty)
321                                      end                                      end
322                                  | NONE => err(cxt, [                                  | NONE => err(cxt, [
323                                        S "strand ", A name,                                      S "strand '", A strand,
324                                        S " does not have state variable ", A field                                      S "' does not have state variable '", A field, S "'"
325                                      ])                                      ])
326                                (* end case *)                              (* end case *))
327                              end                          | NONE => err(cxt, [S "unknown strand '", A strand, S "'"])
                         | NONE => err(cxt, [S "unknown strand ", A strand])  
328                        (* end case *))                        (* end case *))
329                    | (_, ty) => err (cxt, [                    | (_, ty) => err (cxt, [
330                          S "expected strand type, but found ", TY ty,                          S "expected strand type, but found ", TY ty,
331                          S " in selection of ", A field                          S " in selection of '", A field, S "'"
332                        ])                        ])
333                  (* end case *))                  (* end case *))
334              | PT.E_Real e => (case check (env, cxt, e)              | PT.E_Real e => (case check (env, cxt, e)
# Line 218  Line 340 
340                        ])                        ])
341                  (* end case *))                  (* end case *))
342              | PT.E_Load nrrd => let              | PT.E_Load nrrd => let
343                  val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_image))                  val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_image))
344                  in                  in
345                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
346                  end                  end
347              | PT.E_Image nrrd => let              | PT.E_Image nrrd => let
348                  val (tyArgs, Ty.T_Fun(_, rngTy)) = Util.instantiate(Var.typeOf(BV.fn_load))                  val (tyArgs, Ty.T_Fun(_, rngTy)) = TU.instantiate(Var.typeOf(BV.fn_load))
349                  in                  in
350                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)                    (AST.E_LoadNrrd(tyArgs, nrrd, rngTy), rngTy)
351                  end                  end
352              | PT.E_Var x => (case E.findVar (#env env, x)              | PT.E_Var x => (case E.findVar (env, x)
353                   of SOME x' => (                   of SOME x' => (
354                        markUsed (x', true);                        markUsed (x', true);
355                        (AST.E_Var x', Var.monoTypeOf x'))                        (AST.E_Var x', Var.monoTypeOf x'))
# Line 237  Line 359 
359              | PT.E_Lit lit => checkLit lit              | PT.E_Lit lit => checkLit lit
360              | PT.E_Id d => let              | PT.E_Id d => let
361                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
362                        Util.instantiate(Var.typeOf(BV.identity))                        TU.instantiate(Var.typeOf(BV.identity))
363                  in                  in
364                    if Unify.equalType(Ty.T_Tensor(checkShape(cxt, [d,d])), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(cxt, [d,d])), rngTy)
365                      then (AST.E_Apply(BV.identity, tyArgs, [], rngTy), rngTy)                      then (AST.E_Apply(BV.identity, tyArgs, [], rngTy), rngTy)
# Line 245  Line 367 
367                  end                  end
368              | PT.E_Zero dd => let              | PT.E_Zero dd => let
369                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
370                        Util.instantiate(Var.typeOf(BV.zero))                        TU.instantiate(Var.typeOf(BV.zero))
371                  in                  in
372                    if Unify.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)
373                      then (AST.E_Apply(BV.zero, tyArgs, [], rngTy), rngTy)                      then (AST.E_Apply(BV.zero, tyArgs, [], rngTy), rngTy)
# Line 253  Line 375 
375                  end                  end
376              | PT.E_NaN dd => let              | PT.E_NaN dd => let
377                  val (tyArgs, Ty.T_Fun(_, rngTy)) =                  val (tyArgs, Ty.T_Fun(_, rngTy)) =
378                        Util.instantiate(Var.typeOf(BV.nan))                        TU.instantiate(Var.typeOf(BV.nan))
379                  in                  in
380                    if Unify.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)                    if Unify.equalType(Ty.T_Tensor(checkShape(cxt, dd)), rngTy)
381                      then (AST.E_Apply(BV.nan, tyArgs, [], rngTy), rngTy)                      then (AST.E_Apply(BV.nan, tyArgs, [], rngTy), rngTy)
# Line 269  Line 391 
391                         of NONE => Ty.T_Error                         of NONE => Ty.T_Error
392                          | SOME ty => ty                          | SOME ty => ty
393                        (* end case *))                        (* end case *))
394                  in                (* process the arguments checking that they all have the expected type *)
395                    case realType(TU.pruneHead ty)                  fun chkArgs (ty, shape) = let
                    of ty as Ty.T_Tensor shape => let  
396                          val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)                          val Ty.Shape dd = TU.pruneShape shape (* NOTE: this may fail if we allow user polymorphism *)
397                          val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))                          val resTy = Ty.T_Tensor(Ty.Shape(Ty.DimConst(List.length args) :: dd))
398                          fun chkArgs (arg::args, argTy::tys, args') = (case coerceType(ty, argTy, arg)                        fun chkArgs (arg::args, argTy::tys, args') = (
399                                case Util.coerceType(ty, (arg, argTy))
400                                 of SOME arg' => chkArgs (args, tys, arg'::args')                                 of SOME arg' => chkArgs (args, tys, arg'::args')
401                                  | NONE => (                                  | NONE => (
402                                      TypeError.error(cxt, [                                      TypeError.error(cxt, [
403                                          S "arguments of tensor construction must have same type"                                          S "arguments of tensor construction must have same type"
404                                        ]);                                        ]);
405                                      ??)                                    chkArgs (args, tys, bogusExp::args'))
406                                (* end case *))                                (* end case *))
407                            | chkArgs ([], [], args') = (AST.E_Cons(List.rev args', resTy), resTy)                            | chkArgs ([], [], args') = (AST.E_Cons(List.rev args', resTy), resTy)
408                          in                          in
409                            chkArgs (args, tys, [])                            chkArgs (args, tys, [])
410                          end                          end
411                    in
412                      case TU.pruneHead ty
413                       of Ty.T_Int => chkArgs(Ty.realTy, []) (* coerce integers to reals *)
414                        | ty as Ty.T_Tensor shape => chkArgs(ty, shape)
415                      | _ => err(cxt, [S "Invalid argument type for tensor construction"])                      | _ => err(cxt, [S "Invalid argument type for tensor construction"])
416                    (* end case *)                    (* end case *)
417                  end                  end
# Line 299  Line 425 
425            case (check(env, cxt, e1), check(env, cxt, e2))            case (check(env, cxt, e1), check(env, cxt, e2))
426             of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)             of ((e1', Ty.T_Bool), (e2', Ty.T_Bool)) => (mk(e1', e2'), Ty.T_Bool)
427              | ((_, Ty.T_Bool), (_, ty2)) =>              | ((_, Ty.T_Bool), (_, ty2)) =>
428                  err (cxt, [S "expected type 'bool' on rhs of '", S rator, "', but found ", TY ty2])                  err (cxt, [S "expected type 'bool' on rhs of '", S rator, S "', but found ", TY ty2])
429              | ((_, ty1), (_, Ty.T_Bool)) =>              | ((_, ty1), (_, Ty.T_Bool)) =>
430                  err (cxt, [S "expected type 'bool' on lhs of '", S rator, "', but found ", TY ty1])                  err (cxt, [S "expected type 'bool' on lhs of '", S rator, S "', but found ", TY ty1])
431              | ((_, ty1), (_, ty2)) => err (cxt, [              | ((_, ty1), (_, ty2)) => err (cxt, [
432                    S "arguments of '", S rator, "' must have type 'bool', but found ",                    S "arguments of '", S rator, S "' must have type 'bool', but found ",
433                    TY ty1, S " and ", TY ty2                    TY ty1, S " and ", TY ty2
434                  ])                  ])
435            (* end case *))            (* end case *))

Legend:
Removed from v.3402  
changed lines
  Added in v.3405

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0