Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/translate-cfexp.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/translate-cfexp.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5574 - (view) (download)

1 : jhr 5570 (* translate-cfexp.sml
2 :     *
3 :     * Translation for EIN Term that represents closed form expressions
4 :     *
5 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2016 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 : jhr 5572
11 : jhr 5570 structure TranslateCFExp : sig
12 :    
13 : jhr 5572 (* FIXME: add comment explaining function and arguments *)
14 : jhr 5570 val transform_CFExp : MidIR.var * Ein.ein * MidIR.var list
15 :     -> MidIR.var list * Ein.param_kind list * Ein.ein_exp
16 :    
17 :     end = struct
18 :    
19 :     structure IR = MidIR
20 :     structure V = IR.Var
21 :     structure Ty = MidTypes
22 :     structure E = Ein
23 :     structure IMap = IntRedBlackMap
24 :     structure ISet = IntRedBlackSet
25 :     structure SrcIR = HighIR
26 :     structure DstIR = MidIR
27 :    
28 :     val i2s = Int.toString
29 :     val shp2s = String.concatWithMap " " i2s
30 :    
31 :     fun paramToString (i, E.TEN(t, shp)) = concat["T", i2s i, "[", shp2s shp, "]"]
32 :     | paramToString (i, E.FLD d) = concat["F", i2s i, "(", i2s d, ")"]
33 :     | paramToString (i, E.KRN) = "H" ^ i2s i
34 :     | paramToString (i, E.IMG(d, shp)) = concat["V", i2s i, "(", i2s d, ")[", shp2s shp, "]"]
35 :    
36 : jhr 5574 (* The terms with a param_id in the mapp are replaced
37 :     * body - ein expression
38 :     * args - variable arguments
39 :     * dim - dimension
40 :     * SeqId_current - current sequential option
41 :     * mapp - map for replacements
42 :     *)
43 : jhr 5570 fun replace (body, dim, mapp) = let
44 :     (*rewriteTensor
45 :     * This a is a tensor that is treated like a field
46 :     * Replace tensor term with a new term (new id and space components)
47 :     * V => [V_0 , V_1, V_2]
48 :     * and with deltas to turn each component on and off
49 :     * V => V_0*Delta_0i + V_1*Delta_1i V_2+Delta_2i
50 :     *)
51 :     fun rewriteTensor (E.Tensor(tid, alpha)) = let
52 :     fun mkBase (alpha') = E.Tensor(tid, alpha')
53 :     fun mkPoly (alpha') = E.Poly(mkBase alpha', 1, [])
54 :     in
55 :     case alpha
56 :     of [] => mkPoly []
57 :     | [vx] => let
58 :     fun mkComponent cx = E.Opn(E.Prod, [mkPoly [E.C cx], E.Delta(E.C cx, vx)])
59 :     val polyTerms = List.tabulate(dim, (fn n => mkComponent n))
60 :     in
61 :     E.Opn(E.Add, polyTerms)
62 :     end
63 :     | _ => raise Fail "unhandled size"
64 :     (* end case *)
65 :     end
66 :     (*search body for tensor terms that are meant to be replaced*)
67 :     fun rewrite body = (case body
68 :     of E.Tensor(id, _) => if ISet.member(mapp, id) then rewriteTensor body else body
69 :     | E.Lift e => E.Lift(rewrite e)
70 :     | E.Sum(op1, e1) => E.Sum(op1, rewrite e1)
71 :     | E.Op1(E.PowInt n, e1) => let
72 : jhr 5572 val tmp = rewrite e1
73 :     in
74 : jhr 5574 EinUtil.iterPP (List.tabulate(n, fn _ => tmp))
75 : jhr 5572 end
76 : jhr 5570 | E.Op1(op1, e1) => E.Op1(op1, rewrite e1)
77 :     | E.Op2(op2, e1, e2) => E.Op2(op2, rewrite e1, rewrite e2)
78 :     | E.Opn(E.Prod, E.Opn(E.Add, ps)::es) => let
79 : jhr 5574 val ps = List.map (fn e1 => EinUtil.iterPP(e1::es)) ps
80 : jhr 5572 val body = E.Opn(E.Add, ps)
81 :     in
82 :     rewrite body
83 :     end
84 : jhr 5574 | E.Opn(E.Prod, ps) => EinUtil.iterPP(List.map rewrite ps)
85 :     | E.Opn(E.Add , ps) => EinUtil.iterAA(List.map rewrite ps)
86 : jhr 5570 | _ => body
87 :     (* end case*))
88 :     in
89 :     rewrite body
90 :     end
91 :    
92 :     (* Replace the arguments identified in cfexp-ids with the arguments in probe-ids
93 :     * params - EIN params
94 :     * e - EIN body
95 :     * args - vars
96 :     * SeqId - optional sequence index variable
97 :     * cfexp_ids - closed-form expression has ids
98 :     * probe_ids - field is probed at position with ids
99 :     * PROBE(CFEXP (cfexp_ids), probe_ids)
100 :     *)
101 :     fun polyArgs(params, e, args, cfexp_ids, probe_ids) = let
102 :     (* rewrites a single variable
103 : jhr 5572 * rewritement instances of arg at pid position with arg at idx position
104 :     *)
105 :     fun singleTF (pid, args, params, idx, e) = let
106 :     (* check if the current parameter is a sequence and get dimension*)
107 :     (* Note Dev branch supports sequence parameter*)
108 : jhr 5570 val dim = (case List.nth(params, idx)
109 :     of E.TEN (_, []) => 1
110 :     | E.TEN (_, [i]) => i
111 :     | p => raise Fail("unsupported argument type:"^paramToString(idx, p))
112 :     (* end case *))
113 : jhr 5572 (* variable arg, and param*)
114 :     val newArg = List.nth(args, idx)
115 :     val newParam = List.nth(params, idx)
116 :     val rwArg = List.nth(args, pid)
117 :     (* id keeps track of placement and puts it in mapp*)
118 :     fun findArg(_, es, newargs, [], newparams, mapp) =
119 :     (List.revAppend(newargs, es), List.rev newparams, mapp)
120 :     | findArg(id, e1::es, newargs, p1::ps, newparams, mapp) =
121 :     if (IR.Var.same(e1, rwArg))
122 :     then findArg(id+1, es, newArg::newargs, ps, newParam::newparams, ISet.add(mapp, id))
123 :     else findArg(id+1, es, e1::newargs, ps , p1::newparams, mapp)
124 : jhr 5570 val (args, params, mapp) = findArg(0, args, [], params, [], ISet.empty)
125 : jhr 5572 (* get dimension of vector that is being broken into components*)
126 : jhr 5570 val param_pos = List.nth(params, pid)
127 : jhr 5572 (* rewrite position tensor with deltas in body *)
128 : jhr 5570 val e = replace (e, dim, mapp)
129 : jhr 5572 in
130 :     (args, params, e)
131 :     end
132 :     (* iterate over all the input tensor variable expressions *)
133 : jhr 5570 fun iter ([], args, params, _, e) = (args, params, e)
134 :     | iter ((pid, E.T)::es, args, params, idx::idxs, e) = let
135 : jhr 5572 (* variable is treated as a tensor so a simple variable swap is sufficient *)
136 : jhr 5570 val args = List.take(args, pid)@[List.nth(args, idx)]@List.drop(args, pid+1)
137 :     in
138 :     iter(es, args, params, idxs, e)
139 :     end
140 :     | iter ((pid, E.F)::es, args, params, idx::idxs, e) = let
141 : jhr 5572 (* variable is treated as a field so it needs to be expanded into its components*)
142 :     val (args, params, e) = singleTF (pid, args, params, idx, e)
143 : jhr 5570 in
144 :     iter(es, args, params, idxs, e)
145 :     end
146 : jhr 5572 (* probe_id: start of position variables for probe operation *)
147 : jhr 5570 val (args, params, e) = iter(cfexp_ids, args, params, probe_ids, e)
148 :     in
149 :     (args, params, e)
150 :     end
151 :    
152 :     (* apply differentiation *)
153 :     fun rewriteDifferentiate body = (case body
154 :     of E.Apply (E.Partial [], e) => e
155 :     | E.Apply(E.Partial (d1::dx), e) => let
156 :     (* differentiate *)
157 :     val e = DerivativeEin.differentiate ([d1], e)
158 :     in
159 :     rewriteDifferentiate (E.Apply(E.Partial dx, e))
160 :     end
161 :     | E.Op1(op1, e1) => E.Op1(op1, rewriteDifferentiate e1)
162 :     | E.Op2(op2, e1, e2) => E.Op2(op2, rewriteDifferentiate e1, rewriteDifferentiate e2)
163 :     | E.Opn(opn, es) => E.Opn(opn, List.map rewriteDifferentiate es)
164 :     | _ => body
165 :     (* end case*))
166 :    
167 :     (* main function
168 :     * translate probe of cfexp to poly terms
169 :     *)
170 :     fun transform_CFExp (y, ein as Ein.EIN{body, index, params}, args) = let
171 :     val E.Probe(E.OField(E.CFExp cfexp_ids, e, E.Partial dx), expProbe) = body
172 :     val probe_ids = List.map (fn E.Tensor(tid, _) => tid) [expProbe]
173 :     (* Note that the Dev branch allows multi-probe which is why we use a list of ids here *)
174 : jhr 5572 (* check that the number of into parameters matches number of probed arguments*)
175 : jhr 5570 val n_pargs = length(cfexp_ids)
176 :     val n_probe = length(probe_ids)
177 : jhr 5572 val _ = if (n_pargs <> n_probe)
178 :     then raise Fail(concat[
179 :     "n_pargs:", Int.toString( n_pargs), "n_probe:", Int.toString(n_probe)
180 :     ])
181 :     else ()
182 : jhr 5570 (* replace polywrap args/params with probed position(s) args/params *)
183 :     val (args, params, e) = polyArgs(params, e, args, cfexp_ids, probe_ids)
184 :     (* normalize ein by cleaning it up and differntiating*)
185 :     val e = rewriteDifferentiate(E.Apply(E.Partial dx, e))
186 :     in
187 :     (args, params, e)
188 :     end
189 :    
190 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0