Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/clean-params.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/clean-params.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5570 - (view) (download)

1 : jhr 3561 (* clean-param.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     (*
10 :     *cleanParam.sml cleans the parameters in an EIN expression.
11 :     *Cleaning parameters is simple.
12 :     *We keep track of all the paramids used in subexpression(getIdCount()),
13 :     *remap the param ids(mkMap)
14 :     *and choosing the mid-il args that are used, and then lastly rewrites the body.
15 :     *)
16 :     structure CleanParams : sig
17 :    
18 : cchiw 3978 val clean : Ein.ein_exp * Ein.param_kind list * Ein.index_bind list * MidIR.var list -> MidIR.rhs
19 : jhr 3561
20 :     end = struct
21 :    
22 :     structure E = Ein
23 : jhr 3562 structure DstIR = MidIR
24 :     structure DstV = DstIR.Var
25 : jhr 3564 structure IMap = IntRedBlackMap
26 : jhr 3574 structure ISet = IntRedBlackSet
27 : jhr 3561
28 :     (*dictionary to lookup mapp*)
29 : jhr 3564 fun lookupSingleIndex (e1, mapp, str) = (case IMap.find(mapp, e1)
30 :     of SOME l => l
31 :     | _ => raise Fail(str ^ Int.toString e1)
32 : jhr 4317 (* end case *))
33 : jhr 3561
34 : jhr 3575 (* walk the ein expression and compute the set of free parameter indices (i.e., tensor,
35 :     * image, and kernel variables) in the expression.
36 :     *)
37 :     fun getFreeParams b = let
38 :     fun walk (b, mapp) = (case b
39 : jhr 4317 of E.Tensor(id, _) => ISet.add(mapp, id)
40 :     | E.Conv(v, _, h, _) => ISet.add(ISet.add(mapp, h), v)
41 :     | E.Probe(e1, e2) => walk (e2, walk (e1, mapp))
42 : jhr 5570 | E.OField(E.CFExp es, e2, dx) => let
43 :     val es = List.map (fn (id, _) => E.Tensor(id, [])) es
44 :     in
45 :     walk(dx, walk (e2, List.foldl walk mapp es))
46 :     end
47 : jhr 4317 | E.Value _ => raise Fail "unexpected Value"
48 :     | E.Img _ => raise Fail "unexpected Img"
49 :     | E.Krn _ => raise Fail "unexpected Krn"
50 :     | E.Sum(_, e1) => walk (e1, mapp)
51 :     | E.Op1(_, e1) => walk (e1, mapp)
52 :     | E.Op2(_, e1, e2) => walk (e2, walk (e1, mapp))
53 : cchiw 5241 | E.Op3(_, e1, e2, e3) => walk(e3, walk(e2, walk(e1, mapp)))
54 : jhr 4317 | E.Opn(_, es) => List.foldl walk mapp es
55 :     | _ => mapp
56 : jhr 3575 (* end case *))
57 :     in
58 :     walk (b, ISet.empty)
59 :     end
60 :    
61 : jhr 3561 (* mkMapp:dict*params*var list ->dict*params*var list
62 :     * countmapp dictionary keeps track of which ids have been used
63 :     * mapp id the dictionary of the new ids
64 :     *)
65 : jhr 3574 fun mkMapp (freeParams, params, args) = let
66 : jhr 4317 fun m (_, _, mapp, p, [], a, []) = (mapp, rev p, rev a)
67 :     | m (i, j, mapp, p, p1::params, a, a1::arg) =
68 :     if ISet.member(freeParams, i)
69 :     then let
70 :     val mapp2 = IMap.insert(mapp, i, j)
71 :     in
72 :     m (i+1, j+1, mapp2, p1::p, params, a1::a, arg)
73 :     end
74 :     else m (i+1, j, mapp, p, params, a, arg)
75 :     | m (_, _, _, _, _, _, []) = raise Fail "too many parameters"
76 :     | m (_, _, _, _, [], _, _) = raise Fail "too many args"
77 :     in
78 :     m (0, 0, IMap.empty, [], params, [], args)
79 : jhr 5570 end
80 : jhr 3561
81 :     (*rewriteParam:dict*ein_exp ->ein_exp
82 :     *rewrite ids in exp using mapp
83 :     *)
84 : jhr 3562 fun rewriteParam (mapp, e) = let
85 : jhr 5570 fun getId id = lookupSingleIndex(id, mapp, "Mapp doesn't have Param Id ")
86 : jhr 3574 fun rewrite b = (case b
87 : jhr 3562 of E.Tensor(id, alpha) => E.Tensor(getId id, alpha)
88 : jhr 5570 | E.Conv(v, alpha, h, dx) =>
89 :     E.Conv(getId v, alpha, getId h, dx)
90 :     | E.Probe(f, t) => E.Probe(rewrite f, rewrite t)
91 :     | E.OField(E.CFExp es, e2, dx) => E.OField(
92 :     E.CFExp(List.map (fn (id, inputTy) => (getId id, inputTy)) es),
93 :     rewrite e2,
94 :     rewrite dx)
95 : jhr 4317 | E.Sum(sx ,e1) => E.Sum(sx, rewrite e1)
96 :     | E.Op1(op1, e1) => E.Op1(op1, rewrite e1)
97 :     | E.Op2(op2, e1,e2) => E.Op2(op2, rewrite e1, rewrite e2)
98 : cchiw 5241 | E.Op3(op3, e1, e2, e3) => E.Op3(op3, rewrite e1, rewrite e2, rewrite e3)
99 : jhr 4317 | E.Opn(opn, es) => E.Opn(opn, List.map rewrite es)
100 :     | _ => b
101 :     (* end case *))
102 : jhr 3562 in
103 : jhr 5570 rewrite e
104 : jhr 3562 end
105 : jhr 3561
106 :     (* cleanParams:var*ein_exp*param*index* var list ->code
107 :     *cleans params
108 :     *)
109 : jhr 3576 fun clean (body, params, index, args) = let
110 : jhr 4317 val freeParams = getFreeParams body
111 :     val (mapp, Nparams, Nargs) = mkMapp (freeParams, params, args)
112 :     val Nbody = rewriteParam (mapp, body)
113 :     in
114 : jhr 3576 DstIR.EINAPP(Ein.EIN{params=Nparams, index=index, body=Nbody}, Nargs)
115 : jhr 3562 end
116 : jhr 3561
117 :     end (* CleanParam *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0