Home My Page Projects Code Snippets Project Openings 3D graphics for Standard ML
Summary Activity SCM

SCM Repository

[sml3d] View of /trunk/sml3d/src/particles/compiler/optimizer.sml
ViewVC logotype

View of /trunk/sml3d/src/particles/compiler/optimizer.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 864 - (download) (annotate)
Wed Apr 28 17:21:18 2010 UTC (10 years, 3 months ago) by pavelk
File size: 19753 byte(s)
Updated work on UVE. It now works the way I intended it to, although I think I may need to revisit my intentions.
structure Optimize : sig
  val optimizeIR : PSysIR.block list -> PSysIR.block list
 end = struct

  val gProgramChanged = ref false
 
  fun printErr s = TextIO.output(TextIO.stdErr, s ^ "\n")

  structure IR = PSysIR

  (* Unused variable elimination.
   * 
   * Get rid of all of the variables that aren't being used by any of the
   * succeeding statements. Should be done before constant folding to 
   * remove redundancies.
   *)
  fun resetVarRefCount(blocks) = let
    fun resetVar(IR.V{useCount, ...}) = useCount := 0
    fun resetStmt(IR.PRIM(v, p, vl, s)) = (
         List.app resetVar vl;
         resetStmt(s))
      | resetStmt(IR.IF(v, s1, s2)) = (
         resetVar(v);
         resetStmt(s1);
         resetStmt(s2))
      | resetStmt(IR.GOTO(block, vl)) = List.app resetVar vl
      | resetStmt(_) = ()

    fun resetForBlock(IR.BLK{body, ...}) = resetStmt(body)
   in
    List.app resetForBlock blocks
   end
 
  fun countVarReferences(blocks) = let
    fun updateVar(IR.V{useCount, ...}) = useCount := (!useCount) + 1
    fun updateStmt(IR.PRIM(var, prim, vl, stmt)) = (
        List.app updateVar vl;
        updateStmt(stmt))
      | updateStmt(IR.IF(var, s1, s2)) = (
        updateVar(var);
        updateStmt(s1);
        updateStmt(s2))
      | updateStmt(IR.GOTO(block, vl)) = List.app updateVar vl
      | updateStmt(IR.RETURN(vl)) = List.app updateVar vl
      | updateStmt(_) = ()
    
    fun countForBlock(IR.BLK{id, params, body}) = (updateStmt(body); ())
   in
    List.app countForBlock blocks
   end (*countVarReferences*)
  
  fun prv(var as IR.V{name, useCount, ...}) = 
    printErr (String.concat([name, ": ", Int.toString (!useCount)]))

  fun prRefCt(IR.PRIM(v, p, vl, stmt)) = (prv(v); prRefCt(stmt))
    | prRefCt(IR.IF(v, s1, s2)) = (prRefCt(s1); prRefCt(s2))
    | prRefCt(_) = ()

  fun prRefCtForBlk(IR.BLK{body, ...}) = prRefCt(body)
  
  fun removeUnused(blocks) = let
    fun isUnused(v as IR.V{useCount, name, ...}) = 
      if !useCount > 0 then 
        false 
      else 
        true

    fun removeUnusedStmt(IR.PRIM(v, p, vl, s)) =
        if isUnused(v) then
	  (gProgramChanged := true;
          removeUnusedStmt(s))
        else 
          IR.PRIM(v, p, vl, removeUnusedStmt(s))
      | removeUnusedStmt(IR.IF(var, s1, s2)) = 
        IR.IF(var, removeUnusedStmt(s1), removeUnusedStmt(s2))
      | removeUnusedStmt(x) = x
    fun removeUnusedInBlock(IR.BLK{id, params, body}) = 
      IR.BLK{id=id, params=params, body=removeUnusedStmt(body)}

    val _ = resetVarRefCount(blocks)
    val _ = countVarReferences(blocks)
    val _ = gProgramChanged := false
    val optimizedBlocks = List.map removeUnusedInBlock blocks
   in
    if (!gProgramChanged) then 
      removeUnused(optimizedBlocks) 
    else 
      optimizedBlocks
   end (* removeUnused *)

  (* Constant folding
   *
   * Condense most of our constants so that we have a smaller code size. We
   * already know what most of these operations do...
   *)
  
  fun foldConstants(blocks) = let
    
    fun allConsts([]) = true
      | allConsts(v as IR.V{scope, ...} :: vl) = 
        case scope 
          of IR.S_CONST(c) => allConsts(vl) 
           | _ => false
    
    fun extractConstVec(IR.V{scope, ...}) = (case scope 
      of IR.S_CONST c => (case c 
        of IR.C_VEC v => v
         | _ => raise Fail ("Expected const vec.")
        (* end case *))
       | _ => raise Fail ("Expected const vec.")
      (* end case *))
      
    fun extractConstInt(IR.V{scope, ...}) = (case scope 
      of IR.S_CONST c => (case c 
        of IR.C_INT i => i
         | _ => raise Fail ("Expected const int.")
        (* end case *))
       | _ => raise Fail ("Expected const int.")
      (* end case *))
       
    fun extractConstFloat(IR.V{scope, ...}) = (case scope 
      of IR.S_CONST c => (case c 
        of IR.C_FLOAT f => f
	 | IR.C_INT i => Float.fromInt i
         | _ => raise Fail ("Expected const float.")
        (* end case *))
       | _ => raise Fail ("Expected const float.")
      (* end case *))

    fun extractConstBool(IR.V{scope, ...}) = (case scope 
      of IR.S_CONST c => (case c 
        of IR.C_BOOL b => b
         | _ => raise Fail ("Expected const bool.")
        (* end case *))
       | _ => raise Fail ("Expected const bool.")
      (* end case *))
    
    fun createConst(v as IR.V{scope, varType, name, ...}) = (case scope
      of IR.S_LOCAL(rhs) => let
        val (p, vl) = !rhs
       in
        (case p 
          of IR.ADD_VEC => 
            if allConsts(vl) then
              let
	       val v1 = extractConstVec(List.nth(vl, 0))
	       val v2 = extractConstVec(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_VEC(Vec3f.add(v1, v2)))
	      end
	    else v

          | IR.SUB_VEC => 
            if allConsts(vl) then
              let
	       val v1 = extractConstVec(List.nth(vl, 0))
	       val v2 = extractConstVec(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_VEC(Vec3f.sub(v1, v2)))
	      end
	    else v

          | IR.LEN_SQ => 
	    if allConsts(vl) then
	      let
	       val v1 = extractConstVec(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Vec3f.lengthSq(v1)))
	      end
	    else v

          | IR.LEN =>
	    if allConsts(vl) then
	      let
	       val v1 = extractConstVec(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Vec3f.length(v1)))
	      end
	    else v

          | IR.NORM => 
	    if allConsts(vl) then
              let
	       val vec = extractConstVec(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_VEC(Vec3f.normalize(vec)))
	      end
	    else v

          | IR.SCALE => 
            if allConsts(vl) then 
              let 
               val scale = extractConstFloat(List.nth(vl, 0))
               val vec = extractConstVec(List.nth(vl, 1))
              in
	       gProgramChanged := true;
               IR.newConst(name, IR.C_VEC(Vec3f.scale(scale, vec)))
              end (*let*)
            else v

          | IR.DOT => 
            if allConsts(vl) then
              let
	       val v1 = extractConstVec(List.nth(vl, 0))
	       val v2 = extractConstVec(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Vec3f.dot(v1, v2)))
	      end
	    else v

          | IR.CROSS => 
            if allConsts(vl) then
              let
	       val v1 = extractConstVec(List.nth(vl, 0))
	       val v2 = extractConstVec(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_VEC(Vec3f.cross(v1, v2)))
	      end
	    else v

          | IR.ADD => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(v1 + v2))
	      end
	    else v

          | IR.SUB => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(v1 - v2))
	      end
	    else v

          | IR.MULT => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(v1 * v2))
	      end
	    else v

          | IR.DIV => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(v1 / v2))
	      end
	    else v

          | IR.SQRT => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Float.sqrt v1))
	      end
	    else v

          | IR.COS => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Float.cos v1))
	      end
	    else v

          | IR.SIN => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_FLOAT(Float.sin v1))
	      end
	    else v

          | IR.GT =>
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_BOOL(v1 > v2))
	      end
	    else v

          | IR.EQUALS => 
            if allConsts(vl) then
              let
	       val v1 = extractConstFloat(List.nth(vl, 0))
	       val v2 = extractConstFloat(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       (* Not proper equality, but will likely generate
	        * better results. *)
	       IR.newConst(name, IR.C_BOOL(Float.abs(v1 - v2) < Float.epsilon))
	      end
	    else v

          | IR.AND => 
            if allConsts(vl) then
              let
	       val v1 = extractConstBool(List.nth(vl, 0))
	       val v2 = extractConstBool(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_BOOL(v1 andalso v2))
	      end
	    else v

          | IR.OR => 
            if allConsts(vl) then
              let
	       val v1 = extractConstBool(List.nth(vl, 0))
	       val v2 = extractConstBool(List.nth(vl, 1))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_BOOL(v1 orelse v2))
	      end
	    else v

          | IR.NOT => 
            if allConsts(vl) then
              let
	       val v1 = extractConstBool(List.nth(vl, 0))
	      in
	       gProgramChanged := true;
	       IR.newConst(name, IR.C_BOOL(not v1))
	      end
	    else v

          | IR.RAND => v
          | IR.ITOF => 
            if allConsts(vl) then
              let 
               val intVal = extractConstInt(List.nth(vl, 0))
              in
	       gProgramChanged := true;
               IR.newConst(name, IR.C_FLOAT(Float.fromInt (intVal)))
              end
            else v
          | IR.COPY => if allConsts(vl) then List.nth(vl, 0) else v
         (* end case *))
	 end
       | _ => v
      (* end case *))
      
    
    fun foldStmt(IR.PRIM(v as IR.V{name, id, varType, scope, useCount}, p, vl, s)) = 
        let
         val foldedArgs = List.map createConst vl
	in
	 (case scope
          of IR.S_LOCAL(rhs) => rhs := (p, foldedArgs)
           | _ => ()
	 (* end case *));
	 IR.PRIM(v, p, foldedArgs, foldStmt(s))
	end
      | foldStmt(IR.IF(v, s1, s2)) =
        IR.IF(v, foldStmt(s1), foldStmt(s2))
      | foldStmt(IR.GOTO(blk, vl)) =
        IR.GOTO(blk, List.map createConst vl)
      | foldStmt(IR.RETURN(vl)) =
        IR.RETURN(List.map createConst vl)
      | foldStmt(IR.DISCARD) = 
        IR.DISCARD
    
    fun foldConstantsInBlock(IR.BLK{id, params, body}) =
      IR.BLK{id=id, params=params, body=foldStmt(body)}

    val _ = gProgramChanged := false;
    val optimizedBlocks = List.map foldConstantsInBlock blocks
      
   in
    if (!gProgramChanged) then
     foldConstants(optimizedBlocks)
    else
     optimizedBlocks
   end

  (* Useless variable elimination
   *
   * Gets rid of all of the variables that won't be needed for computation. This
   * is slightly stronger than unused variable elimination, since we need to do
   * analysis on which variables we actually don't update over the course of
   * the entire program. This gets to be a little tricky when moving between
   * blocks since we don't preserve variable IDs, and basically they are all
   * mapped to entry/exit points of the blocks.
   *
   * A few key assumptions can be used to guide this. Notably, since we're using
   * a single assignment paradigm, we know that if a variable is changed, then
   * it will be a different variable than the one passed into the block. Hence,
   * any variables that get passed through the block without being used to
   * update something can be called useless.
   *)

  local
   structure Ord = struct
     type ord_key = IR.var
     val compare = IR.compare
   end
  in
   structure IRSet = RedBlackSetFn(Ord)
   structure IRMap = RedBlackMapFn(Ord)
  end

  fun removeUseless(blocks) = let


   (* Find the variables in l1 that are not in l2. *)
   fun findDifferentVars [] _ = (IRSet.empty, IRSet.empty)
     | findDifferentVars (l1 :: l1s)  l2 = let
     fun findInList x2 = IR.compare(l1, x2) = EQUAL
     val (ins, outs) = findDifferentVars l1s l2
    in
     if List.exists findInList l2 then
      (IRSet.add (ins, l1), outs)
     else
      (ins, IRSet.add (outs, l1))
    end

   fun findUselessInBlock(blk as IR.BLK{params, body, ...}) = let
     
     (*
     val filt = List.filter
     fun quicksort << xs = let
       fun qs [] = []
         | qs [x] = [x]
         | qs (p::xs) = let
           val lessThanP = (fn x => << (x, p))
          in
           qs (filt lessThanP xs) @ p :: (qs (filt (not o lessThanP) xs))
          end
      in
       qs xs
      end
     *)

     (* This function returns the potentially useless variables
      * of the block from the statement in question to the end
      * of the block. From there, we can determine which variables
      * are actually useless if any of them make it to the entry
      * point of the block.
      *)

     (* Replace v in useful set with all elements used to update v *)
     fun trackStmt(IR.PRIM(v, p, vl, s)) = let
        val usefuls = trackStmt(s)
       in
	if IRSet.member (usefuls, v) then
	 IRSet.addList(usefuls, vl)
	else
	 usefuls
       end

       (* Here, all we know are the useful variables from s1 and s2.
        * If both of them contain no useful variables, then v is 
	* useless as well. However, the resulting useful variables
	* will be the union of the two sets along with the boolean
	* variable used for the comparison.
	*)
       | trackStmt(IR.IF(v, s1, s2)) = let
         val usefuls1 = trackStmt(s1)
	 val usefuls2 = trackStmt(s2)
	in
	 if IRSet.isEmpty(usefuls1) andalso IRSet.isEmpty(usefuls2) then
	  IRSet.empty
	 else
	  IRSet.add(IRSet.union (usefuls1, usefuls2), v)
	end

       (* The following two cases represent the meaningful ends of the
        * block, and as a result, this is where we setup the primary
        * useful variables. Namely, those variables are the ones which
        * are not part of the parameters, since it means that they were
        * computed from within the block to be used outside of the block.
        *)
       | trackStmt(IR.GOTO(blk, vl)) = let
         val (sames, diffs) = findDifferentVars vl params
        in
         diffs
        end

       | trackStmt(IR.RETURN(vl)) = let
         val (sames, diffs) = findDifferentVars vl params
        in
	 diffs
	end

       (* If we're discarding the particle, then all of the 
        * input to this block is useless. Arguably, these instances
	* will never be reached outside of an if-statement, so we will
	* perform the proper useless at that point.
	*)
       | trackStmt(IR.DISCARD) = IRSet.empty

    in
     IRSet.difference(IRSet.fromList params, trackStmt(body))
    end (* findUselessInBlock *)

   fun findUseless(blocks) = List.foldl IRSet.union IRSet.empty (List.map findUselessInBlock blocks)

   fun prVarIds([]) = ()
     | prVarIds(v as IR.V{id, ...} :: vl) = (printErr (Int.toString id); prVarIds(vl))

   fun propagateUseless(blocks, initUVs) = let
     fun propagateUselessInBlock(blk as IR.BLK{id, body, ...}, uselessVars) = let
       fun hasUseless(uves, []) = false
         | hasUseless(uves, v :: vl) = IRSet.member(uves, v) orelse hasUseless(uves, vl)
       
       fun checkStmt(IR.PRIM(v as IR.V{name, id=vid, ...}, p, vl, s)) = let
           val uvs = checkStmt(s)
	  in
           if hasUseless(uvs, vl) then
	    (gProgramChanged := true;
            IRSet.add(uvs, v))
           else
            uvs
	  end
            
         | checkStmt(IR.IF(v, s1, s2)) = let
	   val uv1s = checkStmt(s1)
	   val uv2s = checkStmt(s2)
	  in
	   IRSet.intersection(uv1s, uv2s)
	  end

	 | checkStmt(IR.GOTO(IR.BLK{params=args, ...}, vlist)) = let
	    fun update(uvs, (v as IR.V{id=id1, ...}) :: vl, (p as IR.V{id=id2, ...}) :: ps) = 
	        if IRSet.member(uvs, v) andalso (not (IRSet.member(uvs, p))) then
		 (gProgramChanged := true;
                 (* printErr (String.concat ["Found useless var: ", Int.toString id1, " mapping to useful param: ", Int.toString id2]);*)
		 update(IRSet.delete(uvs, v), vl, ps))
		else if IRSet.member(uvs, p) andalso (not (IRSet.member(uvs, v))) then
		 (gProgramChanged := true;
                 (* printErr (String.concat ["Found useless param:", Int.toString id2, " mapping from useful var: ", Int.toString id1]); *)
		 update(IRSet.delete(uvs, p), vl, ps))
		else
		 update(uvs, vl, ps)
	      | update(uvs, [], []) = uvs
	      | update(_, _, _) = raise Fail("Mismatch in block parameters") 
	   in
	    update(uselessVars, vlist, args)
	   end

	 | checkStmt(_) = uselessVars

       val result = checkStmt(body)
      in
       (* printErr (String.concat ["Block ", Int.toString id, ", Size of list: ", Int.toString (List.length (IRSet.listItems result))]); *)
       result
      end (* propagateUselessInBlock *)

     val result = (gProgramChanged := false; List.foldr propagateUselessInBlock initUVs blocks)

    in
     if (!gProgramChanged) then
      propagateUseless(blocks, result)
     else
      result
    end

   fun prVarIds([]) = ()
     | prVarIds(IR.V{id, ...} :: vl) = (printErr (Int.toString id); prVarIds(vl))

   (* Returns true if the variable is useless in the given block. It 
    * checks all of the subsequent blocks to make sure that the variable
    * is useless... 
    *
    * !SPEED! This will likely benefit from some memoization, since there's
    * a pretty good chance that we're going to be checking the same block
    * multiple times with this method.
    *)
   fun isVarUselessInBlock(v, blk as IR.BLK{params, body, ...}) = let

     fun uselessInNextBlock(IR.PRIM(v, p, vl, s)) = uselessInNextBlock(s)
       | uselessInNextBlock(IR.IF(v, s1, s2)) = 
         uselessInNextBlock(s1) andalso uselessInNextBlock(s2)
       | uselessInNextBlock(IR.GOTO(
           blk2 as IR.BLK{params=ps, body=s, ...}, 
	   vl
	 )) = 
	let
         fun searchNth(_, []) = true
	   | searchNth(n, pv :: vls) =
	     if IR.compare(pv, v) = EQUAL then
	       isVarUselessInBlock(List.nth (ps, n), blk2)
	     else
	       searchNth(n + 1, vls)
	in
	 searchNth(0, vl)
	end

       | uselessInNextBlock(_) = true

    in
     if IRSet.member(findUselessInBlock(blk), v) then
      uselessInNextBlock(body)
     else
      false
    end (* isVarUselessInBlock *)

   in
    printErr ("Initial useless variables (" ^ (Int.toString (List.length (IRSet.listItems (findUseless(blocks))))) ^"):");
    prVarIds(IRSet.listItems (findUseless(blocks)));
    printErr "\nFinal useless variables:";
    prVarIds(IRSet.listItems (propagateUseless(blocks, findUseless(blocks))))
   end (* removeUselessVariable *)
  
  fun optimizeIR(blocks) = 
   let
    val foldedBlks = foldConstants(blocks)
    val optBlks = removeUnused(foldedBlks)
    (* val optBlks2 = removeUseless(optBlks) *)
   in
    (*List.app prRefCtForBlk optBlks; *)
    optBlks    
   end
 
 end (* structure *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0