Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/ein-to-low-set.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/ein-to-low-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3624 - (view) (download)

1 : jhr 3624 (* ein-to-low-set.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 : cchiw 3444 (*
10 :     * genfn-Does preliminary scan of the body of EIN.EIN for vectorization potential
11 :     * If there is a field then passes to FieldToLow
12 :     * If there is a tensor then passes to handle*() functions to check if indices match
13 :     * i.e. <A_ij+B_ij>_ij vs.<A_ji+B_ij>_ij
14 :     *
15 :     * (1) If indices match then passes to Iter->VecToLow functions.
16 :     * Creates LowIL vector operators.
17 :     * (2) Iter->ScaToLow
18 :     * Creates Low-IL scalar operators
19 :     * Note. The Iter function creates LowIL.CONS and therefore binds the indices in the EIN.body
20 :     *)
21 : cchiw 3602 structure EinToLow = struct
22 : cchiw 3444 local
23 :    
24 :     structure Var = LowIL.Var
25 :     structure E = Ein
26 : cchiw 3602 structure P = Printer
27 :     structure Iter = Iter
28 :     structure EtoSca = ScaToLow
29 :     structure EtoVec = VecToLow
30 :     structure H = Helper
31 :     structure Op = LowOps
32 : cchiw 3444
33 :     in
34 :    
35 : cchiw 3602 fun iter e = Iter.prodIter e
36 :     fun intToReal n = H.intToReal n
37 : cchiw 3444
38 :     (*dropIndex: a list-> int*a*alist
39 : cchiw 3602 * alpha::i->returns length of list-1, i, alpha
40 : cchiw 3444 *)
41 : cchiw 3602 fun dropIndex alpha = let
42 : cchiw 3553 val (e1::es) = List.rev(alpha)
43 : cchiw 3602 in (length alpha-1, e1, List.rev es)
44 : cchiw 3444 end
45 :    
46 :     (*matchLast:E.alpha*int -> (E.alpha) Option
47 :     * Is the last index of alpha E.V n.
48 :     * If so, return the rest of the list
49 :     *)
50 : cchiw 3602 fun matchLast (alpha, n) =
51 :     case List.rev(alpha)
52 :     of (E.V v)::es => (case (n = v)
53 : cchiw 3444 of true => SOME(List.rev es)
54 : cchiw 3553 | _ => NONE
55 : cchiw 3444 (*end case*))
56 :     | _ => NONE
57 :    
58 :     (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
59 : cchiw 3602 * Is the last index of alpha = n.
60 : cchiw 3444 * is n anywhere else?
61 :     *)
62 : cchiw 3553 fun matchFindLast (alpha, n) = let
63 :     val es = List.tl(List.rev(alpha))
64 :     val f = List.find(fn E.V e => e = n | _ => false) es
65 : cchiw 3444 in
66 : cchiw 3602 (matchLast(alpha, n), f)
67 : cchiw 3444 end
68 :    
69 :     (*runGeneralCase:Var*E.EIN*Var-> Var*LowIL.ASSN list
70 :     * does not do vector projections
71 :     * instead approach like a general EIN
72 :     *)
73 : cchiw 3602 fun runGeneralCase (e:Ein.ein, args:LowIL.var list) = let
74 : cchiw 3553 val index = Ein.index e
75 : cchiw 3444 in
76 : cchiw 3602 iter(AvailRHS.new(), index, index, EtoSca.generalfn, (e, args))
77 : cchiw 3444 end
78 :    
79 : cchiw 3602 fun createP((E.EIN{params ,...}, args), vecIndex, id, ix) =
80 :     VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Proj vecIndex)
81 :    
82 :     fun createI((E.EIN{params ,...}, args), id, ix) =
83 :     VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Indx)
84 :    
85 : cchiw 3444 (*handleNeg:.body* int list*info ->Var*LowIL.ASSN list
86 :     * info:(string*E.EIN*Var list)
87 :     * low-IL code for scaling a vector with negative 1.
88 :     *)
89 : cchiw 3602 fun handleNeg (E.Op1(E.Neg, E.Tensor(id , alpha)), index, info) = let
90 :     val (n, vecIX, index') = dropIndex index
91 :     in case (matchLast(alpha, n))
92 : cchiw 3444 of SOME ix1 => let
93 : cchiw 3602 val avail = AvailRHS.new()
94 :     val (avail, vA) = intToReal(avail, ~1)
95 :     val vecB = createP(info, vecIX, id, ix1)
96 :     val nextfnargs = (vA, vecIX, vecB)
97 : cchiw 3444 in
98 : cchiw 3602 iter(avail, index, index', EtoVec.negV, nextfnargs)
99 : cchiw 3444 end
100 :     | NONE => runGeneralCase info
101 :     end
102 :    
103 :     (*handleSub:E.body*int list*info ->Var*LowIL.ASSN list
104 :     * info:(string*E.EIN*Var list)
105 :     * low-IL code for subtracting two vectors
106 :     *)
107 : cchiw 3602 fun handleSub (E.Op2(E.Sub, E.Tensor(id1, alpha), E.Tensor(id2, beta)), index, info) = let
108 :     val (n, vecIX, index') = dropIndex index
109 :     in case(matchLast(alpha, n) , matchLast(beta, n)) of
110 : cchiw 3553 (SOME ix1, SOME ix2) =>let
111 : cchiw 3602 val avail = AvailRHS.new()
112 :     val vecA = createP(info, vecIX, id1, ix1)
113 :     val vecB = createP(info, vecIX, id2, ix2)
114 :     val nextfnargs = (vecIX, Op.subVec vecIX, vecA, vecB)
115 : cchiw 3444 in
116 : cchiw 3602 iter(avail, index, index', EtoVec.op2, nextfnargs)
117 : cchiw 3444 end
118 :     | _ => runGeneralCase info
119 :     end
120 :    
121 :     (*handleAdd:E.body*int list*info ->Var*LowIL.ASSN list
122 :     * info:(string*E.EIN*Var list)
123 :     * low-IL code for adding two vectors
124 :     *)
125 : cchiw 3553 fun handleAdd (E.Opn(E.Add, es), index, info) = let
126 : cchiw 3602 val (n, vecIX, index') = dropIndex index
127 : cchiw 3444 (*check that each tensor in addition list has matching indices*)
128 : cchiw 3602 fun sample ([], rest) = let
129 :     val avail = AvailRHS.new()
130 :     val nextfnargs = (vecIX, List.rev rest)
131 : cchiw 3444 in
132 : cchiw 3602 iter(avail, index, index', EtoVec.addV, nextfnargs)
133 : cchiw 3444 end
134 : cchiw 3602 | sample (E.Tensor(id1, alpha)::ts, rest) = (case (matchLast(alpha, n))
135 :     of SOME ix1 => sample(ts, createP(info, vecIX, id1, ix1)::rest)
136 : cchiw 3444 | _ => runGeneralCase info
137 :     (*end case*))
138 :     | sample _ = runGeneralCase info
139 :     in
140 : cchiw 3602 sample(es, [])
141 : cchiw 3444 end
142 :    
143 :     (*handleScale:E.tensor_id*E.tensor_id*E.alpha*int list*info ->Var*LowIL.ASSN list
144 :     * info:(string*E.EIN*Var list)
145 :     * low-IL code for adding scaling a vector
146 :     *)
147 : cchiw 3602 fun handleScale (id1, id2, alpha2, index, info) = let
148 :     val (n, vecIX, index') = dropIndex index
149 :     in case matchLast(alpha2, n)
150 :     of SOME ix2 => let
151 :     val avail = AvailRHS.new()
152 :     val vecA = createI(info, id1, [])
153 :     val vecB = createP(info, vecIX, id2, ix2)
154 :     val nextfnargs = (vecIX, Op.prodScaV vecIX, vecA, vecB)
155 : cchiw 3444 in
156 : cchiw 3602 iter(avail, index, index', EtoVec.op2, nextfnargs)
157 : cchiw 3444 end
158 : cchiw 3602 | _ =>runGeneralCase info
159 : cchiw 3444 end
160 :    
161 :     (*handleProd:E.body*int list*info ->Var*LowIL.ASSN list
162 :     * info:(string*E.EIN*Var list)
163 :     * low-IL code for vector product
164 :     *)
165 : cchiw 3602 fun handleProd (E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)]), index, info) = let
166 :     val (e, args) = info
167 :     val (n, vecIX, index') = dropIndex index
168 :     val avail = AvailRHS.new()
169 :     in case(matchFindLast(alpha, n), matchFindLast(beta, n))
170 :     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
171 :     (*n is the last index of alpha, beta and nowhere else, possible modulate*)
172 :     val vecA = createP(info, vecIX, id1, ix1)
173 :     val vecB = createP(info, vecIX, id2, ix2)
174 :     val nextfnargs = (vecIX, Op.prodVec vecIX, vecA, vecB)
175 : cchiw 3444 in
176 : cchiw 3602 iter(avail, index, index', EtoVec.op2, nextfnargs)
177 : cchiw 3444 end
178 : cchiw 3602 | ((NONE, NONE), (SOME ix2, NONE)) =>let
179 :     (*n is the last index of beta and nowhere else, possible scaleVector*)
180 :     val vecA = createI(info, id1, alpha)
181 :     val vecB = createP(info, vecIX, id2, ix2)
182 :     val nextfnargs = (vecIX, Op.prodScaV vecIX, vecA, vecB)
183 : cchiw 3444 in
184 : cchiw 3602 iter(avail, index, index', EtoVec.op2 , nextfnargs)
185 : cchiw 3444 end
186 : cchiw 3602 | ((SOME ix1, NONE), (NONE, NONE)) =>let
187 :     (*n is the last index of alpha and nowhere else, ossile scaleVector*)
188 :     val vecA = createI(info, id2, beta)
189 :     val vecB = createP(info, vecIX, id1, ix1)
190 :     val nextfnargs = (vecIX, Op.prodScaV vecIX, vecA, vecB)
191 :    
192 : cchiw 3444 in
193 : cchiw 3602 iter(avail, index, index', EtoVec.op2, nextfnargs)
194 : cchiw 3444 end
195 : cchiw 3553 | _ => runGeneralCase info
196 : cchiw 3444 end
197 :    
198 :     (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
199 :     * info:(string*E.EIN*Var list)
200 :     * low-IL code for dot product
201 :     *)
202 : cchiw 3602 fun handleSumProd1 (E.Sum([(E.V v, _, ub)], E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])), index, info) =
203 :     case(matchFindLast(alpha, v), matchFindLast(beta, v))
204 :     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
205 :     (*v is the last index of alpha, beta and nowhere else, possible sumProd*)
206 :     val avail = AvailRHS.new()
207 :     val vecIX= ub+1
208 :     val vecA = createP(info, vecIX, id1, ix1)
209 :     val vecB = createP(info, vecIX, id2, ix2)
210 :     val nextfnargs = (vecIX, vecA, vecB)
211 :     in
212 :     iter(avail, index, index, EtoVec.dotV, nextfnargs)
213 :     end
214 :     | _ => runGeneralCase info
215 :    
216 : cchiw 3444
217 :     (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
218 :     * info:(string*E.EIN*Var list)
219 :     * low-IL code for double dot product
220 : cchiw 3602 * Sigma_{i, j} A_ij B_ij
221 : cchiw 3444 *)
222 : cchiw 3602 fun handleSumProd2 (body, index, info) = let
223 :     val E.Sum([(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)], E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])) = body
224 :     fun check(v, ub, sx) = (case (matchFindLast(alpha, v), matchFindLast(beta, v))
225 :     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
226 :     (*v is the last index of alpha, beta and nowhere else, possible sumProd*)
227 :     val avail = AvailRHS.new()
228 :     (*val nextfnargs = (Ein.params e, args, sx, ub+1, id1, ix1, id2, ix2)*)
229 :     val vecIX= ub+1
230 :     val vecA = createP(info, vecIX, id1, ix1)
231 :     val vecB = createP(info, vecIX, id2, ix2)
232 :     val nextfnargs = (sx, vecIX, vecA, vecB)
233 :     in
234 :     SOME(iter(avail, index, index, EtoVec.sumDotV, nextfnargs))
235 :     end
236 :     | _ => NONE
237 :     (*end case*))
238 :     in (case check(v1, ub1, (E.V v2, lb2, ub2))
239 :     of SOME e =>e
240 :     | _ => (case check(v2, ub2, (E.V v1, lb1, ub1))
241 :     of SOME e => e
242 :     |_ => runGeneralCase info
243 : cchiw 3444 (*end case*))
244 : cchiw 3602 (*end case*))
245 :     end
246 :    
247 : cchiw 3444 (*scan:var*E.Ein*Var list * Var list-> Var*LowIL.Assgn list
248 :     *scans body for vectorization potential
249 :     *)
250 : cchiw 3553 fun scan (y, e:Ein.ein, args:LowIL.var list) = let
251 :     val b = Ein.body e
252 : cchiw 3602 val index = Ein.index e
253 :     val info = (e, args)
254 :     val all = (b, index, info)
255 :    
256 :     (*any result type*)
257 :     fun gen() = case b
258 :     of E.Sum([_], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) => handleSumProd1 all
259 :     | E.Sum([_, _], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) => handleSumProd2 all
260 :     | _ => runGeneralCase info
261 :     (*non scalar result for sure*)
262 :     fun nonScalar() = case b
263 :     of E.Op2(E.Sub, E.Tensor(_, _::_), E.Tensor(_, _::_)) => handleSub all
264 :     | E.Opn(E.Add, (E.Tensor(_, _::_)::_)) => handleAdd all
265 :     | E.Op1(E.Neg, E.Tensor(_ , _::_)) => handleNeg all
266 :     | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, j::jx)]) => handleScale(s, v, j::jx, index, info)
267 :     | E.Opn(E.Prod, [E.Tensor(v, j::jx), E.Tensor(s , [])]) => handleScale(s, v, j::jx, index, info)
268 :     | E.Opn(E.Prod, [E.Tensor(_ , _::_), E.Tensor(_, _::_)]) => handleProd all
269 :     | _ => gen()
270 :    
271 :     val (avail, _) = case index
272 :     of [3, 3] => runGeneralCase info
273 :     | [3, 3, 3] => runGeneralCase info
274 :     | _::_ => nonScalar()
275 :     | _ => gen()
276 :     in case AvailRHS.getAssignments avail
277 :     of (x, A)::rest => let (*need to reassign the last assgn*)
278 :     val code = List.rev((x, A)::rest)
279 :     val rtn = code@[(y, A)]
280 :     in List.map (fn e =>LowIL.ASSGN(e)) rtn end
281 : cchiw 3444 end
282 :    
283 :    
284 :     end (* local *)
285 :    
286 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0