Home My Page Projects Code Snippets Project Openings 3D graphics for Standard ML
Summary Activity SCM

SCM Repository

[sml3d] View of /trunk/sml3d/src/particles/compiler/optimizer.sml
ViewVC logotype

View of /trunk/sml3d/src/particles/compiler/optimizer.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 866 - (download) (annotate)
Thu Apr 29 20:16:27 2010 UTC (8 years, 7 months ago) by pavelk
File size: 19774 byte(s)
IR translation now returns a program datatype, which has the emitter block separated from the physics block, and also keeps track of the rendering operation requested. Also, added property fields for variables and blocks in order to track for UVE.
structure Optimize : sig
  val optimizeIR : PSysIR.block list -> PSysIR.block list
 end = struct

  val gProgramChanged = ref false
 
  fun printErr s = TextIO.output(TextIO.stdErr, s ^ "\n")

  structure IR = PSysIR

  (* Unused variable elimination.
   * 
   * Get rid of all of the variables that aren't being used by any of the
   * succeeding statements. Should be done before constant folding to 
   * remove redundancies.
   *)
  fun resetVarRefCount(blocks) = let
    fun resetVar(IR.V{useCount, ...}) = useCount := 0
    fun resetStmt(IR.PRIM(v, p, vl, s)) = (
         List.app resetVar vl;
         resetStmt(s))
      | resetStmt(IR.IF(v, s1, s2)) = (
         resetVar(v);
         resetStmt(s1);
         resetStmt(s2))
      | resetStmt(IR.GOTO(block, vl)) = List.app resetVar vl
      | resetStmt(_) = ()

    fun resetForBlock(IR.BLK{body, ...}) = resetStmt(body)
   in
    List.app resetForBlock blocks
   end
 
  fun countVarReferences(blocks) = let
    fun updateVar(IR.V{useCount, ...}) = useCount := (!useCount) + 1
    fun updateStmt(IR.PRIM(var, prim, vl, stmt)) = (
        List.app updateVar vl;
        updateStmt(stmt))
      | updateStmt(IR.IF(var, s1, s2)) = (
        updateVar(var);
        updateStmt(s1);
        updateStmt(s2))
      | updateStmt(IR.GOTO(block, vl)) = List.app updateVar vl
      | updateStmt(IR.RETURN(vl)) = List.app updateVar vl
      | updateStmt(_) = ()
    
    fun countForBlock(IR.BLK{body, ...}) = (updateStmt(body); ())
   in
    List.app countForBlock blocks
   end (*countVarReferences*)
  
  fun prv(var as IR.V{name, useCount, ...}) = 
    printErr (String.concat([name, ": ", Int.toString (!useCount)]))

  fun prRefCt(IR.PRIM(v, p, vl, stmt)) = (prv(v); prRefCt(stmt))
    | prRefCt(IR.IF(v, s1, s2)) = (prRefCt(s1); prRefCt(s2))
    | prRefCt(_) = ()

  fun prRefCtForBlk(IR.BLK{body, ...}) = prRefCt(body)
  
  fun removeUnused(blocks) = let
    fun isUnused(v as IR.V{useCount, name, ...}) = 
      if !useCount > 0 then 
        false 
      else 
        true

    fun removeUnusedStmt(IR.PRIM(v, p, vl, s)) =
        if isUnused(v) then
	  (gProgramChanged := true;
          removeUnusedStmt(s))
        else 
          IR.PRIM(v, p, vl, removeUnusedStmt(s))
      | removeUnusedStmt(IR.IF(var, s1, s2)) = 
        IR.IF(var, removeUnusedStmt(s1), removeUnusedStmt(s2))
      | removeUnusedStmt(x) = x
    fun removeUnusedInBlock(IR.BLK{id, params, body, visited}) = 
      IR.BLK{id=id, params=params, body=removeUnusedStmt(body), visited=visited}

    val _ = resetVarRefCount(blocks)
    val _ = countVarReferences(blocks)
    val _ = gProgramChanged := false
    val optimizedBlocks = List.map removeUnusedInBlock blocks
   in
    if (!gProgramChanged) then 
      removeUnused(optimizedBlocks) 
    else 
      optimizedBlocks
   end (* removeUnused *)

  (* Constant folding
   *
   * Condense most of our constants so that we have a smaller code size. We
   * already know what most of these operations do...
   *)
  
  fun foldConstants(blocks) = let
    
    fun allConsts([]) = true
      | allConsts(v as IR.V{scope, ...} :: vl) = 
        case scope 
          of IR.S_CONST(c) => allConsts(vl) 
           | _ => false
    
    fun extractConstVec(IR.V{scope, ...}) = (case scope 
      of IR.S_CONST c => (case c 
        of IR.C_VEC v => v
         | _ => raise Fail ("Expected const vec.")
        (* end case *))
       | _ => raise Fail ("Expected const vec.")
      (* end case *))
      
    fun extractConstInt(IR.V{scope, ...}) = (case scope 
      of IR.S_CONST c => (case c 
        of IR.C_INT i => i
         | _ => raise Fail ("Expected const int.")
        (* end case *))
       | _ => raise Fail ("Expected const int.")
      (* end case *))
       
    fun extractConstFloat(IR.V{scope, ...}) = (case scope 
      of IR.S_CONST c => (case c 
        of IR.C_FLOAT f => f
	 | IR.C_INT i => Float.fromInt i
         | _ => raise Fail ("Expected const float.")
        (* end case *))
       | _ => raise Fail ("Expected const float.")
      (* end case *))

    fun extractConstBool(IR.V{scope, ...}) = (case scope 
      of IR.S_CONST c => (case c 
        of IR.C_BOOL b => b
         | _ => raise Fail ("Expected const bool.")
        (* end case *))
       | _ => raise Fail ("Expected const bool.")
      (* end case *))
    
    fun createConst(v as IR.V{scope, varType, name, ...}) = (case scope
      of IR.S_LOCAL(rhs) => let
        val (p, vl) = !rhs
       in
        (case p 
          of IR.ADD_VEC => 
            if allConsts(vl) then
              let
	       val v1 = extractConstVec(List.nth(vl, 0))
	       val v2 = extractConstVec(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_VEC(Vec3f.add(v1, v2)))
	      end
	    else v

          | IR.SUB_VEC => 
            if allConsts(vl) then
              let
	       val v1 = extractConstVec(List.nth(vl, 0))
	       val v2 = extractConstVec(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_VEC(Vec3f.sub(v1, v2)))
	      end
	    else v

          | IR.LEN_SQ => 
	    if allConsts(vl) then
	      let
	       val v1 = extractConstVec(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Vec3f.lengthSq(v1)))
	      end
	    else v

          | IR.LEN =>
	    if allConsts(vl) then
	      let
	       val v1 = extractConstVec(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Vec3f.length(v1)))
	      end
	    else v

          | IR.NORM => 
	    if allConsts(vl) then
              let
	       val vec = extractConstVec(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_VEC(Vec3f.normalize(vec)))
	      end
	    else v

          | IR.SCALE => 
            if allConsts(vl) then 
              let 
               val scale = extractConstFloat(List.nth(vl, 0))
               val vec = extractConstVec(List.nth(vl, 1))
              in
	       gProgramChanged := true;
               IR.newConst(name, IR.C_VEC(Vec3f.scale(scale, vec)))
              end (*let*)
            else v

          | IR.DOT => 
            if allConsts(vl) then
              let
	       val v1 = extractConstVec(List.nth(vl, 0))
	       val v2 = extractConstVec(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Vec3f.dot(v1, v2)))
	      end
	    else v

          | IR.CROSS => 
            if allConsts(vl) then
              let
	       val v1 = extractConstVec(List.nth(vl, 0))
	       val v2 = extractConstVec(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_VEC(Vec3f.cross(v1, v2)))
	      end
	    else v

          | IR.ADD => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(v1 + v2))
	      end
	    else v

          | IR.SUB => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(v1 - v2))
	      end
	    else v

          | IR.MULT => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(v1 * v2))
	      end
	    else v

          | IR.DIV => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(v1 / v2))
	      end
	    else v

          | IR.SQRT => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Float.sqrt v1))
	      end
	    else v

          | IR.COS => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Float.cos v1))
	      end
	    else v

          | IR.SIN => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Float.sin v1))
	      end
	    else v

          | IR.GT =>
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_BOOL(v1 > v2))
	      end
	    else v

          | IR.EQUALS => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       (* Not proper equality, but will likely generate
	        * better results. *)
	       IR.newConst(name, IR.C_BOOL(Float.abs(v1 - v2) < Float.epsilon))
	      end
	    else v

          | IR.AND => 
            if allConsts(vl) then
              let
	       val v1 = extractConstBool(List.nth(vl, 0))
	       val v2 = extractConstBool(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_BOOL(v1 andalso v2))
	      end
	    else v

          | IR.OR => 
            if allConsts(vl) then
              let
	       val v1 = extractConstBool(List.nth(vl, 0))
	       val v2 = extractConstBool(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_BOOL(v1 orelse v2))
	      end
	    else v

          | IR.NOT => 
            if allConsts(vl) then
              let
	       val v1 = extractConstBool(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_BOOL(not v1))
	      end
	    else v

          | IR.RAND => v
          | IR.ITOF => 
            if allConsts(vl) then
              let 
               val intVal = extractConstInt(List.nth(vl, 0))
              in
	       gProgramChanged := true;
               IR.newConst(name, IR.C_FLOAT(Float.fromInt (intVal)))
              end
            else v
          | IR.COPY => if allConsts(vl) then List.nth(vl, 0) else v
         (* end case *))
	 end
       | _ => v
      (* end case *))
      
    
    fun foldStmt(IR.PRIM(v as IR.V{scope, ...}, p, vl, s)) = 
        let
         val foldedArgs = List.map createConst vl
	in
	 (case scope
          of IR.S_LOCAL(rhs) => rhs := (p, foldedArgs)
           | _ => ()
	 (* end case *));
	 IR.PRIM(v, p, foldedArgs, foldStmt(s))
	end
      | foldStmt(IR.IF(v, s1, s2)) =
        IR.IF(v, foldStmt(s1), foldStmt(s2))
      | foldStmt(IR.GOTO(blk, vl)) =
        IR.GOTO(blk, List.map createConst vl)
      | foldStmt(IR.RETURN(vl)) =
        IR.RETURN(List.map createConst vl)
      | foldStmt(IR.DISCARD) = 
        IR.DISCARD
    
    fun foldConstantsInBlock(IR.BLK{id, params, body, visited}) =
      IR.BLK{id=id, params=params, body=foldStmt(body), visited=visited}

    val _ = gProgramChanged := false;
    val optimizedBlocks = List.map foldConstantsInBlock blocks
      
   in
    if (!gProgramChanged) then
     foldConstants(optimizedBlocks)
    else
     optimizedBlocks
   end

  (* Useless variable elimination
   *
   * Gets rid of all of the variables that won't be needed for computation. This
   * is slightly stronger than unused variable elimination, since we need to do
   * analysis on which variables we actually don't update over the course of
   * the entire program. This gets to be a little tricky when moving between
   * blocks since we don't preserve variable IDs, and basically they are all
   * mapped to entry/exit points of the blocks.
   *
   * A few key assumptions can be used to guide this. Notably, since we're using
   * a single assignment paradigm, we know that if a variable is changed, then
   * it will be a different variable than the one passed into the block. Hence,
   * any variables that get passed through the block without being used to
   * update something can be called useless.
   *)

  local
   structure Ord = struct
     type ord_key = IR.var
     val compare = IR.compare
   end
  in
   structure IRSet = RedBlackSetFn(Ord)
   structure IRMap = RedBlackMapFn(Ord)
  end

  fun removeUseless(blocks) = let


   (* Find the variables in l1 that are not in l2. *)
   fun findDifferentVars [] _ = (IRSet.empty, IRSet.empty)
     | findDifferentVars (l1 :: l1s)  l2 = let
     fun findInList x2 = IR.compare(l1, x2) = EQUAL
     val (ins, outs) = findDifferentVars l1s l2
    in
     if List.exists findInList l2 then
      (IRSet.add (ins, l1), outs)
     else
      (ins, IRSet.add (outs, l1))
    end

   fun findUselessInBlock(blk as IR.BLK{params, body, ...}) = let
     
     (*
     val filt = List.filter
     fun quicksort << xs = let
       fun qs [] = []
         | qs [x] = [x]
         | qs (p::xs) = let
           val lessThanP = (fn x => << (x, p))
          in
           qs (filt lessThanP xs) @ p :: (qs (filt (not o lessThanP) xs))
          end
      in
       qs xs
      end
     *)

     (* This function returns the potentially useless variables
      * of the block from the statement in question to the end
      * of the block. From there, we can determine which variables
      * are actually useless if any of them make it to the entry
      * point of the block.
      *)

     (* Replace v in useful set with all elements used to update v *)
     fun trackStmt(IR.PRIM(v, p, vl, s)) = let
        val usefuls = trackStmt(s)
       in
	if IRSet.member (usefuls, v) then
	 IRSet.addList(usefuls, vl)
	else
	 usefuls
       end

       (* Here, all we know are the useful variables from s1 and s2.
        * If both of them contain no useful variables, then v is 
	* useless as well. However, the resulting useful variables
	* will be the union of the two sets along with the boolean
	* variable used for the comparison.
	*)
       | trackStmt(IR.IF(v, s1, s2)) = let
         val usefuls1 = trackStmt(s1)
	 val usefuls2 = trackStmt(s2)
	in
	 if IRSet.isEmpty(usefuls1) andalso IRSet.isEmpty(usefuls2) then
	  IRSet.empty
	 else
	  IRSet.add(IRSet.union (usefuls1, usefuls2), v)
	end

       (* The following two cases represent the meaningful ends of the
        * block, and as a result, this is where we setup the primary
        * useful variables. Namely, those variables are the ones which
        * are not part of the parameters, since it means that they were
        * computed from within the block to be used outside of the block.
        *)
       | trackStmt(IR.GOTO(blk, vl)) = let
         val (sames, diffs) = findDifferentVars vl params
        in
         diffs
        end

       | trackStmt(IR.RETURN(vl)) = let
         val (sames, diffs) = findDifferentVars vl params
        in
	 diffs
	end

       (* If we're discarding the particle, then all of the 
        * input to this block is useless. Arguably, these instances
	* will never be reached outside of an if-statement, so we will
	* perform the proper useless at that point.
	*)
       | trackStmt(IR.DISCARD) = IRSet.empty

    in
     IRSet.difference(IRSet.fromList params, trackStmt(body))
    end (* findUselessInBlock *)

   fun findUseless(blocks) = List.foldl IRSet.union IRSet.empty (List.map findUselessInBlock blocks)

   fun prVarIds([]) = ()
     | prVarIds(v as IR.V{id, ...} :: vl) = (printErr (Int.toString id); prVarIds(vl))

   fun propagateUseless(blocks, initUVs) = let
     fun propagateUselessInBlock(blk as IR.BLK{id, body, ...}, uselessVars) = let
       fun hasUseless(uves, []) = false
         | hasUseless(uves, v :: vl) = IRSet.member(uves, v) orelse hasUseless(uves, vl)
       
       fun checkStmt(IR.PRIM(v as IR.V{name, id=vid, ...}, p, vl, s)) = let
           val uvs = checkStmt(s)
	  in
           if hasUseless(uvs, vl) then
	    (gProgramChanged := true;
            IRSet.add(uvs, v))
           else
            uvs
	  end
            
         | checkStmt(IR.IF(v, s1, s2)) = let
	   val uv1s = checkStmt(s1)
	   val uv2s = checkStmt(s2)
	  in
	   IRSet.intersection(uv1s, uv2s)
	  end

	 | checkStmt(IR.GOTO(IR.BLK{params=args, ...}, vlist)) = let
	    fun update(uvs, (v as IR.V{id=id1, ...}) :: vl, (p as IR.V{id=id2, ...}) :: ps) = 
	        if IRSet.member(uvs, v) andalso (not (IRSet.member(uvs, p))) then
		 (gProgramChanged := true;
                 (* printErr (String.concat ["Found useless var: ", Int.toString id1, " mapping to useful param: ", Int.toString id2]);*)
		 update(IRSet.delete(uvs, v), vl, ps))
		else if IRSet.member(uvs, p) andalso (not (IRSet.member(uvs, v))) then
		 (gProgramChanged := true;
                 (* printErr (String.concat ["Found useless param:", Int.toString id2, " mapping from useful var: ", Int.toString id1]); *)
		 update(IRSet.delete(uvs, p), vl, ps))
		else
		 update(uvs, vl, ps)
	      | update(uvs, [], []) = uvs
	      | update(_, _, _) = raise Fail("Mismatch in block parameters") 
	   in
	    update(uselessVars, vlist, args)
	   end

	 | checkStmt(_) = uselessVars

       val result = checkStmt(body)
      in
       (* printErr (String.concat ["Block ", Int.toString id, ", Size of list: ", Int.toString (List.length (IRSet.listItems result))]); *)
       result
      end (* propagateUselessInBlock *)

     val result = (gProgramChanged := false; List.foldr propagateUselessInBlock initUVs blocks)

    in
     if (!gProgramChanged) then
      propagateUseless(blocks, result)
     else
      result
    end

   fun prVarIds([]) = ()
     | prVarIds(IR.V{id, ...} :: vl) = (printErr (Int.toString id); prVarIds(vl))

   (* Returns true if the variable is useless in the given block. It 
    * checks all of the subsequent blocks to make sure that the variable
    * is useless... 
    *
    * !SPEED! This will likely benefit from some memoization, since there's
    * a pretty good chance that we're going to be checking the same block
    * multiple times with this method.
    *)
   fun isVarUselessInBlock(v, blk as IR.BLK{params, body, ...}) = let

     fun uselessInNextBlock(IR.PRIM(v, p, vl, s)) = uselessInNextBlock(s)
       | uselessInNextBlock(IR.IF(v, s1, s2)) = 
         uselessInNextBlock(s1) andalso uselessInNextBlock(s2)
       | uselessInNextBlock(IR.GOTO(
           blk2 as IR.BLK{params=ps, body=s, ...}, 
	   vl
	 )) = 
	let
         fun searchNth(_, []) = true
	   | searchNth(n, pv :: vls) =
	     if IR.compare(pv, v) = EQUAL then
	       isVarUselessInBlock(List.nth (ps, n), blk2)
	     else
	       searchNth(n + 1, vls)
	in
	 searchNth(0, vl)
	end

       | uselessInNextBlock(_) = true

    in
     if IRSet.member(findUselessInBlock(blk), v) then
      uselessInNextBlock(body)
     else
      false
    end (* isVarUselessInBlock *)

   in
    printErr ("Initial useless variables (" ^ (Int.toString (List.length (IRSet.listItems (findUseless(blocks))))) ^"):");
    prVarIds(IRSet.listItems (findUseless(blocks)));
    printErr "\nFinal useless variables:";
    prVarIds(IRSet.listItems (propagateUseless(blocks, findUseless(blocks))))
   end (* removeUselessVariable *)
  
  fun optimizeIR(blocks) = 
   let
    val foldedBlks = foldConstants(blocks)
    val optBlks = removeUnused(foldedBlks)
    (* val optBlks2 = removeUseless(optBlks) *)
   in
    (*List.app prRefCtForBlk optBlks; *)
    optBlks    
   end
 
 end (* structure *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0