SCM Repository
[diderot] / branches / vis15 / src / compiler / high-to-mid / clean-params.sml |
View of /branches/vis15/src/compiler/high-to-mid/clean-params.sml
Parent Directory
|
Revision Log
Revision 3564 -
(download)
(annotate)
Sun Jan 10 17:21:18 2016 UTC (5 years, 1 month ago) by jhr
File size: 3858 byte(s)
Sun Jan 10 17:21:18 2016 UTC (5 years, 1 month ago) by jhr
File size: 3858 byte(s)
working on merge
(* clean-param.sml * * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) * * COPYRIGHT (c) 2016 The University of Chicago * All rights reserved. *) (* *cleanParam.sml cleans the parameters in an EIN expression. *Cleaning parameters is simple. *We keep track of all the paramids used in subexpression(getIdCount()), *remap the param ids(mkMap) *and choosing the mid-il args that are used, and then lastly rewrites the body. *) structure CleanParams : sig val clean : ? -> ? end = struct structure E = Ein structure DstIR = MidIR structure DstV = DstIR.Var structure IMap = IntRedBlackMap fun cnt (DstIR.V{useCnt, ...}) = !useCnt (* QUESTION: why +2? *) fun incUse (DstIR.V{useCnt, ...}) = (useCnt := !useCnt + 2) (*dictionary to lookup mapp*) fun lookupSingleIndex (e1, mapp, str) = (case IMap.find(mapp, e1) of SOME l => l | _ => raise Fail(str ^ Int.toString e1) (* end case *)) (* mkMapp:dict*params*var list ->dict*params*var list * countmapp dictionary keeps track of which ids have been used * mapp id the dictionary of the new ids *) fun mkMapp (countmapp, params, args) = let val n = length params val ix = List.tabulate(n, fn e => e) fun m ([], _, mapp, p, _, a, _) = (mapp, p, a) | m (i::ix, j, mapp, p, p1::params, a, a1::arg) = (case IMap.find(countmapp, i) of SOME _ => let val mapp2 = IMap.insert(mapp, i, j) in m (ix, j+1, mapp2, p@[p1], params, a@[a1], arg) end | _ => m (ix, j, mapp, p, params, a, arg) (* end case *)) | m (_, _, _, _, _, _, []) = raise Fail (String.concat[ "incorrect number of params:more params:", Int.toString n, " args:", Int.toString(length args) ]) | m (_, _, _, _, [], _, _) = raise Fail ("incorrect number of params:more args") in m (ix, 0, IMap.empty, [], params, [], args) end (*getIdCount: ein_exp ->dict *rewrite ids in exp using mapp *) fun getIdCount b = let fun rewrite (b, mapp) = (case b of E.Tensor(id, _) => IMap.insert(mapp, id, 1) | E.Conv(v, _, h, _) => IMap.insert(IMap.insert(mapp, h, 1), v, 1) | E.Probe(e1, e2) => rewrite (e2, rewrite (e1, mapp)) | E.Value _ => raise Fail "unexpected Value" | E.Img _ => raise Fail "unexpected Img" | E.Krn _ => raise Fail "unexpected Krn" | E.Sum(_, e1) => rewrite (e1, mapp) | E.Op1(_, e1) => rewrite (e1, mapp) | E.Op2(_, e1, e2) => rewrite (e2, rewrite (e1, mapp)) | E.Opn(_, es) => List.foldl rewrite mapp es | _ => mapp (* end case *)) in rewrite (b, IMap.empty) end (*rewriteParam:dict*ein_exp ->ein_exp *rewrite ids in exp using mapp *) fun rewriteParam (mapp, e) = let fun getId id = lookupSingleIndex(id,mapp,"Mapp doesn't have Param Id ") fun rewriteExp b = (case b of E.Tensor(id, alpha) => E.Tensor(getId id, alpha) | E.Probe(E.Conv(v, alpha, h, dx), t) => E.Probe(E.Conv(getId v, alpha,getId h,dx),rewriteExp t) | E.Sum(sx ,e1) => E.Sum(sx,rewriteExp e1) | E.Op1(op1, e1) => E.Op1(op1,rewriteExp e1) | E.Op2(op2, e1,e2) => E.Op2(op2,rewriteExp e1,rewriteExp e2) | E.Opn(opn, es) => E.Opn(opn,List.map rewriteExp es) (* end case *)) in rewriteExp e end (* cleanParams:var*ein_exp*param*index* var list ->code *cleans params *) fun clean (y, body, params, index, args) = let val countmapp = getIdCount body val (mapp, Nparams, Nargs) = mkMapp (countmapp, params, args) val () = List.app incUse Nargs val Nbody = rewriteParam (mapp, body) in (y, DstIR.EINAPP(Ein.EIN{params=Nparams, index=index, body=Nbody}, Nargs)) end end (* CleanParam *)
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |