Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/gen-ein.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/gen-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2525 - (view) (download)

1 : cchiw 2522 (*hashs Ein Function after substitution*)
2 :     structure genEin = struct
3 :     local
4 :     structure E = Ein
5 :     structure genHelper=genHelper
6 :     structure genKrn=genKrn
7 :    
8 :    
9 :    
10 :     structure DstIL = LowIL
11 :     structure DstTy = LowILTypes
12 :     structure DstOp = LowOps
13 :     structure Var = LowIL.Var
14 :     in
15 :    
16 :    
17 :    
18 :     (*Iterate over the outside index*)
19 :     (*nextfn is the next function*)
20 :     (*m are arguements*)
21 : cchiw 2525 fun prodIter(origIndex,index,nextfn,args)=let
22 : cchiw 2522
23 :     val index'=List.map (fn (e)=>(e-1)) index
24 : cchiw 2525 fun M(mapp,[],rest,code,shape)=let
25 : cchiw 2522 val (vF,code')=nextfn(mapp,args)
26 :     in (vF, code'@code)
27 :     end
28 : cchiw 2525 | M(a,[0], rest, code,shape)=let
29 : cchiw 2522 val mapp=a@[0]
30 :     val (vF,code')=nextfn(mapp,args)
31 : cchiw 2525 val(vE,E)=genHelper.aaV(DstOp.cons(DstTy.TensorTy shape),[vF]@rest,"Cons",DstTy.TensorTy(shape))
32 : cchiw 2522 in (vE, code'@code@E)
33 :     end
34 : cchiw 2525 | M(a,[c],rest,code,shape)=let
35 : cchiw 2522 val mapp=a@[c]
36 :     val (vE,E)=nextfn(mapp,args)
37 : cchiw 2525 in M(a, [c-1], [vE]@rest,E@code,shape) end
38 :     | M (a,b::c,rest,ccode,s::shape)=let
39 : cchiw 2522 fun S(0, rest,code)=let
40 : cchiw 2525 val (v',code')=M(a@[0],c,[],[],shape)
41 :     val(vA,A)=genHelper.aaV(DstOp.cons(DstTy.TensorTy (s::shape)),[v']@rest,"Cons",DstTy.TensorTy(s::shape))
42 : cchiw 2522 in (vA, code'@code@A) end
43 :     | S(i, rest, code)= let
44 : cchiw 2525 val (v',code')=M(a@[i],c,[],[],shape)
45 : cchiw 2522 in S(i-1,[v']@rest,code'@code) end
46 :     val (vA,code')=S(b, [],[])
47 :     in (vA,code'@ccode) end
48 : cchiw 2525 val (rest',code')= M([],index',[],[],origIndex)
49 : cchiw 2522 in (rest',code')
50 :     end
51 :    
52 :    
53 :    
54 :     (* general expressions*)
55 : cchiw 2525 fun generalfn(ap,(body,_,origargs, args))= let
56 : cchiw 2522 val a=print "in general fn"
57 :     val mappA= ref ap
58 :    
59 :     fun gen body=(case body
60 :     of E.Field _ =>raise Fail(concat["Invalid Field here "] )
61 :     | E.Partial _ =>raise Fail(concat["Invalid FieldPartial here "] )
62 :     | E.Apply _ =>raise Fail(concat["Invalid FieldApply here "] )
63 :     | E.Probe _ =>raise Fail(concat["Invalid FieldProbe here "] )
64 :     | E.Conv _ =>raise Fail(concat["Invalid FieldConv here "] )
65 :     | E.Krn _ =>raise Fail(concat["Invalid FieldKrn here "] )
66 :     | E.Img _=> raise Fail(concat["Invalid FieldImg here "] )
67 :     | E.Value v =>let
68 :     val ref mapp=mappA
69 :     val n=List.nth( mapp,v)
70 :     in genHelper.mkC n end
71 :    
72 :     (*| E.Const c=> []*)
73 :     | E.Tensor(id,[])=> genHelper.mkSca([],(id,[],args))
74 :     | E.Tensor(id,ix)=> let val ref mapp=mappA
75 :     in genHelper.mkSca(mapp,(id,ix,args)) end
76 :     | E.Delta(i,j)=> let val ref mapp=mappA
77 :     in genHelper.evalDelta2(i,j,mapp) end
78 :     | E.Epsilon(i,j,k)=> let
79 :     val ref mapp=mappA
80 :     val n=genHelper.evalEps(i,j,k,mapp)
81 : cchiw 2525 in genHelper.aaV(DstOp.C(n),[],"Const",DstTy.TensorTy([])) end
82 : cchiw 2522 | E.Neg e => let
83 :     val (vA,A)=gen e
84 :     val s=genHelper.skeleton A
85 :     in (case s
86 :     of 0 => (vA,A)
87 :     | ~1 => genHelper.mkC 1
88 :     | 1 => genHelper.mkC ~1
89 :     | _=> let
90 : cchiw 2525 val (vB,B)=genHelper.aaV(DstOp.C (~1),[],"Const",DstTy.TensorTy([]))
91 :     val (vD,D)=genHelper.aaV(DstOp.prodSca,[vB,vA],"prodSca",DstTy.TensorTy([]))
92 : cchiw 2522 in (vD,A@B@D) end
93 :     (*end case*))
94 :     end
95 :    
96 :     | E.Add e=> let
97 :     (*check0 function removes 0 from list *)
98 : cchiw 2525 fun checkO ([],[],[])=let val (vA,A)=genHelper.aaV(DstOp.C(1),[],"Const",DstTy.TensorTy([])) in ([vA],A) end
99 : cchiw 2522 | checkO(ids,code,[])=(ids,code)
100 :     | checkO(ids,code, e1::es)=let
101 :     val (a,b)=gen e1
102 :     val s=genHelper.skeleton b
103 :     in (case s
104 :     of 0 => checkO(ids,code,es)
105 :     | _ => checkO(ids@[a],code@b,es)
106 :     (*end case*))
107 :     end
108 :     val (ids,code)=checkO([],[],e)
109 :     in (case ids
110 :     of [id1]=> (id1,code)
111 : cchiw 2525 | _=>let val (vB,B)=genHelper.mkMultiple(ids,DstOp.addSca,DstTy.TensorTy([]))
112 : cchiw 2522 in (vB,code@B) end
113 :     (*end case*))
114 :     end
115 :    
116 :     | E.Sub (e1,e2)=>let
117 :     val (vA,A)=gen e1
118 :     val (vB,B)=gen e2
119 :     val sA=genHelper.skeleton A
120 :     val sB=genHelper.skeleton B
121 :    
122 :     (* checks if either expression evaluates to 0*)
123 :     in (case (sA,sB)
124 :     of (0,0)=> genHelper.mkC 0
125 :     |(0,_)=> let
126 :     val (vD,D)= genHelper.mkC ~1
127 : cchiw 2525 val (vE,E)= genHelper.aaV(DstOp.prodSca,[vD,vB],"prodSca",DstTy.TensorTy([]))
128 : cchiw 2522 in (vE,B@D@E) end
129 :    
130 :     | (_,0)=> (vA,A)
131 :     | _ => let
132 : cchiw 2525 val (vD,D)= genHelper.aaV(DstOp.subSca,[vA,vB],"subSca",DstTy.TensorTy([]))
133 : cchiw 2522 in (vD, A@B@D) end
134 :     (*end case*))
135 :     end
136 : cchiw 2525 | E.Sum(sx,E.Prod(E.Img im::E.Krn(id,del,pos)::es))=>let
137 : cchiw 2522 val ref mapp=mappA
138 : cchiw 2525 val harg=List.nth(origargs,id)
139 :     val h=genHelper.getKernel(harg)
140 :     in genKrn.mkkrns(mapp,(body,h,args))
141 : cchiw 2522 end
142 : cchiw 2525
143 : cchiw 2522 | E.Prod e => let
144 :     (*checkO removes 1 from list, and returns 0 if there is one*)
145 : cchiw 2525 fun checkO ([],[],[])=let val (vA,A)=genHelper.aaV(DstOp.C 1,[],"Const",DstTy.TensorTy([])) in ([vA],A) end
146 : cchiw 2522 | checkO(ids,code,[])=(ids,code)
147 :     | checkO(ids,code, e1::es)=let
148 :     val (a,b)=gen e1
149 :     val sB=genHelper.skeleton b
150 :     in (case sB
151 :     of 0 => ([a],b)
152 :     | 1 => checkO(ids,code,es)
153 :     | _ => checkO(ids@[a],code@b,es)
154 :     (*end case*))
155 :     end
156 :     val (ids,code)=checkO([],[],e)
157 :     in (case ids
158 :     of [id1]=> (id1,code)
159 :     | _=>let
160 : cchiw 2525 val (vB,B)=genHelper.mkMultiple(ids,DstOp.prodSca,DstTy.TensorTy([]))
161 : cchiw 2522 in (vB,code@B) end
162 :     (*end case*))
163 :     end
164 :     | E.Div(e1,e2)=>(let
165 :     val (vA,A)=gen e1
166 :     val sA=genHelper.skeleton A
167 :     in (case sA
168 :     of 0=> genHelper.mkC 0
169 :     | _=> let
170 :     val (vB,B)=gen e2
171 : cchiw 2525 val (vD,D)= genHelper.aaV(DstOp.divSca,[vA,vB],"divSca",DstTy.TensorTy([]))
172 : cchiw 2522 in (vD, A@B@D) end
173 :     (*end case*))
174 :    
175 :     end)
176 :     | E.Sum(sumx, e)=> let
177 :     val m=print "in general sum"
178 :     val ref orig=mappA
179 :     fun sumloop(mapsum)= (mappA:=(orig@mapsum); let
180 :     val(vA,A)=gen e
181 :     val sA=genHelper.skeleton A
182 :     in (case sA
183 :     of 0 =>([],[])
184 :     | _=>([vA],A)
185 :     (*end case*))
186 :    
187 :     end )
188 :     fun sumI1(left,(0,lb1),[],rest,code)=let
189 :     val mapp=left@[lb1]
190 :     val (vD,pre)= sumloop(mapp)
191 :     in (vD@rest,pre@code) end
192 :     | sumI1(left,(i,lb1),[],rest,code)=let
193 :     val mapp=left@[i+lb1]
194 :     val (vD,pre)=sumloop(mapp)
195 :     in sumI1(left,(i-1,lb1),[],vD@rest,pre@code) end
196 :     | sumI1(left,(0,lb1),(a,lb2,ub)::sx,rest,code)=
197 :     sumI1(left@[lb1],(ub-lb2,lb2),sx,rest,code)
198 :     | sumI1(left,(s,lb1),(a,lb2,ub)::sx,rest,code)=let
199 :     val (rest',code')=sumI1(left@[s+lb1],(ub-lb2,lb2),sx,rest,code)
200 :     in sumI1(left,(s-1,lb1),(E.V 0,lb2,ub)::sx,rest',code') end
201 :     val (_,lb,ub)=hd(sumx)
202 :     val(li, code)=sumI1([],(ub-lb,lb),tl(sumx),[],[])
203 :     in (case li
204 :     of [l1] => (l1,code)
205 : cchiw 2525 |_=>let val(vF,F)=genHelper.mkMultiple(li,DstOp.addSca,DstTy.TensorTy([]))
206 : cchiw 2522 in (vF,code@F) end
207 :     (*end case*))
208 :     end
209 :    
210 :     (*end case*))
211 :    
212 :     in gen body end
213 :    
214 :     (*Below functions are used to check for vectorization*)
215 :     (*Check Addition expression *)
216 : cchiw 2525 fun handleSimpleAdd(E.Add body,index,origargs, args)=let
217 : cchiw 2522 val n=length(index)-1
218 :     fun add (lft,[])=let
219 :     val index'=List.take(index,n)
220 : cchiw 2525 val m=List.nth(index,n)
221 :     in prodIter(index,index',genHelper.handleAddVec,(lft,index,m,args)) end
222 :     | add(lft,E.Tensor(id,[])::es) =prodIter(index,index,generalfn,(E.Add body,[],origargs, args))
223 : cchiw 2522 | add(lft,E.Tensor(id,list1)::es) =let
224 : cchiw 2525
225 : cchiw 2522 val n1=length(list1)-1
226 : cchiw 2525
227 : cchiw 2522 val E.V i=List.nth(list1, n1)
228 :     in if(i=n) then let
229 :     val list1'=List.take(list1,n1)
230 :     in add(lft@[(id,list1')],es) end
231 : cchiw 2525 else prodIter(index,index,generalfn,(E.Add body,[],origargs, args))
232 : cchiw 2522 end
233 : cchiw 2525 | add(lft,e::es)=prodIter(index,index,generalfn,(E.Add body,[],origargs, args))
234 : cchiw 2522 in add([],body)
235 :     end
236 :    
237 :    
238 :     (*Addition and sutraction, with just two tensors *)
239 :     fun handleSimpleOp (orig,index,f1,f2,args)=let
240 :     val [(id1,list1),(id2,list2)]=orig
241 :     val n1=length(list1)-1
242 :     val n2=length(list2)-1
243 :     val n=length(index)-1
244 :     val vi=List.nth(list1, n1)
245 :     val vj=List.nth(list2,n2)
246 :     val E.V i=vi
247 :     val E.V j=vj
248 :     in
249 :     if(j=n andalso i=j)
250 :     then let
251 :     val list1'=List.take(list1,n1)
252 :     val list2'=List.take(list2,n2)
253 :     val index'=List.take(index,n)
254 : cchiw 2525 val m=List.nth(index,n)
255 :     in prodIter(index,index',f1,([(id1,list1'),(id2,list2')],[],m,args)) end
256 :     else prodIter(index,index,f2,(orig,[],args))
257 : cchiw 2522 end
258 :    
259 :     (*Need to double check here*)
260 : cchiw 2525 fun handleNeg(orig,index,id, ix,origargs, args)=let
261 : cchiw 2522 (*Create a Vector from Tensor*)
262 :     val n=(length index)-1
263 :     val i=List.nth(ix, n)
264 :     in (case i
265 :     of E.V v =>
266 :     if(v=n) (*can use vectorization*)
267 :     then let
268 :     val (vA,A)= genHelper.mkC(0)
269 : cchiw 2525 val uuu=print "post genHelper Call"
270 : cchiw 2522 val index'=List.take(index,n)
271 :     val ix'=List.take(ix,n)
272 : cchiw 2525 val m=List.nth(index,n)
273 :     val g=print "IN HANDLE NEG-- PUPPY\n"
274 :     val ggg=print(Int.toString(m))
275 :     val (vB,B)=prodIter(index,index',genHelper.mkNegV,((vA,id,ix'),[],m,args))
276 :     in (vB,A@B)
277 :     end
278 :     else prodIter(index,index,generalfn,(orig,[],origargs, args))
279 :     |_ => prodIter(index,index,generalfn,(orig,[],origargs, args))
280 : cchiw 2522 (*end case *))
281 :     end
282 :    
283 :    
284 :     (*Prodduct of two tensors*)
285 :    
286 :    
287 :    
288 : cchiw 2525 fun handleProd(orig,index,sx,origargs, args)=let
289 : cchiw 2522 val [(id1,list1),(id2,list2)]=orig
290 :     val n1=length(list1)-1
291 :     val n2=length(list2)-1
292 :     val vi=List.nth(list1, n1)
293 :     val vj=List.nth(list2,n2)
294 :     val list1'=List.take(list1,n1)
295 :     val list2'=List.take(list2,n2)
296 :     val ns=length(sx)
297 :     in
298 :     if(ns=0)
299 :     then let
300 :     val m=genHelper.findDup(list1,list2)
301 :     val n=length(index)-1
302 :     val E.V i=vi
303 :     val E.V j=vj
304 :     val index'=List.take(index,n)
305 : cchiw 2525 val mm=List.nth(index,n)
306 : cchiw 2522 in (case m
307 :     of NONE =>
308 :     (*{A_.. B_..j}_...j ? i.e, outproduct*)
309 :     (* s*v otherwise s*s *)
310 : cchiw 2525 if(j=n) then
311 :     prodIter(index,index',genHelper.mkprodScaV,([(id1,list1),(id2,list2')],[],mm,args))
312 :     else prodIter(index,index,genHelper.mkprodSca,(orig,[],args))
313 : cchiw 2522 | _ =>
314 :     (*{A_i B_i}_i? i.e. modoulate*)
315 :     (* v*v otherwise s*s*)
316 :     if(i=j andalso i=n)
317 : cchiw 2525 then prodIter(index,index',genHelper.mkprodVec,([(id1,list1'),(id2,list2')],[],mm,args))
318 :     else prodIter(index,index,genHelper.mkprodSca,(orig,[],args))
319 : cchiw 2522 (*end case*)) end
320 :     else if (ns=1) then let
321 :     val [(sx1,lb,ub)]=sx
322 : cchiw 2525
323 : cchiw 2522 in
324 :     if(vi=vj andalso vi=sx1)
325 : cchiw 2525 then prodIter(index,index,genHelper.mkprodSumVec,([(id1,list1'),(id2,list2')],[],ub,args)) (*v,v*)
326 :     else prodIter(index,index,genHelper.sum,(orig,sx,args)) (*s,s *)
327 : cchiw 2522 end
328 :     else if (ns=2)
329 :     then let val [(sx1,lb1,ub1),(sx2,lb2,ub2)]=sx
330 : cchiw 2525
331 : cchiw 2522 in if(vi=vj andalso vi=sx1)
332 : cchiw 2525
333 :     then prodIter(index,index,genHelper.sumDot,([(id1,list1'),(id2,list2')],[(sx2,lb2,ub2)],ub1,args))
334 : cchiw 2522 else if(vi=vj andalso vi=sx2)
335 : cchiw 2525 then prodIter(index,index,genHelper.sumDot,([(id1,list1'),(id2,list2')],[(sx1,lb1,ub1)],ub2,args))
336 :     else prodIter(index,index,genHelper.sum,(orig,sx,args))
337 : cchiw 2522 end
338 : cchiw 2525 else prodIter(index,index,genHelper.sum,(orig,sx,args))
339 : cchiw 2522 end
340 :    
341 :    
342 :    
343 : cchiw 2525 fun handleScVProd(body,orig,index,sx,origargs, args)=let
344 : cchiw 2522 val (id1,id2,list2)=orig
345 :     val n2=length(list2)-1
346 :     val vj=List.nth(list2,n2)
347 :     val E.V j=vj
348 :     val nsx=length sx
349 :     val n=length(index)-1
350 : cchiw 2525 val m=List.nth(index,n)
351 : cchiw 2522 in if(j=n andalso nsx=0)
352 :     then let
353 :     val index'=List.take(index,n) val list2'=List.take(list2,n2)
354 : cchiw 2525 val q=print(String.concat["Puppy-Make Vector index",Int.toString(m)])
355 :     in prodIter(index,index',genHelper.mkprodScaV,([(id1,[]),(id2,list2')],[],m, args)) end
356 :     else prodIter(index,index,generalfn,(body,[],origargs, args))
357 : cchiw 2522 end
358 :    
359 :    
360 :     (*Simple Operators on two tensors, examine to see if we could use vectors *)
361 : cchiw 2525 (*Have to pass orig args to everyone in case we have a kernel or image in a later stage of iteration*)
362 : cchiw 2522
363 : cchiw 2525 fun genfn(y,Ein.EIN{params, index, body},origargs,a)= let
364 : cchiw 2522 val sx= ref[]
365 :     val n=length index
366 : cchiw 2525 val args=(a,params)
367 : cchiw 2522 (*Potential for Vectorization here *)
368 :     fun gen b=(case b
369 :     of E.Field _ =>raise Fail(concat["Invalid Field here "] )
370 :     | E.Partial _ =>raise Fail(concat["Invalid Field here "] )
371 :     | E.Apply _ =>raise Fail(concat["Invalid Field here "] )
372 :     | E.Probe _ =>raise Fail(concat["Invalid Field here "] )
373 :     | E.Conv _ =>raise Fail(concat["Invalid Field here "] )
374 :    
375 :     (*| E.Const _=>[]*)
376 :    
377 :    
378 :    
379 : cchiw 2525 | E.Neg(E.Tensor(id,ix))=> handleNeg(body,index,id, ix,origargs, args)
380 :     | E.Add _ => handleSimpleAdd(body,index,origargs, args)
381 : cchiw 2522 | E.Sub(E.Tensor(id1, ix1), E.Tensor(id2, ix2)) =>
382 :     handleSimpleOp([(id1,ix1),(id2,ix2)],index,genHelper.mksubVec,genHelper.mksubSca,args)
383 :     | E.Prod[E.Tensor(id1, []), E.Tensor(id2, ix2)] =>let
384 :     val ref x=sx
385 : cchiw 2525 in handleScVProd(body,(id1,id2,ix2),index,x,origargs, args) end
386 : cchiw 2522 | E.Prod[E.Tensor(id1, ix1), E.Tensor(id2, [])] =>let
387 :     val ref x=sx
388 : cchiw 2525 in handleScVProd(body,(id2,id1,ix1),index,x, origargs, args) end
389 : cchiw 2523 (* | E.Prod[E.Tensor(id1, ix1), E.Tensor(id2, ix2)] =>
390 : cchiw 2522 let
391 :     val ref x=sx
392 :     in
393 :     handleProd([(id1,ix1),(id2,ix2)],index,x,args)
394 : cchiw 2523 end*)
395 : cchiw 2522 (*| E.Div(E.Tensor _,E.Tensor _ )=>[]*)
396 :    
397 : cchiw 2525 | E.Sum(ss,E.Prod(E.Img im::E.Krn(id,del,pos)::es))=>let
398 : cchiw 2522 val ref x=sx
399 :     val m=print "\n match img"
400 :    
401 :     in
402 : cchiw 2525 if(length x=0) then let
403 :     val harg=List.nth(origargs,id)
404 :     val h=genHelper.getKernel(harg)
405 :    
406 :     in prodIter(index,index,genKrn.mkkrns,(b, h,args)) end
407 :     else prodIter(index,index,generalfn,(body,[],origargs, args))
408 : cchiw 2522 end
409 :     | E.Sum(sx', e)=> (let
410 :     val ref x=sx
411 :     in sx:=x@sx' end ;gen e)
412 : cchiw 2525 | _ => prodIter(index,index,generalfn,(body,[],origargs, args))
413 : cchiw 2522 (*end case*))
414 :    
415 :     (*Scalars only, not vectorization potential*)
416 :     fun single b=(case b
417 : cchiw 2525 of E.Tensor(id,[]) =>genHelper.mkSca([],(id,[], args))
418 :     | E.Const _=>generalfn([],(b,[],origargs, args))
419 :     | E.Neg _ => generalfn([],(b,[],origargs, args))
420 :     | E.Add _ =>generalfn([],(b,[],origargs, args))
421 :     | E.Sub _=> generalfn([],(b,[],origargs, args))
422 :     | E.Div _=> generalfn([],(b,[],origargs, args))
423 :     | E.Prod [E.Tensor(_,[]),E.Tensor(_,[])] => generalfn([],(b,[],origargs, args))
424 : cchiw 2522 | _=> gen b
425 :    
426 :     (*end case*))
427 :    
428 :     in (case n of 0 =>single body | _=> gen body) end
429 :    
430 :     end (* local *)
431 :    
432 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0