Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3627 - (view) (download)

1 : jhr 3626 (* ein-to-low.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     (*
10 :     * genfn-Does preliminary scan of the body of EIN.EIN for vectorization potential
11 :     * If there is a field then passes to FieldToLow
12 :     * If there is a tensor then passes to handle*() functions to check if indices match
13 :     * i.e. <A_ij+B_ij>_ij vs.<A_ji+B_ij>_ij
14 :     *
15 :     * (1) If indices match then passes to Iter->VecToLow functions.
16 :     * Creates LowIL vector operators.
17 :     * (2) Iter->ScaToLow
18 :     * Creates Low-IL scalar operators
19 : jhr 3627 * Note. The Iter function creates LowIL.CONS and therefore binds the indices in the EIN.body
20 : jhr 3626 *)
21 : jhr 3627
22 : jhr 3626 structure EinToLow : sig
23 :    
24 : jhr 3627 val expand : LowIL.var * Ein.ein * LowIL.var list -> LowIL.assign list
25 :    
26 : jhr 3626 end = struct
27 :    
28 :     structure Var = LowIL.Var
29 :     structure E = Ein
30 :     structure Iter = Iter
31 :     structure EtoSca = ScaToLow
32 :     structure EtoVec = VecToLow
33 :     structure H = Helper
34 :     structure Op = LowOps
35 :    
36 :     fun iter e = Iter.prodIter e
37 :     fun intToReal n = H.intToReal n
38 : jhr 3627
39 :     (* `dropIndex alpha` returns the (len, i, alpha') where
40 :     * len = length(alpha') and alpha = alpha'@[i].
41 :     *)
42 : jhr 3626 fun dropIndex alpha = let
43 : jhr 3627 fun drop ([], _, _) = raise Fail "dropIndex[]"
44 :     | drop ([idx], n, idxs') = (n, idx, List.rev idxs')
45 :     | drop (idx::idxs, n, idx') = drop (idxs, n+1, idx::idx')
46 :     in
47 :     drop (alpha, 0, [])
48 :     end
49 : jhr 3626
50 :     (*matchLast:E.alpha*int -> (E.alpha) Option
51 :     * Is the last index of alpha E.V n.
52 : jhr 3627 * If so, return the rest of the list
53 : jhr 3626 *)
54 : jhr 3627 fun matchLast (alpha, n) = (case List.rev alpha
55 :     of (E.V v)::es => if (n = v) then SOME(List.rev es) else NONE
56 : jhr 3626 | _ => NONE
57 : jhr 3627 (* end case *))
58 : jhr 3626
59 :     (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
60 :     * Is the last index of alpha = n.
61 :     * is n anywhere else?
62 :     *)
63 :     fun matchFindLast (alpha, n) = let
64 :     val es = List.tl(List.rev(alpha))
65 :     val f = List.find(fn E.V e => e = n | _ => false) es
66 :     in
67 :     (matchLast(alpha, n), f)
68 : jhr 3627 end
69 : jhr 3626
70 :     (*runGeneralCase:Var*E.EIN*Var-> Var*LowIL.ASSN list
71 :     * does not do vector projections
72 :     * instead approach like a general EIN
73 :     *)
74 :     fun runGeneralCase (e:Ein.ein, args:LowIL.var list) = let
75 :     val index = Ein.index e
76 :     in
77 :     iter(AvailRHS.new(), index, index, EtoSca.generalfn, (e, args))
78 :     end
79 :    
80 :     fun createP((E.EIN{params ,...}, args), vecIndex, id, ix) =
81 :     VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Proj vecIndex)
82 : jhr 3627
83 : jhr 3626 fun createI((E.EIN{params ,...}, args), id, ix) =
84 :     VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Indx)
85 : jhr 3627
86 : jhr 3626 (*handleNeg:.body* int list*info ->Var*LowIL.ASSN list
87 :     * info:(string*E.EIN*Var list)
88 :     * low-IL code for scaling a vector with negative 1.
89 :     *)
90 :     fun handleNeg (E.Op1(E.Neg, E.Tensor(id , alpha)), index, info) = let
91 :     val (n, vecIX, index') = dropIndex index
92 :     in case (matchLast(alpha, n))
93 :     of SOME ix1 => let
94 :     val avail = AvailRHS.new()
95 :     val (avail, vA) = intToReal(avail, ~1)
96 :     val vecB = createP(info, vecIX, id, ix1)
97 :     val nextfnargs = (vA, vecIX, vecB)
98 :     in
99 :     iter(avail, index, index', EtoVec.negV, nextfnargs)
100 :     end
101 :     | NONE => runGeneralCase info
102 :     end
103 :    
104 :     (*handleSub:E.body*int list*info ->Var*LowIL.ASSN list
105 :     * info:(string*E.EIN*Var list)
106 :     * low-IL code for subtracting two vectors
107 :     *)
108 :     fun handleSub (E.Op2(E.Sub, E.Tensor(id1, alpha), E.Tensor(id2, beta)), index, info) = let
109 :     val (n, vecIX, index') = dropIndex index
110 :     in case(matchLast(alpha, n) , matchLast(beta, n)) of
111 :     (SOME ix1, SOME ix2) =>let
112 :     val avail = AvailRHS.new()
113 :     val vecA = createP(info, vecIX, id1, ix1)
114 :     val vecB = createP(info, vecIX, id2, ix2)
115 :     val nextfnargs = (vecIX, Op.subVec vecIX, vecA, vecB)
116 :     in
117 :     iter(avail, index, index', EtoVec.op2, nextfnargs)
118 : jhr 3627 end
119 : jhr 3626 | _ => runGeneralCase info
120 :     end
121 :    
122 :     (*handleAdd:E.body*int list*info ->Var*LowIL.ASSN list
123 :     * info:(string*E.EIN*Var list)
124 :     * low-IL code for adding two vectors
125 :     *)
126 :     fun handleAdd (E.Opn(E.Add, es), index, info) = let
127 :     val (n, vecIX, index') = dropIndex index
128 :     (*check that each tensor in addition list has matching indices*)
129 :     fun sample ([], rest) = let
130 :     val avail = AvailRHS.new()
131 :     val nextfnargs = (vecIX, List.rev rest)
132 :     in
133 :     iter(avail, index, index', EtoVec.addV, nextfnargs)
134 : jhr 3627 end
135 : jhr 3626 | sample (E.Tensor(id1, alpha)::ts, rest) = (case (matchLast(alpha, n))
136 :     of SOME ix1 => sample(ts, createP(info, vecIX, id1, ix1)::rest)
137 :     | _ => runGeneralCase info
138 : jhr 3627 (* end case *))
139 : jhr 3626 | sample _ = runGeneralCase info
140 :     in
141 :     sample(es, [])
142 :     end
143 : jhr 3627
144 : jhr 3626 (*handleScale:E.tensor_id*E.tensor_id*E.alpha*int list*info ->Var*LowIL.ASSN list
145 :     * info:(string*E.EIN*Var list)
146 :     * low-IL code for adding scaling a vector
147 :     *)
148 :     fun handleScale (id1, id2, alpha2, index, info) = let
149 :     val (n, vecIX, index') = dropIndex index
150 :     in case matchLast(alpha2, n)
151 :     of SOME ix2 => let
152 :     val avail = AvailRHS.new()
153 :     val vecA = createI(info, id1, [])
154 :     val vecB = createP(info, vecIX, id2, ix2)
155 :     val nextfnargs = (vecIX, Op.prodScaV vecIX, vecA, vecB)
156 :     in
157 :     iter(avail, index, index', EtoVec.op2, nextfnargs)
158 :     end
159 :     | _ =>runGeneralCase info
160 :     end
161 :    
162 :     (*handleProd:E.body*int list*info ->Var*LowIL.ASSN list
163 :     * info:(string*E.EIN*Var list)
164 :     * low-IL code for vector product
165 :     *)
166 :     fun handleProd (E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)]), index, info) = let
167 :     val (e, args) = info
168 :     val (n, vecIX, index') = dropIndex index
169 :     val avail = AvailRHS.new()
170 :     in case(matchFindLast(alpha, n), matchFindLast(beta, n))
171 :     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
172 :     (*n is the last index of alpha, beta and nowhere else, possible modulate*)
173 :     val vecA = createP(info, vecIX, id1, ix1)
174 :     val vecB = createP(info, vecIX, id2, ix2)
175 :     val nextfnargs = (vecIX, Op.prodVec vecIX, vecA, vecB)
176 :     in
177 :     iter(avail, index, index', EtoVec.op2, nextfnargs)
178 :     end
179 :     | ((NONE, NONE), (SOME ix2, NONE)) =>let
180 :     (*n is the last index of beta and nowhere else, possible scaleVector*)
181 :     val vecA = createI(info, id1, alpha)
182 :     val vecB = createP(info, vecIX, id2, ix2)
183 :     val nextfnargs = (vecIX, Op.prodScaV vecIX, vecA, vecB)
184 :     in
185 :     iter(avail, index, index', EtoVec.op2 , nextfnargs)
186 :     end
187 :     | ((SOME ix1, NONE), (NONE, NONE)) =>let
188 :     (*n is the last index of alpha and nowhere else, ossile scaleVector*)
189 :     val vecA = createI(info, id2, beta)
190 :     val vecB = createP(info, vecIX, id1, ix1)
191 :     val nextfnargs = (vecIX, Op.prodScaV vecIX, vecA, vecB)
192 : jhr 3627
193 : jhr 3626 in
194 :     iter(avail, index, index', EtoVec.op2, nextfnargs)
195 :     end
196 :     | _ => runGeneralCase info
197 :     end
198 : jhr 3627
199 : jhr 3626 (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
200 :     * info:(string*E.EIN*Var list)
201 :     * low-IL code for dot product
202 :     *)
203 : jhr 3627 fun handleSumProd1 (E.Sum([(E.V v, _, ub)], E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])), index, info) =
204 : jhr 3626 case(matchFindLast(alpha, v), matchFindLast(beta, v))
205 :     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
206 :     (*v is the last index of alpha, beta and nowhere else, possible sumProd*)
207 :     val avail = AvailRHS.new()
208 :     val vecIX= ub+1
209 :     val vecA = createP(info, vecIX, id1, ix1)
210 :     val vecB = createP(info, vecIX, id2, ix2)
211 :     val nextfnargs = (vecIX, vecA, vecB)
212 :     in
213 :     iter(avail, index, index, EtoVec.dotV, nextfnargs)
214 :     end
215 :     | _ => runGeneralCase info
216 :    
217 : jhr 3627
218 : jhr 3626 (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
219 :     * info:(string*E.EIN*Var list)
220 :     * low-IL code for double dot product
221 :     * Sigma_{i, j} A_ij B_ij
222 :     *)
223 :     fun handleSumProd2 (body, index, info) = let
224 : jhr 3627 val E.Sum(
225 :     [(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],
226 :     E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])
227 :     ) = body
228 :     fun check (v, ub, sx) = (case (matchFindLast(alpha, v), matchFindLast(beta, v))
229 : jhr 3626 of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
230 :     (*v is the last index of alpha, beta and nowhere else, possible sumProd*)
231 :     val avail = AvailRHS.new()
232 :     (*val nextfnargs = (Ein.params e, args, sx, ub+1, id1, ix1, id2, ix2)*)
233 :     val vecIX= ub+1
234 :     val vecA = createP(info, vecIX, id1, ix1)
235 :     val vecB = createP(info, vecIX, id2, ix2)
236 :     val nextfnargs = (sx, vecIX, vecA, vecB)
237 :     in
238 :     SOME(iter(avail, index, index, EtoVec.sumDotV, nextfnargs))
239 :     end
240 :     | _ => NONE
241 : jhr 3627 (* end case *))
242 : jhr 3626 in (case check(v1, ub1, (E.V v2, lb2, ub2))
243 :     of SOME e =>e
244 :     | _ => (case check(v2, ub2, (E.V v1, lb1, ub1))
245 :     of SOME e => e
246 :     |_ => runGeneralCase info
247 : jhr 3627 (* end case *))
248 :     (* end case *))
249 : jhr 3626 end
250 :    
251 : jhr 3627 (* scan:var*E.Ein*Var list * Var list-> Var*LowIL.Assgn list
252 :     * scans body for vectorization potential
253 :     *)
254 :     fun expand (y, Ein.EIN{params, index, body}, args : LowIL.var list) = let
255 :     val info = (e, args)
256 :     val all = (b, index, info)
257 :     (* any result type *)
258 :     fun gen () = (case body
259 :     of E.Sum([_], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
260 :     handleSumProd1 all
261 :     | E.Sum([_, _], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
262 :     handleSumProd2 all
263 :     | _ => runGeneralCase info
264 :     (* end case *))
265 :     (* non scalar result for sure *)
266 :     fun nonScalar () = (case body
267 :     of E.Op2(E.Sub, E.Tensor(_, _::_), E.Tensor(_, _::_)) =>
268 :     handleSub all
269 :     | E.Opn(E.Add, (E.Tensor(_, _::_)::_)) =>
270 :     handleAdd all
271 :     | E.Op1(E.Neg, E.Tensor(_ , _::_)) =>
272 :     handleNeg all
273 :     | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, j::jx)]) =>
274 :     handleScale(s, v, j::jx, index, info)
275 :     | E.Opn(E.Prod, [E.Tensor(v, j::jx), E.Tensor(s , [])]) =>
276 :     handleScale(s, v, j::jx, index, info)
277 :     | E.Opn(E.Prod, [E.Tensor(_ , _::_), E.Tensor(_, _::_)]) =>
278 :     handleProd all
279 :     | _ => gen()
280 :     (* end case *))
281 :     (* QUESTION: what is special about "3" here? *)
282 :     val (avail, _) = (case index
283 :     of [3, 3] => runGeneralCase info
284 :     | [3, 3, 3] => runGeneralCase info
285 :     | _::_ => nonScalar()
286 :     | _ => gen()
287 :     (* end case *))
288 :     val (x, asgn):: rest = AvailRHS.getAssignments avail
289 :     in
290 :     List.revMap LowIL.ASSGN ((y, A)::(x, A)::rest)
291 :     end
292 : jhr 3626
293 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0