Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/simplify/simple-contract.sml
ViewVC logotype

View of /branches/vis15/src/compiler/simplify/simple-contract.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3502 - (download) (annotate)
Thu Dec 17 23:13:35 2015 UTC (3 years, 6 months ago) by jhr
File size: 8897 byte(s)
working on merge
(* simple-contract.sml
 *
 * This is a limited contraction phase for the SimpleAST representation.  The purpose is
 * to eliminate unused variables and dead code.  Specifically, the following transformations
 * are performed:
 *
 *   -- unused constant and global variables are elminated
 *
 *   -- unused strand state variables are eliminated (but not outputs)
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure SimpleContract : sig

    val transform : Simple.program -> Simple.program

  end = struct

    structure S = Simple
    structure SV = SimpleVar
    structure ST = Stats

  (********** Counters for statistics **********)
    val cntUnusedConst		= ST.newCounter "simple-contract:unused-constant"
    val cntUnusedGlobal		= ST.newCounter "simple-contract:unused-global-var"
    val cntUnusedState		= ST.newCounter "simple-contract:unused-state-var"
    val cntUnusedLocal		= ST.newCounter "simple-contract:unused-local-var"
    val cntDeadAssign		= ST.newCounter "simple-contract:dead-assign"
    val cntDeadIf		= ST.newCounter "simple-contract:dead-if"
    val cntDeadForeach		= ST.newCounter "simple-contract:dead-foreach"
    val firstCounter            = cntUnusedConst
    val lastCounter             = cntDeadForeach

    fun sumChanges () = ST.sum {from=firstCounter, to=lastCounter}

  (* for constant, global, and strand-state variables, we count uses *)
    local
      val {clrFn, getFn, peekFn, ...} = SV.newProp (fn _ => ref 0)
    in
      fun use x = let val r = getFn x in r := !r + 1 end
      fun unuse x = let val r = getFn x in r := !r - 1 end
      fun markUsed x = (case SV.kindOf x
	     of SV.ConstVar => use x
	      | SV.GlobalVar => use x
	      | SV.StrandStateVar => use x
	      | SV.StrandOutputVar => use x
	      | SV.LocalVar => use x
	      | _ => ()
	    (* end case *))
      fun isUsed x = (case peekFn x of SOME(ref n) => (n > 0) | _ => false)
      fun clrUsedMark x = clrFn x
    end (* local *)

  (* analyze a block for unused variables *)
    fun analyzeBlock blk = let
	  fun analyzeBlk (S.Block{code, ...}) = List.app analyzeStm code
	  and analyzeStm stm = (case stm
                 of S.S_Var(x, NONE) => ()
		  | S.S_Var(x, SOME e) => analyzeExp e
                  | S.S_Assign(x, e) => analyzeExp e
                  | S.S_IfThenElse(x, b1, b2) => (markUsed x; analyzeBlk b1; analyzeBlk b2)
		  | S.S_Foreach(x, xs, blk) => analyzeBlk blk
                  | S.S_New(strnd, xs) => List.app markUsed xs
                  | S.S_Continue => ()
                  | S.S_Die => ()
                  | S.S_Stabilize => ()
                  | S.S_Return x => markUsed x
                  | S.S_Print xs => List.app markUsed xs
		  | S.S_MapReduce{args, ...} => List.app markUsed args
		(* end case *))
	  and analyzeExp exp = (case exp
		 of S.E_Var x => markUsed x
		  | S.E_Lit _ => ()
		  | S.E_Select(x, fld) => (markUsed x; markUsed fld)
		  | S.E_Apply(f, xs, _) => List.app markUsed xs
		  | S.E_Prim(_, _, xs, _) => List.app markUsed xs
		  | S.E_Tensor(xs, _) => List.app markUsed xs
		  | S.E_Seq(xs, _) => List.app markUsed xs
		  | S.E_Slice(x, indices, _) => (markUsed x; List.app (Option.app markUsed) indices)
		  | S.E_Coerce{x, ...} => markUsed x
		  | S.E_LoadSeq _ => ()
		  | S.E_LoadImage _ => ()
		(* end case *))
	  in
	    analyzeBlk blk
	  end

  (* count the variable uses in a strand *)
    fun analyzeStrand (S.Strand{state, stateInit, initM, updateM, stabilizeM, ...}) = (
	(* mark all outputs as being used *)
	  List.app
	    (fn x => (case SV.kindOf x of SV.StrandOutputVar => use x | _ => ()))
	      state;
	  analyzeBlock stateInit;
	  Option.app analyzeBlock initM;
	  analyzeBlock updateM;
	  Option.app analyzeBlock stabilizeM)

  (* an initial pass to count the variable uses over the entire program *)
    fun analyze (S.Program{constInit, init, strand, create, update, ...}) = (
	  analyzeBlock constInit;
	  analyzeBlock init;
	  analyzeStrand strand;
	  case create of S.Create{code, ...} => analyzeBlock code;
	  Option.app analyzeBlock update)

  (* rewrite a block and remove references to unused variables *)
    fun contractBlock blk = let
	  fun delete exp = (case exp
		 of S.E_Var x => unuse x
		  | S.E_Lit _ => ()
		  | S.E_Select(x, fld) => (unuse x; unuse fld)
		  | S.E_Apply(f, xs, _) => List.app unuse xs
		  | S.E_Prim(_, _, xs, _) => List.app unuse xs
		  | S.E_Tensor(xs, _) => List.app unuse xs
		  | S.E_Seq(xs, _) => List.app unuse xs
		  | S.E_Slice(x, indices, _) => (unuse x; List.app (Option.app unuse) indices)
		  | S.E_Coerce{x, ...} => unuse x
		  | S.E_LoadSeq _ => ()
		  | S.E_LoadImage _ => ()
		(* end case *))
	  fun contractBlk (S.Block{props, code}) = let
		fun contractStms [] = []
		  | contractStms (stm::stms) = (case stm
		       of S.S_Var(x, NONE) => if isUsed x
			    then stm :: contractStms stms
			    else (ST.tick cntUnusedLocal; contractStms stms)
			| S.S_Var(x, SOME e) => if isUsed x
			    then stm :: contractStms stms
			    else (ST.tick cntUnusedLocal; delete e; contractStms stms)
			| S.S_Assign(x, e) => if isUsed x
			    then stm :: contractStms stms
			    else (ST.tick cntDeadAssign; delete e; contractStms stms)
			| S.S_IfThenElse(x, b1, b2) => (
			    case (contractBlk b1, contractBlk b2)
			     of (S.Block{code=[], ...}, S.Block{code=[], ...}) => (
				  ST.tick cntDeadIf; unuse x; contractStms stms)
			      | (b1, b2) => S.S_IfThenElse(x, b1, b2) :: contractStms stms
			    (* end case *))
			| S.S_Foreach(x, xs, blk) => (
			    case contractBlk blk
			     of S.Block{code=[], ...} => (
				  ST.tick cntDeadForeach; unuse xs; contractStms stms)
			      | blk => S.S_Foreach(x, xs, blk) :: contractStms stms
			    (* end case *))
			| _ => stm :: contractStms stms
		      (* end case *))
		in
		  S.Block{props = props, code = contractStms code}
		end
	  fun loop (nChanges, blk) = let
		val blk = contractBlk blk
		val n = sumChanges()
		in
		  if (n <> nChanges) then loop (n, blk) else blk
		end
	  in
	    loop (sumChanges(), blk)
	  end

  (* contract a strand *)
    fun contractStrand strand = let
	  val S.Strand{name, params, state, stateInit, initM, updateM, stabilizeM} = strand
	  in
	    S.Strand{
		name = name, params = params, state = state,
		stateInit = contractBlock stateInit,
		initM = Option.map contractBlock initM,
		updateM = contractBlock updateM,
		stabilizeM = Option.map contractBlock stabilizeM
	      }
	  end

  (* contract a program *)
    fun contractProg (nChanges, prog) = let
	  val S.Program{
		  props, consts, inputs, constInit, globals, funcs, init, strand, create, update
		} = prog
	  val constInit = contractBlock constInit
	  val init = contractBlock init
	  val strand = contractStrand strand
	  val update = Option.map contractBlock update
	  val n = sumChanges()
	  in
	    List.app clrUsedMark consts;
	    List.app clrUsedMark globals;
	    if n = nChanges
	      then (n, prog)
	      else (n, S.Program{
		  props = props, consts = consts, inputs = inputs,
		  constInit = constInit, globals = globals, funcs = funcs,
		  init = init, strand = strand, create = create, update = update
		})
	  end

  (* remove unused state variables from a strand and clear properties *)
    fun finishStrand strand = let
	  val S.Strand{name, params, state, stateInit, initM, updateM, stabilizeM} = strand
	  val (used, unused) = List.partition isUsed state
	  in
	    List.app clrUsedMark used;
	    if List.null unused
	      then strand
	      else S.Strand{
		  name = name, params = params, state = used,
		  stateInit = stateInit,
		  initM = initM, updateM = updateM, stabilizeM = stabilizeM
		}
	  end

  (* remove unused constant, global, and state variables from the program *)
    fun finishProg prog = let
	  val S.Program{
		  props, consts, inputs, constInit, globals, funcs, init, strand, create, update
		} = prog
	  in
	    S.Program{
		props = props,
		consts = List.filter isUsed consts,
		inputs = inputs,
		constInit = constInit,
		globals = List.filter isUsed globals,
		funcs = funcs,
		init = init,
		strand = finishStrand strand,
		create = create,
		update = update
	      }
	  end

    fun transform prog = let
	(* first we count the variable uses over the entire program *)
	  val () = analyze prog
	(* then contract until there are no more changes *)
	  val n = sumChanges()
	  val (nChanges, prog) = let
		fun loop (nChanges, prog) = let
		      val (n, prog) = contractProg (nChanges, prog)
		      in
			if (n <> nChanges) then loop (n, prog) else (n, prog)
		      end
		in
		  loop (n, prog)
		end
	  in
	  (* finally we finish the program by removing unused constant, global, and state variables *)
	    if (n <> nChanges)
	      then finishProg prog
	      else prog (* no contraction occurred *)
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0