Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/tree-ir/check-tree.sml
ViewVC logotype

View of /branches/vis15/src/compiler/tree-ir/check-tree.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3862 - (download) (annotate)
Sun May 15 15:44:30 2016 UTC (3 years ago) by jhr
File size: 16704 byte(s)
working on merge
(* check-tree.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *
 * TODO: check global and state variable consistency
 *)

(* FIXME: the cehcking function should be parameterized over the vector layout of the target *)

structure CheckTree : sig

    val check : string * TreeIR.program -> bool

  end = struct

    structure IR = TreeIR
    structure Op = TreeOps
    structure Ty = TreeTypes
    structure GVar = TreeGlobalVar
    structure SVar = TreeStateVar
    structure Var = TreeVar
    structure VSet = Var.Set

    datatype token
      = NL | S of string | A of Atom.atom | V of IR.var
      | TY of Ty.t | TYS of Ty.t list

    fun error errBuf toks = let
          fun tok2str NL = "\n  ** "
            | tok2str (S s) = s
            | tok2str (A s) = Atom.toString s
            | tok2str (V x) = Var.toString x
            | tok2str (TY ty) = Ty.toString ty
            | tok2str (TYS []) = "()"
            | tok2str (TYS[ty]) = Ty.toString ty
            | tok2str (TYS tys) = String.concat[
                  "(", String.concatWith " * " (List.map Ty.toString tys), ")"
                ]
          in
            errBuf := concat ("**** Error: " :: List.map tok2str toks)
              :: !errBuf
          end

  (* utility function for synthesizing eigenvector/eigenvalue signature *)
    fun eigenSig dim = let
          val tplTy = Ty.TupleTy[
                  Ty.SeqTy(Ty.realTy, SOME dim),
                  Ty.SeqTy(Ty.TensorTy[dim], SOME dim)
                ]
          in
(* FIXME: what about pieces? *)
            (tplTy, [Ty.TensorTy[dim, dim]])
          end

    exception BadVecType of int

  (* Return the signature of a TreeIR operator. *)
    fun sigOfOp (vecTy, rator) = (case rator
           of Op.IAdd => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
            | Op.ISub => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
            | Op.IMul => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
            | Op.IDiv => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
            | Op.IMod => (Ty.IntTy, [Ty.IntTy, Ty.IntTy])
            | Op.INeg => (Ty.IntTy, [Ty.IntTy])
	    | Op.RAdd => (Ty.realTy, [Ty.realTy, Ty.realTy])
	    | Op.RSub => (Ty.realTy, [Ty.realTy, Ty.realTy])
	    | Op.RMul => (Ty.realTy, [Ty.realTy, Ty.realTy])
	    | Op.RDiv => (Ty.realTy, [Ty.realTy, Ty.realTy])
	    | Op.RNeg => (Ty.realTy, [Ty.realTy])
	    | Op.RClamp => (Ty.realTy, [Ty.realTy, Ty.realTy, Ty.realTy])
	    | Op.RLerp => (Ty.realTy, [Ty.realTy, Ty.realTy, Ty.realTy])
            | Op.LT ty => (Ty.BoolTy, [ty, ty])
            | Op.LTE ty => (Ty.BoolTy, [ty, ty])
            | Op.EQ ty => (Ty.BoolTy, [ty, ty])
            | Op.NEQ ty => (Ty.BoolTy, [ty, ty])
            | Op.GT ty => (Ty.BoolTy, [ty, ty])
            | Op.GTE ty => (Ty.BoolTy, [ty, ty])
            | Op.Not => (Ty.BoolTy, [Ty.BoolTy])
            | Op.Abs ty => (ty, [ty])
            | Op.Max ty => (ty, [ty, ty])
            | Op.Min ty => (ty, [ty, ty])
	    | Op.VAdd d => (vecTy d, [vecTy d, vecTy d])
	    | Op.VSub d => (vecTy d, [vecTy d, vecTy d])
	    | Op.VScale d => (vecTy d, [Ty.realTy, vecTy d])
	    | Op.VMul d => (vecTy d, [vecTy d, vecTy d])
	    | Op.VNeg d => (vecTy d, [vecTy d])
	    | Op.VSum d => (Ty.realTy, [vecTy d])
	    | Op.VIndex(n, _) => (Ty.realTy, [vecTy n])
            | Op.VClamp d => (vecTy d, [vecTy d, Ty.realTy, Ty.realTy])
            | Op.VMapClamp d => (vecTy d, [vecTy d, vecTy d, vecTy d])
            | Op.VLerp d => (vecTy d, [vecTy d, vecTy d, Ty.realTy])
	    | Op.TensorIndex(ty, _) => (Ty.realTy, [ty])
	    | Op.ProjectLast(ty as Ty.TensorTy dd, _) => (Ty.TensorTy[List.last dd], [ty])
            | Op.EigenVecs2x2 => eigenSig 2
            | Op.EigenVecs3x3 => eigenSig 3
(* FIXME: what about pieces? *)
            | Op.EigenVals2x2 => (Ty.SeqTy(Ty.realTy, SOME 2), [Ty.TensorTy[2, 2]])
            | Op.EigenVals3x3 => (Ty.SeqTy(Ty.realTy, SOME 3), [Ty.TensorTy[3, 3]])
            | Op.Zero ty => (ty, [])
            | Op.Select(ty as Ty.TupleTy tys, i) => (List.nth(tys, i-1), [ty])
            | Op.Subscript(ty as Ty.SeqTy(elemTy, _)) => (elemTy, [ty, Ty.intTy])
            | Op.MkDynamic(ty, n) => (Ty.SeqTy(ty, NONE), [Ty.SeqTy(ty, SOME n)])
            | Op.Prepend ty => (Ty.SeqTy(ty, NONE), [ty, Ty.SeqTy(ty, NONE)])
            | Op.Append ty => (Ty.SeqTy(ty, NONE), [Ty.SeqTy(ty, NONE), ty])
            | Op.Concat ty => (Ty.SeqTy(ty, NONE), [Ty.SeqTy(ty, NONE), Ty.SeqTy(ty, NONE)])
	    | Op.Range => (Ty.SeqTy(Ty.intTy, NONE), [Ty.IntTy, Ty.IntTy])
            | Op.Length ty => (Ty.intTy, [Ty.SeqTy(ty, NONE)])
	    | Op.SphereQuery(ptTy, strandTy) => (Ty.SeqTy(strandTy, NONE), [ptTy, Ty.realTy])
	    | Op.Sqrt => (Ty.realTy, [Ty.realTy])
	    | Op.Cos => (Ty.realTy, [Ty.realTy])
	    | Op.ArcCos => (Ty.realTy, [Ty.realTy])
	    | Op.Sin => (Ty.realTy, [Ty.realTy])
	    | Op.ArcSin => (Ty.realTy, [Ty.realTy])
	    | Op.Tan => (Ty.realTy, [Ty.realTy])
	    | Op.ArcTan => (Ty.realTy, [Ty.realTy])
	    | Op.Exp => (Ty.realTy, [Ty.realTy])
            | Op.Ceiling d => (vecTy d, [vecTy d])
            | Op.Floor d => (vecTy d, [vecTy d])
            | Op.Round d => (vecTy d, [vecTy d])
            | Op.Trunc d => (vecTy d, [vecTy d])
            | Op.IntToReal => (Ty.realTy, [Ty.intTy])
            | Op.RealToInt 1 => (Ty.IntTy, [Ty.realTy])
            | Op.RealToInt d => (Ty.SeqTy(Ty.IntTy, SOME d), [vecTy d])
(* not sure if we will need these
      | R_All of ty
      | R_Exists of ty
      | R_Max of ty
      | R_Min of ty
      | R_Sum of ty
      | R_Product of ty
      | R_Mean of ty
      | R_Variance of ty
*)
(* FIXME: these should probably be compiled down to lower-level operartions at this point!
	    | Op.Transform info => let
                val dim = ImageInfo.dim info
                in
		  if (dim = 1)
		    then (Ty.realTy, [Ty.ImageTy info])
		    else (Ty.matrixTy(dim, dim), [Ty.ImageTy info])
                end
	    | Op.Translate info => let
                val dim = ImageInfo.dim info
                in
		  if (dim = 1)
		    then (Ty.realTy, [Ty.ImageTy info])
		    else (Ty.matrixTy(dim, dim), [Ty.ImageTy info])
                end
*)
	    | Op.ControlIndex(info, _, _) => (Ty.IntTy, [Ty.ImageTy info, Ty.IntTy])
            | Op.Inside(info, _) => (Ty.BoolTy, [vecTy(ImageInfo.dim info), Ty.ImageTy info])
            | Op.ImageDim(info, _) => (Ty.IntTy, [Ty.ImageTy info])
            | Op.LoadSeq(ty, _) => (ty, [])
            | Op.LoadImage(ty, _) => (ty, [])
	    | Op.MathFn f => MathFns.sigOf (Ty.realTy, f)
            | _ => raise Fail("sigOf: invalid operator " ^ Op.toString rator)
          (* end case *))

    fun check (phase, prog) = let
	  val IR.Program{
		  props, target={layout, ...}, consts, inputs, constInit,
		  globals, globalInit, strand, create, update
		} = prog
          val errBuf = ref []
          val errFn = error errBuf
          fun final () = (case !errBuf
                 of [] => false
                  | errs => (
                      Log.msg ["********** IR Errors detected after ", phase, " **********\n"];
                      List.app (fn msg => Log.msg [msg, "\n"]) (List.rev errs);
                      true)
                (* end case *))
	  fun sigOf rator = let
		fun vecTy d = (case layout d
		       of {padded, pieces=[w], ...} => Ty.VecTy(d, w)
			| _ => (
			    errFn [
				S "invalid width ", S(Int.toString d), S " for ", S(Op.toString rator)
			      ];
			    Ty.VecTy(d,d))
		      (* end case *))
		in
		  sigOfOp (vecTy, rator)
		end
        (* check a variable use *)
          fun checkVar (bvs, x) = if VSet.member(bvs, x)
                then ()
                else errFn [S "variable ", V x, S " is not bound"]
	  fun chkBlock (bvs : VSet.set, IR.Block{locals, body}) = let
		fun chkExp (cxt, bvs : VSet.set, e) = let
		      fun chk e = (case e
			     of IR.E_Global gv => GVar.ty gv
			      | IR.E_State(NONE, sv) => SVar.ty sv
			      | IR.E_State(SOME e, sv) => (
(* FIXME: check type of e *)
				  SVar.ty sv)
			      | IR.E_Var x => (checkVar(bvs, x); Var.ty x)
			      | IR.E_Lit(Literal.Int _) => Ty.IntTy
			      | IR.E_Lit(Literal.Real _) => Ty.realTy
			      | IR.E_Lit(Literal.String _) => Ty.StringTy
			      | IR.E_Lit(Literal.Bool _) => Ty.BoolTy
			      | IR.E_Op(rator, args) => let
				  val (resTy, paramTys) = sigOf rator
				  val argTys = List.map chk args
				  in
				    if ListPair.allEq Ty.same (paramTys, argTys)
				      then ()
				      else errFn [
					  S "argument type mismatch in application of ",
					  S(Op.toString rator), S(cxt()),
					  NL, S "expected: ", TYS paramTys,
					  NL, S "found:    ", TYS argTys
					];
				    resTy
				  end
			      | IR.E_Vec(_, es) => let
				  fun chkArg (i, e) = (case chk e
					 of Ty.VecTy(1, 1) => () (* ok *)
					  | ty => errFn [
						S "component ", S(Int.toString i),
						S " of vector does has type ", TY ty, S(cxt())
					      ])
				(* check the result vector type *)
				  val ty = (case layout(List.length es)
(* FIXME: check layout width against E_Vec width *)
					 of {wid, padded, pieces=[w]} => Ty.VecTy(wid, w)
					  | {wid, ...} => (
					      errFn [
						  S "invalid width ", S(Int.toString wid),
						  S " for E_Vec", S(cxt())
						];
					      Ty.VecTy(List.length es, wid))
					(* end case *))
				  in
				    List.appi chkArg es;
				    ty
				  end
			      | IR.E_Cons([], ty) => (
				  errFn [S "empty cons", S(cxt())];
				  ty)
			      | IR.E_Cons(es, consTy as Ty.TensorTy dd) => let
				  val nelems = List.foldl Int.* 1 dd
				  in
				    if (length es <> nelems)
				      then errFn [
					  S "cons has incorrect number of elements", S(cxt()),
					  NL, S "  expected: ", S(Int.toString nelems),
					  NL, S "  found:    ", S(Int.toString(length es))
					]
				      else ();
				    chkElems ("cons", Ty.realTy, es);
				    consTy
				  end
			      | IR.E_Cons(es, ty) => (
				  errFn [S "unexpected type for cons", S(cxt()), S ": ", TY ty];
				  ty)
			      | IR.E_Seq([], ty as Ty.SeqTy(_, SOME 0)) => ty
			      | IR.E_Seq([], ty as Ty.SeqTy(_, SOME n)) => (
				  errFn [S "empty sequence, but expected ", TY ty, S(cxt())];
				  ty)
			      | IR.E_Seq(es, seqTy as Ty.SeqTy(ty, NONE)) => (
				  chkElems ("sequence", ty, es);
				  seqTy)
			      | IR.E_Seq(es, seqTy as Ty.SeqTy(ty, SOME n)) => (
				  if (length es <> n)
				    then errFn [
				        S "sequence has incorrect number of elements", S(cxt()),
					NL, S "  expected: ", S(Int.toString n),
					NL, S "  found:    ", S(Int.toString(length es))
				      ]
				    else ();
				  chkElems ("sequence", ty, es);
				  seqTy)
			      | IR.E_Seq(es, ty) => (
				  errFn [S "unexpected type for sequence", S(cxt()), S ": ", TY ty];
				  ty)
			      | IR.E_Pack(layout, es) => let
				  fun chkOne (i, ty, ty') = if Ty.same(ty, ty')
					then ()
					else errFn[
					    S "mismatch in component ", S(Int.toString i),
					    S " of PACK", S(cxt()),
					    NL, S "  expected: ", TY ty',
					    NL, S "  found:    ", TY ty
					  ]
				  in
				    ListPair.appi chkOne (List.map chk es, Ty.piecesOf layout);
				    Ty.TensorTy[#wid layout]
				  end
			      | IR.E_VLoad(layout, e, i) => let
				  val ty = chk e
				  val expectedTy = Ty.TensorTy[#wid layout]
				  in
				    if Ty.same(ty, expectedTy)
				      then ()
				      else errFn [
					  S "type mismatch in E_VLoad, S(cxt())",
					  NL, S "  expected: ", TY expectedTy,
					  NL, S "  found:    ", TY ty
					];
				    Ty.nthVec(layout, i)
				  end
			    (* end case *))
		      and chkElems (cxt', ty, []) = ()
			| chkElems (cxt', ty, e::es) = let
			    val ty' = chk e
			    in
			      if Ty.same(ty, ty')
				then ()
				else errFn [
				    S "element of ", S cxt', S " has incorrect type", S(cxt()),
				    NL, S "expected: ", TY ty,
				    NL, S "found:    ", TY ty'
				  ];
			      chkElems (cxt', ty, es)
			    end
		      in
			chk e
		      end
		fun chkStm (stm, bvs : VSet.set) = (case stm
		       of IR.S_Comment _ => bvs
			| IR.S_Assign(isDef, x, e) => let
			    val ty = chkExp (
				  fn () => concat[" in assignment to local ", Var.name x],
				  bvs, e)
			    in
			      if Ty.same(Var.ty x, ty)
				then ()
				else errFn[
				    S "type mismatch in assignment to local ", S(Var.name x),
				    NL, S "lhs: ", TY(Var.ty x),
				    NL, S "rhs: ", TY ty
				  ];
			      if isDef
				then VSet.add(bvs, x)
				else (checkVar(bvs, x); bvs)
			    end
			| IR.S_MAssign(xs, e) => raise Fail "FIXME"
			| IR.S_GAssign(gv, e) => let
			    val ty = chkExp (
				  fn () => concat[" assignment to global ", GVar.name gv],
				  bvs, e)
			    in
			      if Ty.same(GVar.ty gv, ty)
				then ()
				else errFn[
				    S "type mismatch in assignment to global ", S(GVar.name gv),
				    NL, S "lhs: ", TY(GVar.ty gv),
				    NL, S "rhs: ", TY ty
				  ];
			      bvs
			    end
			| IR.S_IfThen(e, b) => let
			    val ty = chkExp (fn () => " in if-then", bvs, e)
			    in
			      if Ty.same(ty, Ty.BoolTy)
				then ()
				else errFn[
				    S "expected bool for if-then, but found ", TY ty
				  ];
			      chkBlock (bvs, b);
			      bvs
			    end
			| IR.S_IfThenElse(e, b1, b2) => let
			    val ty = chkExp (fn () => " in if-then-else", bvs, e)
			    in
			      if Ty.same(ty, Ty.BoolTy)
				then ()
				else errFn[
				    S "expected bool for if-then-else, but found ", TY ty
				  ];
			      chkBlock (bvs, b1);
			      chkBlock (bvs, b2);
			      bvs
			    end
			| IR.S_Foreach(x, e, b) => (
			    case chkExp (fn () => " in foreach", bvs, e)
			     of Ty.SeqTy(ty, _) =>
				  if Ty.same(ty, Var.ty x)
				    then ()
				    else errFn [
					S "type mismatch in foreach ", V x,
					NL, S "variable type: ", TY(Var.ty x),
				        NL, S "domain type:   ", TY ty
				      ]
			      | ty => errFn [
				    S "domain of foreach is not sequence type; found ", TY ty
				  ]
			    (* end case *);
			    ignore (chkBlock (VSet.add(bvs, x), b));
			    bvs)
			| IR.S_LoadNrrd(x, name) => bvs (* FIXME: check type of x *)
			| IR.S_Input(gv, _, _, NONE) => bvs
			| IR.S_Input(gv, _, _, SOME e) => let
			    val ty = chkExp (fn () => concat[" in input ", GVar.name gv], bvs, e)
			    in
			      if Ty.same(GVar.ty gv, ty)
				then ()
				else errFn[
				    S "type mismatch in default for input ", S(GVar.name gv),
				    NL, S "expected: ", TY(GVar.ty gv),
				    NL, S "found:    ", TY ty
				  ];
			      bvs
			    end
			| IR.S_InputNrrd(gv, _, _, _) => (
			    case GVar.ty gv
			     of Ty.SeqTy(_, NONE) => ()
			      | Ty.ImageTy _ => ()
			      | ty => errFn [
				    S "input variable ", S(GVar.name gv), S " has bogus type ",
				    TY ty, S " for lhs for InputNrrd"
				  ]
			    (* end case *);
			    bvs)
			| IR.S_New(_, es) => (
			    List.app (fn e => ignore (chkExp(fn () => concat[" in new"], bvs, e))) es;
			    bvs)
			| IR.S_Save(sv, e) => let
			    val ty = chkExp (fn () => concat[" in save ", SVar.name sv], bvs, e)
			    in
			      if Ty.same(SVar.ty sv, ty)
				then ()
				else errFn[
				    S "type mismatch in assignment to state variable ",
				    S(SVar.name sv),
				    NL, S "lhs: ", TY(SVar.ty sv),
				    NL, S "rhs: ", TY ty
				  ];
			      bvs
			    end
			| IR.S_Exit => bvs
			| IR.S_Print(tys, es) => (
			    if (length tys <> length es)
			      then errFn [
				]
			      else ();
			    ListPair.appi
			      (fn (i, ty, e) => let
				val ty' = chkExp(fn () => concat[" in print"], bvs, e)
				in
				  if Ty.same(ty, ty')
				    then ()
				    else errFn[
					S "type mismatch in argument ", S(Int.toString i),
					S " of print",
					NL, S "expected:  ", TY ty,
					NL, S "but found: ", TY ty'
				      ]
				end)
				(tys, es);
			    bvs)
			| IR.S_Active => bvs
			| IR.S_Stabilize => bvs
			| IR.S_Die => bvs
		      (* end case *))
		val bvs = VSet.addList(bvs, !locals)
		in
		  ignore (List.foldl chkStm bvs body)
		end
	  fun chkOptBlock (_, NONE) = ()
	    | chkOptBlock (bvs, SOME blk) = ignore (chkBlock (bvs, blk))
	  fun chkStrand (IR.Strand{name, params, state, stateInit, initM, updateM, stabilizeM}) = (
		ignore (chkBlock (VSet.fromList params, stateInit));
		chkOptBlock (VSet.empty, initM);
		ignore (chkBlock (VSet.empty, updateM));
		chkOptBlock (VSet.empty, stabilizeM))
	  in
	    ignore (chkBlock (VSet.empty, constInit));
	    ignore (chkBlock (VSet.empty, globalInit));
	    chkStrand strand;
	    case create of IR.Create{code, ...} => ignore (chkBlock (VSet.empty, code));
	    chkOptBlock (VSet.empty, update);
	    final ()
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0