Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/clean-params.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/clean-params.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3573, Mon Jan 11 18:30:36 2016 UTC revision 3574, Mon Jan 11 23:05:08 2016 UTC
# Line 24  Line 24 
24      structure DstIR = MidIR      structure DstIR = MidIR
25      structure DstV = DstIR.Var      structure DstV = DstIR.Var
26      structure IMap = IntRedBlackMap      structure IMap = IntRedBlackMap
27        structure ISet = IntRedBlackSet
     fun cnt (DstIR.V{useCnt, ...}) = !useCnt  
 (* QUESTION: why +2? *)  
     fun incUse (DstIR.V{useCnt, ...}) = (useCnt := !useCnt + 2)  
28    
29      (*dictionary to lookup mapp*)      (*dictionary to lookup mapp*)
30      fun lookupSingleIndex (e1, mapp, str) = (case IMap.find(mapp, e1)      fun lookupSingleIndex (e1, mapp, str) = (case IMap.find(mapp, e1)
# Line 39  Line 36 
36      * countmapp dictionary keeps track of which ids have been used      * countmapp dictionary keeps track of which ids have been used
37      * mapp id the dictionary of the new ids      * mapp id the dictionary of the new ids
38      *)      *)
39      fun mkMapp (countmapp, params, args) = let      fun mkMapp (freeParams, params, args) = let
40            val n = length params            fun m (_, _, mapp, p, [], a, []) = (mapp, rev p, rev a)
41            val ix = List.tabulate(n, fn e => e)              | m (i, j, mapp, p, p1::params, a, a1::arg) =
42            fun m ([], _, mapp, p, _, a, _) = (mapp, p, a)                  if ISet.member(freeParams, i)
43              | m (i::ix, j, mapp, p, p1::params, a, a1::arg) = (case IMap.find(countmapp, i)                    then let
                  of SOME _ => let  
44                        val mapp2 = IMap.insert(mapp, i, j)                        val mapp2 = IMap.insert(mapp, i, j)
45                         in                         in
46                           m (ix, j+1, mapp2, p@[p1], params, a@[a1], arg)                        m (i+1, j+1, mapp2, p1::p, params, a1::a, arg)
47                         end                         end
48                    | _ => m (ix, j, mapp, p, params, a, arg)                    else m (i+1, j, mapp, p, params, a, arg)
49                  (* end case *))              | m (_, _, _, _, _, _, []) = raise Fail "too many parameters"
50              | m (_, _, _, _, _, _, []) = raise Fail (String.concat[              | m (_, _, _, _, [], _, _) = raise Fail "too many args"
                   "incorrect number of params:more params:", Int.toString n,  
                   " args:", Int.toString(length args)  
                 ])  
             | m (_, _, _, _, [], _, _) = raise Fail ("incorrect number of params:more args")  
51            in            in
52              m (ix, 0, IMap.empty, [], params, [], args)              m (0, 0, IMap.empty, [], params, [], args)
53            end            end
54    
55      (*getIdCount: ein_exp ->dict    (* walk the ein expression and compute the set of free parameter indices (i.e., tensor,
56      *rewrite ids in exp using mapp     * image, and kernel variables) in the expression.
57      *)      *)
58      fun getIdCount b = let      fun getIdCount b = let
59            fun rewrite (b, mapp) = (case b            fun walk (b, mapp) = (case b
60                   of E.Tensor(id, _) => IMap.insert(mapp, id, 1)                   of E.Tensor(id, _) => ISet.add(mapp, id)
61                    | E.Conv(v, _, h, _) => IMap.insert(IMap.insert(mapp, h, 1), v, 1)                    | E.Conv(v, _, h, _) => ISet.add(ISet.add(mapp, h), v)
62                    | E.Probe(e1, e2) => rewrite (e2, rewrite (e1, mapp))                    | E.Probe(e1, e2) => walk (e2, walk (e1, mapp))
63                    | E.Value _ => raise Fail "unexpected Value"                    | E.Value _ => raise Fail "unexpected Value"
64                    | E.Img _ => raise Fail "unexpected Img"                    | E.Img _ => raise Fail "unexpected Img"
65                    | E.Krn _ => raise Fail "unexpected Krn"                    | E.Krn _ => raise Fail "unexpected Krn"
66                    | E.Sum(_, e1) => rewrite (e1, mapp)                    | E.Sum(_, e1) => walk (e1, mapp)
67                    | E.Op1(_, e1) => rewrite (e1, mapp)                    | E.Op1(_, e1) => walk (e1, mapp)
68                    | E.Op2(_, e1, e2) => rewrite (e2, rewrite (e1, mapp))                    | E.Op2(_, e1, e2) => walk (e2, walk (e1, mapp))
69                    | E.Opn(_, es) => List.foldl rewrite mapp es                    | E.Opn(_, es) => List.foldl walk mapp es
70                    | _ => mapp                    | _ => mapp
71                  (* end case *))                  (* end case *))
72            in            in
73              rewrite (b, IMap.empty)              walk (b, ISet.empty)
74            end            end
75    
76      (*rewriteParam:dict*ein_exp ->ein_exp      (*rewriteParam:dict*ein_exp ->ein_exp
# Line 86  Line 78 
78      *)      *)
79      fun rewriteParam (mapp, e) = let      fun rewriteParam (mapp, e) = let
80            fun getId id = lookupSingleIndex(id,mapp,"Mapp doesn't have Param Id ")            fun getId id = lookupSingleIndex(id,mapp,"Mapp doesn't have Param Id ")
81            fun rewriteExp b = (case b            fun rewrite b = (case b
82                   of E.Tensor(id, alpha) => E.Tensor(getId id, alpha)                   of E.Tensor(id, alpha) => E.Tensor(getId id, alpha)
83                    | E.Probe(E.Conv(v, alpha, h, dx), t) =>                    | E.Probe(E.Conv(v, alpha, h, dx), t) =>
84                        E.Probe(E.Conv(getId v, alpha,getId h,dx),rewriteExp t)                        E.Probe(E.Conv(getId v, alpha, getId h, dx), rewrite t)
85                    | E.Sum(sx ,e1) => E.Sum(sx,rewriteExp e1)                    | E.Sum(sx ,e1) => E.Sum(sx, rewrite e1)
86                    | E.Op1(op1, e1) => E.Op1(op1,rewriteExp e1)                    | E.Op1(op1, e1) => E.Op1(op1, rewrite e1)
87                    | E.Op2(op2, e1,e2) => E.Op2(op2,rewriteExp e1,rewriteExp e2)                    | E.Op2(op2, e1,e2) => E.Op2(op2, rewrite e1, rewrite e2)
88                    | E.Opn(opn, es) => E.Opn(opn,List.map rewriteExp es)                    | E.Opn(opn, es) => E.Opn(opn, List.map rewrite es)
89                    | _ => b                    | _ => b
90                  (* end case *))                  (* end case *))
91            in            in
92              rewriteExp e              rewrite e
93            end            end
94    
95      (* cleanParams:var*ein_exp*param*index* var list ->code      (* cleanParams:var*ein_exp*param*index* var list ->code
96      *cleans params      *cleans params
97      *)      *)
98      fun clean (y, body, params, index, args) = let      fun clean (y, body, params, index, args) = let
99            val countmapp = getIdCount body            val freeParams = getIdCount body
100            val (mapp, Nparams, Nargs) = mkMapp (countmapp, params, args)            val (mapp, Nparams, Nargs) = mkMapp (freeParams, params, args)
           val () = List.app incUse Nargs  
101            val Nbody = rewriteParam (mapp, body)            val Nbody = rewriteParam (mapp, body)
102            in            in
103              (y, DstIR.EINAPP(Ein.EIN{params=Nparams, index=index, body=Nbody}, Nargs))              (y, DstIR.EINAPP(Ein.EIN{params=Nparams, index=index, body=Nbody}, Nargs))

Legend:
Removed from v.3573  
changed lines
  Added in v.3574

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0