Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-to-mid/split.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3552 - (download) (annotate)
Wed Jan 6 18:48:59 2016 UTC (4 years, 2 months ago) by jhr
File size: 11702 byte(s)
working on merge
(* split.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)
 
 (*
  During the transition from high-IR to mid-IR, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators.  Each EIN expression then only has one operation working on  constants, tensors, deltas, epsilons, images and kernels.
 
 (1) When the outer EIN operator is $ \in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.
 
 (1a.) When a subexpression is a field expression $\circledast, \nabla $ then it becomes 0. When it is another operation $ {@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.
 
 (1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.
 
 (1c) Call cleanParams.sml to clean the params in the subexpression.\\
 *)

structure Split : sig

  end = struct
   
    structure E = Ein
    structure DstIR = MidIR
    structure DstTy = MidTypes
    structure DstV = DstIR.Var

    structure cleanP = cleanParams
    structure cleanI = cleanIndex


    in

    val numFlag = true   (*remove common subexpression*)
    fun mkEin e = E.mkEin e
    val einappzero = DstIR.EINAPP(mkEin([], [], E.Const 0), [])
    fun setEinZero y = (y, einappzero)
    fun cleanParams e = cleanP.cleanParams e
    fun cleanIndex e = cleanI.cleanIndex e
    fun toStringBind e = MidToString.toStringBind e

 (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
  * floats an inner subexpression out to be its own ein expression and
  * returns the replacement tensor.
  * cleans the index and params of subexpression
  * creates new param and replacement tensor for the original ein_exp
  *)
    fun float (name, e, params, index, sx, args, fieldset, flag) = let
	  val (tshape, sizes, body) = cleanIndex(e, index, sx)
	  val id = length params
	  val Rparams = params@[E.TEN(true, sizes)]
	  val Re = E.Tensor(id, tshape)
	  val M = DstV.new (concat[name, "_l_", Int.toString id], DstTy.TensorTy sizes)
	  val Rargs = args@[M]
	  val einapp as (_, einapp0) = cleanParams (M, body, Rparams, sizes, Rargs)
	  val (Rargs, newbies, fieldset) = if flag
		  then let
		    val (fieldset, var) = einSet.rtnVar(fieldset, M, einapp0)
		    in
		      case var
		       of NONE => (Rargs], [einapp], fieldset)
			| SOME v => (args@[v], [], fieldset))
		      (* end case *))
		    end
		  else (args@[M], [einapp], fieldset)
	  in
	    (Re, Rparams, Rargs, newbies, fieldset)
	  end

  (* some Ein expressions get replaced by zero (e.g., fields), others get floated to
   * top level, and the rest remain the same.
   *)
    datatype op_replace = ZERO | FLOAT | SAME

  (*  checks to see if this sub-expression is floated out or split from original *)
    fun shouldFloat e = (case e
	   of E.Field _   => ZERO
	    | E.Conv _    => ZERO
	    | E.Apply _   => ZERO
	    | E.Lift _    => ZERO
	    | E.Op1 _     => FLOAT
	    | E.Op2 _     => FLOAT
	    | E.Opn _     => FLOAT
	    | E.Sum _     => FLOAT
	    | E.Probe _   => FLOAT
	    | E.Partial _ => err "Partial used after normalize"
	    | E.Krn _     => err "Krn used before expand"
	    | E.Value _   => err "Value used before expand"
	    | E.Img _     => err "Probe used before expand"
	    | _           => SAME
	  (* end case *))

    (* *************************************** helpers ******************************** *)

    fun rewriteOp (name, e1, params, index, sx, args, fieldset, flag) = (case shouldFloat e1
	   of ZERO => (E.Const 0, params, args, [], fieldset)
	    | FLOAT => lift(name, e1, params, index, sx, args, fieldset, flag)
	    | SAME => (e1, params, args, [], fieldset)
          (* end case *))

    fun unaryOp(name, sx, e1, x) = let
	  val ((y, DstIR.EINAPP(ein, args)), fieldset, flag) = x
	  val params = Ein.params ein
	  val index = Ein.index ein 
	  in
            rewriteOp (name, e1, params, index, sx, args, fieldset, flag)
	  end

    fun multOp (name, sx, list1, x) = let
	  val ((y, DstIR.EINAPP(ein, args)), fieldset, flag) = x
	  val params = Ein.params ein
	  val index = Ein.index ein
	  fun m ([], rest, params, args, code, fieldset) = (rest, params, args, code, fieldset)
	    | m (e1::es, rest, params, args, code, fieldset) = let
		val (e1', params', args', code', fieldset) =
		      rewriteOp (name, e1, params, index, sx, args, fieldset, flag)
		in
		  m(es, rest@[e1'], params', args', code@code', fieldset)
		end
	  in
            m( list1, [], params, args, [], fieldset)
	  end

    (*clean params*)
    fun cleanOrig (body, params, args, x) = let
	  val ((y, DstIR.EINAPP(ein, _)), _, _) = x
	  val index = Ein.index ein 
	  in
	    cleanParams (y, body, params, index, args)
	  end

    (* *************************************** general handle Ops ******************************** *)
    fun handleUnaryOp (name, opp, x, e1) = let
	  val (e1', params', args', code, fieldset) = unaryOp(name, [], e1, x)
	  val body' = E.Op1(opp, e1')
	  val einapp = cleanOrig(body', params', args', x)
	  in
	    (einapp, code, fieldset)
	  end
    fun handleBinaryOp (name, opp, x, es) = let
	  val ([e1', e2'], params', args', code, fieldset) = multOp(name, [], es, x)
	  val body' =E.Op2(opp, e1', e2')
	  val einapp= cleanOrig(body', params', args', x)
	  in
	    (einapp, code, fieldset)
	  end
    fun handleMultOp (name, opp, x, es)= let
	  val (e1', params', args', code, fieldset) = multOp(name, [], es, x)
	  val body = E.Opn(opp , e1')
	  val einapp = cleanOrig(body, params', args', x)
	  in
	    (einapp, code, fieldset)
	  end

    (* ***************************************specific handle Ops ******************************** *)
    fun handleDiv (e1, e2, x) = let
	  val ((y, DstIR.EINAPP(ein, args)), fieldset, flag) = x
	  val params = Ein.params ein
	  val index = Ein.index ein
	  val (e1', params1', args1', code1', fieldset) = rewriteOp("div-num", e1, params, index, [], args, fieldset, flag)
	  val (e2', params2', args2', code2', fieldset) = rewriteOp("div-denom", e2, params1', index, [], args1', fieldset, flag)
	  val body' = E.Op2(E.Div, e1', e2')
	  val einapp = cleanOrig(body', params2', args2', x)
	  in
	    (einapp, code1'@code2', fieldset)
	  end

    fun handleSumProd (e1, sx, x) = let
	  val (e1', params', args', code, fieldset)=  multOp("sumprod", sx, e1, x)
	  val body'= E.Sum(sx, E.Opn(E.Prod, e1'))
	  val einapp= cleanOrig(body', params', args', x)
	  in
	    (einapp, code, fieldset)
	  end

    (* *************************************** Split ******************************** *)

    (* split:var*ein_app-> (var*einap)*code
    * split ein expression into smaller pieces
      note we leave summation around probe exp
    *)
    fun split((y, einapp as DstIR.EINAPP(Ein.EIN{params, index, body}, args)), fieldset, flag) =let
	  val x = ((y, einapp), fieldset, flag)
	  val zero = (setEinZero y, [], fieldset)
	  val default = ((y, einapp), [], fieldset)
	  val sumIndex = ref []
          fun error () = raise Fail("Poorly formed EIN operator: " ^ EinPP.expToString body)
	  fun rewrite b = (case b
		 of E.Const  _ => default
		  | E.ConstR _ => default
		  | E.Tensor _ => default
		  | E.Delta _ => default
		  | E.Epsilon _ => default
		  | E.Eps2 _ => default
		  | E.Field _ => raise Fail "should have been swept"
		  | E.Lift e => raise Fail "should have been swept"
		  | E.Conv _ => raise Fail "should have been swept"
		  | E.Partial _ => raise Fail "Partial used after normalize"
		  | E.Apply _ => raise Fail "should have been swept"
		  | E.Probe(E.Conv _, _) => default
		  | E.Probe(E.Field _, _) => error()
		  | E.Probe _ => error()
		  | E.Value _ => raise Fail "Value used before expand"
		  | E.Img _ => raise Fail "Probe used before expand"
		  | E.Krn _ => raise Fail "Krn used before expand"
		  | E.Sum(_, E.Probe(E.Conv _, _)) => default
		  | E.Sum(sx, E.Tensor _)    => default
		  | E.Sum(sx, E.Opn(E.Prod, e1)) => handleSumProd (e1, sx, x)
		  | E.Sum(sx, E.Delta d) => handleSumProd ([E.Delta d], sx, x)
		  | E.Sum(sx, _) => raise Fail "summation not distributed:"^str)
		  | E.Op1(op1, e1) => (case op1
		       of E.Neg => handleUnaryOp ("neg", op1, x, e1)
			| E.Sqrt => handleUnaryOp ("sqrt", op1, x, e1)
			| E.Exp => handleUnaryOp ("exp", op1, x, e1)
			| E.PowInt n1 => handleUnaryOp ("PowInt", op1, x, e1)
			| _ => handleUnaryOp ("Trig", op1, x, e1)
		     (*end case *))
		  | E.Op2(E.Sub, e1, e2) => handleBinaryOp ("subtract", E.Sub, x, [e1, e2])
		  | E.Op2(E.Div, e1, e2) => handleDiv (e1, e2, x)
		  | E.Opn(E.Add, es) => handleMultOp ("add", E.Add, x, es)
		  | E.Opn(Prod, [E.Tensor(id0, []), E.Tensor(id1, [i]), E.Tensor(id2, [])]) =>
		      rewrite (E.Opn(E.Prod, [
			  E.Opn(E.Prod, [E.Tensor(id0, []), E.Tensor(id2, [])]), E.Tensor(id1, [i])
			]))
		  | E.Opn(E.Prod, es) => handleMultOp("prod", E.Prod, x, es)
	      (* end case *))
	  val (einapp2, newbies, fieldset) = rewrite body
	  in
            ((einapp2, newbies), fieldset)
	  end
      | split ((y, app), fieldset, _) = (((y, app), []), fieldset)


    (* *************************************** main  ******************************** *)
    fun limitSplit(einapp2, fields2, splitlimit) = let
        val fieldset= einSet.EinSet.empty
        val _ =print ("\nSPLit with limit"^(Int.toString(splitlimit)))
        fun itercode([], rest, code, cnt) = (("\n Empty-SplitCount: "^Int.toString(cnt));(rest, code))
        | itercode(e1::newbies, rest, code, cnt) = let
            val ((einapp3, code3), _) = split(e1, fieldset, numFlag)
            val (rest4, code4) = itercode(code3, [], [], cnt+1)
            val _ =testp [toStringBind(e1), "\n\t===>\n", toStringBind(einapp3), "\nand\n", (String.concatWith", \n\t"(List.map toStringBind (code4@rest4)))]
            in
                if (length(rest@newbies@code) > splitlimit) then let
                        val _ =("\n SplitCount: "^Int.toString(cnt))
                        val code5 = code4@rest4@code
                        val rest5 = rest@[einapp3]
                        in
                            (rest5, code5@newbies)(*tab4*)
                        end
                else  itercode(newbies, rest@[einapp3], code4@rest4@code, cnt+2)
            end
        val(rest, code)= itercode([einapp2], [], [], 0)
        in
           
              fields2@code@rest (*B*)
        end

    fun splitEinApp einapp0 =let
        val fieldset= einSet.EinSet.empty
        val einapp2 = [einapp0]
        fun itercode([], rest, code, _) = (rest, code)
        | itercode(e1::newbies, rest, code, cnt) = let
            val ((einapp3, code3), _) = split(e1, fieldset, numFlag)
            val (rest4, code4) = itercode(code3, [], [], cnt+1)
                val _ =testp [toStringBind(e1), "\n\t===>\n", toStringBind(einapp3), "\nand\n", (String.concatWith", \n\t"(List.map toStringBind (code4@rest4)))]
            in
                itercode(newbies, rest@[einapp3], code4@( rest4)@code, cnt+2)
            end
        val(rest, code)= itercode(einapp2, [], [], 0)
        in
            (code@rest)
        end

  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0