Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/low-to-tree/low-to-tree.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/low-to-tree/low-to-tree.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 4039, Fri Jun 24 15:35:09 2016 UTC revision 4128, Thu Jun 30 23:19:13 2016 UTC
# Line 78  Line 78 
78                | NONE => let                | NONE => let
79                    val rep = (case V.ty x                    val rep = (case V.ty x
80                           of Ty.TensorTy[d] => VEC(U.newVectorVars(Env.layoutVec env d))                           of Ty.TensorTy[d] => VEC(U.newVectorVars(Env.layoutVec env d))
81                            | ty => VAR(U.newLocalVar x)                            | ty => if AssignTypes.isMemoryVar x
82                                  then VAR(U.newMemVar x)
83                                  else VAR(U.newLocalVar x)
84                          (* end case *))                          (* end case *))
85                    in                    in
86                      setFn (x, rep);                      setFn (x, rep);
# Line 123  Line 125 
125              | _ => e              | _ => e
126           (* end case *))           (* end case *))
127    
128      (* turn an expression of type TensorRefTy to one of TensorRef *)
129        fun mkDeref e = (case TreeTypeOf.exp e
130               of TTy.TensorRefTy(shp as _::_) => T.E_Op(TOp.TensorCopy shp, [e])
131                | _ => e
132             (* end case *))
133    
134      fun cvtScalarTy Ty.BoolTy = TTy.BoolTy      fun cvtScalarTy Ty.BoolTy = TTy.BoolTy
135        | cvtScalarTy Ty.IntTy = TTy.IntTy        | cvtScalarTy Ty.IntTy = TTy.IntTy
136        | cvtScalarTy (Ty.TensorTy[]) = TTy.realTy        | cvtScalarTy (Ty.TensorTy[]) = TTy.realTy
# Line 130  Line 138 
138    
139    (* define a new local variable and bind x to it in the environment. *)    (* define a new local variable and bind x to it in the environment. *)
140      fun newLocal (env, x) = let      fun newLocal (env, x) = let
141            val x' = U.newLocalVar x            val x' = if AssignTypes.isMemoryVar x then U.newMemVar x else U.newLocalVar x
142              in
143                Env.bindSimple (env, x, T.E_Var x');
144                x'
145              end
146    
147      (* define a new local variable and bind x to it in the environment. *)
148        fun newMemLocal (env, x) = let
149              val x' = U.newMemVar x
150            in            in
151              Env.bindSimple (env, x, T.E_Var x');              Env.bindSimple (env, x, T.E_Var x');
152              x'              x'
# Line 257  Line 273 
273                  in                  in
274                    (Env.TREE(T.E_Op(rator, args)), stms)                    (Env.TREE(T.E_Op(rator, args)), stms)
275                  end                  end
276              fun bindTREE' rator = let
277                    val (args, stms) = singleArgs (env, args)
278                    in
279                      (Env.TREE(T.E_Op(rator, args)), stms)
280                    end
281            fun bindRHS (ty, rator) = let            fun bindRHS (ty, rator) = let
282                  val (args, stms) = simpleArgs (env, args)                  val (args, stms) = simpleArgs (env, args)
283                  in                  in
# Line 264  Line 285 
285                  end                  end
286            fun bindVOp rator = let            fun bindVOp rator = let
287                  val (layout, argss, stms) = vectorArgs (env, args)                  val (layout, argss, stms) = vectorArgs (env, args)
288                  fun mkArgs (_, [], []) = []                  fun mkArgs (w, [p], [args]) = [T.E_Op(rator(w, p), args)]
289                    | mkArgs (w, p::ps, args::argss) =                    | mkArgs (w, p::ps, args::argss) =
290                        T.E_Op(rator(w, p), args) :: mkArgs (w-p, ps, argss)                        T.E_Op(rator(p, p), args) :: mkArgs (w-p, ps, argss)
291                      | mkArgs _ = raise Fail "bindVOp: arity mismatch"
292                  val exps = mkArgs (#wid layout, #pieces layout, argss)                  val exps = mkArgs (#wid layout, #pieces layout, argss)
293                  in                  in
294                    (Env.VEC(layout, exps), stms)                    (Env.VEC(layout, exps), stms)
# Line 323  Line 345 
345                    in                    in
346                      (Env.TREE(List.foldr (fn (e, es) => T.E_Op(TOp.RAdd, [e, es])) e es), stms)                      (Env.TREE(List.foldr (fn (e, es) => T.E_Op(TOp.RAdd, [e, es])) e es), stms)
347                    end                    end
348                  | Op.VDot _ => let
349                      val (layout, argss, stms) = vectorArgs (env, args)
350                      fun mkArgs (w, [p], [args]) = [T.E_Op(TOp.VDot(w, p), args)]
351                        | mkArgs (w, p::ps, args::argss) =
352                            T.E_Op(TOp.VDot(p, p), args) :: mkArgs (w-p, ps, argss)
353                        | mkArgs _ = raise Fail "VDot: arity mismatch"
354                      val e::es = mkArgs (#wid layout, #pieces layout, argss)
355                      in
356                        (Env.TREE(List.foldr (fn (e, es) => T.E_Op(TOp.RAdd, [e, es])) e es), stms)
357                      end
358                | Op.VIndex(_, i) => let                | Op.VIndex(_, i) => let
359                    val [v] = args                    val [v] = args
360                    val ({wid, pieces, ...}, es, stms) = vectorArg (env, v)                    val ({wid, pieces, ...}, es, stms) = vectorArg (env, v)
# Line 336  Line 368 
368                      (select (i, wid, pieces, es), stms)                      (select (i, wid, pieces, es), stms)
369                    end                    end
370                | Op.VClamp n => let                | Op.VClamp n => let
371                    val [v, lo, hi] = args                    val [lo, hi, v] = args
372                    val (layout, vs, stms) = vectorArg (env, v)                    val (lo, stms) = simpleArg env (lo, [])
                   val (lo, stms) = simpleArg env (lo, stms)  
373                    val (hi, stms) = simpleArg env (hi, stms)                    val (hi, stms) = simpleArg env (hi, stms)
374                      val (layout, vs, stms') = vectorArg (env, v)
375                    val exps = mkArgs                    val exps = mkArgs
376                          (fn (w, p, x) => (TOp.VClamp(w, p), [x, lo, hi]))                          (fn (w, p, x) => (TOp.VClamp(w, p), [lo, hi, x]))
377                            (layout, vs)                            (layout, vs)
378                    in                    in
379                      (Env.VEC(layout, exps), List.rev stms)                      (Env.VEC(layout, exps), List.revAppend(stms, List.rev stms))
380                    end                    end
381                | Op.VMapClamp n => bindVOp TOp.VMapClamp                | Op.VMapClamp n => bindVOp TOp.VMapClamp
382                | Op.VLerp n => bindVOp TOp.VLerp                | Op.VLerp n => let
383                      val [u, v, t] = args
384                      val (layout, us, stms1) = vectorArg (env, u)
385                      val (_, vs, stms2) = vectorArg (env, v)
386                      val (t, stms) = simpleArg env (t, stms2 @ stms1)
387                      val exps = let
388                            fun mkArgs (w, [p], [u], [v]) = [T.E_Op(TOp.VLerp(w, p), [u, v, t])]
389                              | mkArgs (w, p::ps, u::ur, v::vr) =
390                                  T.E_Op(TOp.VLerp(p, p), [u, v, t]) :: mkArgs (w-p, ps, ur, vr)
391                              | mkArgs _ = raise Fail "VLerp: arity mismatch"
392                            in
393                              mkArgs (#wid layout, #pieces layout, us, vs)
394                            end
395                      in
396                        (Env.VEC(layout, exps), stms)
397                      end
398                | Op.TensorIndex(ty, idxs) => let                | Op.TensorIndex(ty, idxs) => let
399                    val ([arg], stms) = simpleArgs (env, args)                    val ([arg], stms) = simpleArgs (env, args)
400                    val ty = TreeTypeOf.exp arg                    val ty = TreeTypeOf.exp arg
# Line 363  Line 410 
410                | Op.Select(ty, i) => bindTREE (TOp.Select(U.trType ty, i))                | Op.Select(ty, i) => bindTREE (TOp.Select(U.trType ty, i))
411                | Op.Subscript ty => bindTREE (TOp.Subscript(U.trType ty))                | Op.Subscript ty => bindTREE (TOp.Subscript(U.trType ty))
412                | Op.MkDynamic(ty, n) => bindTREE (TOp.MkDynamic(U.trType ty, n))                | Op.MkDynamic(ty, n) => bindTREE (TOp.MkDynamic(U.trType ty, n))
413                | Op.Append ty => bindTREE (TOp.Append(U.trType ty))                | Op.Append ty => bindTREE' (TOp.Append(U.trType ty))
414                | Op.Prepend ty => bindTREE (TOp.Prepend(U.trType ty))                | Op.Prepend ty => bindTREE' (TOp.Prepend(U.trType ty))
415                | Op.Concat ty => bindTREE (TOp.Concat(U.trType ty))                | Op.Concat ty => bindTREE (TOp.Concat(U.trType ty))
416                | Op.Range => bindTREE TOp.Range                | Op.Range => bindTREE TOp.Range
417                | Op.Length ty => bindTREE (TOp.Length(U.trType ty))                | Op.Length ty => bindTREE (TOp.Length(U.trType ty))
418                | Op.SphereQuery(ty1, ty2) => bindTREE (TOp.SphereQuery(U.trType ty1, U.trType ty2))                | Op.SphereQuery(ty1, ty2) => bindTREE' (TOp.SphereQuery(U.trType ty1, U.trType ty2))
419                | Op.Sqrt => bindTREE TOp.Sqrt                | Op.Sqrt => bindTREE TOp.Sqrt
420                | Op.Cos => bindTREE TOp.Cos                | Op.Cos => bindTREE TOp.Cos
421                | Op.ArcCos => bindTREE TOp.ArcCos                | Op.ArcCos => bindTREE TOp.ArcCos
# Line 435  Line 482 
482          rhs is simple          rhs is simple
483          rhs is vector          rhs is vector
484  *)  *)
485      fun trAssign (env, lhs, rhs) = let      fun trAssign (env, lhs, rhs : IR.rhs) = let
486            fun getLHS () = (case UnifyVars.eqClassOf lhs of SOME x => x | _ => lhs)            fun getLHS () = (case UnifyVars.eqClassOf lhs of SOME x => x | _ => lhs)
           fun bindRHS rhs = Env.bindVar (env, getLHS(), Env.RHS(U.trTempType(V.ty lhs), rhs))  
487          (* binding for the lhs variable, where the rhs is a simple expression.  We check to          (* binding for the lhs variable, where the rhs is a simple expression.  We check to
488           * see if it is part of an merged equivalence class, in which case we need to generate           * see if it is part of an merged equivalence class, in which case we need to generate
489           * assigment(s)           * assigment(s)
# Line 462  Line 508 
508                  in                  in
509                    case eqClassRepOf(env, lhs)                    case eqClassRepOf(env, lhs)
510                     of NOEQ => (                     of NOEQ => (
511                          bindRHS (T.E_Op(rator, args));                          Env.bindVar (env, lhs, Env.RHS(U.trTempType(V.ty lhs), T.E_Op(rator, args)));
512                          stms)                          stms)
513                      | VAR x' => stms @ [mkAssign' (x', T.E_Op(rator, args))]                      | VAR x' => stms @ [mkAssign' (x', T.E_Op(rator, args))]
514                      | VEC _ => raise Fail ("unexpected VEC for lhs " ^ V.toString lhs)                      | VEC _ => raise Fail ("unexpected VEC for lhs " ^ V.toString lhs)
515                    (* end case *)                    (* end case *)
516                  end                  end
517            (* bind the lhs to a tensor cons expression (including Op.Zero) *)
518              fun bindCons (args, Ty.TensorTy[d], stms) = let
519                    val layout = Env.layoutVec env d
520                    fun mkVecs (args, w::ws) = let
521                        (* take arguments from args to build a vector value of width w; pad as
522                         * necessary.
523                         *)
524                          fun take (0, args, es) = T.E_Vec(w, w, List.rev es) :: mkVecs (args, ws)
525                            | take (i, [], es) = if #padded layout andalso null ws
526                                then [T.E_Vec(w-i, w, List.rev es)]
527                                else raise Fail "too few arguments for CONS"
528                            | take (i, arg::args, es) = take (i-1, args, arg :: es)
529                          in
530                            take (w, args, [])
531                          end
532                      | mkVecs ([], []) = []
533                      | mkVecs (_, []) = raise Fail "too many arguments for CONS"
534                    val es = mkVecs (args, #pieces layout)
535                    in
536                      case eqClassRepOf(env, lhs)
537                       of NOEQ => if (V.useCount lhs > 1)
538                            then (Env.bindVar(env, lhs, Env.VEC(layout, es)); stms)
539                            else let
540                              val vs = U.newVectorVars layout
541                              in
542                                Env.bindVar (env, lhs, Env.VEC(layout, List.map T.E_Var vs));
543                                ListPair.foldl (fn (v, e, stms) => mkDefn(v, e)::stms) stms (vs, es)
544                              end
545                        | VEC xs =>
546                            ListPair.foldl (fn (x, e, stms) => mkAssign(x, e)::stms) stms (xs, es)
547                        | _ => raise Fail "inconsistent"
548                      (* end case *)
549                    end
550                | bindCons (args, ty as Ty.TensorTy _, stms) = let
551                    val ty = U.trType ty
552                    val cons = T.E_Cons(args, ty)
553                    in
554                      case eqClassRepOf(env, lhs)
555                       of NOEQ => if (V.useCount lhs > 1)
556                            then mkDefn (newMemLocal (env, lhs), cons) :: stms
557                            else (
558                              Env.bindVar (env, lhs, Env.RHS(ty, cons));
559                              stms)
560                        | VAR x => mkAssign (x, cons) :: stms
561                        | VEC xs => raise Fail "inconsistent"
562                      (* end case *)
563                    end
564            in            in
565              case rhs              case rhs
566               of IR.GLOBAL x => bindSimple (T.E_Global(mkGlobalVar x))               of IR.GLOBAL x => bindSimple (T.E_Global(mkGlobalVar x))
# Line 492  Line 585 
585                    val z = T.E_Lit(Literal.Real(RealLit.zero false))                    val z = T.E_Lit(Literal.Real(RealLit.zero false))
586                    val sz = List.foldl Int.* 1 dd                    val sz = List.foldl Int.* 1 dd
587                    in                    in
588                      bindRHS (T.E_Cons(List.tabulate(sz, fn _ => z), U.trType ty));                      bindCons (List.tabulate(sz, fn _ => z), ty, [])
                     []  
589                    end                    end
590                | IR.OP(Op.LoadSeq(ty, file), []) => let                | IR.OP(Op.LoadSeq(ty, file), []) => let
591                    val lhs = newLocal (env, getLHS ())                    val lhs = newLocal (env, getLHS ())
# Line 525  Line 617 
617                        | _ => raise Fail "inconsistent"                        | _ => raise Fail "inconsistent"
618                      (* end case *)                      (* end case *)
619                    end                    end
               | IR.CONS(args, Ty.TensorTy[d]) => let  
                   val layout = Env.layoutVec env d  
                   val (args, stms) = simpleArgs (env, args)  
                   fun mkVecs (args, w::ws) = let  
                       (* take arguments from args to build a vector value of width w; pad as  
                        * necessary.  
                        *)  
                         fun take (0, args, es) = T.E_Vec(w, w, List.rev es) :: mkVecs (args, ws)  
                           | take (i, [], es) = if #padded layout andalso null ws  
                               then [T.E_Vec(w-i, w, List.rev es)]  
                               else raise Fail "too few arguments for CONS"  
                           | take (i, arg::args, es) = take (i-1, args, arg :: es)  
                         in  
                           take (w, args, [])  
                         end  
                     | mkVecs ([], []) = []  
                     | mkVecs (_, []) = raise Fail "too many arguments for CONS"  
                   val es = mkVecs (args, #pieces layout)  
                   in  
                     case (eqClassRepOf(env, lhs), V.useCount lhs > 1)  
                      of (NOEQ, false) => (Env.bindVar(env, lhs, Env.VEC(layout, es)); stms)  
                       | (NOEQ, true) => let  
                           val vs = U.newVectorVars layout  
                           in  
                             Env.bindVar (env, lhs, Env.VEC(layout, List.map T.E_Var vs));  
                             ListPair.foldl (fn (v, e, stms) => mkDefn(v, e)::stms) stms (vs, es)  
                           end  
                       | (VEC xs, _) =>  
                           ListPair.foldl (fn (x, e, stms) => mkAssign(x, e)::stms) stms (xs, es)  
                       | _ => raise Fail "inconsistent"  
                     (* end case *)  
                   end  
620                | IR.CONS(args, ty) => let                | IR.CONS(args, ty) => let
621                    val (es, stms) = singleArgs (env, args)                    val (es, stms) = simpleArgs (env, args)
                   val ty = U.trType ty  
622                    in                    in
623                      Env.bindVar (env, getLHS (), Env.RHS(ty, T.E_Cons(es, ty)));                      bindCons (es, ty, stms)
                     stms  
624                    end                    end
625                | IR.SEQ(args, ty) => let                | IR.SEQ(args, ty) => let
626                    val (es, stms) = singleArgs (env, args)                    val (es, stms) = singleArgs (env, args)
627                    val ty = U.trType ty                    val ty = U.trType ty
628                    (* if we are dealing with a sequence of tensors, then we need to copy references *)
629                      val es = (case ty
630                             of TTy.SeqTy(TTy.TensorTy _, _) => List.map mkDeref es
631                              | _ => es
632                            (* end case *))
633                      val seq = T.E_Seq(es, ty)
634                    in                    in
635                      Env.bindVar (env, getLHS (), Env.RHS(ty, T.E_Seq(es, ty)));                      case eqClassRepOf(env, lhs)
636                      stms                       of NOEQ => if (V.useCount lhs > 1)
637                              then mkDefn (newMemLocal (env, lhs), seq) :: stms
638                              else (
639                                Env.bindVar (env, lhs, Env.RHS(ty, seq));
640                                stms)
641                          | VAR x => mkAssign (x, seq) :: stms
642                          | VEC xs => raise Fail "inconsistent"
643                        (* end case *)
644                    end                    end
645                | rhs => raise Fail(concat["unexpected ", IR.RHS.toString rhs, " in LowIR code"])                | rhs => raise Fail(concat["unexpected ", IR.RHS.toString rhs, " in LowIR code"])
646              (* end case *)              (* end case *)
# Line 599  Line 670 
670                    | Env.TREE e => e                    | Env.TREE e => e
671                    | _ => raise Fail("expected scalar binding for " ^ V.toString x)                    | _ => raise Fail("expected scalar binding for " ^ V.toString x)
672                  (* end case *))                  (* end case *))
673            (* analyze the CFG *)
674            val _ = UnifyVars.analyze cfg            val _ = UnifyVars.analyze cfg
675              val _ = AssignTypes.analyze cfg
676          (* join (stk, stms, k): handle a control-flow join, where env is the          (* join (stk, stms, k): handle a control-flow join, where env is the
677           * current environment, stk is the stack of open ifs (the top of stk specifies           * current environment, stk is the stack of open ifs (the top of stk specifies
678           * which branch we are in), stms are the TreeIL statements preceding the join           * which branch we are in), stms are the TreeIL statements preceding the join
# Line 860  Line 933 
933                  globals = List.map mkGlobalVar globals,                  globals = List.map mkGlobalVar globals,
934                  globInit = trCFG globInit,                  globInit = trCFG globInit,
935                  strand = trStrand info strand,                  strand = trStrand info strand,
936                  create = let                  create = Create.map trCFG create,
                   val IR.Create{dim, code} = create  
                   in  
                     T.Create{dim = dim, code = trCFG code}  
                   end,  
937                  init = Option.map trCFG init,                  init = Option.map trCFG init,
938                  update = Option.map trCFG update                  update = Option.map trCFG update
939                }                }

Legend:
Removed from v.4039  
changed lines
  Added in v.4128

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0