Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /trunk/src/typechecker/typechecker.sml
ViewVC logotype

Diff of /trunk/src/typechecker/typechecker.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 82, Wed May 26 18:20:49 2010 UTC revision 85, Wed May 26 19:51:10 2010 UTC
# Line 14  Line 14 
14      structure Ty = Types      structure Ty = Types
15      structure U = Util      structure U = Util
16    
17        val realZero = AST.E_Lit(Literal.Float(FloatLit.zero true))
18    
19    (* check a differentiation level, which muse be >= 0 *)    (* check a differentiation level, which muse be >= 0 *)
20      fun checkDiff (cxt, k) =      fun checkDiff (cxt, k) =
21            if (k < 0)            if (k < 0)
# Line 65  Line 67 
67              | (Literal.Bool _) => (AST.E_Lit lit, Ty.T_Bool)              | (Literal.Bool _) => (AST.E_Lit lit, Ty.T_Bool)
68            (* end case *))            (* end case *))
69    
70      (* resolve overloading: we use a simple scheme that selects the first operator in the
71       * list that matches the argument types.
72       *)
73        fun resolveOverload (rator, argTys, args, candidates) = let
74              fun tryCandidates [] = raise Fail(concat[
75                      "unable to resolve overloaded operator \"", Atom.toString rator, "\""
76                    ])
77                | tryCandidates (x::xs) = let
78                    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf x)
79                    in
80                      if U.tryMatchTypes(domTy, argTys)
81                        then (AST.E_Apply(x, tyArgs, args, rngTy), rngTy)
82                        else tryCandidates xs
83                    end
84              in
85                tryCandidates candidates
86              end
87    
88    (* typecheck an expression and translate it to AST *)    (* typecheck an expression and translate it to AST *)
89      fun checkExpr (env, cxt, e) = (case e      fun checkExpr (env, cxt, e) = (case e
90             of PT.E_Mark m => checkExpr (env, #span m, #tree m)             of PT.E_Mark m => checkExpr (env, #span m, #tree m)
# Line 107  Line 127 
127                          in                          in
128                            if U.matchTypes(domTy, [ty1, ty2])                            if U.matchTypes(domTy, [ty1, ty2])
129                              then (AST.E_Apply(rator, tyArgs, [e1', e2'], rngTy), rngTy)                              then (AST.E_Apply(rator, tyArgs, [e1', e2'], rngTy), rngTy)
130                              else raise Fail "type error for binary operator"                              else raise Fail(concat[
131                                    "type error for binary operator \"", Var.nameOf rator, "\""
132                                  ])
133                          end                          end
134                      | ovldList => raise Fail "unimplemented" (* FIXME *)                      | ovldList => resolveOverload (rator, [ty1, ty2], [e1', e2'], ovldList)
135                    (* end case *)                    (* end case *)
136                  end                  end
137              | PT.E_UnaryOp(rator, e) => let              | PT.E_UnaryOp(rator, e) => let
# Line 121  Line 143 
143                          in                          in
144                            if U.matchType(domTy, ty)                            if U.matchType(domTy, ty)
145                              then (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy)                              then (AST.E_Apply(rator, tyArgs, [e'], rngTy), rngTy)
146                              else raise Fail "type error for binary operator"                              else raise Fail(concat[
147                                    "type error for unary operator \"", Var.nameOf rator, "\""
148                                  ])
149                          end                          end
150                      | ovldList => raise Fail "unimplemented" (* FIXME *)                      | ovldList => resolveOverload (rator, [ty], [e'], ovldList)
151                    (* end case *)                    (* end case *)
152                  end                  end
153              | PT.E_Tuple args => let              | PT.E_Tuple args => let
# Line 149  Line 173 
173                  val ty = checkTy(cxt, ty)                  val ty = checkTy(cxt, ty)
174                  val (args, tys) = checkExprList (env, cxt, args)                  val (args, tys) = checkExprList (env, cxt, args)
175                  in                  in
176                    raise Fail "E_Cons unimplemented" (* FIXME *)                    case (ty, tys)
177                       of (Ty.T_Tensor(Ty.Shape[]), [Ty.T_Int]) => (* int to real conversion *)
178                            (AST.E_Apply(BasisVars.i2r, [], args, ty), ty)
179                        | (Ty.T_Tensor(Ty.Shape[]), _) => raise Fail "invalid \"real\" conversion"
180                        | (Ty.T_Tensor(Ty.Shape dims), _) => let
181                            fun getDim (Ty.DimConst k) = k
182                              | getDim _ = raise Fail "unexpected dimension variable"
183                            val resultArity = List.foldl (fn (dim, a) => getDim dim * a) 1 dims
184                            val argArity = List.length args
185                            in
186                              if (resultArity = argArity)
187                                then (AST.E_Cons(ty, args), ty)
188                              else if (resultArity > argArity)
189                                then let
190                                  val xArgs = List.tabulate (resultArity-argArity, fn _ => realZero)
191                                  in
192                                    (AST.E_Cons(ty, args@xArgs), ty)
193                                  end
194                              else raise Fail "arity mismatch in tensor construction"
195                            end
196                      (* end case *)
197                  end                  end
198            (* end case *))            (* end case *))
199    

Legend:
Removed from v.82  
changed lines
  Added in v.85

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0