Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/cxx-util/gen-tys-and-ops.sml
ViewVC logotype

View of /branches/vis15/src/compiler/cxx-util/gen-tys-and-ops.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3931 - (download) (annotate)
Sun Jun 5 14:13:21 2016 UTC (2 years, 9 months ago) by jhr
File size: 7298 byte(s)
  working on merge: code generation
(* gen-tys-and-ops.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *)

structure GenTysAndOps : sig

    val gen : CodeGenEnv.t * CollectInfo.t -> CLang.decl list

  end = struct

    structure IR = TreeIR
    structure Ty = TreeTypes
    structure CL = CLang
    structure RN = CxxNames
    structure Env = CodeGenEnv

    val zero = RealLit.zero false

    fun mkReturn exp = CL.mkReturn(SOME exp)
    fun mkInt i = CL.mkInt(IntInf.fromInt i)

    fun genWrapperStruct (spec : TargetSpec.t, ty, dcls) = (case ty
	   of Ty.TensorTy shape => let
		val realTy = if #double spec then CL.double else CL.float
		val len = List.foldl Int.* 1 shape
		val name = RN.tensorStruct shape
		val constrDcl = CL.D_Constr(
		      [], NONE, name,
		      [CL.PARAM([], CL.T_Array(realTy, SOME len), "data")],
		      [CL.mkApply("_data", [CL.mkVar "data"])],
		      SOME(CL.mkBlock[]))
		val structDcl = CL.D_ClassDef{
			name = name,
			from = NONE,
			public = [
			    CL.D_Var([], CL.T_Ptr realTy, "_data", NONE),
			    constrDcl
			  ],
			protected = [],
			private = []
		      }
		in
		  structDcl :: dcls
		end
	    | Ty.TupleTy tys => raise Fail "FIXME: TupleTy"
(* TODO
	    | Ty.SeqTy(ty, NONE) =>
	    | Ty.SeqTy(ty, SOME n) =>
*)
	    | ty => dcls
	  (* end case *))

    val ostreamRef = CL.T_Named "std::ostream&"

    fun output (e, e') = CL.mkBinOp(e, CL.#<<, e')

  (* generate code for the expression "e << s", where "s" is string literal *)
    fun outString (CL.E_BinOp(e, CL.#<<, CL.E_Str s1), s2) =
	  output (e, CL.mkStr(s1 ^ String.toCString s2))
      | outString (e, s) = output (e, CL.mkStr(String.toCString s))

  (* generate a printing function for tensors with the given shape *)
    fun genTensorPrinter shape = let
	  fun ten i = CL.mkSubscript(CL.mkSelect(CL.mkVar "ten", "_data"), mkInt i)
	  fun prefix (true, lhs) = lhs
	    | prefix (false, lhs) = outString(lhs, ",")
	  fun lp (isFirst, lhs, i, [d]) = let
		fun lp' (_, lhs, i, 0) = (i, outString(lhs, "]"))
		  | lp' (isFirst, lhs, i, n) =
		      lp' (false, output (prefix (isFirst, lhs), ten i), i+1, n-1)
		in
		  lp' (true, outString(lhs, "["), i, d)
		end
	    | lp (isFirst, lhs, i, d::dd) = let
		fun lp' (_, lhs, i, 0) = (i, outString(lhs, "]"))
		  | lp' (isFirst, lhs, i, n) = let
		      val (i, lhs) = lp (true, prefix (isFirst, lhs), i, dd)
		      in
			lp' (false, lhs, i, n-1)
		      end
		in
		  lp' (true, outString(lhs, "["), i, d)
		end
	  val params = [
		  CL.PARAM([], ostreamRef, "outs"),
		  CL.PARAM([], RN.tensorTy shape, "ten")
		]
	  val (_, exp) = lp (true, CL.mkVar "outs", 0, shape)
	  in
	    CL.D_Func(["static"], ostreamRef, "operator<<", [], mkReturn exp)
	  end

    fun genPrinter (ty, dcls) = (case ty
	   of Ty.TensorTy shape => genTensorPrinter shape :: dcls
	    | Ty.TupleTy tys => raise Fail "FIXME: printer for tuples"
(* the following two types will be handled by template expansion
	    | Ty.SeqTy(ty, NONE) =>
	    | Ty.SeqTy(ty, SOME n) =>
*)
	    | ty => dcls
	  (* end case *))

    fun genVecTyDecl (env, w, pw, dcls) = let
	  val (realTy, realTyName, realTySz) = if #double(Env.target env)
		then (CL.double, "double", 8)
		else (CL.float, "float", 4)
	  val cTyName = RN.vecTyName w
	  val cTy = CL.T_Named cTyName
	  val typedefDcl = CL.D_Verbatim[concat[
		  "typedef ", realTyName, " ", cTyName, " __attribute__ ((vector_size (",
		  Int.toString(realTySz * pw), ")));\n"
		]]
	  in
	    typedefDcl :: dcls
	  end

    datatype operation = datatype CollectInfo.operation

    fun mkLerp (ty, name, realTy, mkT) = CL.D_Func(
	  ["inline"], ty, name,
	  [CL.PARAM([], ty, "a"), CL.PARAM([], ty, "b"), CL.PARAM([], realTy, "t")],
	  mkReturn (
	    CL.mkBinOp(
	      CL.mkVar "a",
	      CL.#+,
	      CL.mkBinOp(
		mkT(CL.mkVar "t"),
		CL.#*,
		CL.mkBinOp(CL.mkVar "a", CL.#-, CL.mkVar "b")))))

    fun doOp env (rator, dcls) = let
	  val realTy = Env.realTy env
	  fun mkVec (w, pw, f) = CL.mkVec(
		RN.vecTy w,
		List.tabulate(pw, fn i => if i < w then f i else CL.mkFlt(zero, realTy)))
	  val dcl = (case rator
		 of RClamp => raise Fail "FIXME: RClamp"
		  | RLerp => mkLerp (realTy, "lerp", realTy, fn x => x)
		  | VSum(w, pw) => let
		      val name = RN.vsum w
		      val params = [CL.PARAM([], RN.vecTy w, "v")]
		      fun mkSum 0 = CL.mkSubscript(CL.mkVar "v", mkInt 0)
		        | mkSum i = CL.mkBinOp(mkSum(i-1), CL.#+, CL.mkSubscript(CL.mkVar "v", mkInt i))
		      in
			CL.D_Func(["inline"], realTy, name, params, mkReturn(mkSum(w-1)))
		      end
		  | VClamp(w, pw) => raise Fail "FIXME: VClamp"
		  | VMapClamp(w, pw) => raise Fail "FIXME: VMapClamp"
		  | VLerp(w, pw) =>
		      mkLerp (RN.vecTy w, "vlerp", realTy, fn x => mkVec(w, pw, fn i => x))
		  | VScale(w, pw) => let
		      val cTy = RN.vecTy w
		      in
			CL.D_Func(["inline"], cTy, "vscale",
			  [CL.PARAM([], realTy, "s"), CL.PARAM([], cTy, "v")],
			  mkReturn(
			    CL.mkBinOp(mkVec(w, pw, fn _ => CL.mkVar "s"), CL.#*, CL.mkVar "v")))
		      end
		  | VLoad(w, pw) => let
		      val name = RN.vload w
		      val cTy = RN.vecTy w
		      fun arg i = CL.mkSubscript(CL.mkVar "vp", mkInt i)
		      in
			CL.D_Func(["inline"], cTy, name,
			  [CL.PARAM([], CL.T_Ptr realTy, "vp")],
			  mkReturn(mkVec (w, pw, arg)))
		      end
		  | VCons(w, pw) => let
		      val name = RN.vcons w
		      val cTy = RN.vecTy w
		      val params = List.tabulate(w, fn i => CL.PARAM([], realTy, "r"^Int.toString i))
		      fun arg i = CL.mkVar("r"^Int.toString i)
		      in
			CL.D_Func(["inline"], cTy, name,
			  params,
			  mkReturn(mkVec (w, pw, arg)))
		      end
		  | VPack layout => let
		      val name = RN.vpack (#wid layout)
		      val vParamTys = Ty.piecesOf layout
		      val vParams = List.mapi
			    (fn (i, Ty.VecTy(w, _)) => CL.PARAM([], RN.vecTy w, "v"^Int.toString i))
			      vParamTys
		      val dstTy = CL.T_Array(realTy, SOME(#wid layout))
		      fun mkAssign (i, v, j) =
			    CL.mkAssign(
			      CL.mkSubscript(CL.mkVar "dst", mkInt i),
			      CL.mkSubscript(v, mkInt j))
		      fun mkAssignsForPiece (dstStart, pieceIdx, wid, stms) = let
			    val piece = CL.mkVar("v"^Int.toString pieceIdx)
			    fun mk (j, stms) = if (j < wid)
				  then mk (j+1, mkAssign (dstStart+j, piece, j) :: stms)
				  else stms
			    in
			      mk (0, stms)
			    end
		      fun mkAssigns (_, [], _, stms) = CL.mkBlock(List.rev stms)
			| mkAssigns (i, Ty.VecTy(w, _)::tys, offset, stms) =
			    mkAssigns (i+1, tys, offset+w, mkAssignsForPiece(offset, i, w, stms))
		      in
			CL.D_Func(["inline"], CL.voidTy, name,
			  CL.PARAM([], dstTy, "dst") :: vParams,
			  mkAssigns (0, vParamTys, 0, []))
		      end
		(* end case *))
	  in
	    dcl :: dcls
	  end

    fun gen (env, info) = let
	  val spec = Env.target env
	  fun doType (ty, isPrint, dcls) = let
		val dcls = if isPrint
		      then genWrapperStruct (spec, ty, genPrinter (ty, dcls))
		      else dcls
		val dcls = (case ty
		       of Ty.VecTy(w, pw) => genVecTyDecl (env, w, pw, dcls)
			| _ => dcls
		      (* end case *))
		in
		  dcls
		end
	  in
	    CollectInfo.foldOverTypes
	      doType
		(CollectInfo.foldOverOps (doOp env) [] info)
		  info
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0