Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] Diff of /branches/ein16/src/compiler/typechecker/typechecker.sml
 [diderot] / branches / ein16 / src / compiler / typechecker / typechecker.sml

# Diff of /branches/ein16/src/compiler/typechecker/typechecker.sml

revision 3945, Thu Jun 9 20:00:15 2016 UTC revision 3946, Sat Jun 11 00:39:19 2016 UTC
# Line 259  Line 259
259              tryCandidates candidates              tryCandidates candidates
260            end            end
261
262
263
264    (* type check an outer product, which has the constraint:
265    *     ALL[sigma1, sigma2] . tensor[sigma1] * tensor[sigma2] -> tensor[sigma1, sigma2]
266    * and similarly for fields.
267    *)
268    fun chkOuterProduct (cxt, e1, ty1, e2, ty2) = let
269    fun mergeShp (Ty.Shape dd1, Ty.Shape dd2) = SOME(Ty.Shape(dd1@dd2))
270    | mergeShp _ = NONE
271    fun shapeError () = err (cxt, [
272    S "unable to determine result shape of outer product\n",
273    S "  found: ", TYS[ty1, ty2], S "\n"
274    ])
275    fun error () = err (cxt, [
276    S "type error for arguments of binary operator \"⊗\"\n",
277    S "  found: ", TYS[ty1, ty2], S "\n"
278    ])
279    in
280    case (TU.prune ty1, TU.prune ty2)
281    (* tensor * tensor outer product *)
282    of (Ty.T_Tensor s1, Ty.T_Tensor s2) => (case mergeShp(s1, s2)
283    of SOME shp => let
284    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_outer_tt)
285    val resTy = Ty.T_Tensor shp
286    in
287    if U.equalTypes(domTy, [ty1, ty2])
288    andalso U.equalType(rngTy, resTy)
289    then (AST.E_Apply(BV.op_outer_tt, tyArgs, [e1, e2], rngTy), rngTy)
290    else error()
291    end
292    | NONE => shapeError()
293    (* end case *))
294    (* field * tensor outer product *)
295    | (Ty.T_Field{diff, dim, shape=s1}, Ty.T_Tensor s2) => (case mergeShp(s1, s2)
296    of SOME shp => let
297    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_outer_ft)
298    val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
299    in
300    if U.equalTypes(domTy, [ty1, ty2]) andalso U.equalType(rngTy, resTy)
301    then (AST.E_Apply(BV.op_outer_ft, tyArgs, [e1, e2], rngTy), rngTy)
302    else error()
303    end
304    | NONE => shapeError()
305    (* end case *))
306    (* tensor * field outer product *)
307    | (Ty.T_Tensor s1, Ty.T_Field{diff=diff, dim=dim, shape=s2}) => (case mergeShp(s1, s2)
308    of SOME shp => let
309    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = Util.instantiate(Var.typeOf BV.op_outer_tf)
310    val resTy = Ty.T_Field{diff=diff, dim=dim, shape=shp}
311    in
312    if U.equalTypes(domTy, [ty1, ty2]) andalso U.equalType(rngTy, resTy)
313    then (AST.E_Apply(BV.op_outer_tf, tyArgs, [e1, e2], rngTy), rngTy)
314    else error()
315    end
316    | NONE => shapeError()
317    (* end case *))
318    (* field * field outer product *)
319    | (Ty.T_Field{diff=k1, dim=dim1, shape=s1}, Ty.T_Field{diff=k2, dim=dim2, shape=s2}) => (
320    case mergeShp(s1, s2)
321    of SOME shp => let
322    val (tyArgs, Ty.T_Fun(domTy, rngTy)) = (*TU.instantiate*)Util.instantiate(Var.typeOf BV.op_outer_ff)
323    val resTy = Ty.T_Field{diff=k1, dim=dim1, shape=shp}
324    in
325    (* FIXME: the resulting differentiation should be the minimum of k1 and k2 *)
326    if U.equalDim(dim1, dim2)
327    andalso U.equalTypes(domTy, [ty1, ty2])
328    andalso U.equalType(rngTy, resTy)
329    then (AST.E_Apply
330    (BV.op_outer_ff, tyArgs, [e1, e2], rngTy), rngTy)
331    else error()
332    end
333    | NONE => shapeError()
334    (* end case *))
335    | _ => error()
336    (* end case *)
337    end
338
339    (* typecheck an expression and translate it to AST *)    (* typecheck an expression and translate it to AST *)
340      fun checkExpr (env : env, cxt, e) = (case e      fun checkExpr (env : env, cxt, e) = (case e
341             of PT.E_Mark m => checkExpr (withEnvAndContext (env, cxt, m))             of PT.E_Mark m => checkExpr (withEnvAndContext (env, cxt, m))
# Line 482  Line 559
559                                S "  found: ", TYS[ty1, ty2], S "\n"                                S "  found: ", TYS[ty1, ty2], S "\n"
560                              ])                              ])
561                        (* end case *))                        (* end case *))
562                      else (case Env.findFunc (#env env, rator)                      else if Atom.same(rator, BasisNames.op_outer)
563                            then (print "outer found";chkOuterProduct (cxt, e1', ty1, e2', ty2))