Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/ein-to-low.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3174 - (view) (download)

1 : cchiw 2859 (*
2 :     * genfn-Does preliminary scan of the body of EIN.EIN for vectorization potential
3 :     * If there is a field then passes to FieldToLow
4 :     * If there is a tensor then passes to handle*() functions to check if indices match
5 :     * i.e. <A_ij+B_ij>_ij vs.<A_ji+B_ij>_ij
6 :     *
7 :     * (1) If indices match then passes to Iter->VecToLow functions.
8 :     * Creates LowIL vector operators.
9 :     * (2) Iter->ScaToLow
10 :     * Creates Low-IL scalar operators
11 :     * Note. The Iter function creates LowIL.CONS and therefore binds the indices in the EIN.body
12 :     *)
13 :     structure EinToLow = struct
14 :     local
15 :    
16 :     structure Var = LowIL.Var
17 :     structure E = Ein
18 :     structure P=Printer
19 :     structure Iter=Iter
20 :     structure EtoFld= FieldToLow
21 :     structure EtoSca= ScaToLow
22 :     structure EtoVec= VecToLow
23 :     structure H=Helper
24 :    
25 :     in
26 :    
27 :     fun iter e=Iter.prodIter e
28 :     fun evalField e= EtoFld.evalField e
29 : cchiw 2870 fun intToReal n=H.intToReal n
30 : cchiw 2859
31 : cchiw 2870 fun testp p=print(String.concat p)
32 :    
33 : cchiw 2859 (*dropIndex: a list-> int*a*alist
34 :     * alpha::i->returns length of list-1,i,alpha
35 :     *)
36 :     fun dropIndex alpha=let
37 :     val (e1::es)=List.rev(alpha)
38 :     in (length alpha-1,e1,List.rev es)
39 :     end
40 :    
41 :     (*matchLast:E.alpha*int -> (E.alpha) Option
42 :     * Is the last index of alpha E.V n.
43 :     * If so, return the rest of the list
44 :     *)
45 :     fun matchLast(alpha, n)=let
46 :     val (e1::es)=List.rev(alpha)
47 :     in (case e1
48 :     of E.V v =>(case (n=v)
49 :     of true => SOME(List.rev es)
50 :     |_ => NONE
51 :     (*end case*))
52 :     | _ => NONE
53 :     (*end case*))
54 :     end
55 :    
56 :     (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
57 :     * Is the last index of alpha =n.
58 :     * is n anywhere else?
59 :     *)
60 :     fun matchFindLast(alpha, n)=let
61 :     val es=List.tl(List.rev(alpha))
62 :     val f=List.find(fn E.V e=>e=n|_=>false) es
63 :     in
64 :     (matchLast(alpha,n),f)
65 :     end
66 :    
67 :     (*runGeneralCase:Var*E.EIN*Var-> Var*LowIL.ASSN list
68 :     * does not do vector projections
69 :     * instead approach like a general EIN
70 :     *)
71 :     fun runGeneralCase info=let
72 :     val (lhs,e,args)=info
73 :     val index=Ein.index e
74 :     in
75 :     iter(index,index,EtoSca.generalfn,info)
76 :     end
77 :    
78 :     (*handleNeg:.body* int list*info ->Var*LowIL.ASSN list
79 :     * info:(string*E.EIN*Var list)
80 :     * low-IL code for scaling a vector with negative 1.
81 :     *)
82 :     fun handleNeg(E.Neg(E.Tensor(id ,alpha)),index,info)=let
83 :     val (n,vecIndex,index')=dropIndex index
84 :     in (case (matchLast(alpha,n))
85 :     of SOME ix1 => let
86 : cchiw 2870 val (vA,A)= intToReal ~1
87 : cchiw 2859 val (lhs,e,args)=info
88 :     val nextfnargs=(lhs,Ein.params e,args,vecIndex, vA, id,ix1)
89 :     val (vB,B)=iter(index,index',EtoVec.negV,nextfnargs)
90 :     in
91 :     (vB,A@B)
92 :     end
93 :     | NONE => runGeneralCase info
94 :     (*end case*))
95 :     end
96 :    
97 :     (*handleSub:E.body*int list*info ->Var*LowIL.ASSN list
98 :     * info:(string*E.EIN*Var list)
99 :     * low-IL code for subtracting two vectors
100 :     *)
101 :     fun handleSub(E.Sub(E.Tensor(id1,alpha),E.Tensor(id2,beta)),index,info)=let
102 :     val (n,vecIndex,index')=dropIndex index
103 :     in (case(matchLast(alpha,n) , matchLast(beta,n)) of
104 :     (SOME ix1,SOME ix2)=>let
105 :     val (lhs,e,args)=info
106 :    
107 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,ix1,id2,ix2)
108 :     in
109 :     iter(index,index',EtoVec.subV,nextfnargs)
110 :     end
111 :     | _ => runGeneralCase info
112 :     (*end case*))
113 :     end
114 :    
115 :     (*handleAdd:E.body*int list*info ->Var*LowIL.ASSN list
116 :     * info:(string*E.EIN*Var list)
117 :     * low-IL code for adding two vectors
118 :     *)
119 :     fun handleAdd(E.Add es,index,info)=let
120 :     val (n,vecIndex,index')=dropIndex index
121 :     (*check that each tensor in addition list has matching indices*)
122 :     fun sample([],rest)=let
123 :     val (lhs,e,args)=info
124 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,rest)
125 :     in
126 :     iter(index,index',EtoVec.addV,nextfnargs)
127 :     end
128 :     | sample(E.Tensor(id,alpha)::ts,rest) =(case (matchLast(alpha,n))
129 :     of SOME ix1 => sample(ts,rest@[(id,ix1)])
130 :     | _ => runGeneralCase info
131 :     (*end case*))
132 :     | sample _ = runGeneralCase info
133 :     in
134 :     sample(es,[])
135 :     end
136 :    
137 :     (*handleScale:E.tensor_id*E.tensor_id*E.alpha*int list*info ->Var*LowIL.ASSN list
138 :     * info:(string*E.EIN*Var list)
139 :     * low-IL code for adding scaling a vector
140 :     *)
141 :     fun handleScale(id1,id2,alpha2,index,info)=let
142 :     val (n,vecIndex,index')=dropIndex index
143 :     in (case matchLast(alpha2,n)
144 :     of SOME ix2=> let
145 :     val (lhs,e,args)=info
146 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,[],id2,ix2)
147 :     in
148 :     iter(index,index',EtoVec.scaleV,nextfnargs)
149 :     end
150 :     | _=>runGeneralCase info
151 :     (*end case*))
152 :     end
153 :    
154 :     (*handleProd:E.body*int list*info ->Var*LowIL.ASSN list
155 :     * info:(string*E.EIN*Var list)
156 :     * low-IL code for vector product
157 :     *)
158 :     fun handleProd(E.Prod[E.Tensor(id1 , alpha), E.Tensor(id2, beta)],index,info)=let
159 :     val (lhs,e,args)=info
160 :     val (n,vecIndex,index')=dropIndex index
161 :     in (case(matchFindLast(alpha,n),matchFindLast(beta,n))
162 :     of ((SOME ix1,NONE),(SOME ix2,NONE)) => let
163 :     (*n is the last index of alpha, beta and nowhere else,possible modulate*)
164 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,ix1,id2,ix2)
165 :     in
166 :     iter(index,index',EtoVec.prodV,nextfnargs)
167 :     end
168 :     | ((NONE,NONE),(SOME ix2,NONE)) =>let
169 :     (*n is the last index of beta and nowhere else,possible scaleVector*)
170 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,id1,alpha,id2,ix2)
171 :     in
172 :     iter(index,index',EtoVec.scaleV,nextfnargs)
173 :     end
174 :     | ((SOME ix1,NONE),(NONE,NONE)) =>let
175 :     (*n is the last index of alpha and nowhere else,ossile scaleVector*)
176 :     val nextfnargs=(lhs,Ein.params e, args,vecIndex,id2,beta,id1,ix1)
177 :     in
178 :     iter(index,index',EtoVec.scaleV,nextfnargs)
179 :     end
180 :     | _ =>runGeneralCase info
181 :     (*end case*))
182 :     end
183 :    
184 :     (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
185 :     * info:(string*E.EIN*Var list)
186 :     * low-IL code for dot product
187 :     *)
188 :     fun handleSumProd1(E.Sum([(E.V v,_,ub)],E.Prod[E.Tensor(id1 , alpha), E.Tensor(id2, beta)]),index,info)=(case(matchFindLast(alpha,v),matchFindLast(beta,v))
189 :     of ((SOME ix1,NONE),(SOME ix2,NONE)) => let
190 :     (*v is the last index of alpha, beta and nowhere else,possible sumProd*)
191 :     val (lhs,e,args)=info
192 :     val nextfnargs=(lhs,Ein.params e, args,ub+1,id1,ix1,id2,ix2)
193 :     in
194 :     iter(index,index,EtoVec.dotV,nextfnargs)
195 :     end
196 :     | _ =>runGeneralCase info
197 :     (*end case*))
198 :    
199 :     (*handleSumProd:E.body*int list*info ->Var*LowIL.ASSN list
200 :     * info:(string*E.EIN*Var list)
201 :     * low-IL code for double dot product
202 :     * Sigma_{i,j} A_ij B_ij
203 :     *)
204 :     fun handleSumProd2(E.Sum([(E.V v1,lb1,ub1),(E.V v2,lb2,ub2)],E.Prod[E.Tensor(id1 , alpha), E.Tensor(id2, beta)]),index,info)=let
205 :     fun check(v,ub,sx)=(case(matchFindLast(alpha,v),matchFindLast(beta,v))
206 :     of ((SOME ix1,NONE),(SOME ix2,NONE)) => let
207 :     (*v is the last index of alpha, beta and nowhere else,possible sumProd*)
208 :     val (lhs,e,args)=info
209 :     val nextfnargs=(lhs,Ein.params e, args,sx,ub+1,id1,ix1,id2,ix2)
210 :     in
211 :     SOME(iter(index,index,EtoVec.sumDotV,nextfnargs))
212 :     end
213 :     | _=> NONE
214 :     (*end case*))
215 :     in (case check(v1,ub1,(E.V v2,lb2,ub2))
216 :     of SOME e=>e
217 :     | _=> (case check(v2,ub2,(E.V v1,lb1,ub1))
218 :     of SOME e=> e
219 :     |_ =>runGeneralCase info
220 :     (*end case*))
221 :     (*end case*))
222 :     end
223 :    
224 :     (*scan:var*E.Ein*Var list * Var list-> Var*LowIL.Assgn list
225 :     *scans body for vectorization potential
226 :     *)
227 :     fun scan(y,e,args)= let
228 :     val lhs=LowIL.Var.name y
229 :     val b=Ein.body e
230 :     val index=Ein.index e
231 :     val info=(lhs,e,args)
232 :     val all=(b,index,info)
233 : cchiw 3054 fun gen body=(case (index,body)
234 :     of (_::es,E.Neg(E.Tensor(_ ,i::ix))) =>
235 : cchiw 2859 handleNeg all
236 : cchiw 3054 | (_::es,E.Sub(E.Tensor(_,i::ix),E.Tensor(_,j::jx))) =>
237 : cchiw 2859 handleSub all
238 : cchiw 3054 | (_::es, E.Add(E.Tensor(_,i::ix)::_)) =>
239 : cchiw 2859 handleAdd all
240 : cchiw 3054 | (_::es, E.Prod[E.Tensor(s, []), E.Tensor(v, j::jx)]) =>
241 : cchiw 2859 handleScale(s,v,j::jx,index,info)
242 : cchiw 3054 | (_::es,E.Prod[E.Tensor(v, j::jx), E.Tensor(s , [])]) =>
243 : cchiw 2859 handleScale(s,v,j::jx,index,info)
244 : cchiw 3054 | (_::es,E.Prod[E.Tensor(_ , i::ix), E.Tensor(_, j::jx)]) =>
245 : cchiw 2859 handleProd all
246 : cchiw 3054 | ( _,E.Sum([_], E.Prod[E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
247 : cchiw 2859 handleSumProd1 all
248 : cchiw 3054 | ( _ ,E.Sum([_,_], E.Prod[E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
249 : cchiw 2859 handleSumProd2 all
250 : cchiw 3054 | ( _ ,E.Sum(x,E.Prod(E.Img(Vid,_,_)::E.Krn(Hid,_,_)::_))) =>
251 : cchiw 3174 (
252 :     iter(index,index,evalField,(body,info)))
253 : cchiw 3054 | (_,_ )=> runGeneralCase info
254 : cchiw 2859 (*end case*))
255 : cchiw 3174
256 : cchiw 2859 val (_,code) =gen b
257 : cchiw 3174
258 : cchiw 2923 in (case code
259 :     of(* []=> []
260 :     | *) _=> let (*need to reassign the last assgn*)
261 :     val LowIL.ASSGN (a1,A)=List.hd(List.rev(code))
262 :     val c=LowIL.ASSGN (y,A)
263 :     in
264 :     code@[c]
265 :     end
266 :     (*end case*))
267 :     end
268 :    
269 : cchiw 2859
270 :    
271 :    
272 :     end (* local *)
273 :    
274 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0