Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/translate-cfexp.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/translate-cfexp.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5570 - (view) (download)

1 : jhr 5570 (* translate-cfexp.sml
2 :     *
3 :     * Translation for EIN Term that represents closed form expressions
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2016 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :     structure TranslateCFExp : sig
11 :    
12 :     val transform_CFExp : MidIR.var * Ein.ein * MidIR.var list
13 :     -> MidIR.var list * Ein.param_kind list * Ein.ein_exp
14 :    
15 :     end = struct
16 :    
17 :     structure IR = MidIR
18 :     structure V = IR.Var
19 :     structure Ty = MidTypes
20 :     structure E = Ein
21 :     structure IMap = IntRedBlackMap
22 :     structure ISet = IntRedBlackSet
23 :     structure SrcIR = HighIR
24 :     structure DstIR = MidIR
25 :    
26 :     val i2s = Int.toString
27 :     val shp2s = String.concatWithMap " " i2s
28 :    
29 :     fun paramToString (i, E.TEN(t, shp)) = concat["T", i2s i, "[", shp2s shp, "]"]
30 :     | paramToString (i, E.FLD d) = concat["F", i2s i, "(", i2s d, ")"]
31 :     | paramToString (i, E.KRN) = "H" ^ i2s i
32 :     | paramToString (i, E.IMG(d, shp)) = concat["V", i2s i, "(", i2s d, ")[", shp2s shp, "]"]
33 :    
34 :     fun iterP es = let
35 :     fun iterPP ([], [r]) = r
36 :     | iterPP ([], rest) = E.Opn(E.Prod, rest)
37 :     | iterPP (E.Const 0::es, rest) = E.Const(0)
38 :     | iterPP (E.Const 1::es, rest) = iterPP(es, rest)
39 :     | iterPP (E.Delta(E.C c1, E.V v1)::E.Delta(E.C c2, E.V v2)::es, rest) =
40 :     (* variable can't be 0 and 1 '*)
41 :     if (c1 = c2 orelse (not (v1 = v2)))
42 :     then iterPP (es, E.Delta(E.C c1, E.V v1)::E.Delta(E.C c2, E.V v2)::rest)
43 :     else E.Const(0)
44 :     | iterPP (E.Opn(E.Prod, ys)::es, rest) = iterPP(ys@es, rest)
45 :     | iterPP (e1::es, rest) = iterPP(es, e1::rest)
46 :     in
47 :     iterPP(es, [])
48 :     end
49 :    
50 :     fun iterA es = let
51 :     fun iterAA ([], []) = E.Const 0
52 :     | iterAA ([], [r]) = r
53 :     | iterAA ([], rest) = E.Opn(E.Add, rest)
54 :     | iterAA (E.Const 0::es, rest) = iterAA(es, rest)
55 :     | iterAA (E.Opn(E.Add, ys)::es, rest) = iterAA(ys@es, rest)
56 :     | iterAA (e1::es, rest) = iterAA(es, e1::rest)
57 :     in
58 :     iterAA(es, [])
59 :     end
60 :    
61 :     (* The terms with a param_id in the mapp are replaced
62 :     * body - ein expression
63 :     * args - variable arguments
64 :     * dim - dimension
65 :     * SeqId_current - current sequential option
66 :     * mapp - map for replacements
67 :     *)
68 :     fun replace (body, dim, mapp) = let
69 :     (*rewriteTensor
70 :     * This a is a tensor that is treated like a field
71 :     * Replace tensor term with a new term (new id and space components)
72 :     * V => [V_0 , V_1, V_2]
73 :     * and with deltas to turn each component on and off
74 :     * V => V_0*Delta_0i + V_1*Delta_1i V_2+Delta_2i
75 :     *)
76 :     fun rewriteTensor (E.Tensor(tid, alpha)) = let
77 :     fun mkBase (alpha') = E.Tensor(tid, alpha')
78 :     fun mkPoly (alpha') = E.Poly(mkBase alpha', 1, [])
79 :     in
80 :     case alpha
81 :     of [] => mkPoly []
82 :     | [vx] => let
83 :     fun mkComponent cx = E.Opn(E.Prod, [mkPoly [E.C cx], E.Delta(E.C cx, vx)])
84 :     val polyTerms = List.tabulate(dim, (fn n => mkComponent n))
85 :     in
86 :     E.Opn(E.Add, polyTerms)
87 :     end
88 :     | _ => raise Fail "unhandled size"
89 :     (* end case *)
90 :     end
91 :     (*search body for tensor terms that are meant to be replaced*)
92 :     fun rewrite body = (case body
93 :     of E.Tensor(id, _) => if ISet.member(mapp, id) then rewriteTensor body else body
94 :     | E.Lift e => E.Lift(rewrite e)
95 :     | E.Sum(op1, e1) => E.Sum(op1, rewrite e1)
96 :     | E.Op1(E.PowInt n, e1) => let
97 :     val tmp = rewrite e1
98 :     in
99 :     iterP (List.tabulate(n, fn _ => tmp))
100 :     end
101 :     | E.Op1(op1, e1) => E.Op1(op1, rewrite e1)
102 :     | E.Op2(op2, e1, e2) => E.Op2(op2, rewrite e1, rewrite e2)
103 :     | E.Opn(E.Prod, E.Opn(E.Add, ps)::es) => let
104 :     val ps = List.map (fn e1 => iterP(e1::es)) ps
105 :     val body = E.Opn(E.Add, ps)
106 :     in
107 :     rewrite body
108 :     end
109 :     | E.Opn(E.Prod, ps) => iterP(List.map rewrite ps)
110 :     | E.Opn(E.Add , ps) => iterA(List.map rewrite ps)
111 :     | _ => body
112 :     (* end case*))
113 :     in
114 :     rewrite body
115 :     end
116 :    
117 :     (* Replace the arguments identified in cfexp-ids with the arguments in probe-ids
118 :     * params - EIN params
119 :     * e - EIN body
120 :     * args - vars
121 :     * SeqId - optional sequence index variable
122 :     * cfexp_ids - closed-form expression has ids
123 :     * probe_ids - field is probed at position with ids
124 :     * PROBE(CFEXP (cfexp_ids), probe_ids)
125 :     *)
126 :     fun polyArgs(params, e, args, cfexp_ids, probe_ids) = let
127 :     (* rewrites a single variable
128 :     * rewritement instances of arg at pid position with arg at idx position
129 :     *)
130 :     fun single_TF (pid, args, params, idx, e) = let
131 :     (*check if the current parameter is a sequence and get dimension*)
132 :     (*Note Dev branch supports sequence parameter*)
133 :     val dim = (case List.nth(params, idx)
134 :     of E.TEN (_, []) => 1
135 :     | E.TEN (_, [i]) => i
136 :     | p => raise Fail("unsupported argument type:"^paramToString(idx, p))
137 :     (* end case *))
138 :     (*variable arg, and param*)
139 :     val arg_new = List.nth(args, idx)
140 :     val param_new = List.nth(params, idx)
141 :     val arg_rewrited = List.nth(args, pid)
142 :     (*id keeps track of placement and puts it in mapp*)
143 :     fun findArg(_, es, newargs, [], newparams, mapp) = ((List.rev newargs)@es, List.rev newparams, mapp)
144 :     | findArg(id, e1::es, newargs, p1::ps, newparams, mapp) =
145 :     if(IR.Var.same(e1, arg_rewrited))
146 :     then findArg(id+1, es, arg_new::newargs, ps, param_new::newparams, ISet.add(mapp, id))
147 :     else findArg(id+1, es, e1::newargs, ps , p1::newparams, mapp)
148 :     val (args, params, mapp) = findArg(0, args, [], params, [], ISet.empty)
149 :     (* get dimension of vector that is being broken into components*)
150 :     val param_pos = List.nth(params, pid)
151 :     (* rewrite position tensor with deltas in body *)
152 :     val e = replace (e, dim, mapp)
153 :     in (args, params, e) end
154 :     (*iterate over all the input tensor variable expressions *)
155 :     fun iter ([], args, params, _, e) = (args, params, e)
156 :     | iter ((pid, E.T)::es, args, params, idx::idxs, e) = let
157 :     (*variable is treated as a tensor so a simple variable swap is sufficient *)
158 :     val args = List.take(args, pid)@[List.nth(args, idx)]@List.drop(args, pid+1)
159 :     in
160 :     iter(es, args, params, idxs, e)
161 :     end
162 :     | iter ((pid, E.F)::es, args, params, idx::idxs, e) = let
163 :     (*variable is treated as a field so it needs to be expanded into its components*)
164 :     val (args, params, e) = single_TF (pid, args, params, idx, e)
165 :     in
166 :     iter(es, args, params, idxs, e)
167 :     end
168 :     (*probe_id: start of position variables for probe operation *)
169 :     val (args, params, e) = iter(cfexp_ids, args, params, probe_ids, e)
170 :     in
171 :     (args, params, e)
172 :     end
173 :    
174 :     (* apply differentiation *)
175 :     fun rewriteDifferentiate body = (case body
176 :     of E.Apply (E.Partial [], e) => e
177 :     | E.Apply(E.Partial (d1::dx), e) => let
178 :     (* differentiate *)
179 :     val e = DerivativeEin.differentiate ([d1], e)
180 :     in
181 :     rewriteDifferentiate (E.Apply(E.Partial dx, e))
182 :     end
183 :     | E.Op1(op1, e1) => E.Op1(op1, rewriteDifferentiate e1)
184 :     | E.Op2(op2, e1, e2) => E.Op2(op2, rewriteDifferentiate e1, rewriteDifferentiate e2)
185 :     | E.Opn(opn, es) => E.Opn(opn, List.map rewriteDifferentiate es)
186 :     | _ => body
187 :     (* end case*))
188 :    
189 :     (* main function
190 :     * translate probe of cfexp to poly terms
191 :     *)
192 :     fun transform_CFExp (y, ein as Ein.EIN{body, index, params}, args) = let
193 :     val E.Probe(E.OField(E.CFExp cfexp_ids, e, E.Partial dx), expProbe) = body
194 :     val probe_ids = List.map (fn E.Tensor(tid, _) => tid) [expProbe]
195 :     (* Note that the Dev branch allows multi-probe which is why we use a list of ids here *)
196 :     (*check that the number of into parameters matches number of probed arguments*)
197 :     val n_pargs = length(cfexp_ids)
198 :     val n_probe = length(probe_ids)
199 :     val _ = if(not(n_pargs = n_probe))
200 :     then raise Fail(concat[" n_pargs:", Int.toString( n_pargs), "n_probe:", Int.toString(n_probe)])
201 :     else 1
202 :     (* replace polywrap args/params with probed position(s) args/params *)
203 :     val (args, params, e) = polyArgs(params, e, args, cfexp_ids, probe_ids)
204 :     (* normalize ein by cleaning it up and differntiating*)
205 :     val e = rewriteDifferentiate(E.Apply(E.Partial dx, e))
206 :     in
207 :     (args, params, e)
208 :     end
209 :    
210 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0