Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/step1.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/step1.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2665 - (view) (download)

1 : cchiw 2615 (*Looks for vectorization potential*)
2 : cchiw 2612 structure step1 = struct
3 :     local
4 :     structure DstTy = LowILTypes
5 :     structure DstOp = LowOps
6 :     structure Var = LowIL.Var
7 :     structure E = Ein
8 :     structure P=Printer
9 :     structure genKrn=genKrn
10 :     structure S2= step2
11 :     structure S3= step3
12 :    
13 :     in
14 :    
15 :    
16 :    
17 :     val Sca=DstTy.TensorTy([])
18 :     fun errS str=raise Fail(str)
19 :    
20 :     (*Returns last index, and rest of list*)
21 :     fun getLast list=let
22 :     val a=List.rev(list)
23 :     val A=List.hd(a)
24 :     val b=List.tl(a)
25 :     val B=List.rev(b)
26 :     in (A,B) end
27 :    
28 :     (*Returns last index, and rest of list*)
29 :     fun getLastAll list=let
30 :     val a=List.rev(list)
31 :     val A=List.hd(a)
32 :     val b=List.tl(a)
33 :     val B=List.rev(b)
34 :     in (case A
35 :     of E.V i =>(i,A,B)
36 :     | _ => errS("Last Index is not Variable")
37 :     (*end case*))
38 :     end
39 :    
40 :     fun findDup(list1,list2)=let
41 :     fun current []=NONE
42 :     | current (v::vs)=(case (List.find (fn x => x =v) list2)
43 :     of NONE =>current vs
44 :     |_=> SOME 1
45 :     (*end case*))
46 :     in
47 :     current list1
48 :     end
49 :    
50 :     (*-----------Handle Functions for vectorization-----------*)
51 :    
52 :    
53 :     (*Negative Tensor*)
54 :     fun handleNeg(id, ix,index, nextfnArgs)=let
55 :     val n=(length index)-1
56 :     fun default _=S2.prodIter(index,index,S2.generalfn, nextfnArgs)
57 :     in (case (List.nth(ix, n))
58 :     of E.V v =>
59 :     if(v=n) then let
60 : cchiw 2624 val (vA,A)= S3.mkInt ~1
61 : cchiw 2612 val ix'=List.take(ix,n)
62 :     val (LastIndex,index')=getLast index
63 :     val (_,_,info)=nextfnArgs
64 :     val (vB,B)=S2.prodIter(index,index',S3.mkNegV,((vA,id,ix'),LastIndex,info))
65 :     in (vB,A@B) end
66 :     else default 0
67 :     |_ => default 0
68 :     (*end case *))
69 :     end
70 :    
71 :     (*Subtract Tensors*)
72 :     fun handleSimpleOp(id1,list1,id2,list2,index,f1,f2,nextfnArgs)=let
73 :     fun default _=S2.prodIter(index,index,f2, nextfnArgs)
74 :     in (case (list1, list2)
75 :     of ([],_)=> default 0
76 :     | (_,[])=> default 0
77 :     | _ => let
78 :     val (i,vi,list1')=getLastAll list1
79 :     val (j,vj,list2')=getLastAll list2
80 :     val n=length(index)-1
81 :     in
82 :     if(j=i andalso j=n)
83 :     then let
84 :     val (LastIndex,index')=getLast index
85 :     val (_,_,info)=nextfnArgs
86 :     in S2.prodIter(index,index',f1,(id1,list1',id2,list2',LastIndex,info)) end
87 :     else default 0
88 :     end
89 :     (*end case*))
90 :     end
91 :    
92 :    
93 :     (*Add Vector *)
94 :     fun handleAdd(terms,index,nextfnArgs)=let
95 :     val n=length(index)-1
96 :     fun default _=S2.prodIter(index,index,S2.generalfn, nextfnArgs)
97 :     fun add (e,rest)=(case e
98 :     of [] => let
99 :     val (LastIndex,index')=getLast index
100 :     val (_,_,info)=nextfnArgs
101 :     in S2.prodIter(index,index',S3.handleAddVec,(rest,LastIndex,info)) end
102 :     | E.Tensor(id,[])::es => default 0
103 :     | E.Tensor(id,alpha)::es => let
104 :     val (i,_,list1')=getLastAll alpha
105 :     in if(i=n) then add(es,rest@[(id,list1')])
106 :     else default 0
107 :     end
108 :    
109 :     | _=> default 0
110 :     (*end case *))
111 :     in
112 :     add(terms,[])
113 :     end
114 :    
115 :    
116 :     (*Scaling*)
117 :     fun handleScVProd(id1,id2,alpha,index,nextfnArgs)=let
118 :     val (j,vj,list2')=getLastAll alpha
119 :     val (LastIndex,index')= getLast index
120 :     val n=length(index)-1
121 :     val (_,_,info)=nextfnArgs
122 :     fun default _=S2.prodIter(index,index,S2.generalfn, nextfnArgs)
123 :     in if(j=n)
124 :     then S2.prodIter(index,index',S3.mkprodScaV,(id1,[],id2,list2',LastIndex,info))
125 :     else default 0
126 :     end
127 :    
128 :    
129 :    
130 :     (*None:{A_.. B_..j}_...j ? i.e, outproduct : s*v otherwise s*s *)
131 :     (*Some: {A_i B_i}_i? i.e. modoulate: v*v otherwise s*s*)
132 :     fun handleProd(id1,list1,id2,list2,index,sx,nextfnArgs)=let
133 :     val (i,vi,list1')=getLastAll list1
134 :     val (j,vj,list2')=getLastAll list2
135 :     val (_,_,info)=nextfnArgs
136 : cchiw 2613 fun default _ = S2.prodIter(index,index,S2.generalfn, nextfnArgs)
137 :    
138 : cchiw 2612 in (case sx
139 :     of []=>let
140 :     val (LastIndex,index')=getLast index
141 :     val n=length(index)-1
142 :     in (case (findDup(list1,list2))
143 :     of NONE =>if(j=n)
144 :     then S2.prodIter(index,index',S3.mkprodScaV,(id1,list1,id2,list2',LastIndex,info))
145 :     else default 0
146 :     | _ => if(i=j andalso i=n)
147 :     then S2.prodIter(index,index',S3.mkprodVec,(id1,list1',id2,list2',LastIndex,info))
148 :     else default 0
149 :     (*end case*))
150 :     end
151 :     | [(sx1,0,ub)]=> if(vi=vj andalso vi=sx1)
152 :     then S2.prodIter(index,index,S3.mkprodSumVec,(id1,list1',id2,list2',ub+1,info))
153 :     else default 0
154 :     | [(sx1,lb1,ub1),(sx2,lb2,ub2)]=>
155 :     if(vi=vj andalso vi=sx1)
156 : cchiw 2637 then S2.prodIter(index,index,S3.sumDot,((sx2,lb2,ub2),(id1,list1',id2,list2',ub1+1, info)))
157 : cchiw 2612 else if(vi=vj andalso vi=sx2)
158 : cchiw 2637 then S2.prodIter(index,index,S3.sumDot,((sx1,lb1,ub1),(id1,list1',id2,list2',ub2+1,info)))
159 : cchiw 2612 else default 0
160 :     | _ => default 0
161 :     (*end case*))
162 :     end
163 :    
164 :    
165 :    
166 :    
167 :     (*General Function looks for vectorization potential*)
168 :     (*A=(body,info)*)
169 :     (*NextFn/GenerAlFn:(mapp, A)*)
170 :     (*ProdIter: index,index, nextfn, A*)
171 :     (*handleFunction:[(id,ix)]'s,index,info *)
172 :     (*params,args,origargs*)
173 :     fun genfn(Ein.EIN{params, index, body},args,origargs)= let
174 :     val info=(params,args)
175 :     val nextfnArgs=(body,origargs,info)
176 :     val iterArgs=(index,index,S2.generalfn,nextfnArgs)
177 :     fun gen body=(case body
178 : cchiw 2665 of(**) E.Neg(E.Tensor(id1,ix1)) =>
179 : cchiw 2612 handleNeg(id1, ix1,index, nextfnArgs)
180 :     | E.Sub(E.Tensor(id1,ix1),E.Tensor(id2,ix2)) =>
181 :     handleSimpleOp(id1,ix1,id2,ix2,index,S3.mksubVec,S2.generalfn,nextfnArgs)
182 :     | E.Add es =>
183 :     handleAdd(es,index,nextfnArgs)
184 :     | E.Prod[e] => gen e
185 :     | E.Prod[E.Tensor(id1,[]), E.Tensor(id2, [])] =>
186 :     S2.prodIter iterArgs
187 :     | E.Prod[E.Tensor(id1, []), E.Tensor(id2, ix2)] =>
188 :     handleScVProd(id1,id2,ix2,index,nextfnArgs)
189 :     | E.Prod[E.Tensor(id2, ix2), E.Tensor(id1, [])] =>
190 :     handleScVProd(id1,id2,ix2,index,nextfnArgs)
191 :     | E.Prod[E.Tensor(id1, ix1), E.Tensor(id2, ix2)] =>
192 :     handleProd(id1,ix1,id2,ix2,index,[],nextfnArgs)
193 :     | E.Sum(sx,E.Prod[E.Tensor(id1,ix1),E.Tensor(id2,ix2)]) =>
194 :     handleProd(id1,ix1,id2,ix2,index,sx,nextfnArgs)
195 :     | E.Sum(x,E.Prod(E.Img(Vid,_,_)::E.Krn(Hid,_,_)::_)) =>
196 :     let
197 :     val harg=List.nth(origargs,Hid)
198 :     val h=S3.getKernel(harg)
199 :     val imgarg=List.nth(origargs,Vid)
200 :     val v=S3.getImage(imgarg)
201 :     in
202 :     S2.prodIter(index,index,genKrn.evalField,(body,v,h,info))
203 :     end
204 : cchiw 2613 | _ => S2.prodIter iterArgs
205 : cchiw 2612 (*end case*))
206 :     in gen body
207 :     end
208 :    
209 :     end (* local *)
210 :    
211 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0