SCM Repository
[diderot] / branches / charisee_dev / src / compiler / high-to-mid / float-ein.sml |
View of /branches/charisee_dev/src/compiler/high-to-mid/float-ein.sml
Parent Directory
|
Revision Log
Revision 3680 -
(download)
(annotate)
Thu Feb 18 20:07:38 2016 UTC (6 years, 3 months ago) by cchiw
File size: 6457 byte(s)
Thu Feb 18 20:07:38 2016 UTC (6 years, 3 months ago) by cchiw
File size: 6457 byte(s)
add avail
(* float-ein.sml * * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) * * COPYRIGHT (c) 2016 The University of Chicago * All rights reserved. *) structure FloatEin : sig val transform : MidIL.var * (*Ein.ein * MidIL.var list*) MidIL.rhs -> MidIL.assignment list end = struct structure IR = MidIL structure V = IR.Var structure Ty = MidILTypes structure E = Ein structure cleanP=cleanParams structure cleanI=cleanIndex structure AvailRHS=AvailRHSCnt fun mkEin e=Ein.mkEin e fun mkEinApp(rator,args)=IR.EINAPP(rator,args) fun cleanParams e = cleanP.cleanParams e fun cleanIndex e = cleanI.cleanIndex e fun cut (name, origProbe, params, index, sx, argsOrig, avail, newvx) = let (*clean and rewrite current body*) val (tshape, sizes, body) = cleanIndex.cleanIndex(origProbe, index, sx) val id = length params val Rparams = params@[E.TEN(1, sizes)] val M = V.new (concat[name, "_l_", Int.toString id], Ty.TensorTy sizes) val (y, IR.EINAPP(ein, args)) = cleanParams(M, body, Rparams, sizes, argsOrig@[M]) (* shift indices in probe body from constant to variable *) val Ein.EIN{ body=E.Probe(E.Conv(V, [c1], h, dx), pos), index = index0, params = params0 } = ein (* FIXME: this code is specialized to 3D *) val index1 = index0@[3] val unshiftedBody = E.Probe(E.Conv(V, [E.V newvx], h, dx), pos) (* clean to get body indices in order *) val (_ , _, body1) = cleanIndex(unshiftedBody, index1, []) val lhs1 = V.new ("L", Ty.TensorTy index1) val ein1 = mkEin(params0, index1, body1) val lhs2 = AvailRHS.addAssign avail (lhs1, mkEinApp(ein1, args)) val Rargs = argsOrig @ [lhs2] (*Probe that tensor at a constant position c1*) val Re = E.Tensor(id, c1 :: tshape) val Rparams = params @ [E.TEN(1, index1)] in (Re, Rparams, Rargs) end (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code) *lifts expression and returns replacement tensor * cleans the index and params of subexpression *creates new param and replacement tensor for the original ein_exp *) fun lift (name, e, params, index, sx, args, avail) = let val (tshape, sizes, body) = cleanIndex(e, index, sx) val id = length params val Rparams = params @ [E.TEN(1, sizes)] val Re = E.Tensor(id, tshape) val M = V.new (concat[name, "_l_", Int.toString id], Ty.TensorTy sizes) val (_, einapp) = cleanParams.cleanParams(M, body, Rparams, sizes, args @ [M]) val var = AvailRHS.addAssign avail (M, einapp) val Rargs = args @ [var] in (Re, Rparams, Rargs) end fun isOp e = (case e of E.Op1 _ => true | E.Op2 _ => true | E.Opn _ => true | E.Sum _ => true | E.Probe _ => true | _ => false (* end case *)) fun transform (y, IR.EINAPP(ein as Ein.EIN{body=E.Probe _, ...}, args)) = [IR.ASSGN(y, IR.EINAPP(ein, args))] | transform (y, IR.EINAPP(ein as Ein.EIN{body=E.Sum(_, E.Probe _), ...}, args)) = [IR.ASSGN(y, IR.EINAPP(ein, args))] | transform (y, IR.EINAPP(Ein.EIN{params, index, body}, args)) = let val avail = AvailRHS.new() fun filterOps (es, params, args, index, sx) = let fun filter ([], es', params, args) = (rev es', params, args) | filter (e::es, es', params, args) = if isOp e then let val (e', params', args') = lift("op1_e3", e, params, index, sx, args, avail) in filter (es, e'::es', params', args') end else filter (es, e::es', params, args) in filter (es, [], params, args) end fun rewrite (sx, exp, params, args) = (case exp of E.Probe(E.Conv(_, [E.C _], _, []), _) => cut ("cut", exp, params, index, sx, args, avail, 0) | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0]), _) => cut ("cut", exp, params, index, sx, args, avail, 1) | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0, E.V 1]), _) => cut ("cut", exp, params, index, sx, args, avail, 2) | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0, E.V 1, E.V 2]), _) => cut ("cut", exp, params, index, sx, args, avail, 3) | E.Probe _ => lift ("probe", exp, params, index, sx, args, avail) | E.Sum(_, E.Probe _) => lift ("probe", exp, params, index, sx, args, avail) (*keep eps *) (* | E.Opn(E.Prod,[E.G _, E.Probe _]) => lift ("probe", exp, params, index, sx, args, avail) | E.Sum(_, E.Opn(E.Prod,[E.G _, E.Probe _])) => lift ("probe", exp, params, index, sx, args, avail) *) (* skipping lift then replace probe in place E.Probe _ => let val (params',body',args',code) = ProbeEin.replaceProbeF(params,index,sx,exp,args) val _ =List.map (fn e=> AvailRHS.addAssignNoSearch avail e) code in (body',params',args') end *) | E.Op1(op1, e1) => let val (e1', params', args') = rewrite (sx, e1, params, args) val ([e1], params', args') = filterOps ([e1'], params', args', index, sx) in (E.Op1(op1, e1), params', args') end | E.Op2(op2, e1, e2) => let val (e1', params', args') = rewrite (sx, e1, params, args) val (e2', params', args') = rewrite (sx, e2, params', args') val ([e1', e2'], params', args') = filterOps ([e1', e2'], params', args', index, sx) in (E.Op2(op2, e1', e2'), params', args') end | E.Opn(opn, es) => let fun iter ([], es, params, args) = (List.rev es, params, args) | iter (e::es, es', params, args) = let val (e', params', args') = rewrite (sx, e, params, args) in iter (es, e'::es', params', args') end val (es, params, args) = iter (es, [], params, args) val (es, params, args) = filterOps (es, params, args, index, sx) in (E.Opn(opn, es), params, args) end | E.Sum(sx1, e) => let val (e', params', args') = rewrite (sx1@sx, e, params, args) in (E.Sum(sx1, e'), params', args') end | _ => (exp, params, args) (* end case *)) val (body', params', args') = rewrite ([], body, params, args) val einapp = cleanParams.cleanParams (y, body', params', index, args') val c=IR.ASSGN einapp (*count number of available terms*) val n=AvailRHS.getCount avail val _ =if(n>1) then print("\ncnt"^Int.toString(n)) else print"" in List.rev (c :: AvailRHS.getAssignments avail) end end
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |