Home My Page Projects Code Snippets Project Openings diderot

# SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-to-mid/split.sml
 [diderot] / branches / vis15 / src / compiler / high-to-mid / split.sml

# View of /branches/vis15/src/compiler/high-to-mid/split.sml

Wed Jan 6 18:48:59 2016 UTC (4 years, 2 months ago) by jhr
File size: 11702 byte(s)
working on merge
(* split.sml
*
* This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
*
* COPYRIGHT (c) 2015 The University of Chicago
*)

(*
During the transition from high-IR to mid-IR, complicated EIN expressions are split into simpler ones in order to better identify methods for code generation and common subexpressions. Combining EIN operators in the optimization phase can lead to large and complicated EIN operators. A general code generator would need to expand every operation to work on scalars, which could miss the opportunity for vectorization and lead to poor code generation. Instead, every EIN operator is split into a set of simple EIN operators.  Each EIN expression then only has one operation working on  constants, tensors, deltas, epsilons, images and kernels.

(1) When the outer EIN operator is $\in {--, +, -, *, /, \sum}$ then for each subexpression analyze to see if they need to be rewritten.

(1a.) When a subexpression is a field expression $\circledast, \nabla$ then it becomes 0. When it is another operation ${@ --, +, -, *, /, \sum}$ then we lift that subexpression and create a new EIN operator. We replace the subexpression with a tensor expression that represent it's size.

(1b) Call cleanIndex.sml to clean the indices in the subexpression, and get the shape for the tensor replacement.

(1c) Call cleanParams.sml to clean the params in the subexpression.\\
*)

structure Split : sig

end = struct

structure E = Ein
structure DstIR = MidIR
structure DstTy = MidTypes
structure DstV = DstIR.Var

structure cleanP = cleanParams
structure cleanI = cleanIndex

in

val numFlag = true   (*remove common subexpression*)
fun mkEin e = E.mkEin e
val einappzero = DstIR.EINAPP(mkEin([], [], E.Const 0), [])
fun setEinZero y = (y, einappzero)
fun cleanParams e = cleanP.cleanParams e
fun cleanIndex e = cleanI.cleanIndex e
fun toStringBind e = MidToString.toStringBind e

(* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
* floats an inner subexpression out to be its own ein expression and
* returns the replacement tensor.
* cleans the index and params of subexpression
* creates new param and replacement tensor for the original ein_exp
*)
fun float (name, e, params, index, sx, args, fieldset, flag) = let
val (tshape, sizes, body) = cleanIndex(e, index, sx)
val id = length params
val Rparams = params@[E.TEN(true, sizes)]
val Re = E.Tensor(id, tshape)
val M = DstV.new (concat[name, "_l_", Int.toString id], DstTy.TensorTy sizes)
val Rargs = args@[M]
val einapp as (_, einapp0) = cleanParams (M, body, Rparams, sizes, Rargs)
val (Rargs, newbies, fieldset) = if flag
then let
val (fieldset, var) = einSet.rtnVar(fieldset, M, einapp0)
in
case var
of NONE => (Rargs], [einapp], fieldset)
| SOME v => (args@[v], [], fieldset))
(* end case *))
end
else (args@[M], [einapp], fieldset)
in
(Re, Rparams, Rargs, newbies, fieldset)
end

(* some Ein expressions get replaced by zero (e.g., fields), others get floated to
* top level, and the rest remain the same.
*)
datatype op_replace = ZERO | FLOAT | SAME

(*  checks to see if this sub-expression is floated out or split from original *)
fun shouldFloat e = (case e
of E.Field _   => ZERO
| E.Conv _    => ZERO
| E.Apply _   => ZERO
| E.Lift _    => ZERO
| E.Op1 _     => FLOAT
| E.Op2 _     => FLOAT
| E.Opn _     => FLOAT
| E.Sum _     => FLOAT
| E.Probe _   => FLOAT
| E.Partial _ => err "Partial used after normalize"
| E.Krn _     => err "Krn used before expand"
| E.Value _   => err "Value used before expand"
| E.Img _     => err "Probe used before expand"
| _           => SAME
(* end case *))

(* *************************************** helpers ******************************** *)

fun rewriteOp (name, e1, params, index, sx, args, fieldset, flag) = (case shouldFloat e1
of ZERO => (E.Const 0, params, args, [], fieldset)
| FLOAT => lift(name, e1, params, index, sx, args, fieldset, flag)
| SAME => (e1, params, args, [], fieldset)
(* end case *))

fun unaryOp(name, sx, e1, x) = let
val ((y, DstIR.EINAPP(ein, args)), fieldset, flag) = x
val params = Ein.params ein
val index = Ein.index ein
in
rewriteOp (name, e1, params, index, sx, args, fieldset, flag)
end

fun multOp (name, sx, list1, x) = let
val ((y, DstIR.EINAPP(ein, args)), fieldset, flag) = x
val params = Ein.params ein
val index = Ein.index ein
fun m ([], rest, params, args, code, fieldset) = (rest, params, args, code, fieldset)
| m (e1::es, rest, params, args, code, fieldset) = let
val (e1', params', args', code', fieldset) =
rewriteOp (name, e1, params, index, sx, args, fieldset, flag)
in
m(es, rest@[e1'], params', args', code@code', fieldset)
end
in
m( list1, [], params, args, [], fieldset)
end

(*clean params*)
fun cleanOrig (body, params, args, x) = let
val ((y, DstIR.EINAPP(ein, _)), _, _) = x
val index = Ein.index ein
in
cleanParams (y, body, params, index, args)
end

(* *************************************** general handle Ops ******************************** *)
fun handleUnaryOp (name, opp, x, e1) = let
val (e1', params', args', code, fieldset) = unaryOp(name, [], e1, x)
val body' = E.Op1(opp, e1')
val einapp = cleanOrig(body', params', args', x)
in
(einapp, code, fieldset)
end
fun handleBinaryOp (name, opp, x, es) = let
val ([e1', e2'], params', args', code, fieldset) = multOp(name, [], es, x)
val body' =E.Op2(opp, e1', e2')
val einapp= cleanOrig(body', params', args', x)
in
(einapp, code, fieldset)
end
fun handleMultOp (name, opp, x, es)= let
val (e1', params', args', code, fieldset) = multOp(name, [], es, x)
val body = E.Opn(opp , e1')
val einapp = cleanOrig(body, params', args', x)
in
(einapp, code, fieldset)
end

(* ***************************************specific handle Ops ******************************** *)
fun handleDiv (e1, e2, x) = let
val ((y, DstIR.EINAPP(ein, args)), fieldset, flag) = x
val params = Ein.params ein
val index = Ein.index ein
val (e1', params1', args1', code1', fieldset) = rewriteOp("div-num", e1, params, index, [], args, fieldset, flag)
val (e2', params2', args2', code2', fieldset) = rewriteOp("div-denom", e2, params1', index, [], args1', fieldset, flag)
val body' = E.Op2(E.Div, e1', e2')
val einapp = cleanOrig(body', params2', args2', x)
in
(einapp, code1'@code2', fieldset)
end

fun handleSumProd (e1, sx, x) = let
val (e1', params', args', code, fieldset)=  multOp("sumprod", sx, e1, x)
val body'= E.Sum(sx, E.Opn(E.Prod, e1'))
val einapp= cleanOrig(body', params', args', x)
in
(einapp, code, fieldset)
end

(* *************************************** Split ******************************** *)

(* split:var*ein_app-> (var*einap)*code
* split ein expression into smaller pieces
note we leave summation around probe exp
*)
fun split((y, einapp as DstIR.EINAPP(Ein.EIN{params, index, body}, args)), fieldset, flag) =let
val x = ((y, einapp), fieldset, flag)
val zero = (setEinZero y, [], fieldset)
val default = ((y, einapp), [], fieldset)
val sumIndex = ref []
fun error () = raise Fail("Poorly formed EIN operator: " ^ EinPP.expToString body)
fun rewrite b = (case b
of E.Const  _ => default
| E.ConstR _ => default
| E.Tensor _ => default
| E.Delta _ => default
| E.Epsilon _ => default
| E.Eps2 _ => default
| E.Field _ => raise Fail "should have been swept"
| E.Lift e => raise Fail "should have been swept"
| E.Conv _ => raise Fail "should have been swept"
| E.Partial _ => raise Fail "Partial used after normalize"
| E.Apply _ => raise Fail "should have been swept"
| E.Probe(E.Conv _, _) => default
| E.Probe(E.Field _, _) => error()
| E.Probe _ => error()
| E.Value _ => raise Fail "Value used before expand"
| E.Img _ => raise Fail "Probe used before expand"
| E.Krn _ => raise Fail "Krn used before expand"
| E.Sum(_, E.Probe(E.Conv _, _)) => default
| E.Sum(sx, E.Tensor _)    => default
| E.Sum(sx, E.Opn(E.Prod, e1)) => handleSumProd (e1, sx, x)
| E.Sum(sx, E.Delta d) => handleSumProd ([E.Delta d], sx, x)
| E.Sum(sx, _) => raise Fail "summation not distributed:"^str)
| E.Op1(op1, e1) => (case op1
of E.Neg => handleUnaryOp ("neg", op1, x, e1)
| E.Sqrt => handleUnaryOp ("sqrt", op1, x, e1)
| E.Exp => handleUnaryOp ("exp", op1, x, e1)
| E.PowInt n1 => handleUnaryOp ("PowInt", op1, x, e1)
| _ => handleUnaryOp ("Trig", op1, x, e1)
(*end case *))
| E.Op2(E.Sub, e1, e2) => handleBinaryOp ("subtract", E.Sub, x, [e1, e2])
| E.Op2(E.Div, e1, e2) => handleDiv (e1, e2, x)
| E.Opn(Prod, [E.Tensor(id0, []), E.Tensor(id1, [i]), E.Tensor(id2, [])]) =>
rewrite (E.Opn(E.Prod, [
E.Opn(E.Prod, [E.Tensor(id0, []), E.Tensor(id2, [])]), E.Tensor(id1, [i])
]))
| E.Opn(E.Prod, es) => handleMultOp("prod", E.Prod, x, es)
(* end case *))
val (einapp2, newbies, fieldset) = rewrite body
in
((einapp2, newbies), fieldset)
end
| split ((y, app), fieldset, _) = (((y, app), []), fieldset)

(* *************************************** main  ******************************** *)
fun limitSplit(einapp2, fields2, splitlimit) = let
val fieldset= einSet.EinSet.empty
val _ =print ("\nSPLit with limit"^(Int.toString(splitlimit)))
fun itercode([], rest, code, cnt) = (("\n Empty-SplitCount: "^Int.toString(cnt));(rest, code))
| itercode(e1::newbies, rest, code, cnt) = let
val ((einapp3, code3), _) = split(e1, fieldset, numFlag)
val (rest4, code4) = itercode(code3, [], [], cnt+1)
val _ =testp [toStringBind(e1), "\n\t===>\n", toStringBind(einapp3), "\nand\n", (String.concatWith", \n\t"(List.map toStringBind (code4@rest4)))]
in
if (length(rest@newbies@code) > splitlimit) then let
val _ =("\n SplitCount: "^Int.toString(cnt))
val code5 = code4@rest4@code
val rest5 = rest@[einapp3]
in
(rest5, code5@newbies)(*tab4*)
end
else  itercode(newbies, rest@[einapp3], code4@rest4@code, cnt+2)
end
val(rest, code)= itercode([einapp2], [], [], 0)
in

fields2@code@rest (*B*)
end

fun splitEinApp einapp0 =let
val fieldset= einSet.EinSet.empty
val einapp2 = [einapp0]
fun itercode([], rest, code, _) = (rest, code)
| itercode(e1::newbies, rest, code, cnt) = let
val ((einapp3, code3), _) = split(e1, fieldset, numFlag)
val (rest4, code4) = itercode(code3, [], [], cnt+1)
val _ =testp [toStringBind(e1), "\n\t===>\n", toStringBind(einapp3), "\nand\n", (String.concatWith", \n\t"(List.map toStringBind (code4@rest4)))]
in
itercode(newbies, rest@[einapp3], code4@( rest4)@code, cnt+2)
end
val(rest, code)= itercode(einapp2, [], [], 0)
in
(code@rest)
end

end; (* local *)

end (* local *)