Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/step3.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/step3.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2830 - (view) (download)

1 : cchiw 2615 (*Helper function gen-ein*)
2 : cchiw 2612 structure step3 = struct
3 :     local
4 :    
5 :     structure DstIL = LowIL
6 :     structure DstTy = LowILTypes
7 :     structure DstOp = LowOps
8 :     structure Var = LowIL.Var
9 :     structure SrcIL = MidIL
10 :     structure SrcOp = MidOps
11 :     structure E = Ein
12 :     structure tS= toStringEin
13 : cchiw 2627
14 : cchiw 2615
15 : cchiw 2612 in
16 :    
17 : cchiw 2829 val testing=0
18 : cchiw 2612 val bV= ref 0
19 :     fun err str=raise Fail(str)
20 : cchiw 2827 val iTy=DstTy.IntTy
21 : cchiw 2620 val Sca=DstTy.TensorTy []
22 : cchiw 2676 val addR=DstOp.addSca
23 :    
24 : cchiw 2620 fun lookup k d = d k
25 :     fun q e1=Int.toString e1
26 : cchiw 2612
27 : cchiw 2827 fun testp n= (case testing
28 :     of 0=> 0
29 :     | _ =>((print (String.concat n));1)
30 :     (*end case*))
31 :    
32 : cchiw 2612 fun insert (key, value) d =fn s =>
33 : cchiw 2827 if s = key then (testp[Int.toString(key),"=>",Int.toString(value)];SOME value)
34 : cchiw 2612 else d s
35 :    
36 :    
37 :     (*Get kernel and Image bindings*)
38 :     fun getKernel x = (case SrcIL.Var.binding x
39 : cchiw 2680 of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _ ) ,_ ))=> h
40 : cchiw 2612 | vb => (err (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
41 :     (* end case *))
42 :    
43 : cchiw 2620
44 : cchiw 2680 fun getImageSrc x = (case SrcIL.Var.binding x
45 :     of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage img, _ )) => img
46 : cchiw 2612 | vb => (err (String.concat["\n -- Not an image, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
47 : cchiw 2680 (* end case *))
48 : cchiw 2612
49 :     (*Make assignment*)
50 :     fun aaV(opss,args,pre,ty)=let
51 :     val a=DstIL.Var.new(pre ,ty)
52 :     val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
53 : cchiw 2827 val _ =testp[tS.toStringAll(ty,code)]
54 : cchiw 2612 in
55 :     (a,[code])
56 :     end
57 :    
58 : cchiw 2827 fun mkReal n=let
59 :     val a=DstIL.Var.new("Real" ,Sca)
60 :     val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
61 :     val _ =testp[tS.toStringAll(Sca,code)]
62 :     in
63 :     (a,[code])
64 :     end
65 : cchiw 2637
66 : cchiw 2624 fun mkInt n=let
67 : cchiw 2827 val a=DstIL.Var.new("Int" ,iTy)
68 : cchiw 2637 val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
69 : cchiw 2827 val _ =testp[tS.toStringAll(iTy,code)]
70 : cchiw 2624 in
71 :     (a,[code])
72 :     end
73 : cchiw 2620
74 :    
75 : cchiw 2624
76 : cchiw 2612 (*mk Multiple, Add Ids on list1*)
77 : cchiw 2789 fun mkMultiple((lhs,_,_),list1,rator,ty)=let
78 :     fun add([],_,_) = err"no element in mkMultiple"
79 :     | add([e1],_,_) = (e1,[])
80 :     | add([e1,e2],code,_) = let
81 :     val (vA,A)=aaV(rator,[e1,e2],lhs^"_2",ty)
82 : cchiw 2612 in (vA,code@A) end
83 : cchiw 2789 | add(e1::e2::es,code,count) = let
84 :     val (vA,A)=aaV(rator,[e1,e2],lhs^"_"^Int.toString count,ty)
85 :     in add(vA::es,code@A,count-1)
86 : cchiw 2612 end
87 :     in
88 : cchiw 2789 add(list1,[],List.length list1)
89 : cchiw 2612 end
90 :    
91 :    
92 :     fun mapIndex(e1,mapp)=(case e1
93 :     of E.V e => (case (lookup e mapp)
94 :     of NONE=> err("Outside Bound:"^Int.toString(e))
95 :     |SOME s => s)
96 :     | E.C c=> c
97 :     (*end case*))
98 :    
99 :    
100 :     (*Integer, or Generic Tensor*)
101 :     fun getTensorTy(params, id)=(case List.nth(params,id)
102 :     of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
103 :     | E.TEN(_,shape)=> DstTy.TensorTy shape
104 :     |_=> err"NONE Tensor Param")
105 :    
106 : cchiw 2631 fun q e=Int.toString(e)
107 : cchiw 2612
108 : cchiw 2827
109 : cchiw 2795 fun mkSca(mapp,(id, [],(lhs,params,args))) = (List.nth(args,id),[])
110 : cchiw 2789 | mkSca(mapp,(id,ix,(lhs,params,args)))= let
111 : cchiw 2612 val nU=List.nth(args,id)
112 : cchiw 2631 val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
113 :     val ix'=DstTy.indexTy ixx
114 : cchiw 2612 val argTy=getTensorTy(params,id)
115 : cchiw 2795 val opp=DstOp.IndexTensor(id,ix',argTy)
116 : cchiw 2612 in
117 : cchiw 2789 aaV(opp,[nU],lhs^"S"^Int.toString(id),Sca)
118 : cchiw 2612 end
119 :    
120 : cchiw 2827 fun mkIntAsn(mapp,(id,ix,(lhs,params,args)))= let
121 :     val nU=List.nth(args,id)
122 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
123 :     val ix'=DstTy.indexTy ixx
124 :     val argTy=getTensorTy(params,id)
125 :     val opp=DstOp.IndexTensor(id,ix',argTy)
126 :     in
127 :     aaV(opp,[nU],lhs^"I"^Int.toString(id),iTy)
128 :     end
129 :    
130 :    
131 :    
132 : cchiw 2612 (*eval Epsilon*)
133 :     fun evalEps(mapp,a,b,c)=let
134 :     val i=mapIndex(E.V a,mapp)
135 :     val j=mapIndex(E.V b,mapp)
136 :     val k=mapIndex(E.V c,mapp)
137 :     in
138 : cchiw 2827 if(i=j orelse j=k orelse i=k) then mkReal 0
139 : cchiw 2612 else
140 :     if(j>i) then
141 : cchiw 2827 if(j>k andalso k>i) then mkReal ~1 else mkReal 1
142 :     else if(i>k andalso k>j) then mkReal 1 else mkReal ~1
143 : cchiw 2612
144 :     end
145 :    
146 :     (*eval Delta*)
147 :     fun evalDelta2(mapp,a,b)= let
148 :     val i=mapIndex(a,mapp)
149 :     val j=mapIndex(b,mapp)
150 :     in
151 : cchiw 2827 if(i=j) then mkReal 1 else mkReal 0
152 : cchiw 2612 end
153 :    
154 :    
155 :     fun evalDels(mapp,dels)=let
156 :     fun m(a,b)=if(a=b) then 1 else 0
157 :     fun ij(i,j)=(case (i,j)
158 :     of (E.V a, E.V b)=>m(mapIndex(i,mapp),mapIndex(j,mapp))
159 :     | (E.C a, E.V b)=>m(a,mapIndex(j,mapp))
160 :     | (E.V a, E.C b)=>m(mapIndex(i,mapp),b)
161 :     | (E.C a, E.C b)=>m(i,j)
162 :     (*end case*))
163 :     val dels'=List.map ij dels
164 :     in
165 :     List.foldl(fn(x,y)=>x+y) 0 dels'
166 :     end
167 :    
168 :    
169 :     (*--------------------Vectorization Helper Functions--------------------*)
170 :     (*val nextfnArgs=(body,params,args,origargs)*)
171 :    
172 : cchiw 2789 fun mkVec(mapp,(id,[],vecIX,(lhs,params,args)))= let
173 : cchiw 2612 val nU=List.nth(args,id)
174 : cchiw 2631 in (nU,[]) end
175 : cchiw 2789 | mkVec(mapp,(id,ix,vecIX,(lhs,params,args)))= let
176 : cchiw 2631 val nU=List.nth(args,id)
177 : cchiw 2612 val ix'=DstTy.indexTy(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
178 :     val argTy= getTensorTy(params,id)
179 : cchiw 2631 val vecTy=DstTy.TensorTy [vecIX]
180 : cchiw 2791 val opp=DstOp.ProjectTensor(id,vecIX,ix',argTy)
181 : cchiw 2612 in
182 : cchiw 2789 aaV(opp,[nU],lhs^"V"^Int.toString(id),vecTy)
183 : cchiw 2612 end
184 :    
185 :    
186 : cchiw 2791
187 :    
188 : cchiw 2612 (*product of -1 and 1 projection*)
189 : cchiw 2789 fun mkNegV(mapp,((vA,id,ix),vecIX,info as (lhs,_,_)))=let
190 : cchiw 2612 val (vB, B)= mkVec(mapp,(id,ix,vecIX,info))
191 : cchiw 2789 val (vD, D)=aaV(DstOp.prodScaV vecIX,[vA, vB],lhs^"prodScaV",DstTy.TensorTy [vecIX])
192 : cchiw 2612 in
193 :     (vD,B@D)
194 :     end
195 :    
196 :     (* Vector Subtraction*)
197 : cchiw 2789 fun mksubVec(mapp,(id1,ix1,id2,ix2,vecIX,info as (lhs,_,_)))= let
198 : cchiw 2612 val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
199 :     val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
200 : cchiw 2789 val (vD,D)= aaV(DstOp.subVec vecIX ,[vA, vB],lhs^"subVec",DstTy.TensorTy [vecIX])
201 : cchiw 2612 in
202 :     (vD, A@B@D)
203 :     end
204 :    
205 :     (*Vector Addition *)
206 :     fun handleAddVec(mapp,(es,vecIX,info))=let
207 :     fun add([],rest,code)=(rest,code)
208 :     | add((id1,ix1)::es,rest,code)=let
209 :     val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
210 :     in add(es,rest@[vA],code@A)
211 :     end
212 :     val (rest,code)=add(es,[],[])
213 : cchiw 2789 val (vA,A)=mkMultiple(info,rest,DstOp.addVec vecIX,DstTy.TensorTy([vecIX]))
214 : cchiw 2612 in
215 :     (vA,code@A)
216 :     end
217 :    
218 :     (*Vector Scaling*)
219 : cchiw 2789 fun mkprodScaV(mapp,(id1,ix1,id2,ix2,vecIX,info as (lhs,_,_)))=let
220 : cchiw 2631 val (vA,A)= mkSca(mapp,(id1,ix1,info))
221 : cchiw 2612 val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
222 : cchiw 2789 val (vD,D)= aaV(DstOp.prodScaV vecIX,[vA, vB],lhs^"prodScaV",DstTy.TensorTy([vecIX]))
223 : cchiw 2612 in
224 :     (vD,A@B@D)
225 :     end
226 :    
227 :     (*Vector Product*)
228 : cchiw 2789 fun mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info as (lhs,_,_)))= let
229 : cchiw 2612 val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
230 :     val (vB, B)= mkVec(mapp,(id2,ix2,vecIX,info))
231 : cchiw 2789 val (vD, D)=aaV(DstOp.prodVec vecIX,[vA, vB],lhs^"prodV",DstTy.TensorTy([vecIX]))
232 : cchiw 2612 in
233 :     (vD, A@B@D)
234 :     end
235 :    
236 :     (*Sum of Vector Product*)
237 :    
238 : cchiw 2680
239 : cchiw 2830 fun mkprodSumVec(mapp,(id1,ix1,id2,ix2,vecIX, info as (lhs,_,_)))=let
240 :     val (vD,D)=mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info))
241 :     val (vE, E)=aaV(DstOp.sumVec vecIX,[vD], lhs^"sumVec",DstTy.realTy)
242 :     in
243 :     (vE, D @E)
244 :     end
245 : cchiw 2680
246 : cchiw 2830
247 : cchiw 2612 (*Dot Product like summation *)
248 : cchiw 2789 fun sumDot(mapp, ((E.V v,lb,ub),t as (_,_,_,_,_,info) ))=let
249 : cchiw 2612
250 :     fun sumI(a,0,rest,code)=let
251 :     val mapp =insert(v, 0) a
252 :     val (vE, E)=mkprodSumVec(mapp,t)
253 :     val rest'=[vE]@rest
254 : cchiw 2789 val (vF, F)=mkMultiple(info,rest',addR,Sca)
255 : cchiw 2612 in
256 :     (vF,E@code@F)
257 :     end
258 :     | sumI(a,sx,rest',code')=let
259 :     val mapp =insert(v, (sx+lb)) a
260 :     val (vE, E)=mkprodSumVec(mapp,t)
261 :     in
262 :     sumI(a,sx-1,[vE]@rest',E@code')
263 :     end
264 :     in
265 :     sumI(mapp, (ub-lb), [],[])
266 :     end
267 : cchiw 2615 | sumDot _= raise Fail "Non-variable index in summation"
268 : cchiw 2612
269 :     end
270 :    
271 :    
272 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0