Home My Page Projects Code Snippets Project Openings diderot

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/low-to-tree/low-to-tree.sml
 [diderot] / branches / vis15 / src / compiler / low-to-tree / low-to-tree.sml

Diff of /branches/vis15/src/compiler/low-to-tree/low-to-tree.sml

revision 4039, Fri Jun 24 15:35:09 2016 UTC revision 4163, Thu Jul 7 09:45:19 2016 UTC
# Line 65  Line 65
65      val {getFn = getStateVar, ...} = SV.newProp mkStateVar      val {getFn = getStateVar, ...} = SV.newProp mkStateVar
66      end      end
67
68      (* associate Tree IL function variables with Low IL variables using properties *)
69        local
70          fun mkFuncVar f = let
71                val (resTy, paramTys) = IR.Func.ty f
72                in
73    (* QUESTION: what about vector result/arguments? *)
74                  TreeFunc.new (IR.Func.name f, U.trType resTy, List.map U.trType paramTys)
75                end
76        in
77        val {getFn = getFuncVar, ...} = IR.Func.newProp mkFuncVar
78        end
79
80    (* for variables that are in an equivalence class (see UnifyVars), we use a single    (* for variables that are in an equivalence class (see UnifyVars), we use a single
81     * TreeIR variable (or vector of variables) to represent them.     * TreeIR variable (or vector of variables) to represent them.
82     *)     *)
# Line 78  Line 90
90                | NONE => let                | NONE => let
91                    val rep = (case V.ty x                    val rep = (case V.ty x
92                           of Ty.TensorTy[d] => VEC(U.newVectorVars(Env.layoutVec env d))                           of Ty.TensorTy[d] => VEC(U.newVectorVars(Env.layoutVec env d))
93                            | ty => VAR(U.newLocalVar x)                            | ty => if AssignTypes.isMemoryVar x
94                                  then VAR(U.newMemVar x)
95                                  else VAR(U.newLocalVar x)
96                          (* end case *))                          (* end case *))
97                    in                    in
98                      setFn (x, rep);                      setFn (x, rep);
# Line 123  Line 137
137              | _ => e              | _ => e
138           (* end case *))           (* end case *))
139
140      (* turn an expression of type TensorRefTy to one of TensorRef *)
141        fun mkDeref e = (case TreeTypeOf.exp e
142               of TTy.TensorRefTy(shp as _::_) => T.E_Op(TOp.TensorCopy shp, [e])
143                | _ => e
144             (* end case *))
145
146      fun cvtScalarTy Ty.BoolTy = TTy.BoolTy      fun cvtScalarTy Ty.BoolTy = TTy.BoolTy
147        | cvtScalarTy Ty.IntTy = TTy.IntTy        | cvtScalarTy Ty.IntTy = TTy.IntTy
148        | cvtScalarTy (Ty.TensorTy[]) = TTy.realTy        | cvtScalarTy (Ty.TensorTy[]) = TTy.realTy
# Line 130  Line 150
150
151    (* define a new local variable and bind x to it in the environment. *)    (* define a new local variable and bind x to it in the environment. *)
152      fun newLocal (env, x) = let      fun newLocal (env, x) = let
153            val x' = U.newLocalVar x            val x' = if AssignTypes.isMemoryVar x then U.newMemVar x else U.newLocalVar x
154              in
155                Env.bindSimple (env, x, T.E_Var x');
156                x'
157              end
158
159      (* define a new local variable and bind x to it in the environment. *)
160        fun newMemLocal (env, x) = let
161              val x' = U.newMemVar x
162            in            in
163              Env.bindSimple (env, x, T.E_Var x');              Env.bindSimple (env, x, T.E_Var x');
164              x'              x'
# Line 257  Line 285
285                  in                  in
286                    (Env.TREE(T.E_Op(rator, args)), stms)                    (Env.TREE(T.E_Op(rator, args)), stms)
287                  end                  end
288              fun bindTREE' rator = let
289                    val (args, stms) = singleArgs (env, args)
290                    in
291                      (Env.TREE(T.E_Op(rator, args)), stms)
292                    end
293            fun bindRHS (ty, rator) = let            fun bindRHS (ty, rator) = let
294                  val (args, stms) = simpleArgs (env, args)                  val (args, stms) = simpleArgs (env, args)
295                  in                  in
# Line 264  Line 297
297                  end                  end
298            fun bindVOp rator = let            fun bindVOp rator = let
299                  val (layout, argss, stms) = vectorArgs (env, args)                  val (layout, argss, stms) = vectorArgs (env, args)
300                  fun mkArgs (_, [], []) = []                  fun mkArgs (w, [p], [args]) = [T.E_Op(rator(w, p), args)]
301                    | mkArgs (w, p::ps, args::argss) =                    | mkArgs (w, p::ps, args::argss) =
302                        T.E_Op(rator(w, p), args) :: mkArgs (w-p, ps, argss)                        T.E_Op(rator(p, p), args) :: mkArgs (w-p, ps, argss)
303                      | mkArgs _ = raise Fail "bindVOp: arity mismatch"
304                  val exps = mkArgs (#wid layout, #pieces layout, argss)                  val exps = mkArgs (#wid layout, #pieces layout, argss)
305                  in                  in
306                    (Env.VEC(layout, exps), stms)                    (Env.VEC(layout, exps), stms)
# Line 323  Line 357
357                    in                    in
358                      (Env.TREE(List.foldr (fn (e, es) => T.E_Op(TOp.RAdd, [e, es])) e es), stms)                      (Env.TREE(List.foldr (fn (e, es) => T.E_Op(TOp.RAdd, [e, es])) e es), stms)
359                    end                    end
360                  | Op.VDot _ => let
361                      val (layout, argss, stms) = vectorArgs (env, args)
362                      fun mkArgs (w, [p], [args]) = [T.E_Op(TOp.VDot(w, p), args)]
363                        | mkArgs (w, p::ps, args::argss) =
364                            T.E_Op(TOp.VDot(p, p), args) :: mkArgs (w-p, ps, argss)
365                        | mkArgs _ = raise Fail "VDot: arity mismatch"
366                      val e::es = mkArgs (#wid layout, #pieces layout, argss)
367                      in
368                        (Env.TREE(List.foldr (fn (e, es) => T.E_Op(TOp.RAdd, [e, es])) e es), stms)
369                      end
370                | Op.VIndex(_, i) => let                | Op.VIndex(_, i) => let
371                    val [v] = args                    val [v] = args
372                    val ({wid, pieces, ...}, es, stms) = vectorArg (env, v)                    val ({wid, pieces, ...}, es, stms) = vectorArg (env, v)
# Line 336  Line 380
380                      (select (i, wid, pieces, es), stms)                      (select (i, wid, pieces, es), stms)
381                    end                    end
382                | Op.VClamp n => let                | Op.VClamp n => let
383                    val [v, lo, hi] = args                    val [lo, hi, v] = args
384                    val (layout, vs, stms) = vectorArg (env, v)                    val (lo, stms) = simpleArg env (lo, [])
val (lo, stms) = simpleArg env (lo, stms)
385                    val (hi, stms) = simpleArg env (hi, stms)                    val (hi, stms) = simpleArg env (hi, stms)
386                      val (layout, vs, stms') = vectorArg (env, v)
387                    val exps = mkArgs                    val exps = mkArgs
388                          (fn (w, p, x) => (TOp.VClamp(w, p), [x, lo, hi]))                          (fn (w, p, x) => (TOp.VClamp(w, p), [lo, hi, x]))
389                            (layout, vs)                            (layout, vs)
390                    in                    in
391                      (Env.VEC(layout, exps), List.rev stms)                      (Env.VEC(layout, exps), List.revAppend(stms, List.rev stms))
392                    end                    end
393                | Op.VMapClamp n => bindVOp TOp.VMapClamp                | Op.VMapClamp n => bindVOp TOp.VMapClamp
394                | Op.VLerp n => bindVOp TOp.VLerp                | Op.VLerp n => let
395                      val [u, v, t] = args
396                      val (layout, us, stms1) = vectorArg (env, u)
397                      val (_, vs, stms2) = vectorArg (env, v)
398                      val (t, stms) = simpleArg env (t, stms2 @ stms1)
399                      val exps = let
400                            fun mkArgs (w, [p], [u], [v]) = [T.E_Op(TOp.VLerp(w, p), [u, v, t])]
401                              | mkArgs (w, p::ps, u::ur, v::vr) =
402                                  T.E_Op(TOp.VLerp(p, p), [u, v, t]) :: mkArgs (w-p, ps, ur, vr)
403                              | mkArgs _ = raise Fail "VLerp: arity mismatch"
404                            in
405                              mkArgs (#wid layout, #pieces layout, us, vs)
406                            end
407                      in
408                        (Env.VEC(layout, exps), stms)
409                      end
410                | Op.TensorIndex(ty, idxs) => let                | Op.TensorIndex(ty, idxs) => let
411                    val ([arg], stms) = simpleArgs (env, args)                    val ([arg], stms) = simpleArgs (env, args)
412                    val ty = TreeTypeOf.exp arg                    val ty = TreeTypeOf.exp arg
# Line 363  Line 422
422                | Op.Select(ty, i) => bindTREE (TOp.Select(U.trType ty, i))                | Op.Select(ty, i) => bindTREE (TOp.Select(U.trType ty, i))
423                | Op.Subscript ty => bindTREE (TOp.Subscript(U.trType ty))                | Op.Subscript ty => bindTREE (TOp.Subscript(U.trType ty))
424                | Op.MkDynamic(ty, n) => bindTREE (TOp.MkDynamic(U.trType ty, n))                | Op.MkDynamic(ty, n) => bindTREE (TOp.MkDynamic(U.trType ty, n))
425                | Op.Append ty => bindTREE (TOp.Append(U.trType ty))                | Op.Append ty => bindTREE' (TOp.Append(U.trType ty))
426                | Op.Prepend ty => bindTREE (TOp.Prepend(U.trType ty))                | Op.Prepend ty => bindTREE' (TOp.Prepend(U.trType ty))
427                | Op.Concat ty => bindTREE (TOp.Concat(U.trType ty))                | Op.Concat ty => bindTREE (TOp.Concat(U.trType ty))
428                | Op.Range => bindTREE TOp.Range                | Op.Range => bindTREE TOp.Range
429                | Op.Length ty => bindTREE (TOp.Length(U.trType ty))                | Op.Length ty => bindTREE (TOp.Length(U.trType ty))
430                | Op.SphereQuery(ty1, ty2) => bindTREE (TOp.SphereQuery(U.trType ty1, U.trType ty2))                | Op.SphereQuery(ty1, ty2) => bindTREE' (TOp.SphereQuery(U.trType ty1, U.trType ty2))
431                | Op.Sqrt => bindTREE TOp.Sqrt                | Op.Sqrt => bindTREE TOp.Sqrt
432                | Op.Cos => bindTREE TOp.Cos                | Op.Cos => bindTREE TOp.Cos
433                | Op.ArcCos => bindTREE TOp.ArcCos                | Op.ArcCos => bindTREE TOp.ArcCos
# Line 410  Line 469
469                | Op.Translate info => bindTREE (TOp.Translate info)                | Op.Translate info => bindTREE (TOp.Translate info)
470                | Op.ControlIndex(info, ctl, d) => bindTREE (TOp.ControlIndex(info, ctl, d))                | Op.ControlIndex(info, ctl, d) => bindTREE (TOp.ControlIndex(info, ctl, d))
472    (*
473                | Op.Inside(info, s) => bindTREE (TOp.Inside(info, s))                | Op.Inside(info, s) => bindTREE (TOp.Inside(info, s))
474    *)
475                  | Op.Inside(info, s) => (case ImageInfo.dim info
476                       of 1 => bindTREE (TOp.Inside(VectorLayout.realLayout, info, s))
477                        | d => let
478                            val [x, img] = args
479                            val (layout, args, stms) = vectorArg (env, x)
480                            val (img, stms) = simpleArg env (img, stms)
481                            in
482                              (Env.TREE(T.E_Op(TOp.Inside(layout, info, s), args@[img])), stms)
483                            end
484                      (* end case *))
485                | Op.ImageDim(info, d) => bindTREE(TOp.ImageDim(info, d))                | Op.ImageDim(info, d) => bindTREE(TOp.ImageDim(info, d))
486                | Op.MathFn f => bindTREE (TOp.MathFn f)                | Op.MathFn f => bindTREE (TOp.MathFn f)
487                | rator => raise Fail("bogus operator " ^ Op.toString srcRator)                | rator => raise Fail("bogus operator " ^ Op.toString srcRator)
# Line 435  Line 506
506          rhs is simple          rhs is simple
507          rhs is vector          rhs is vector
508  *)  *)
509      fun trAssign (env, lhs, rhs) = let      fun trAssign (env, lhs, rhs : IR.rhs) = let
510            fun getLHS () = (case UnifyVars.eqClassOf lhs of SOME x => x | _ => lhs)            fun getLHS () = (case UnifyVars.eqClassOf lhs of SOME x => x | _ => lhs)
fun bindRHS rhs = Env.bindVar (env, getLHS(), Env.RHS(U.trTempType(V.ty lhs), rhs))
511          (* binding for the lhs variable, where the rhs is a simple expression.  We check to          (* binding for the lhs variable, where the rhs is a simple expression.  We check to
512           * see if it is part of an merged equivalence class, in which case we need to generate           * see if it is part of an merged equivalence class, in which case we need to generate
513           * assigment(s)           * assigment(s)
# Line 462  Line 532
532                  in                  in
533                    case eqClassRepOf(env, lhs)                    case eqClassRepOf(env, lhs)
534                     of NOEQ => (                     of NOEQ => (
535                          bindRHS (T.E_Op(rator, args));                          Env.bindVar (env, lhs, Env.RHS(U.trTempType(V.ty lhs), T.E_Op(rator, args)));
536                          stms)                          stms)
537                      | VAR x' => stms @ [mkAssign' (x', T.E_Op(rator, args))]                      | VAR x' => stms @ [mkAssign' (x', T.E_Op(rator, args))]
538                      | VEC _ => raise Fail ("unexpected VEC for lhs " ^ V.toString lhs)                      | VEC _ => raise Fail ("unexpected VEC for lhs " ^ V.toString lhs)
539                    (* end case *)                    (* end case *)
540                  end                  end
541            (* bind the lhs to a tensor cons expression (including Op.Zero) *)
542              fun bindCons (args, Ty.TensorTy[d], stms) = let
543                    val layout = Env.layoutVec env d
544                    fun mkVecs (args, w::ws) = let
545                        (* take arguments from args to build a vector value of width w; pad as
546                         * necessary.
547                         *)
548                          fun take (0, args, es) = T.E_Vec(w, w, List.rev es) :: mkVecs (args, ws)
549                            | take (i, [], es) = if #padded layout andalso null ws
550                                then [T.E_Vec(w-i, w, List.rev es)]
551                                else raise Fail "too few arguments for CONS"
552                            | take (i, arg::args, es) = take (i-1, args, arg :: es)
553                          in
554                            take (w, args, [])
555                          end
556                      | mkVecs ([], []) = []
557                      | mkVecs (_, []) = raise Fail "too many arguments for CONS"
558                    val es = mkVecs (args, #pieces layout)
559                    in
560                      case eqClassRepOf(env, lhs)
561                       of NOEQ => if (V.useCount lhs > 1)
562                            then (Env.bindVar(env, lhs, Env.VEC(layout, es)); stms)
563                            else let
564                              val vs = U.newVectorVars layout
565                              in
566                                Env.bindVar (env, lhs, Env.VEC(layout, List.map T.E_Var vs));
567                                ListPair.foldl (fn (v, e, stms) => mkDefn(v, e)::stms) stms (vs, es)
568                              end
569                        | VEC xs =>
570                            ListPair.foldl (fn (x, e, stms) => mkAssign(x, e)::stms) stms (xs, es)
571                        | _ => raise Fail "inconsistent"
572                      (* end case *)
573                    end
574                | bindCons (args, ty as Ty.TensorTy _, stms) = let
575                    val ty = U.trType ty
576                    val cons = T.E_Cons(args, ty)
577                    in
578                      case eqClassRepOf(env, lhs)
579                       of NOEQ => if (V.useCount lhs > 1)
580                            then mkDefn (newMemLocal (env, lhs), cons) :: stms
581                            else (
582                              Env.bindVar (env, lhs, Env.RHS(ty, cons));
583                              stms)
584                        | VAR x => mkAssign (x, cons) :: stms
585                        | VEC xs => raise Fail "inconsistent"
586                      (* end case *)
587                    end
588            in            in
589              case rhs              case rhs
590               of IR.GLOBAL x => bindSimple (T.E_Global(mkGlobalVar x))               of IR.GLOBAL x => bindSimple (T.E_Global(mkGlobalVar x))
# Line 492  Line 609
609                    val z = T.E_Lit(Literal.Real(RealLit.zero false))                    val z = T.E_Lit(Literal.Real(RealLit.zero false))
610                    val sz = List.foldl Int.* 1 dd                    val sz = List.foldl Int.* 1 dd
611                    in                    in
612                      bindRHS (T.E_Cons(List.tabulate(sz, fn _ => z), U.trType ty));                      bindCons (List.tabulate(sz, fn _ => z), ty, [])
[]
613                    end                    end
614                | IR.OP(Op.LoadSeq(ty, file), []) => let                | IR.OP(Op.LoadSeq(ty, file), []) => let
615                    val lhs = newLocal (env, getLHS ())                    val lhs = newLocal (env, getLHS ())
# Line 525  Line 641
641                        | _ => raise Fail "inconsistent"                        | _ => raise Fail "inconsistent"
642                      (* end case *)                      (* end case *)
643                    end                    end
644                | IR.CONS(args, Ty.TensorTy[d]) => let                | IR.CONS(args, ty) => let
645                    val layout = Env.layoutVec env d                    val (es, stms) = simpleArgs (env, args)
val (args, stms) = simpleArgs (env, args)
fun mkVecs (args, w::ws) = let
(* take arguments from args to build a vector value of width w; pad as
* necessary.
*)
fun take (0, args, es) = T.E_Vec(w, w, List.rev es) :: mkVecs (args, ws)
| take (i, [], es) = if #padded layout andalso null ws
then [T.E_Vec(w-i, w, List.rev es)]
else raise Fail "too few arguments for CONS"
| take (i, arg::args, es) = take (i-1, args, arg :: es)
in
take (w, args, [])
end
| mkVecs ([], []) = []
| mkVecs (_, []) = raise Fail "too many arguments for CONS"
val es = mkVecs (args, #pieces layout)
in
case (eqClassRepOf(env, lhs), V.useCount lhs > 1)
of (NOEQ, false) => (Env.bindVar(env, lhs, Env.VEC(layout, es)); stms)
| (NOEQ, true) => let
val vs = U.newVectorVars layout
646                            in                            in
647                              Env.bindVar (env, lhs, Env.VEC(layout, List.map T.E_Var vs));                      bindCons (es, ty, stms)
ListPair.foldl (fn (v, e, stms) => mkDefn(v, e)::stms) stms (vs, es)
end
| (VEC xs, _) =>
ListPair.foldl (fn (x, e, stms) => mkAssign(x, e)::stms) stms (xs, es)
| _ => raise Fail "inconsistent"
(* end case *)
648                    end                    end
649                | IR.CONS(args, ty) => let                | IR.SEQ(args, ty) => let
650                    val (es, stms) = singleArgs (env, args)                    val (es, stms) = singleArgs (env, args)
651                    val ty = U.trType ty                    val ty = U.trType ty
652                    (* if we are dealing with a sequence of tensors, then we need to copy references *)
653                      val es = (case ty
654                             of TTy.SeqTy(TTy.TensorTy _, _) => List.map mkDeref es
655                              | _ => es
656                            (* end case *))
657                      val seq = T.E_Seq(es, ty)
658                    in                    in
659                      Env.bindVar (env, getLHS (), Env.RHS(ty, T.E_Cons(es, ty)));                      case eqClassRepOf(env, lhs)
660                      stms                       of NOEQ => if (V.useCount lhs > 1)
661                              then mkDefn (newMemLocal (env, lhs), seq) :: stms
662                              else (
663                                Env.bindVar (env, lhs, Env.RHS(ty, seq));
664                                stms)
665                          | VAR x => mkAssign (x, seq) :: stms
666                          | VEC xs => raise Fail "inconsistent"
667                        (* end case *)
668                    end                    end
669                | IR.SEQ(args, ty) => let                | IR.APPLY(f, args) => let
670                    val (es, stms) = singleArgs (env, args)                    val (es, stms) = singleArgs (env, args)
val ty = U.trType ty
671                    in                    in
672                      Env.bindVar (env, getLHS (), Env.RHS(ty, T.E_Seq(es, ty)));                      Env.bindVar (env, lhs, Env.TREE(T.E_Apply(getFuncVar f, es)));
673                      stms                      stms
674                    end                    end
675                | rhs => raise Fail(concat["unexpected ", IR.RHS.toString rhs, " in LowIR code"])                | rhs => raise Fail(concat["unexpected ", IR.RHS.toString rhs, " in LowIR code"])
# Line 599  Line 700
700                    | Env.TREE e => e                    | Env.TREE e => e
701                    | _ => raise Fail("expected scalar binding for " ^ V.toString x)                    | _ => raise Fail("expected scalar binding for " ^ V.toString x)
702                  (* end case *))                  (* end case *))
703            (* analyze the CFG *)
704            val _ = UnifyVars.analyze cfg            val _ = UnifyVars.analyze cfg
705              val _ = AssignTypes.analyze cfg
706          (* join (stk, stms, k): handle a control-flow join, where env is the          (* join (stk, stms, k): handle a control-flow join, where env is the
707           * current environment, stk is the stack of open ifs (the top of stk specifies           * current environment, stk is the stack of open ifs (the top of stk specifies
708           * which branch we are in), stms are the TreeIL statements preceding the join           * which branch we are in), stms are the TreeIL statements preceding the join
# Line 749  Line 852
852
853      fun trCFG info cfg = ScopeVars.assignScopes ([], trCFGWithEnv (Env.new info, cfg))      fun trCFG info cfg = ScopeVars.assignScopes ([], trCFGWithEnv (Env.new info, cfg))
854
855        fun trFunc info (IR.Func{name, params, body}) = let
856              val name' = getFuncVar name
857              val params' = List.map U.newParamVar params
858              val body' = trCFG info body
859              in
860                T.Func{name = name', params = params', body = body'}
861              end
862
863    (* Build a strand method from a TreeIR block.  We need to check for language features    (* Build a strand method from a TreeIR block.  We need to check for language features
864     * that require the world pointer (e.g., printing) and for references to global variables.     * that require the world pointer (e.g., printing) and for references to global variables.
865     *)     *)
# Line 847  Line 958
958            val prog = Flatten.transform prog            val prog = Flatten.transform prog
959            val LowIR.Program{            val LowIR.Program{
960                    props, consts, inputs, constInit, globals,                    props, consts, inputs, constInit, globals,
961                    globInit, strand, create, init, update                    funcs, globInit, strand, create, init, update
962                  } = prog                  } = prog
963            val trCFG = trCFG info            val trCFG = trCFG info
964            in            in
# Line 858  Line 969
969                  inputs = List.map (Inputs.map mkGlobalVar) inputs,                  inputs = List.map (Inputs.map mkGlobalVar) inputs,
970                  constInit = trCFG constInit,                  constInit = trCFG constInit,
971                  globals = List.map mkGlobalVar globals,                  globals = List.map mkGlobalVar globals,
972                    funcs = List.map (trFunc info) funcs,
973                  globInit = trCFG globInit,                  globInit = trCFG globInit,
974                  strand = trStrand info strand,                  strand = trStrand info strand,
975                  create = let                  create = Create.map trCFG create,
val IR.Create{dim, code} = create
in
T.Create{dim = dim, code = trCFG code}
end,
976                  init = Option.map trCFG init,                  init = Option.map trCFG init,
977                  update = Option.map trCFG update                  update = Option.map trCFG update
978                }                }

Legend:
 Removed from v.4039 changed lines Added in v.4163