Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3743 - (view) (download)

1 : jhr 3626 (* ein-to-low.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     (*
10 :     * genfn-Does preliminary scan of the body of EIN.EIN for vectorization potential
11 :     * If there is a field then passes to FieldToLow
12 :     * If there is a tensor then passes to handle*() functions to check if indices match
13 :     * i.e. <A_ij+B_ij>_ij vs.<A_ji+B_ij>_ij
14 :     *
15 : jhr 3666 * (1) If indices match then passes to Iter->ToVec functions.
16 : jhr 3632 * Creates LowIR vector operators.
17 : jhr 3626 * (2) Iter->ScaToLow
18 :     * Creates Low-IL scalar operators
19 : jhr 3632 * Note. The Iter function creates LowIR.CONS and therefore binds the indices in the EIN.body
20 : jhr 3626 *)
21 : jhr 3627
22 : jhr 3626 structure EinToLow : sig
23 :    
24 : cchiw 3741 val expand : LowIR.var * Ein.ein * LowIR.var list -> LowIR.assignment list
25 : jhr 3627
26 : jhr 3626 end = struct
27 :    
28 : jhr 3632 structure Var = LowIR.Var
29 : jhr 3626 structure E = Ein
30 :     structure Op = LowOps
31 : jhr 3648 structure Mk = MkLowIR
32 : jhr 3666 structure ToVec = EinToVector
33 : jhr 3710 structure IMap = IntRedBlackMap
34 : jhr 3626
35 : jhr 3627 (* `dropIndex alpha` returns the (len, i, alpha') where
36 :     * len = length(alpha') and alpha = alpha'@[i].
37 :     *)
38 : jhr 3626 fun dropIndex alpha = let
39 : jhr 3627 fun drop ([], _, _) = raise Fail "dropIndex[]"
40 :     | drop ([idx], n, idxs') = (n, idx, List.rev idxs')
41 :     | drop (idx::idxs, n, idx') = drop (idxs, n+1, idx::idx')
42 :     in
43 :     drop (alpha, 0, [])
44 :     end
45 : jhr 3626
46 :     (*matchLast:E.alpha*int -> (E.alpha) Option
47 :     * Is the last index of alpha E.V n.
48 : jhr 3627 * If so, return the rest of the list
49 : jhr 3626 *)
50 : jhr 3627 fun matchLast (alpha, n) = (case List.rev alpha
51 :     of (E.V v)::es => if (n = v) then SOME(List.rev es) else NONE
52 : jhr 3626 | _ => NONE
53 : jhr 3627 (* end case *))
54 : jhr 3626
55 :     (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
56 :     * Is the last index of alpha = n.
57 :     * is n anywhere else?
58 :     *)
59 :     fun matchFindLast (alpha, n) = let
60 : jhr 3628 fun find es = List.find (fn (E.V idx') => (n = idx') | _ => false) es
61 :     in
62 :     case List.rev alpha
63 :     of (E.V v)::es => if (n = v)
64 :     then (SOME(List.rev es), find es)
65 :     else (NONE, find es)
66 :     | _::es => (NONE, find es)
67 :     | [] => (NONE, NONE)
68 :     (* end case *)
69 :     end
70 : jhr 3626
71 : jhr 3646 (* unroll the body of an Ein expression. The arguments are
72 :     * shape -- the shape of the tensor computed by the expression
73 :     * index -- the shape of the iteration structure
74 :     * bodyFn -- the function for generating the body
75 :     *)
76 : jhr 3710 fun unroll (shape, index, bodyFn) = let
77 :     val avail = AvailRHS.new()
78 :     fun bodyFn' (mapp, n, m) = bodyFn (avail, IMap.insert(mapp, n, m))
79 :     fun iter (mapp, xs, ys, shape, n, zs) = (case (xs, ys)
80 :     of ([], []) => bodyFn (avail, mapp)
81 :     | ([x], []) =>
82 :     Mk.cons (avail, shape, List.rev (bodyFn' (mapp, n, x) :: zs))
83 :     | (x::xr, []) =>
84 :     iter (mapp, xr, [], shape, n, bodyFn' (mapp, n, x) :: zs)
85 :     | ([], [y]) =>
86 :     iter (mapp, List.tabulate (y, Fn.id), [], shape, n, zs)
87 :     | ([], y::yr) => let
88 :     val _ :: shape' = shape
89 :     val n' = n + 1
90 :     fun lp (i, ws) = if (i < y)
91 :     then let
92 :     val w = iter (IMap.insert (mapp, n, i), [], yr, shape', n+1, [])
93 :     in
94 :     lp (i+1, w::ws)
95 :     end
96 :     else Mk.cons (avail, shape, List.rev ws)
97 :     in
98 :     lp (0, [])
99 :     end
100 :     | _ => raise Fail "unroll: shape is larger than index"
101 :     (* end case *))
102 : jhr 3646 in
103 : jhr 3710 ignore (iter (IMap.empty, [], index, shape, 0, []));
104 :     avail
105 : jhr 3646 end
106 : jhr 3626
107 : jhr 3646 (* in the general case, we expand the body to scalar code *)
108 : cchiw 3741 fun scalarExpand (params, body, index, lowArgs) =
109 : jhr 3710 unroll (
110 :     index, index,
111 : jhr 3728 fn (avail, mapp) => EinToScalar.expand {
112 :     (* FIXME: do we need the params? *)
113 : cchiw 3741 avail=avail, mapp=mapp, params=params, body=body, lowArgs=lowArgs
114 : jhr 3728 })
115 : jhr 3646
116 : jhr 3666 fun createP (args, vecIndex, id, ix) =
117 :     ToVec.Param{id = id, arg = List.nth(args, id), ix = ix, kind = ToVec.Proj vecIndex}
118 :     fun createI (args, id, ix) =
119 :     ToVec.Param{id = id, arg = List.nth(args, id), ix = ix, kind = ToVec.Indx}
120 : jhr 3648
121 : jhr 3646 (* generate low-IL code for scaling a non-scalar tensor; `sId` is the scalar
122 :     * parameter's ID and `vId` is the tensor parameter's ID.
123 :     *)
124 : jhr 3647 fun expandScale (sId, vId, shape, params, body, index, args) = let
125 : jhr 3632 val (n, vecIX, index') = dropIndex index
126 :     in
127 : jhr 3647 case matchLast(shape, n)
128 : jhr 3646 of SOME ix => let
129 : jhr 3666 val vecA = createI (args, sId, [])
130 :     val vecB = createP (args, vecIX, vId, ix)
131 : jhr 3710 val binop = ToVec.binopV (Op.VScale vecIX, vecIX)
132 : jhr 3632 in
133 : jhr 3647 unroll (
134 : jhr 3710 index, index',
135 :     fn (avail, mapp) => binop (avail, mapp, vecA, vecB))
136 : jhr 3632 end
137 : jhr 3646 | _ => scalarExpand (params, body, index, args)
138 : jhr 3632 (* end case *)
139 :     end
140 : jhr 3626
141 : jhr 3647 (* handle potential sum-of-products (i.e., inner products); otherwise fall back to the
142 :     * general scalar case.
143 :     *)
144 :     fun expandInner (params, body, index, args) = (case body
145 :     of E.Sum(
146 :     [(E.V v, _, ub)],
147 :     E.Opn(E.Prod, [E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)])
148 :     ) => (case (matchFindLast(alpha, v), matchFindLast(beta, v))
149 :     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
150 :     (* v is the last index of alpha, beta and nowhere else *)
151 :     val vecIX= ub+1
152 : jhr 3666 val vecA = createP (args, vecIX, id1, ix1)
153 :     val vecB = createP (args, vecIX, id2, ix2)
154 : jhr 3647 in
155 : jhr 3710 unroll (
156 :     index, index,
157 :     fn (avail, mapp) => ToVec.dotV (avail, mapp, vecA, vecB))
158 : jhr 3647 end
159 :     | _ => scalarExpand (params, body, index, args)
160 :     (* end case *))
161 :     | E.Sum(
162 : jhr 3627 [(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],
163 : jhr 3647 E.Opn(E.Prod, [E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)])
164 :     ) => let
165 : jhr 3707 fun check (v, ub, i, lb', ub') = (
166 :     case (matchFindLast(alpha, v), matchFindLast(beta, v))
167 : jhr 3647 of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
168 :     (* v is the last index of alpha, beta and nowhere else *)
169 :     val vecIX = ub+1
170 : jhr 3666 val vecA = createP (args, vecIX, id1, ix1)
171 :     val vecB = createP (args, vecIX, id2, ix2)
172 : jhr 3647 in
173 :     SOME(unroll (
174 : jhr 3710 index, index,
175 :     fn (avail, mapp) =>
176 :     ToVec.sumDotV (avail, mapp, i, lb', ub', vecA, vecB)))
177 : jhr 3647 end
178 :     | _ => NONE
179 :     (* end case *))
180 :     in
181 : jhr 3707 case check(v1, ub1, v2, lb2, ub2)
182 : jhr 3647 of SOME e =>e
183 : jhr 3707 | _ => (case check(v2, ub2, v1, lb1, ub1)
184 : jhr 3647 of SOME e => e
185 :     | _ => scalarExpand (params, body, index, args)
186 :     (* end case *))
187 :     (* end case *)
188 :     end
189 :     | _ => scalarExpand (params, body, index, args)
190 :     (* end case *))
191 : jhr 3626
192 : jhr 3641 (* expand an Ein expression that has a non-scalar result *)
193 :     fun nonScalar (params, body, index, args) = (case body
194 :     of E.Op2(E.Sub, E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)) => let
195 :     val (n, vecIX, index') = dropIndex index
196 :     in
197 :     case (matchLast(alpha, n), matchLast(beta, n))
198 :     of (SOME ix1, SOME ix2) => let
199 : jhr 3666 val vecA = createP (args, vecIX, id1, ix1)
200 : jhr 3703 val vecB = createP (args, vecIX, id2, ix2)
201 : jhr 3710 val binop = ToVec.binopV (Op.VSub vecIX, vecIX)
202 : jhr 3641 in
203 : jhr 3654 unroll (
204 : jhr 3710 index, index',
205 :     fn (avail, mapp) => binop (avail, mapp, vecA, vecB))
206 : jhr 3641 end
207 : jhr 3646 | _ => scalarExpand (params, body, index, args)
208 : jhr 3641 (* end case *)
209 :     end
210 :     | E.Opn(E.Add, es as E.Tensor(_, _::_)::_) => let
211 :     val (n, vecIX, index') = dropIndex index
212 :     (* QUESTION: what does the following comment mean? What do we do if each tensor has matching indices? *)
213 :     (* check that each tensor in addition list has matching indices *)
214 : jhr 3646 fun sample ([], rest) =
215 : jhr 3710 unroll (
216 :     index, index',
217 :     fn (avail, mapp) => ToVec.addV (avail, mapp, List.rev rest))
218 : jhr 3641 | sample (E.Tensor(id1, alpha)::ts, rest) = (case matchLast(alpha, n)
219 : jhr 3666 of SOME ix1 => sample(ts, createP (args, vecIX, id1, ix1)::rest)
220 : jhr 3646 | _ => scalarExpand (params, body, index, args)
221 : jhr 3641 (* end case *))
222 : jhr 3646 | sample _ = scalarExpand (params, body, index, args)
223 : jhr 3641 in
224 : jhr 3646 sample (es, [])
225 : jhr 3641 end
226 :     | E.Op1(E.Neg, E.Tensor(id, alpha as (_::_))) => let
227 :     val (n, vecIX, index') = dropIndex index
228 :     in
229 :     case matchLast (alpha, n)
230 : jhr 3710 of SOME ix1 => unroll (
231 :     index, index',
232 :     fn (avail, mapp) => ToVec.negV (avail, mapp, createP (args, vecIX, id, ix1)))
233 : jhr 3646 | _ => scalarExpand (params, body, index, args)
234 : jhr 3641 (* end case *)
235 :     end
236 :     | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, shp as _::_)]) =>
237 : jhr 3646 expandScale (s, v, shp, params, body, index, args)
238 : jhr 3641 | E.Opn(E.Prod, [E.Tensor(v, shp as _::_), E.Tensor(s , [])]) =>
239 : jhr 3646 expandScale (s, v, shp, params, body, index, args)
240 : jhr 3641 | E.Opn(E.Prod, [E.Tensor(id1 , alpha as _::_), E.Tensor(id2, beta as _::_)]) => let
241 :     val (n, vecIX, index') = dropIndex index
242 :     in
243 :     case (matchFindLast(alpha, n), matchFindLast(beta, n))
244 :     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
245 :     (* n is the last index of alpha, beta and nowhere else, possible modulate *)
246 : jhr 3666 val vecA = createP (args, vecIX, id1, ix1)
247 :     val vecB = createP (args, vecIX, id2, ix2)
248 : jhr 3710 val binop = ToVec.binopV (Op.VMul vecIX, vecIX)
249 : jhr 3641 in
250 : jhr 3646 unroll (
251 : jhr 3710 index, index',
252 :     fn (avail, mapp) => binop (avail, mapp, vecA, vecB))
253 : jhr 3641 end
254 :     | ((NONE, NONE), (SOME ix2, NONE)) => let
255 :     (* n is the last index of beta and nowhere else, possible scaleVector *)
256 : jhr 3666 val vecA = createI (args, id1, alpha)
257 :     val vecB = createP (args, vecIX, id2, ix2)
258 : jhr 3710 val binop = ToVec.binopV (Op.VScale vecIX, vecIX)
259 : jhr 3641 in
260 : jhr 3646 unroll (
261 : jhr 3710 index, index',
262 :     fn (avail, mapp) => binop (avail, mapp, vecA, vecB))
263 : jhr 3641 end
264 :     | ((SOME ix1, NONE), (NONE, NONE)) => let
265 :     (* n is the last index of alpha and nowhere else, ossile scaleVector *)
266 : jhr 3666 val vecA = createI (args, id2, beta)
267 :     val vecB = createP (args, vecIX, id1, ix1)
268 : jhr 3710 val binop = ToVec.binopV (Op.VScale vecIX, vecIX)
269 : jhr 3641 in
270 : jhr 3646 unroll (
271 : jhr 3710 index, index',
272 :     fn (avail, mapp) => binop (avail, mapp, vecA, vecB))
273 : jhr 3641 end
274 : jhr 3646 | _ => scalarExpand (params, body, index, args)
275 : jhr 3641 (* end case *)
276 :     end
277 : jhr 3647 | _ => expandInner (params, body, index, args)
278 : jhr 3641 (* end case *))
279 :    
280 : jhr 3632 (* scan:var*E.Ein*Var list * Var list-> Var*LowIR.Assgn list
281 : jhr 3627 * scans body for vectorization potential
282 :     *)
283 : cchiw 3741 fun expand (y, Ein.EIN{params, index, body}, args) = let
284 : jhr 3646 (* FIXME: We really only care if the last dimension of a non-scalar shape is "3", since it
285 :     * causes poor performance in the generated code. Really, this should be handled by the LowIR
286 :     * to TreeIR transform, where we take into account machine vector widths.
287 :     *)
288 : jhr 3710 val avail = (case index
289 : cchiw 3741 of [3, 3] => scalarExpand (params, body, index, args)
290 :     | [3, 3, 3] => scalarExpand (params, body, index, args)
291 : jhr 3641 | _::_ => nonScalar (params, body, index, args)
292 : jhr 3647 | _ => expandInner (params, body, index, args)
293 : jhr 3627 (* end case *))
294 : jhr 3654 val (_, asgn) :: rest = AvailRHS.getAssignments avail
295 : cchiw 3743
296 :     fun n (y, LowIR.OP(op1, args)) = print(String.concat["\nvar: ", LowIR.Var.name(y), " = OP(",Op.toString(op1),", ",
297 :     String.concatWith","(List.map (fn e=> LowIR.Var.name e) args), ")"])
298 :     | n (y, LowIR.EINAPP(ein, args)) = print(String.concat["\nvar: ", LowIR.Var.name(y), " = EINAPP(",EinPP.toString(ein),", ",
299 :     String.concatWith","(List.map (fn e=> LowIR.Var.name e) args), ")"])
300 :     | n (y, LowIR.CONS(args, _)) = print(String.concat["\nvar: ", LowIR.Var.name(y), " = CONS(_,",
301 :     String.concatWith","(List.map (fn e=> LowIR.Var.name e) args), ")"])
302 :     | n ((y,_)) = print(String.concat["\nvar: ", LowIR.Var.name y])
303 :     fun iter(str, es) = (print str ; List.map n es)
304 :     (*val _ = iter("\n\n\nfinal stmts", List.rev (((y, asgn)::rest)))*)
305 : jhr 3627 in
306 : jhr 3659 List.revMap LowIR.ASSGN ((y, asgn)::rest)
307 : jhr 3627 end
308 : jhr 3626
309 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0