Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/translate-cfexp.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/translate-cfexp.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5572 - (view) (download)

1 : jhr 5570 (* translate-cfexp.sml
2 :     *
3 :     * Translation for EIN Term that represents closed form expressions
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2016 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 : jhr 5572
11 : jhr 5570 structure TranslateCFExp : sig
12 :    
13 : jhr 5572 (* FIXME: add comment explaining function and arguments *)
14 : jhr 5570 val transform_CFExp : MidIR.var * Ein.ein * MidIR.var list
15 :     -> MidIR.var list * Ein.param_kind list * Ein.ein_exp
16 :    
17 :     end = struct
18 :    
19 :     structure IR = MidIR
20 :     structure V = IR.Var
21 :     structure Ty = MidTypes
22 :     structure E = Ein
23 :     structure IMap = IntRedBlackMap
24 :     structure ISet = IntRedBlackSet
25 :     structure SrcIR = HighIR
26 :     structure DstIR = MidIR
27 :    
28 :     val i2s = Int.toString
29 :     val shp2s = String.concatWithMap " " i2s
30 :    
31 :     fun paramToString (i, E.TEN(t, shp)) = concat["T", i2s i, "[", shp2s shp, "]"]
32 :     | paramToString (i, E.FLD d) = concat["F", i2s i, "(", i2s d, ")"]
33 :     | paramToString (i, E.KRN) = "H" ^ i2s i
34 :     | paramToString (i, E.IMG(d, shp)) = concat["V", i2s i, "(", i2s d, ")[", shp2s shp, "]"]
35 :    
36 :     fun iterP es = let
37 :     fun iterPP ([], [r]) = r
38 :     | iterPP ([], rest) = E.Opn(E.Prod, rest)
39 :     | iterPP (E.Const 0::es, rest) = E.Const(0)
40 :     | iterPP (E.Const 1::es, rest) = iterPP(es, rest)
41 :     | iterPP (E.Delta(E.C c1, E.V v1)::E.Delta(E.C c2, E.V v2)::es, rest) =
42 :     (* variable can't be 0 and 1 '*)
43 :     if (c1 = c2 orelse (not (v1 = v2)))
44 :     then iterPP (es, E.Delta(E.C c1, E.V v1)::E.Delta(E.C c2, E.V v2)::rest)
45 :     else E.Const(0)
46 :     | iterPP (E.Opn(E.Prod, ys)::es, rest) = iterPP(ys@es, rest)
47 :     | iterPP (e1::es, rest) = iterPP(es, e1::rest)
48 :     in
49 :     iterPP(es, [])
50 :     end
51 :    
52 :     fun iterA es = let
53 :     fun iterAA ([], []) = E.Const 0
54 :     | iterAA ([], [r]) = r
55 :     | iterAA ([], rest) = E.Opn(E.Add, rest)
56 :     | iterAA (E.Const 0::es, rest) = iterAA(es, rest)
57 :     | iterAA (E.Opn(E.Add, ys)::es, rest) = iterAA(ys@es, rest)
58 :     | iterAA (e1::es, rest) = iterAA(es, e1::rest)
59 :     in
60 :     iterAA(es, [])
61 :     end
62 :    
63 :     (* The terms with a param_id in the mapp are replaced
64 :     * body - ein expression
65 :     * args - variable arguments
66 :     * dim - dimension
67 :     * SeqId_current - current sequential option
68 :     * mapp - map for replacements
69 :     *)
70 :     fun replace (body, dim, mapp) = let
71 :     (*rewriteTensor
72 :     * This a is a tensor that is treated like a field
73 :     * Replace tensor term with a new term (new id and space components)
74 :     * V => [V_0 , V_1, V_2]
75 :     * and with deltas to turn each component on and off
76 :     * V => V_0*Delta_0i + V_1*Delta_1i V_2+Delta_2i
77 :     *)
78 :     fun rewriteTensor (E.Tensor(tid, alpha)) = let
79 :     fun mkBase (alpha') = E.Tensor(tid, alpha')
80 :     fun mkPoly (alpha') = E.Poly(mkBase alpha', 1, [])
81 :     in
82 :     case alpha
83 :     of [] => mkPoly []
84 :     | [vx] => let
85 :     fun mkComponent cx = E.Opn(E.Prod, [mkPoly [E.C cx], E.Delta(E.C cx, vx)])
86 :     val polyTerms = List.tabulate(dim, (fn n => mkComponent n))
87 :     in
88 :     E.Opn(E.Add, polyTerms)
89 :     end
90 :     | _ => raise Fail "unhandled size"
91 :     (* end case *)
92 :     end
93 :     (*search body for tensor terms that are meant to be replaced*)
94 :     fun rewrite body = (case body
95 :     of E.Tensor(id, _) => if ISet.member(mapp, id) then rewriteTensor body else body
96 :     | E.Lift e => E.Lift(rewrite e)
97 :     | E.Sum(op1, e1) => E.Sum(op1, rewrite e1)
98 :     | E.Op1(E.PowInt n, e1) => let
99 : jhr 5572 val tmp = rewrite e1
100 :     in
101 :     iterP (List.tabulate(n, fn _ => tmp))
102 :     end
103 : jhr 5570 | E.Op1(op1, e1) => E.Op1(op1, rewrite e1)
104 :     | E.Op2(op2, e1, e2) => E.Op2(op2, rewrite e1, rewrite e2)
105 :     | E.Opn(E.Prod, E.Opn(E.Add, ps)::es) => let
106 : jhr 5572 val ps = List.map (fn e1 => iterP(e1::es)) ps
107 :     val body = E.Opn(E.Add, ps)
108 :     in
109 :     rewrite body
110 :     end
111 : jhr 5570 | E.Opn(E.Prod, ps) => iterP(List.map rewrite ps)
112 :     | E.Opn(E.Add , ps) => iterA(List.map rewrite ps)
113 :     | _ => body
114 :     (* end case*))
115 :     in
116 :     rewrite body
117 :     end
118 :    
119 :     (* Replace the arguments identified in cfexp-ids with the arguments in probe-ids
120 :     * params - EIN params
121 :     * e - EIN body
122 :     * args - vars
123 :     * SeqId - optional sequence index variable
124 :     * cfexp_ids - closed-form expression has ids
125 :     * probe_ids - field is probed at position with ids
126 :     * PROBE(CFEXP (cfexp_ids), probe_ids)
127 :     *)
128 :     fun polyArgs(params, e, args, cfexp_ids, probe_ids) = let
129 :     (* rewrites a single variable
130 : jhr 5572 * rewritement instances of arg at pid position with arg at idx position
131 :     *)
132 :     fun singleTF (pid, args, params, idx, e) = let
133 :     (* check if the current parameter is a sequence and get dimension*)
134 :     (* Note Dev branch supports sequence parameter*)
135 : jhr 5570 val dim = (case List.nth(params, idx)
136 :     of E.TEN (_, []) => 1
137 :     | E.TEN (_, [i]) => i
138 :     | p => raise Fail("unsupported argument type:"^paramToString(idx, p))
139 :     (* end case *))
140 : jhr 5572 (* variable arg, and param*)
141 :     val newArg = List.nth(args, idx)
142 :     val newParam = List.nth(params, idx)
143 :     val rwArg = List.nth(args, pid)
144 :     (* id keeps track of placement and puts it in mapp*)
145 :     fun findArg(_, es, newargs, [], newparams, mapp) =
146 :     (List.revAppend(newargs, es), List.rev newparams, mapp)
147 :     | findArg(id, e1::es, newargs, p1::ps, newparams, mapp) =
148 :     if (IR.Var.same(e1, rwArg))
149 :     then findArg(id+1, es, newArg::newargs, ps, newParam::newparams, ISet.add(mapp, id))
150 :     else findArg(id+1, es, e1::newargs, ps , p1::newparams, mapp)
151 : jhr 5570 val (args, params, mapp) = findArg(0, args, [], params, [], ISet.empty)
152 : jhr 5572 (* get dimension of vector that is being broken into components*)
153 : jhr 5570 val param_pos = List.nth(params, pid)
154 : jhr 5572 (* rewrite position tensor with deltas in body *)
155 : jhr 5570 val e = replace (e, dim, mapp)
156 : jhr 5572 in
157 :     (args, params, e)
158 :     end
159 :     (* iterate over all the input tensor variable expressions *)
160 : jhr 5570 fun iter ([], args, params, _, e) = (args, params, e)
161 :     | iter ((pid, E.T)::es, args, params, idx::idxs, e) = let
162 : jhr 5572 (* variable is treated as a tensor so a simple variable swap is sufficient *)
163 : jhr 5570 val args = List.take(args, pid)@[List.nth(args, idx)]@List.drop(args, pid+1)
164 :     in
165 :     iter(es, args, params, idxs, e)
166 :     end
167 :     | iter ((pid, E.F)::es, args, params, idx::idxs, e) = let
168 : jhr 5572 (* variable is treated as a field so it needs to be expanded into its components*)
169 :     val (args, params, e) = singleTF (pid, args, params, idx, e)
170 : jhr 5570 in
171 :     iter(es, args, params, idxs, e)
172 :     end
173 : jhr 5572 (* probe_id: start of position variables for probe operation *)
174 : jhr 5570 val (args, params, e) = iter(cfexp_ids, args, params, probe_ids, e)
175 :     in
176 :     (args, params, e)
177 :     end
178 :    
179 :     (* apply differentiation *)
180 :     fun rewriteDifferentiate body = (case body
181 :     of E.Apply (E.Partial [], e) => e
182 :     | E.Apply(E.Partial (d1::dx), e) => let
183 :     (* differentiate *)
184 :     val e = DerivativeEin.differentiate ([d1], e)
185 :     in
186 :     rewriteDifferentiate (E.Apply(E.Partial dx, e))
187 :     end
188 :     | E.Op1(op1, e1) => E.Op1(op1, rewriteDifferentiate e1)
189 :     | E.Op2(op2, e1, e2) => E.Op2(op2, rewriteDifferentiate e1, rewriteDifferentiate e2)
190 :     | E.Opn(opn, es) => E.Opn(opn, List.map rewriteDifferentiate es)
191 :     | _ => body
192 :     (* end case*))
193 :    
194 :     (* main function
195 :     * translate probe of cfexp to poly terms
196 :     *)
197 :     fun transform_CFExp (y, ein as Ein.EIN{body, index, params}, args) = let
198 :     val E.Probe(E.OField(E.CFExp cfexp_ids, e, E.Partial dx), expProbe) = body
199 :     val probe_ids = List.map (fn E.Tensor(tid, _) => tid) [expProbe]
200 :     (* Note that the Dev branch allows multi-probe which is why we use a list of ids here *)
201 : jhr 5572 (* check that the number of into parameters matches number of probed arguments*)
202 : jhr 5570 val n_pargs = length(cfexp_ids)
203 :     val n_probe = length(probe_ids)
204 : jhr 5572 val _ = if (n_pargs <> n_probe)
205 :     then raise Fail(concat[
206 :     "n_pargs:", Int.toString( n_pargs), "n_probe:", Int.toString(n_probe)
207 :     ])
208 :     else ()
209 : jhr 5570 (* replace polywrap args/params with probed position(s) args/params *)
210 :     val (args, params, e) = polyArgs(params, e, args, cfexp_ids, probe_ids)
211 :     (* normalize ein by cleaning it up and differntiating*)
212 :     val e = rewriteDifferentiate(E.Apply(E.Partial dx, e))
213 :     in
214 :     (args, params, e)
215 :     end
216 :    
217 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0