Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/low-to-tree/low-to-tree.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/low-to-tree/low-to-tree.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 4025, Wed Jun 22 15:01:29 2016 UTC revision 4077, Tue Jun 28 14:56:14 2016 UTC
# Line 117  Line 117 
117      fun mkDefn (x, e) = T.S_Assign(true, x, e)      fun mkDefn (x, e) = T.S_Assign(true, x, e)
118      val zero = T.E_Lit(Literal.Real(RealLit.zero false))      val zero = T.E_Lit(Literal.Real(RealLit.zero false))
119    
120      (* turn an expression of type TensorTy to one of TensorTyRef *)
121        fun mkRef e = (case TreeTypeOf.exp e
122               of TTy.TensorTy(shp as _::_) => T.E_Op(TOp.TensorRef shp, [e])
123                | _ => e
124             (* end case *))
125    
126      (* turn an expression of type TensorRefTy to one of TensorRef *)
127        fun mkDeref e = (case TreeTypeOf.exp e
128               of TTy.TensorRefTy(shp as _::_) => T.E_Op(TOp.TensorCopy shp, [e])
129                | _ => e
130             (* end case *))
131    
132      fun cvtScalarTy Ty.BoolTy = TTy.BoolTy      fun cvtScalarTy Ty.BoolTy = TTy.BoolTy
133        | cvtScalarTy Ty.IntTy = TTy.IntTy        | cvtScalarTy Ty.IntTy = TTy.IntTy
134        | cvtScalarTy (Ty.TensorTy[]) = TTy.realTy        | cvtScalarTy (Ty.TensorTy[]) = TTy.realTy
# Line 189  Line 201 
201            fun expToArg (e, stms) = (case V.ty x            fun expToArg (e, stms) = (case V.ty x
202                   of Ty.TensorTy[d] => let                   of Ty.TensorTy[d] => let
203                        val layout = Env.layoutVec env d                        val layout = Env.layoutVec env d
204  (* QUESTION: can "e" be a complicated expression or are we guaranteed that it will just                        val e = mkRef e
  * be a memory reference?  
  *)  
205                        val es = List.tabulate (                        val es = List.tabulate (
206                              List.length(#pieces layout),                              List.length(#pieces layout),
207                              fn i => T.E_VLoad(layout, e, i))                              fn i => T.E_VLoad(layout, e, i))
# Line 248  Line 258 
258            end            end
259    
260      fun trOp (env, srcRator, args) = let      fun trOp (env, srcRator, args) = let
261            fun bindOp rator = let            fun bindTREE rator = let
262                  val (args, stms) = simpleArgs (env, args)                  val (args, stms) = simpleArgs (env, args)
263                  in                  in
264                    (Env.TREE(T.E_Op(rator, args)), stms)                    (Env.TREE(T.E_Op(rator, args)), stms)
265                  end                  end
266              fun bindTREE' rator = let
267                    val (args, stms) = singleArgs (env, args)
268                    in
269                      (Env.TREE(T.E_Op(rator, args)), stms)
270                    end
271            fun bindRHS (ty, rator) = let            fun bindRHS (ty, rator) = let
272                  val (args, stms) = simpleArgs (env, args)                  val (args, stms) = simpleArgs (env, args)
273                  in                  in
# Line 260  Line 275 
275                  end                  end
276            fun bindVOp rator = let            fun bindVOp rator = let
277                  val (layout, argss, stms) = vectorArgs (env, args)                  val (layout, argss, stms) = vectorArgs (env, args)
278                  fun mkArgs (_, [], []) = []                  fun mkArgs (w, [p], [args]) = [T.E_Op(rator(w, p), args)]
279                    | mkArgs (w, p::ps, args::argss) =                    | mkArgs (w, p::ps, args::argss) =
280                        T.E_Op(rator(w, p), args) :: mkArgs (w-p, ps, argss)                        T.E_Op(rator(p, p), args) :: mkArgs (w-p, ps, argss)
281                      | mkArgs _ = raise Fail "bindVOp: arity mismatch"
282                  val exps = mkArgs (#wid layout, #pieces layout, argss)                  val exps = mkArgs (#wid layout, #pieces layout, argss)
283                  in                  in
284                    (Env.VEC(layout, exps), stms)                    (Env.VEC(layout, exps), stms)
# Line 276  Line 292 
292                  end                  end
293            in            in
294              case srcRator              case srcRator
295               of Op.IAdd => bindOp TOp.IAdd               of Op.IAdd => bindTREE TOp.IAdd
296                | Op.ISub => bindOp TOp.ISub                | Op.ISub => bindTREE TOp.ISub
297                | Op.IMul => bindOp TOp.IMul                | Op.IMul => bindTREE TOp.IMul
298                | Op.IDiv => bindOp TOp.IDiv                | Op.IDiv => bindTREE TOp.IDiv
299                | Op.IMod => bindOp TOp.IMod                | Op.IMod => bindTREE TOp.IMod
300                | Op.INeg => bindOp TOp.INeg                | Op.INeg => bindTREE TOp.INeg
301  (* QUESTION: should we just use VAdd 1, etc ?*)  (* QUESTION: should we just use VAdd 1, etc ?*)
302                | Op.RAdd => bindOp TOp.RAdd                | Op.RAdd => bindTREE TOp.RAdd
303                | Op.RSub => bindOp TOp.RSub                | Op.RSub => bindTREE TOp.RSub
304                | Op.RMul => bindOp TOp.RMul                | Op.RMul => bindTREE TOp.RMul
305                | Op.RDiv => bindOp TOp.RDiv                | Op.RDiv => bindTREE TOp.RDiv
306                | Op.RNeg => bindOp TOp.RNeg                | Op.RNeg => bindTREE TOp.RNeg
307                | Op.LT ty => bindOp (TOp.LT (cvtScalarTy ty))                | Op.LT ty => bindTREE (TOp.LT (cvtScalarTy ty))
308                | Op.LTE ty => bindOp (TOp.LTE (cvtScalarTy ty))                | Op.LTE ty => bindTREE (TOp.LTE (cvtScalarTy ty))
309                | Op.EQ ty => bindOp (TOp.EQ (cvtScalarTy ty))                | Op.EQ ty => bindTREE (TOp.EQ (cvtScalarTy ty))
310                | Op.NEQ ty => bindOp (TOp.NEQ (cvtScalarTy ty))                | Op.NEQ ty => bindTREE (TOp.NEQ (cvtScalarTy ty))
311                | Op.GT ty => bindOp (TOp.GT (cvtScalarTy ty))                | Op.GT ty => bindTREE (TOp.GT (cvtScalarTy ty))
312                | Op.GTE ty => bindOp (TOp.GTE (cvtScalarTy ty))                | Op.GTE ty => bindTREE (TOp.GTE (cvtScalarTy ty))
313                | Op.Not => bindOp TOp.Not                | Op.Not => bindTREE TOp.Not
314                | Op.Abs ty => bindOp (TOp.Abs (cvtScalarTy ty))                | Op.Abs ty => bindTREE (TOp.Abs (cvtScalarTy ty))
315                | Op.Max ty => bindOp (TOp.Max (cvtScalarTy ty))                | Op.Max ty => bindTREE (TOp.Max (cvtScalarTy ty))
316                | Op.Min ty => bindOp (TOp.Min (cvtScalarTy ty))                | Op.Min ty => bindTREE (TOp.Min (cvtScalarTy ty))
317                | Op.RClamp => bindOp TOp.RClamp                | Op.RClamp => bindTREE TOp.RClamp
318                | Op.RLerp => bindOp TOp.RLerp                | Op.RLerp => bindTREE TOp.RLerp
319                | Op.VAdd _ => bindVOp TOp.VAdd                | Op.VAdd _ => bindVOp TOp.VAdd
320                | Op.VSub _ => bindVOp TOp.VSub                | Op.VSub _ => bindVOp TOp.VSub
321                | Op.VScale _ => let                | Op.VScale _ => let
# Line 319  Line 335 
335                    in                    in
336                      (Env.TREE(List.foldr (fn (e, es) => T.E_Op(TOp.RAdd, [e, es])) e es), stms)                      (Env.TREE(List.foldr (fn (e, es) => T.E_Op(TOp.RAdd, [e, es])) e es), stms)
337                    end                    end
338                  | Op.VDot _ => let
339                      val (layout, argss, stms) = vectorArgs (env, args)
340                      fun mkArgs (w, [p], [args]) = [T.E_Op(TOp.VDot(w, p), args)]
341                        | mkArgs (w, p::ps, args::argss) =
342                            T.E_Op(TOp.VDot(p, p), args) :: mkArgs (w-p, ps, argss)
343                        | mkArgs _ = raise Fail "VDot: arity mismatch"
344                      val e::es = mkArgs (#wid layout, #pieces layout, argss)
345                      in
346                        (Env.TREE(List.foldr (fn (e, es) => T.E_Op(TOp.RAdd, [e, es])) e es), stms)
347                      end
348                | Op.VIndex(_, i) => let                | Op.VIndex(_, i) => let
349                    val [v] = args                    val [v] = args
350                    val ({wid, pieces, ...}, es, stms) = vectorArg (env, v)                    val ({wid, pieces, ...}, es, stms) = vectorArg (env, v)
# Line 343  Line 369 
369                      (Env.VEC(layout, exps), List.rev stms)                      (Env.VEC(layout, exps), List.rev stms)
370                    end                    end
371                | Op.VMapClamp n => bindVOp TOp.VMapClamp                | Op.VMapClamp n => bindVOp TOp.VMapClamp
372                | Op.VLerp n => bindVOp TOp.VLerp                | Op.VLerp n => let
373                | Op.TensorIndex(Ty.TensorTy dd, idxs) => let                    val [u, v, t] = args
374                    val (args, stms) = simpleArgs (env, args)                    val (layout, us, stms1) = vectorArg (env, u)
375                      val (_, vs, stms2) = vectorArg (env, v)
376                      val (t, stms) = simpleArg env (t, stms2 @ stms1)
377                      val exps = let
378                            fun mkArgs (w, [p], [u], [v]) = [T.E_Op(TOp.VLerp(w, p), [u, v, t])]
379                              | mkArgs (w, p::ps, u::ur, v::vr) =
380                                  T.E_Op(TOp.VLerp(p, p), [u, v, t]) :: mkArgs (w-p, ps, ur, vr)
381                              | mkArgs _ = raise Fail "VLerp: arity mismatch"
382                    in                    in
383                      (Env.TREE(T.E_Op(TOp.TensorIndex(TTy.TensorRefTy dd, idxs), args)), stms)                            mkArgs (#wid layout, #pieces layout, us, vs)
384                    end                    end
               | Op.ProjectLast(Ty.TensorTy dd, idxs) => let  
                   val (args, stms) = simpleArgs (env, args)  
385                    in                    in
386                      (Env.TREE(T.E_Op(TOp.ProjectLast(TTy.TensorRefTy dd, idxs), args)), stms)                      (Env.VEC(layout, exps), stms)
387                    end                    end
388                | Op.Select(ty, i) => bindOp (TOp.Select(U.trType ty, i))                | Op.TensorIndex(ty, idxs) => let
389                | Op.Subscript ty => bindOp (TOp.Subscript(U.trType ty))                    val ([arg], stms) = simpleArgs (env, args)
390                | Op.MkDynamic(ty, n) => bindOp (TOp.MkDynamic(U.trType ty, n))                    val ty = TreeTypeOf.exp arg
391                | Op.Append ty => bindOp (TOp.Append(U.trType ty))                    in
392                | Op.Prepend ty => bindOp (TOp.Prepend(U.trType ty))                      (Env.TREE(T.E_Op(TOp.TensorIndex(ty, idxs), [arg])), stms)
393                | Op.Concat ty => bindOp (TOp.Concat(U.trType ty))                    end
394                | Op.Range => bindOp TOp.Range                | Op.ProjectLast(_, idxs) => let
395                | Op.Length ty => bindOp (TOp.Length(U.trType ty))                    val ([arg], stms) = simpleArgs (env, args)
396                | Op.SphereQuery(ty1, ty2) => bindOp (TOp.SphereQuery(U.trType ty1, U.trType ty2))                    val ty = TreeTypeOf.exp arg
397                | Op.Sqrt => bindOp TOp.Sqrt                    in
398                | Op.Cos => bindOp TOp.Cos                      (Env.TREE(T.E_Op(TOp.ProjectLast(ty, idxs), [arg])), stms)
399                | Op.ArcCos => bindOp TOp.ArcCos                    end
400                | Op.Sin => bindOp TOp.Sin                | Op.Select(ty, i) => bindTREE (TOp.Select(U.trType ty, i))
401                | Op.ArcSin => bindOp TOp.ArcSin                | Op.Subscript ty => bindTREE (TOp.Subscript(U.trType ty))
402                | Op.Tan => bindOp TOp.Tan                | Op.MkDynamic(ty, n) => bindTREE (TOp.MkDynamic(U.trType ty, n))
403                | Op.ArcTan => bindOp TOp.ArcTan                | Op.Append ty => bindTREE' (TOp.Append(U.trType ty))
404                | Op.Exp  => bindOp TOp.Exp                | Op.Prepend ty => bindTREE' (TOp.Prepend(U.trType ty))
405                | Op.Ceiling 1 => bindOp TOp.RCeiling                | Op.Concat ty => bindTREE (TOp.Concat(U.trType ty))
406                  | Op.Range => bindTREE TOp.Range
407                  | Op.Length ty => bindTREE (TOp.Length(U.trType ty))
408                  | Op.SphereQuery(ty1, ty2) => bindTREE' (TOp.SphereQuery(U.trType ty1, U.trType ty2))
409                  | Op.Sqrt => bindTREE TOp.Sqrt
410                  | Op.Cos => bindTREE TOp.Cos
411                  | Op.ArcCos => bindTREE TOp.ArcCos
412                  | Op.Sin => bindTREE TOp.Sin
413                  | Op.ArcSin => bindTREE TOp.ArcSin
414                  | Op.Tan => bindTREE TOp.Tan
415                  | Op.ArcTan => bindTREE TOp.ArcTan
416                  | Op.Exp  => bindTREE TOp.Exp
417                  | Op.Ceiling 1 => bindTREE TOp.RCeiling
418                | Op.Ceiling d => bindVOp TOp.VCeiling                | Op.Ceiling d => bindVOp TOp.VCeiling
419                | Op.Floor 1 => bindOp TOp.RFloor                | Op.Floor 1 => bindTREE TOp.RFloor
420                | Op.Floor d => bindVOp TOp.VFloor                | Op.Floor d => bindVOp TOp.VFloor
421                | Op.Round 1 => bindOp TOp.RRound                | Op.Round 1 => bindTREE TOp.RRound
422                | Op.Round d => bindVOp TOp.VRound                | Op.Round d => bindVOp TOp.VRound
423                | Op.Trunc 1 => bindOp TOp.RTrunc                | Op.Trunc 1 => bindTREE TOp.RTrunc
424                | Op.Trunc d => bindVOp TOp.VTrunc                | Op.Trunc d => bindVOp TOp.VTrunc
425                | Op.IntToReal => bindOp TOp.IntToReal                | Op.IntToReal => bindTREE TOp.IntToReal
426                | Op.RealToInt 1 => bindOp TOp.RealToInt                | Op.RealToInt 1 => bindTREE TOp.RealToInt
427                | Op.RealToInt d => let                | Op.RealToInt d => let
428                    val [v] = args                    val [v] = args
429                    val (layout, args, stms) = vectorArg (env, v)                    val (layout, args, stms) = vectorArg (env, v)
# Line 400  Line 443 
443                | Op.R_Mean ty => ??                | Op.R_Mean ty => ??
444                | Op.R_Variance ty => ??                | Op.R_Variance ty => ??
445  *)  *)
446                | Op.Transform info => bindOp (TOp.Transform info)                | Op.Transform info => bindTREE (TOp.Transform info)
447                | Op.Translate info => bindOp (TOp.Translate info)                | Op.Translate info => bindTREE (TOp.Translate info)
448                | Op.ControlIndex(info, ctl, d) => bindOp (TOp.ControlIndex(info, ctl, d))                | Op.ControlIndex(info, ctl, d) => bindTREE (TOp.ControlIndex(info, ctl, d))
449                | Op.LoadVoxel info => bindOp (TOp.LoadVoxel info)                | Op.LoadVoxel info => bindTREE (TOp.LoadVoxel info)
450                | Op.Inside(info, s) => bindOp (TOp.Inside(info, s))                | Op.Inside(info, s) => bindTREE (TOp.Inside(info, s))
451                | Op.ImageDim(info, d) => bindOp(TOp.ImageDim(info, d))                | Op.ImageDim(info, d) => bindTREE(TOp.ImageDim(info, d))
452                | Op.MathFn f => bindOp (TOp.MathFn f)                | Op.MathFn f => bindTREE (TOp.MathFn f)
453                | rator => raise Fail("bogus operator " ^ Op.toString srcRator)                | rator => raise Fail("bogus operator " ^ Op.toString srcRator)
454              (* end case *)              (* end case *)
455            end            end
# Line 431  Line 474 
474  *)  *)
475      fun trAssign (env, lhs, rhs) = let      fun trAssign (env, lhs, rhs) = let
476            fun getLHS () = (case UnifyVars.eqClassOf lhs of SOME x => x | _ => lhs)            fun getLHS () = (case UnifyVars.eqClassOf lhs of SOME x => x | _ => lhs)
477            fun bindRHS rhs = Env.bindVar (env, getLHS(), Env.RHS(U.trType(V.ty lhs), rhs))            fun bindRHS rhs = Env.bindVar (env, getLHS(), Env.RHS(U.trTempType(V.ty lhs), rhs))
478          (* binding for the lhs variable, where the rhs is a simple expression.  We check to          (* binding for the lhs variable, where the rhs is a simple expression.  We check to
479           * see if it is part of an merged equivalence class, in which case we need to generate           * see if it is part of an merged equivalence class, in which case we need to generate
480           * assigment(s)           * assigment(s)
# Line 442  Line 485 
485                    | VEC xs' => (case V.ty lhs                    | VEC xs' => (case V.ty lhs
486                         of Ty.TensorTy[d] => let                         of Ty.TensorTy[d] => let
487                              val layout = Env.layoutVec env d                              val layout = Env.layoutVec env d
488                                val rhs = mkRef rhs
489                              in                              in
490                                List.mapi                                List.mapi
491                                  (fn (i, x') => mkAssign(x', T.E_VLoad(layout, rhs, i)))                                  (fn (i, x') => mkAssign(x', T.E_VLoad(layout, rhs, i)))
# Line 504  Line 548 
548                    in                    in
549                      case (rhs, eqClassRepOf(env, lhs), emitBind)                      case (rhs, eqClassRepOf(env, lhs), emitBind)
550                       of (_, NOEQ, false) => (Env.bindVar (env, lhs, rhs); stms)                       of (_, NOEQ, false) => (Env.bindVar (env, lhs, rhs); stms)
551    (* FIXME: if the rhs has TensorRef type, then we should make the lhs TensorRef too! *)
552                        | (Env.TREE e, NOEQ, true) => mkDefn'(newLocal(env, lhs), e) :: stms                        | (Env.TREE e, NOEQ, true) => mkDefn'(newLocal(env, lhs), e) :: stms
553                        | (Env.TREE e, VAR x', _) => mkAssign'(x', e) :: stms                        | (Env.TREE e, VAR x', _) => mkAssign'(x', e) :: stms
554                        | (Env.VEC(layout, es), NOEQ, true) => let                        | (Env.VEC(layout, es), NOEQ, true) => let
# Line 559  Line 604 
604                | IR.SEQ(args, ty) => let                | IR.SEQ(args, ty) => let
605                    val (es, stms) = singleArgs (env, args)                    val (es, stms) = singleArgs (env, args)
606                    val ty = U.trType ty                    val ty = U.trType ty
607                    (* if we are dealing with a sequence of tensors, then we need to copy references *)
608                      val es = (case ty
609                             of TTy.SeqTy(TTy.TensorTy _, _) => List.map mkDeref es
610                              | _ => es
611                            (* end case *))
612                    in                    in
613                      Env.bindVar (env, getLHS (), Env.RHS(ty, T.E_Seq(es, ty)));                      Env.bindVar (env, getLHS (), Env.RHS(ty, T.E_Seq(es, ty)));
614                      stms                      stms
# Line 671  Line 721 
721                        end                        end
722                    | IR.MASSIGN{stm=([], Op.Print tys, xs), succ, ...} => let                    | IR.MASSIGN{stm=([], Op.Print tys, xs), succ, ...} => let
723                        val (es, stms') = singleArgs (env, xs)                        val (es, stms') = singleArgs (env, xs)
724                        (* translate TensorTy to TensorRefTy in the type list *)
725                        fun trType (Ty.TensorTy(shp as _::_)) = TTy.TensorRefTy shp                        fun trType (Ty.TensorTy(shp as _::_)) = TTy.TensorRefTy shp
726                          | trType ty = U.trType ty                          | trType ty = U.trType ty
727                        val stm = T.S_Print(List.map trType tys, es)                        val tys = List.map trType tys
728                          val stm = T.S_Print(tys, List.map mkRef es)
729                        in                        in
730                          doNode (!succ, ifStk, stm :: List.revAppend (stms', stms))                          doNode (!succ, ifStk, stm :: List.revAppend (stms', stms))
731                        end                        end
# Line 850  Line 902 
902                  globals = List.map mkGlobalVar globals,                  globals = List.map mkGlobalVar globals,
903                  globInit = trCFG globInit,                  globInit = trCFG globInit,
904                  strand = trStrand info strand,                  strand = trStrand info strand,
905                  create = let                  create = Create.map trCFG create,
                   val IR.Create{dim, code} = create  
                   in  
                     T.Create{dim = dim, code = trCFG code}  
                   end,  
906                  init = Option.map trCFG init,                  init = Option.map trCFG init,
907                  update = Option.map trCFG update                  update = Option.map trCFG update
908                }                }

Legend:
Removed from v.4025  
changed lines
  Added in v.4077

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0