Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/ein16/src/compiler/low-il/ein-to-low.sml
ViewVC logotype

Annotation of /branches/ein16/src/compiler/low-il/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5244 - (view) (download)

1 : cchiw 3541 (*
2 :     * genfn-Does preliminary scan of the body of EIN.EIN for vectorization potential
3 :     * If there is a field then passes to FieldToLow
4 :     * If there is a tensor then passes to handle*() functions to check if indices match
5 :     * i.e. <A_ij+B_ij>_ij vs.<A_ji+B_ij>_ij
6 :     *
7 :     * (1) If indices match then passes to Iter->VecToLow functions.
8 :     * Creates LowIL vector operators.
9 :     * (2) Iter->ScaToLow
10 :     * Creates Low-IL scalar operators
11 :     * Note. The Iter function creates LowIL.CONS and therefore binds the indices in the EIN.body
12 :     *)
13 :     structure EinToLowSet = struct
14 :     local
15 :    
16 :     structure Var = LowIL.Var
17 :     structure E = Ein
18 :     structure P=Printer
19 :     structure Iter=IterSet
20 :     structure EtoFld= FieldToLowSet
21 :     structure EtoSca= ScaToLowSet
22 :     structure EtoVec= VecToLowSet
23 :     structure H=HelperSet
24 :    
25 :     in
26 :    
27 :     fun iter e=Iter.prodIter e
28 :     fun evalField e= EtoFld.evalField e
29 :     fun intToReal n=H.intToReal n
30 :     fun testp p=(String.concat p)
31 : cchiw 3687 val scaFlag= ref false
32 : cchiw 3541
33 : cchiw 3687 val controls = [("scaFlag",scaFlag,"scaFlag")]
34 :    
35 : cchiw 4698 fun alpha_toStr([]) = ""
36 :     | alpha_toStr(E.V(e1)::es) = String.concat["V", Int.toString(e1), alpha_toStr(es)]
37 :     | alpha_toStr(E.C(e1,_)::es) = String.concat["C", Int.toString(e1), alpha_toStr(es)]
38 :    
39 : cchiw 3541 (*dropIndex: a list-> int*a*alist
40 :     * alpha::i->returns length of list-1,i,alpha
41 :     *)
42 :     fun dropIndex alpha=let
43 :     val (e1::es)=List.rev(alpha)
44 :     in (length alpha-1,e1,List.rev es)
45 :     end
46 :    
47 :     (*matchLast:E.alpha*int -> (E.alpha) Option
48 :     * Is the last index of alpha E.V n.
49 :     * If so, return the rest of the list
50 :     *)
51 :     fun matchLast(alpha, n)=let
52 : cchiw 4698 val _ = (String.concat["alpha:", alpha_toStr(alpha), "n:", Int.toString(n)])
53 : cchiw 3541 val (e1::es)=List.rev(alpha)
54 :     in (case e1
55 :     of E.V v =>(case (n=v)
56 :     of true => SOME(List.rev es)
57 :     |_ => NONE
58 :     (*end case*))
59 :     | _ => NONE
60 :     (*end case*))
61 :     end
62 :    
63 :     (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
64 : cchiw 4698 * Is the last index of alpha = n.
65 : cchiw 3541 * is n anywhere else?
66 :     *)
67 :     fun matchFindLast(alpha, n)=let
68 :     val es=List.tl(List.rev(alpha))
69 :     val f=List.find(fn E.V e=>e=n|_=>false) es
70 :     in
71 :     (matchLast(alpha,n),f)
72 :     end
73 :    
74 :     (*runGeneralCase:Var*E.EIN*Var-> Var*LowIL.ASSN list
75 :     * does not do vector projections
76 :     * instead approach like a general EIN
77 :     *)
78 :    
79 :     fun runGeneralCase(lhs:string,e:Ein.ein,args:LowIL.var list)=let
80 :     val info=(lhs,e,args)
81 :     val index=Ein.index e
82 :     val opset= lowSet.LowSet.empty
83 :    
84 :     val rtn= iter(opset,index,index,EtoSca.generalfn,info)
85 :     val (_,_,code)=rtn
86 :     val n= length(code)
87 : cchiw 3655
88 : cchiw 3541 (*
89 :     val _ =if (n>10) then print(String.concat["\n Gen(",Int.toString(n),")",P.printerE(e)]) else print""*)
90 :     in rtn end
91 :    
92 :    
93 :     (*handleNeg:.body* int list*info ->Var*LowIL.ASSN list
94 :     * info:(string*E.EIN*Var list)
95 :     * low-IL code for scaling a vector with negative 1.
96 :     *)
97 :     fun handleNeg(E.Op1(E.Neg,E.Tensor(id ,alpha)),index,info)=let
98 :     val (n,vecIndex,index')=dropIndex index
99 :     in (case (matchLast(alpha,n))
100 :     of SOME ix1 => let
101 :     val setT= lowSet.LowSet.empty
102 :     val (setA,vA,A)= intToReal( setT, ~1)
103 :     val (lhs,e,args)=info
104 :     val nextfnargs=(lhs,Ein.params e,args,vecIndex, vA, id,ix1)
105 :     val (setB,vB,B)=iter(setA,index,index',EtoVec.negV,nextfnargs)
106 :     in
107 :     (setB,vB,A@B)
108 :     end
109 :     | NONE => runGeneralCase info
110 :     (*end case*))
111 :     end
112 :    
113 :     (*handleSub:E.body*int list*info ->Var*LowIL.ASSN list
114 :     * info:(string*E.EIN*Var list)
115 :     * low-IL code for subtracting two vectors
116 :     *)
117 :     fun handleSub(E.Op2(E.Sub,E.Tensor(id1,alpha),E.Tensor(id2,beta)),index,info)=let
118 : cchiw 5244
119 : cchiw 3541 val (n,vecIndex,index')=dropIndex index
120 :     in (case(matchLast(alpha,n) , matchLast(beta,n)) of
121 :     (SOME ix1,SOME ix2)=>let
122 :     val (lhs,e,args)=info
123 :     val setT= lowSet.LowSet.empty
124 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,ix1,id2,ix2)
125 : cchiw 5244 val _ =(String.concat["\nsubtraction:",P.printerE(e),String.concatWith","(List.map LowIL.Var.toString args)])
126 : cchiw 3541 in
127 :     iter(setT,index,index',EtoVec.subV,nextfnargs)
128 :     end
129 :     | _ => runGeneralCase info
130 :     (*end case*))
131 :     end
132 :    
133 :    
134 :     (*handleAdd:E.body*int list*info ->Var*LowIL.ASSN list
135 :     * info:(string*E.EIN*Var list)
136 :     * low-IL code for adding two vectors
137 :     *)
138 :     fun handleAdd(E.Opn(E.Add, es),index,info)=let
139 :     val (n,vecIndex,index')=dropIndex index
140 :     (*check that each tensor in addition list has matching indices*)
141 :     fun sample([],rest)=let
142 :     val (lhs,e,args)=info
143 :     val setT= lowSet.LowSet.empty
144 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,rest)
145 :     in
146 :     iter(setT,index,index',EtoVec.addV,nextfnargs)
147 :     end
148 :     | sample(E.Tensor(id,alpha)::ts,rest) =(case (matchLast(alpha,n))
149 :     of SOME ix1 => sample(ts,rest@[(id,ix1)])
150 :     | _ => runGeneralCase info
151 :     (*end case*))
152 :     | sample _ = runGeneralCase info
153 :     in
154 :     sample(es,[])
155 :     end
156 :    
157 :    
158 :    
159 :     (*handleScale:E.tensor_id*E.tensor_id*E.alpha*int list*info ->Var*LowIL.ASSN list
160 :     * info:(string*E.EIN*Var list)
161 :     * low-IL code for adding scaling a vector
162 :     *)
163 :     fun handleScale(id1,id2,alpha2,index,info)=let
164 : cchiw 4698 val _ = "\n *******inside handle scale"
165 : cchiw 3541 val (n,vecIndex,index')=dropIndex index
166 : cchiw 4698 val _ = "post drop index"
167 :     (*in (case matchLast(alpha2,n)*)
168 :     (* can we vectorize?
169 :     * check to see if n is the last index and is not repeated in alpha
170 :     *)
171 :     in (case matchFindLast(alpha2, n)
172 :     of (SOME ix2, NONE) => let
173 :     val _ = "post match last\n"
174 : cchiw 3541 val (lhs,e,args)=info
175 :     val setT= lowSet.LowSet.empty
176 : cchiw 4698 val _ = (P.printbody(Ein.body e))
177 : cchiw 3541 val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,[],id2,ix2)
178 : cchiw 4698
179 : cchiw 3541 in
180 :     iter(setT,index,index',EtoVec.scaleV,nextfnargs)
181 :     end
182 :     | _=>runGeneralCase info
183 :     (*end case*))
184 :     end
185 :    
186 :     (*handleProd:E.body*int list*info ->Var*LowIL.ASSN list
187 :     * info:(string*E.EIN*Var list)
188 :     * low-IL code for vector product
189 :     *)
190 :     fun handleProd(E.Opn(E.Prod,[E.Tensor(id1 , alpha), E.Tensor(id2, beta)]),index,info)=let
191 :     val (lhs,e,args)=info
192 :     val (n,vecIndex,index')=dropIndex index
193 :     val setT= lowSet.LowSet.empty
194 : cchiw 3663 (*val _ =print(String.concat["\nproduct:",P.printerE(e),String.concatWith","(List.map LowIL.Var.toString args)])*)
195 : cchiw 3662
196 : cchiw 3541 in (case(matchFindLast(alpha,n),matchFindLast(beta,n))
197 :     of ((SOME ix1,NONE),(SOME ix2,NONE)) => let
198 :     (*n is the last index of alpha, beta and nowhere else,possible modulate*)
199 :    
200 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,ix1,id2,ix2)
201 :     in
202 :     iter(setT,index,index',EtoVec.prodV,nextfnargs)
203 :     end
204 :     | ((NONE,NONE),(SOME ix2,NONE)) =>let
205 :     (*n is the last index of beta and nowhere else,possible scaleVector*)
206 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,alpha,id2,ix2)
207 :     in
208 :     iter(setT,index,index',EtoVec.scaleV,nextfnargs)
209 :     end
210 :     | ((SOME ix1,NONE),(NONE,NONE)) =>let
211 : cchiw 4698 (*n is the last index of alpha and nowhere else,possible scaleVector*)
212 : cchiw 3541 val nextfnargs=(lhs,Ein.params e, args,vecIndex,id2,beta,id1,ix1)
213 :     in
214 :     iter(setT,index,index',EtoVec.scaleV,nextfnargs)
215 :     end
216 :     | _ =>runGeneralCase info
217 :     (*end case*))
218 :     end
219 :    
220 :     (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
221 :     * info:(string*E.EIN*Var list)
222 :     * low-IL code for dot product
223 :     *)
224 :     fun (*handleSumProd1(E.Sum([(E.V 1,_,ub)],E.Opn(E.Prod,[E.Tensor(id1 , [E.V 1]), E.Tensor(id2, [E.V 1,E.V 0])]) ),[i],info)=let
225 :     val (lhs,e,args)=info
226 :     val setT= lowSet.LowSet.empty
227 :     val nextfnargs=(lhs,Ein.params e, args,ub+1,id1,[],id2,[E.V 0])
228 :     in
229 :     iter(setT,[i],[i],EtoVec.VM,nextfnargs)
230 :     end
231 :    
232 :     | handleSumProd1(E.Sum([(E.V 2,_,ub)],E.Prod[E.Tensor(id1 , [E.V 0,E.V 2]), E.Tensor(id2, [E.V 2,E.V 1])]),index as [_,_],info)=let
233 :     val (lhs,e,args)=info
234 :     val setT= lowSet.LowSet.empty
235 :     val nextfnargs=(lhs,Ein.params e, args,ub+1,id1,[E.V 0],id2,[E.V 1],index)
236 :     val _ ="\nuses projFirst"
237 :     in
238 :     iter(setT,index,index,EtoVec.MM3,nextfnargs)
239 :     end
240 :    
241 :     |*) handleSumProd1(E.Sum([(E.V v,_,ub)],E.Opn(E.Prod,[E.Tensor(id1 , alpha), E.Tensor(id2, beta)])),index,info)=
242 :     (case(matchFindLast(alpha,v),matchFindLast(beta,v))
243 :     of ((SOME ix1,NONE),(SOME ix2,NONE)) => let
244 :     (*v is the last index of alpha, beta and nowhere else,possible sumProd*)
245 :     val (lhs,e,args)=info
246 :     val setT= lowSet.LowSet.empty
247 :     val nextfnargs=(lhs,Ein.params e, args,ub+1,id1,ix1,id2,ix2)
248 :     in
249 :     iter(setT,index,index,EtoVec.dotV,nextfnargs)
250 :     end
251 :     | _ =>runGeneralCase info
252 :     (*end case*))
253 :    
254 :     (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
255 :     * info:(string*E.EIN*Var list)
256 :     * low-IL code for double dot product
257 :     * Sigma_{i,j} A_ij B_ij
258 :     *)
259 :     fun handleSumProd2(E.Sum([(E.V v1,lb1,ub1),(E.V v2,lb2,ub2)],E.Opn(E.Prod,[E.Tensor(id1 , alpha), E.Tensor(id2, beta)])),index,info)=let
260 :     fun check(v,ub,sx)=(case(matchFindLast(alpha,v),matchFindLast(beta,v))
261 :     of ((SOME ix1,NONE),(SOME ix2,NONE)) => let
262 :     (*v is the last index of alpha, beta and nowhere else,possible sumProd*)
263 :     val (lhs,e,args)=info
264 :     val setT= lowSet.LowSet.empty
265 :     val nextfnargs=(lhs,Ein.params e, args,sx,ub+1,id1,ix1,id2,ix2)
266 :     in
267 :     SOME(iter(setT,index,index,EtoVec.sumDotV,nextfnargs))
268 :     end
269 :     | _=> NONE
270 :     (*end case*))
271 :     in (case check(v1,ub1,(E.V v2,lb2,ub2))
272 :     of SOME e=>e
273 :     | _=> (case check(v2,ub2,(E.V v1,lb1,ub1))
274 :     of SOME e=> e
275 :     |_ =>runGeneralCase info
276 :     (*end case*))
277 :     (*end case*))
278 :     end
279 :    
280 :    
281 :     (*scan:var*E.Ein*Var list * Var list-> Var*LowIL.Assgn list
282 :     *scans body for vectorization potential
283 :     *)
284 :     fun q(y,e:Ein.ein,args:LowIL.var list)= let
285 :     val lhs=LowIL.Var.name y
286 :     val b=Ein.body e
287 :     val index=Ein.index e
288 :     val info=(lhs,e,args)
289 :     val all=(b,index,info)
290 : cchiw 4321 (* val _ =print(String.concat["\n\n*** ", lhs,"=", P.printerE(e), String.concatWith"," (List.map (fn e=> LowIL.Var.name(e))args )])*)
291 : cchiw 3655 fun gen body=(case ([3,4],body)
292 : cchiw 4321 of (_::es,E.Op2(E.Sub,E.Tensor(_,(E.V _ )::ix),E.Tensor(_,(E.V _)::jx)))
293 : cchiw 3541 => handleSub all
294 : cchiw 3655
295 : cchiw 4215 | (_::es, E.Opn(E.Add,(E.Tensor(_,E.V _::ix)::_)))
296 : cchiw 3541 => handleAdd all
297 : cchiw 3655
298 : cchiw 4321 | (_::es,E.Op1(E.Neg,E.Tensor(_ ,(E.V _)::ix)))
299 : cchiw 3541 => handleNeg all
300 : cchiw 3655
301 : cchiw 4236 | (_::es, E.Opn(E.Prod,[E.Tensor(s, []), E.Tensor(v, (j as E.V _)::jx)]))
302 : cchiw 3541 => handleScale(s,v,j::jx,index,info)
303 : cchiw 3655
304 : cchiw 4236 | (_::es,E.Opn(E.Prod,[E.Tensor(v, (j as E.V _)::jx), E.Tensor(s , [])]))
305 : cchiw 3541 => handleScale(s,v,j::jx,index,info)
306 : cchiw 3655
307 : cchiw 4236 | (_::es,E.Opn(E.Prod,[E.Tensor(_ , (E.V _)::ix), E.Tensor(_, (E.V _)::jx)]))
308 : cchiw 3541 => handleProd all
309 : cchiw 3662
310 : cchiw 3663
311 : cchiw 4236 | ( _,E.Sum([_], E.Opn(E.Prod,[E.Tensor(_ , ( E.V _)::_), E.Tensor(_, ( E.V _)::_)])))
312 : cchiw 3541 => handleSumProd1 all
313 : cchiw 4236 | ( _ ,E.Sum([_,_],E.Opn( E.Prod,[E.Tensor(_ , ( E.V _)::_), E.Tensor(_, ( E.V _)::_)])))
314 : cchiw 3541 => handleSumProd2 all
315 : cchiw 3663
316 : cchiw 3596 | (_,_ )=> runGeneralCase info
317 : cchiw 3541 (*end case*))
318 :    
319 :    
320 : cchiw 3655 fun scanSize body=(case (List.rev index,body)
321 :     of (3::_,_) => runGeneralCase info
322 : cchiw 4236 | ( _,E.Sum([(_,0,2)], E.Opn(E.Prod,[E.Tensor(_ , (E.V _)::_), E.Tensor(_, (E.V _)::_)])))
323 : cchiw 3655 => runGeneralCase info
324 : cchiw 4236 | ( _ ,E.Sum([(_,0,2),(_,0,2)],E.Opn( E.Prod,[E.Tensor(_ , (E.V _)::_), E.Tensor(_, (E.V _)::_)])))
325 : cchiw 3655 => runGeneralCase info
326 :    
327 : cchiw 3541 | (_, E.Opn(E.Prod,E.Tensor(_,_::_::_)::_))=> runGeneralCase info
328 :     | (_,E.Sum(_,E.Opn(E.Prod,(E.Tensor(_,_::_::_)::_))))=> runGeneralCase info
329 :     | (_,E.Sum(_,E.Opn(E.Prod,(_::E.Tensor(_,_::_::_)::_))))=> runGeneralCase info
330 :     | (_,_ )=> gen b
331 :     (*end case*))
332 :     in
333 : cchiw 3687 if (!scaFlag) then ( "ein-to-low sca";runGeneralCase info) else ("ein-to-low vec:"; gen b)
334 : cchiw 3541 end
335 :    
336 :     end (* local *)
337 :    
338 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0