Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/simplify/util.sml
ViewVC logotype

View of /branches/vis15/src/compiler/simplify/util.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3482 - (download) (annotate)
Sat Dec 5 14:43:53 2015 UTC (4 years, 2 months ago) by jhr
File size: 6213 byte(s)
  working on merge
(* util.sml
 *
 * Utility code for Simplification.
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure Util : sig

  (* return information about a reduction operator *)
    val reductionInfo : Var.t -> {
	    rator : Var.t,			(* primitive operator *)
	    init : Literal.t,			(* identity element to use for initialization *)
	    mvs : SimpleTypes.meta_arg list	(* meta-variable arguments for primitive application *)
	  }

  (* convert a block into a function by closing over its free variables *)
    val makeFunction : string * Simple.block * SimpleTypes.ty -> Simple.func * Simple.var list

  end = struct

    structure S = Simple
    structure BV = BasisVars
    structure L = Literal
    structure R = RealLit
    structure VMap = SimpleVar.Map

    fun reductionInfo rator =
	  if Var.same(BV.red_all, rator)
	    then {rator = BV.op_and, init = L.Bool true, mvs = []}
	  else if Var.same(BV.red_exists, rator)
	    then {rator = BV.op_or, init = L.Bool false, mvs = []}
	  else if Var.same(BV.red_max, rator)
	    then {rator = BV.fn_max_r, init = L.Real R.negInf, mvs = []}
	  else if Var.same(BV.red_mean, rator)
	    then raise Fail "FIXME: 'mean' reduction not yet supported"
	  else if Var.same(BV.red_min, rator)
	    then {rator = BV.fn_min_r, init = L.Real R.posInf, mvs = []}
	  else if Var.same(BV.red_product, rator)
	    then {rator = BV.mul_rr, init = L.Real R.one, mvs = []}
	  else if Var.same(BV.red_sum, rator)
	    then {rator = BV.add_tt, init = L.Real R.one, mvs = [SimpleTypes.SHAPE[]]}
	  else if Var.same(BV.red_variance, rator)
	    then raise Fail "FIXME: 'variance' reduction not yet supported"
	    else raise Fail(Var.uniqueNameOf rator ^ " is not a reduction operator")

    local
      val n = ref 0
      fun mkFuncId (name, ty) = let val id = !n
	    in
	      n := id + 1;
	      SimpleVar.new(name ^ Int.toString id, SimpleVar.FunVar, ty)
	    end
    in
    fun makeFunction (name, blk, resTy) = let
	  val freeVars = ref []
	  fun cvtVar (env, x) = (case VMap.find(env, x)
		 of SOME x' => (env, x')
		  | NONE => let
		      val x' = SimpleVar.copy(x, SimpleVar.FunParam)
		      in
			freeVars := (x, x') :: !freeVars;
			(VMap.insert(env, x, x'), x')
		      end
		(* end case *))
	  fun cvtVars (env, xs) = let
		fun cvt (x, (env, xs')) = let
		      val (env, x') = cvtVar (env, x)
		      in
			(env, x'::xs')
		      end
		in
		  List.foldr cvt (env, []) xs
		end
	  fun newVar (env, x) = let
		val x' = SimpleVar.copy(x, SimpleVar.LocalVar)
		in
		  (VMap.insert(env, x, x'), x')
		end
	  fun cvtBlock (env, S.Block stms) = let
		fun cvtStms (env, [], stms') = (env, S.Block(List.rev stms'))
		  | cvtStms (env, stm::stms, stms') = let
		      val (env, stm') = cvtStm (env, stm)
		      in
			cvtStms (env, stms, stm'::stms')
		      end
		in
		  cvtStms (env, stms, [])
		end
	  and cvtStm (env, stm) = (case stm
		 of S.S_Var(x, NONE) => let
		      val (env, x') = newVar (env, x)
		      in
			(env, S.S_Var(x', NONE))
		      end
		  | S.S_Var(x, SOME e) => let
		      val (env, e') = cvtExp (env, e)
		      val (env, x') = newVar (env, x)
		      in
			(env, S.S_Var(x', SOME e'))
		      end
		  | S.S_Assign(x, e) => let
		      val (env, e') = cvtExp (env, e)
		      val (env, x') = cvtVar (env, x)
		      in
			(env, S.S_Assign(x', e'))
		      end
		  | S.S_IfThenElse(x, b1, b2) => let
		      val (env, x') = cvtVar (env, x)
		      val (env, b1') = cvtBlock (env, b1)
		      val (env, b2') = cvtBlock (env, b2)
		      in
			(env, S.S_IfThenElse(x', b1', b2'))
		      end
		  | S.S_Foreach(x, xs, b) => let
		      val (env, x') = cvtVar (env, x)
		      val (env, xs') = cvtVar (env, xs)
		      val (env, b') = cvtBlock (env, b)
		      in
			(env, S.S_Foreach(x', xs', b'))
		      end
		  | S.S_New(name, args) => let
		      val (env, args') = cvtVars (env, args)
		      in
			(env, S.S_New(name, args'))
		      end
		  | S.S_Continue => (env, stm)
		  | S.S_Die => (env, stm)
		  | S.S_Stabilize => (env, stm)
		  | S.S_Return x => let
		      val (env, x') = cvtVar (env, x)
		      in
			(env, S.S_Return x')
		      end
		  | S.S_Print xs => let
		      val (env, xs') = cvtVars (env, xs)
		      in
			(env, S.S_Print xs')
		      end
		  | S.S_MapReduce _ => raise Fail "unexpected nested MapReduce"
		(* end case *))
	  and cvtExp (env, exp) = (case exp
		 of S.E_Var x => let
		      val (env, x') = cvtVar (env, x)
		      in
			(env, S.E_Var x')
		       end
		  | S.E_Lit _ => (env, exp)
		  | S.E_Select(x, fld) => let
		      val (env, x') = cvtVar (env, x)
		      in
			(env, S.E_Select(x', fld))
		       end
		  | S.E_Apply(f, args, ty) => let
		      val (env, args') = cvtVars (env, args)
		      in
			(env, S.E_Apply(f, args', ty))
		      end
		  | S.E_Prim(f, mvs, args, ty) => let
		      val (env, args') = cvtVars (env, args)
		      in
			(env, S.E_Prim(f, mvs, args', ty))
		      end
		  | S.E_Tensor(args, ty) => let
		      val (env, args') = cvtVars (env, args)
		      in
			(env, S.E_Tensor(args', ty))
		      end
		  | S.E_Seq(args, ty) => let
		      val (env, args') = cvtVars (env, args)
		      in
			(env, S.E_Seq(args', ty))
		      end
		  | S.E_Slice(x, indices, ty) => let
		      fun cvt (NONE, (env, idxs)) = (env, NONE::idxs)
			| cvt (SOME x, (env, idxs)) = let
			    val (env, x') = cvtVar (env, x)
			    in
			      (env, SOME x' :: idxs)
			    end
		      val (env, x') = cvtVar (env, x)
		      val (env, indices') = List.foldr cvt (env, []) indices
		      in
			(env, S.E_Slice(x', indices', ty))
		      end
		  | S.E_Coerce{srcTy, dstTy, x} => let
		      val (env, x') = cvtVar (env, x)
		      in
			(env, S.E_Coerce{srcTy=srcTy, dstTy=dstTy, x=x'})
		       end
		  | S.E_LoadSeq _ => (env, exp)
		  | S.E_LoadImage _ => (env, exp)
		(* end case *))
	  val (env, blk) = cvtBlock (VMap.empty, blk)
	  val (args, params) = ListPair.unzip (List.rev (! freeVars))
	  val fnTy = SimpleTypes.T_Fun(List.map SimpleVar.typeOf params, resTy)
	  in
	    (S.Func{f=mkFuncId(name, fnTy), params=params, body=blk}, args)
	  end
    end (* local *)

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0