Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3628 - (view) (download)

1 : jhr 3626 (* ein-to-low.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     (*
10 :     * genfn-Does preliminary scan of the body of EIN.EIN for vectorization potential
11 :     * If there is a field then passes to FieldToLow
12 :     * If there is a tensor then passes to handle*() functions to check if indices match
13 :     * i.e. <A_ij+B_ij>_ij vs.<A_ji+B_ij>_ij
14 :     *
15 :     * (1) If indices match then passes to Iter->VecToLow functions.
16 :     * Creates LowIL vector operators.
17 :     * (2) Iter->ScaToLow
18 :     * Creates Low-IL scalar operators
19 : jhr 3627 * Note. The Iter function creates LowIL.CONS and therefore binds the indices in the EIN.body
20 : jhr 3626 *)
21 : jhr 3627
22 : jhr 3626 structure EinToLow : sig
23 :    
24 : jhr 3627 val expand : LowIL.var * Ein.ein * LowIL.var list -> LowIL.assign list
25 :    
26 : jhr 3626 end = struct
27 :    
28 :     structure Var = LowIL.Var
29 :     structure E = Ein
30 :     structure Iter = Iter
31 :     structure EtoSca = ScaToLow
32 :     structure EtoVec = VecToLow
33 :     structure H = Helper
34 :     structure Op = LowOps
35 :    
36 :     fun iter e = Iter.prodIter e
37 :     fun intToReal n = H.intToReal n
38 : jhr 3627
39 :     (* `dropIndex alpha` returns the (len, i, alpha') where
40 :     * len = length(alpha') and alpha = alpha'@[i].
41 :     *)
42 : jhr 3626 fun dropIndex alpha = let
43 : jhr 3627 fun drop ([], _, _) = raise Fail "dropIndex[]"
44 :     | drop ([idx], n, idxs') = (n, idx, List.rev idxs')
45 :     | drop (idx::idxs, n, idx') = drop (idxs, n+1, idx::idx')
46 :     in
47 :     drop (alpha, 0, [])
48 :     end
49 : jhr 3626
50 :     (*matchLast:E.alpha*int -> (E.alpha) Option
51 :     * Is the last index of alpha E.V n.
52 : jhr 3627 * If so, return the rest of the list
53 : jhr 3626 *)
54 : jhr 3627 fun matchLast (alpha, n) = (case List.rev alpha
55 :     of (E.V v)::es => if (n = v) then SOME(List.rev es) else NONE
56 : jhr 3626 | _ => NONE
57 : jhr 3627 (* end case *))
58 : jhr 3626
59 :     (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
60 :     * Is the last index of alpha = n.
61 :     * is n anywhere else?
62 :     *)
63 :     fun matchFindLast (alpha, n) = let
64 : jhr 3628 fun find es = List.find (fn (E.V idx') => (n = idx') | _ => false) es
65 :     in
66 :     case List.rev alpha
67 :     of (E.V v)::es => if (n = v)
68 :     then (SOME(List.rev es), find es)
69 :     else (NONE, find es)
70 :     | _::es => (NONE, find es)
71 :     | [] => (NONE, NONE)
72 :     (* end case *)
73 :     end
74 : jhr 3626
75 :     (*runGeneralCase:Var*E.EIN*Var-> Var*LowIL.ASSN list
76 :     * does not do vector projections
77 :     * instead approach like a general EIN
78 :     *)
79 :     fun runGeneralCase (e:Ein.ein, args:LowIL.var list) = let
80 : jhr 3628 val index = Ein.index e
81 :     in
82 :     iter (AvailRHS.new(), index, index, EtoSca.generalfn, e, args)
83 :     end
84 : jhr 3626
85 :     fun createP((E.EIN{params ,...}, args), vecIndex, id, ix) =
86 :     VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Proj vecIndex)
87 : jhr 3627
88 : jhr 3626 fun createI((E.EIN{params ,...}, args), id, ix) =
89 :     VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Indx)
90 : jhr 3627
91 : jhr 3626 (*handleNeg:.body* int list*info ->Var*LowIL.ASSN list
92 :     * info:(string*E.EIN*Var list)
93 :     * low-IL code for scaling a vector with negative 1.
94 :     *)
95 :     fun handleNeg (E.Op1(E.Neg, E.Tensor(id , alpha)), index, info) = let
96 :     val (n, vecIX, index') = dropIndex index
97 :     in case (matchLast(alpha, n))
98 :     of SOME ix1 => let
99 :     val avail = AvailRHS.new()
100 :     val (avail, vA) = intToReal(avail, ~1)
101 :     val vecB = createP(info, vecIX, id, ix1)
102 :     val nextfnargs = (vA, vecIX, vecB)
103 :     in
104 :     iter(avail, index, index', EtoVec.negV, nextfnargs)
105 :     end
106 :     | NONE => runGeneralCase info
107 :     end
108 :    
109 :     (*handleSub:E.body*int list*info ->Var*LowIL.ASSN list
110 :     * info:(string*E.EIN*Var list)
111 :     * low-IL code for subtracting two vectors
112 :     *)
113 :     fun handleSub (E.Op2(E.Sub, E.Tensor(id1, alpha), E.Tensor(id2, beta)), index, info) = let
114 :     val (n, vecIX, index') = dropIndex index
115 :     in case(matchLast(alpha, n) , matchLast(beta, n)) of
116 :     (SOME ix1, SOME ix2) =>let
117 :     val avail = AvailRHS.new()
118 :     val vecA = createP(info, vecIX, id1, ix1)
119 :     val vecB = createP(info, vecIX, id2, ix2)
120 :     val nextfnargs = (vecIX, Op.subVec vecIX, vecA, vecB)
121 :     in
122 :     iter(avail, index, index', EtoVec.op2, nextfnargs)
123 : jhr 3627 end
124 : jhr 3626 | _ => runGeneralCase info
125 :     end
126 :    
127 :     (*handleAdd:E.body*int list*info ->Var*LowIL.ASSN list
128 :     * info:(string*E.EIN*Var list)
129 :     * low-IL code for adding two vectors
130 :     *)
131 :     fun handleAdd (E.Opn(E.Add, es), index, info) = let
132 :     val (n, vecIX, index') = dropIndex index
133 :     (*check that each tensor in addition list has matching indices*)
134 :     fun sample ([], rest) = let
135 :     val avail = AvailRHS.new()
136 :     val nextfnargs = (vecIX, List.rev rest)
137 :     in
138 :     iter(avail, index, index', EtoVec.addV, nextfnargs)
139 : jhr 3627 end
140 : jhr 3626 | sample (E.Tensor(id1, alpha)::ts, rest) = (case (matchLast(alpha, n))
141 :     of SOME ix1 => sample(ts, createP(info, vecIX, id1, ix1)::rest)
142 :     | _ => runGeneralCase info
143 : jhr 3627 (* end case *))
144 : jhr 3626 | sample _ = runGeneralCase info
145 :     in
146 :     sample(es, [])
147 :     end
148 : jhr 3627
149 : jhr 3626 (*handleScale:E.tensor_id*E.tensor_id*E.alpha*int list*info ->Var*LowIL.ASSN list
150 :     * info:(string*E.EIN*Var list)
151 :     * low-IL code for adding scaling a vector
152 :     *)
153 :     fun handleScale (id1, id2, alpha2, index, info) = let
154 :     val (n, vecIX, index') = dropIndex index
155 :     in case matchLast(alpha2, n)
156 :     of SOME ix2 => let
157 :     val avail = AvailRHS.new()
158 :     val vecA = createI(info, id1, [])
159 :     val vecB = createP(info, vecIX, id2, ix2)
160 :     val nextfnargs = (vecIX, Op.prodScaV vecIX, vecA, vecB)
161 :     in
162 :     iter(avail, index, index', EtoVec.op2, nextfnargs)
163 :     end
164 :     | _ =>runGeneralCase info
165 :     end
166 :    
167 :     (*handleProd:E.body*int list*info ->Var*LowIL.ASSN list
168 :     * info:(string*E.EIN*Var list)
169 :     * low-IL code for vector product
170 :     *)
171 :     fun handleProd (E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)]), index, info) = let
172 :     val (e, args) = info
173 :     val (n, vecIX, index') = dropIndex index
174 :     val avail = AvailRHS.new()
175 :     in case(matchFindLast(alpha, n), matchFindLast(beta, n))
176 :     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
177 :     (*n is the last index of alpha, beta and nowhere else, possible modulate*)
178 :     val vecA = createP(info, vecIX, id1, ix1)
179 :     val vecB = createP(info, vecIX, id2, ix2)
180 :     val nextfnargs = (vecIX, Op.prodVec vecIX, vecA, vecB)
181 :     in
182 :     iter(avail, index, index', EtoVec.op2, nextfnargs)
183 :     end
184 :     | ((NONE, NONE), (SOME ix2, NONE)) =>let
185 :     (*n is the last index of beta and nowhere else, possible scaleVector*)
186 :     val vecA = createI(info, id1, alpha)
187 :     val vecB = createP(info, vecIX, id2, ix2)
188 :     val nextfnargs = (vecIX, Op.prodScaV vecIX, vecA, vecB)
189 :     in
190 :     iter(avail, index, index', EtoVec.op2 , nextfnargs)
191 :     end
192 :     | ((SOME ix1, NONE), (NONE, NONE)) =>let
193 :     (*n is the last index of alpha and nowhere else, ossile scaleVector*)
194 :     val vecA = createI(info, id2, beta)
195 :     val vecB = createP(info, vecIX, id1, ix1)
196 :     val nextfnargs = (vecIX, Op.prodScaV vecIX, vecA, vecB)
197 : jhr 3627
198 : jhr 3626 in
199 :     iter(avail, index, index', EtoVec.op2, nextfnargs)
200 :     end
201 :     | _ => runGeneralCase info
202 :     end
203 : jhr 3627
204 : jhr 3626 (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
205 :     * info:(string*E.EIN*Var list)
206 :     * low-IL code for dot product
207 :     *)
208 : jhr 3627 fun handleSumProd1 (E.Sum([(E.V v, _, ub)], E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])), index, info) =
209 : jhr 3626 case(matchFindLast(alpha, v), matchFindLast(beta, v))
210 :     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
211 :     (*v is the last index of alpha, beta and nowhere else, possible sumProd*)
212 :     val avail = AvailRHS.new()
213 :     val vecIX= ub+1
214 :     val vecA = createP(info, vecIX, id1, ix1)
215 :     val vecB = createP(info, vecIX, id2, ix2)
216 :     val nextfnargs = (vecIX, vecA, vecB)
217 :     in
218 :     iter(avail, index, index, EtoVec.dotV, nextfnargs)
219 :     end
220 :     | _ => runGeneralCase info
221 :    
222 : jhr 3627
223 : jhr 3626 (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
224 :     * info:(string*E.EIN*Var list)
225 :     * low-IL code for double dot product
226 :     * Sigma_{i, j} A_ij B_ij
227 :     *)
228 :     fun handleSumProd2 (body, index, info) = let
229 : jhr 3627 val E.Sum(
230 :     [(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],
231 :     E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])
232 :     ) = body
233 :     fun check (v, ub, sx) = (case (matchFindLast(alpha, v), matchFindLast(beta, v))
234 : jhr 3626 of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
235 :     (*v is the last index of alpha, beta and nowhere else, possible sumProd*)
236 :     val avail = AvailRHS.new()
237 :     (*val nextfnargs = (Ein.params e, args, sx, ub+1, id1, ix1, id2, ix2)*)
238 :     val vecIX= ub+1
239 :     val vecA = createP(info, vecIX, id1, ix1)
240 :     val vecB = createP(info, vecIX, id2, ix2)
241 :     val nextfnargs = (sx, vecIX, vecA, vecB)
242 :     in
243 :     SOME(iter(avail, index, index, EtoVec.sumDotV, nextfnargs))
244 :     end
245 :     | _ => NONE
246 : jhr 3627 (* end case *))
247 : jhr 3626 in (case check(v1, ub1, (E.V v2, lb2, ub2))
248 :     of SOME e =>e
249 :     | _ => (case check(v2, ub2, (E.V v1, lb1, ub1))
250 :     of SOME e => e
251 :     |_ => runGeneralCase info
252 : jhr 3627 (* end case *))
253 :     (* end case *))
254 : jhr 3626 end
255 :    
256 : jhr 3627 (* scan:var*E.Ein*Var list * Var list-> Var*LowIL.Assgn list
257 :     * scans body for vectorization potential
258 :     *)
259 :     fun expand (y, Ein.EIN{params, index, body}, args : LowIL.var list) = let
260 :     val info = (e, args)
261 :     val all = (b, index, info)
262 :     (* any result type *)
263 :     fun gen () = (case body
264 :     of E.Sum([_], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
265 :     handleSumProd1 all
266 :     | E.Sum([_, _], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
267 :     handleSumProd2 all
268 :     | _ => runGeneralCase info
269 :     (* end case *))
270 :     (* non scalar result for sure *)
271 :     fun nonScalar () = (case body
272 :     of E.Op2(E.Sub, E.Tensor(_, _::_), E.Tensor(_, _::_)) =>
273 :     handleSub all
274 :     | E.Opn(E.Add, (E.Tensor(_, _::_)::_)) =>
275 :     handleAdd all
276 :     | E.Op1(E.Neg, E.Tensor(_ , _::_)) =>
277 :     handleNeg all
278 :     | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, j::jx)]) =>
279 :     handleScale(s, v, j::jx, index, info)
280 :     | E.Opn(E.Prod, [E.Tensor(v, j::jx), E.Tensor(s , [])]) =>
281 :     handleScale(s, v, j::jx, index, info)
282 :     | E.Opn(E.Prod, [E.Tensor(_ , _::_), E.Tensor(_, _::_)]) =>
283 :     handleProd all
284 :     | _ => gen()
285 :     (* end case *))
286 :     (* QUESTION: what is special about "3" here? *)
287 :     val (avail, _) = (case index
288 :     of [3, 3] => runGeneralCase info
289 :     | [3, 3, 3] => runGeneralCase info
290 :     | _::_ => nonScalar()
291 :     | _ => gen()
292 :     (* end case *))
293 :     val (x, asgn):: rest = AvailRHS.getAssignments avail
294 :     in
295 :     List.revMap LowIL.ASSGN ((y, A)::(x, A)::rest)
296 :     end
297 : jhr 3626
298 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0