Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/step3.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/step3.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2845 - (view) (download)

1 : cchiw 2615 (*Helper function gen-ein*)
2 : cchiw 2612 structure step3 = struct
3 :     local
4 :    
5 :     structure DstIL = LowIL
6 :     structure DstTy = LowILTypes
7 :     structure DstOp = LowOps
8 :     structure Var = LowIL.Var
9 :     structure SrcIL = MidIL
10 :     structure SrcOp = MidOps
11 :     structure E = Ein
12 : cchiw 2845 structure LowToS= LowToString
13 : cchiw 2615
14 : cchiw 2612 in
15 :    
16 : cchiw 2845 val testing=0
17 :     val bV= ref 0
18 :     fun err str=raise Fail(str)
19 :     val iTy=DstTy.IntTy
20 :     val Sca=DstTy.TensorTy []
21 :     val addR=DstOp.addSca
22 :     fun lookup k d = d k
23 :     fun q e1=Int.toString e1
24 : cchiw 2676
25 : cchiw 2845 fun testp n= (case testing
26 :     of 0=> 0
27 :     | _ =>((print (String.concat n));1)
28 :     (*end case*))
29 : cchiw 2612
30 : cchiw 2845 fun insert (key, value) d =fn s =>
31 :     if s = key then (testp[Int.toString(key),"=>",Int.toString(value)];SOME value)
32 :     else d s
33 : cchiw 2827
34 : cchiw 2845 (*Get kernel and Image bindings*)
35 :     fun getKernel x = (case SrcIL.Var.binding x
36 :     of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _ ) ,_ ))=> h
37 :     | vb => (err (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
38 :     (* end case *))
39 : cchiw 2612
40 : cchiw 2845 fun getImageSrc x = (case SrcIL.Var.binding x
41 :     of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage img, _ )) => img
42 :     | vb => (err (String.concat["\n -- Not an image, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
43 : cchiw 2612 (* end case *))
44 :    
45 : cchiw 2845 (*Make assignment*)
46 :     fun aaV(opss,args,pre,ty)=let
47 :     val a=DstIL.Var.new(pre ,ty)
48 :     val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
49 :     val _ =testp[LowToS.toStringAll(ty,code)]
50 :     in
51 :     (a,[code])
52 :     end
53 : cchiw 2620
54 : cchiw 2845 fun mkReal n=let
55 :     val a=DstIL.Var.new("Real" ,Sca)
56 :     val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
57 :     val _ =testp[LowToS.toStringAll(Sca,code)]
58 :     in
59 :     (a,[code])
60 :     end
61 : cchiw 2612
62 : cchiw 2845 fun mkInt n=let
63 :     val a=DstIL.Var.new("Int" ,iTy)
64 :     val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
65 :     val _ =testp[LowToS.toStringAll(iTy,code)]
66 :     in
67 :     (a,[code])
68 :     end
69 : cchiw 2612
70 : cchiw 2845 (*mk Multiple, Add Ids on list1*)
71 :     fun mkMultiple((lhs,_,_,_),list1,rator,ty)=let
72 :     fun add([],_,_) = err"no element in mkMultiple"
73 :     | add([e1],_,_) = (e1,[])
74 :     | add([e1,e2],code,_) = let
75 :     val (vA,A)=aaV(rator,[e1,e2],lhs^"_2",ty)
76 :     in (vA,code@A) end
77 :     | add(e1::e2::es,code,count) = let
78 :     val (vA,A)=aaV(rator,[e1,e2],lhs^"_"^Int.toString count,ty)
79 :     in add(vA::es,code@A,count-1)
80 :     end
81 :     in
82 :     add(list1,[],List.length list1)
83 :     end
84 : cchiw 2637
85 : cchiw 2845 fun mapIndex(e1,mapp)=(case e1
86 :     of E.V e => (case (lookup e mapp)
87 :     of NONE=> err("Outside Bound:"^Int.toString(e))
88 :     |SOME s => s)
89 :     | E.C c=> c
90 :     (*end case*))
91 : cchiw 2620
92 : cchiw 2845 (*Integer, or Generic Tensor*)
93 :     fun getTensorTy(params, id)=(case List.nth(params,id)
94 :     of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
95 :     | E.TEN(_,shape)=> DstTy.TensorTy shape
96 :     |_=> err"NONE Tensor Param")
97 : cchiw 2620
98 : cchiw 2845 fun q e=Int.toString(e)
99 : cchiw 2624
100 : cchiw 2845 fun mkSca(mapp,(id, [],(lhs,params,_,args))) = (List.nth(args,id),[])
101 :     | mkSca(mapp,(id,ix,(lhs,params,_,args)))= let
102 :     val nU=List.nth(args,id)
103 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
104 :     val ix'=DstTy.indexTy ixx
105 :     val argTy=getTensorTy(params,id)
106 :     val opp=DstOp.IndexTensor(id,ix',argTy)
107 :     in
108 :     aaV(opp,[nU],lhs^"R"^Int.toString(id),Sca)
109 :     end
110 : cchiw 2612
111 : cchiw 2845 fun mkIntAsn(_,(id, [],(_,_,_,args))) = (List.nth(args,id),[])
112 :     | mkIntAsn(mapp,(id,ix,(lhs,params,_,args)))= let
113 :     val nU=List.nth(args,id)
114 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
115 :     val ix'=DstTy.indexTy ixx
116 :     val argTy=getTensorTy(params,id)
117 :     val opp=DstOp.IndexTensor(id,ix',argTy)
118 :     in
119 :     aaV(opp,[nU],lhs^"I"^Int.toString(id),iTy)
120 :     end
121 : cchiw 2612
122 : cchiw 2845 (*eval Epsilon*)
123 :     fun evalEps(mapp,a,b,c)=let
124 :     val i=mapIndex(E.V a,mapp)
125 :     val j=mapIndex(E.V b,mapp)
126 :     val k=mapIndex(E.V c,mapp)
127 :     in
128 :     if(i=j orelse j=k orelse i=k) then mkReal 0
129 :     else
130 :     if(j>i) then
131 :     if(j>k andalso k>i) then mkReal ~1 else mkReal 1
132 :     else if(i>k andalso k>j) then mkReal 1 else mkReal ~1
133 : cchiw 2612
134 : cchiw 2845 end
135 : cchiw 2612
136 : cchiw 2845 (*eval Epsilon*)
137 :     fun evalEps2(mapp,a,b)=let
138 :     val i=mapIndex(E.V a,mapp)
139 :     val j=mapIndex(E.V b,mapp)
140 :     in
141 :     if(i=j) then mkReal 0
142 :     else
143 :     if(j>i) then mkReal 1
144 :     else mkReal ~1
145 :     end
146 : cchiw 2612
147 : cchiw 2845 (*eval Delta*)
148 :     fun evalDelta2(mapp,a,b)= let
149 :     val i=mapIndex(a,mapp)
150 :     val j=mapIndex(b,mapp)
151 :     in
152 :     if(i=j) then mkReal 1 else mkReal 0
153 :     end
154 : cchiw 2612
155 : cchiw 2845 fun evalDels(mapp,dels)=let
156 :     fun m(a,b)=if(a=b) then 1 else 0
157 :     fun ij(i,j)=(case (i,j)
158 :     of (E.V a, E.V b)=>m(mapIndex(i,mapp),mapIndex(j,mapp))
159 :     | (E.C a, E.V b)=>m(a,mapIndex(j,mapp))
160 :     | (E.V a, E.C b)=>m(mapIndex(i,mapp),b)
161 :     | (E.C a, E.C b)=>m(i,j)
162 :     (*end case*))
163 :     val dels'=List.map ij dels
164 :     in
165 :     List.foldl(fn(x,y)=>x+y) 0 dels'
166 :     end
167 : cchiw 2827
168 : cchiw 2845 (*--------------------Vectorization Helper Functions--------------------*)
169 :     (*val nextfnArgs=(body,params,args,origargs)*)
170 :     fun mkVec(mapp,(id,[],vecIX,(lhs,params,_,args)))= let
171 :     val nU=List.nth(args,id)
172 :     in (nU,[]) end
173 :     | mkVec(mapp,(id,ix,vecIX,(lhs,params,_,args)))= let
174 :     val nU=List.nth(args,id)
175 :     val ix'=DstTy.indexTy(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
176 :     val argTy= getTensorTy(params,id)
177 :     val vecTy=DstTy.TensorTy [vecIX]
178 :     val opp=DstOp.ProjectTensor(id,vecIX,ix',argTy)
179 :     in
180 :     aaV(opp,[nU],lhs^"V"^Int.toString(id),vecTy)
181 :     end
182 : cchiw 2612
183 : cchiw 2845 (*product of -1 and 1 projection*)
184 :     fun mkNegV(mapp,((vA,id,ix),vecIX,info as (lhs,_,_,_)))=let
185 :     val (vB, B)= mkVec(mapp,(id,ix,vecIX,info))
186 :     val (vD, D)=aaV(DstOp.prodScaV vecIX,[vA, vB],lhs^"prodScaV",DstTy.TensorTy [vecIX])
187 :     in
188 :     (vD,B@D)
189 :     end
190 : cchiw 2827
191 : cchiw 2845 (* Vector Subtraction*)
192 :     fun mksubVec(mapp,(id1,ix1,id2,ix2,vecIX,info as (lhs,_,_,_)))= let
193 :     val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
194 :     val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
195 :     val (vD,D)= aaV(DstOp.subVec vecIX ,[vA, vB],lhs^"subVec",DstTy.TensorTy [vecIX])
196 :     in
197 :     (vD, A@B@D)
198 :     end
199 : cchiw 2827
200 : cchiw 2845 (*Vector Addition *)
201 :     fun handleAddVec(mapp,(es,vecIX,info))=let
202 :     fun add([],rest,code)=(rest,code)
203 :     | add((id1,ix1)::es,rest,code)=let
204 :     val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
205 :     in add(es,rest@[vA],code@A)
206 :     end
207 :     val (rest,code)=add(es,[],[])
208 :     val (vA,A)=mkMultiple(info,rest,DstOp.addVec vecIX,DstTy.TensorTy([vecIX]))
209 :     in
210 :     (vA,code@A)
211 :     end
212 : cchiw 2827
213 : cchiw 2845 (*Vector Scaling*)
214 :     fun mkprodScaV(mapp,(id1,ix1,id2,ix2,vecIX,info as (lhs,_,_,_)))=let
215 :     val (vA,A)= mkSca(mapp,(id1,ix1,info))
216 :     val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
217 :     val (vD,D)= aaV(DstOp.prodScaV vecIX,[vA, vB],lhs^"prodScaV",DstTy.TensorTy([vecIX]))
218 :     in
219 :     (vD,A@B@D)
220 :     end
221 : cchiw 2612
222 : cchiw 2845 (*Vector Product*)
223 :     fun mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info as (lhs,_,_,_)))= let
224 : cchiw 2612 val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
225 : cchiw 2845 val (vB, B)= mkVec(mapp,(id2,ix2,vecIX,info))
226 :     val (vD, D)=aaV(DstOp.prodVec vecIX,[vA, vB],lhs^"prodV",DstTy.TensorTy([vecIX]))
227 :     in
228 :     (vD, A@B@D)
229 : cchiw 2612 end
230 :    
231 : cchiw 2845 (*Sum of Vector Product*)
232 :     fun mkprodSumVec(mapp,(id1,ix1,id2,ix2,vecIX, info as (lhs,_,_,_)))=let
233 :     val (vD,D)=mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info))
234 :     val (vE, E)=aaV(DstOp.sumVec vecIX,[vD], lhs^"sumVec",DstTy.realTy)
235 :     in
236 :     (vE, D @E)
237 :     end
238 : cchiw 2612
239 : cchiw 2845 (*Dot Product like summation *)
240 :     fun sumDot(mapp, ((E.V v,lb,ub),t as (_,_,_,_,_,info) ))=let
241 : cchiw 2612
242 : cchiw 2845 fun sumI(a,0,rest,code)=let
243 :     val mapp =insert(v, 0) a
244 :     val (vE, E)=mkprodSumVec(mapp,t)
245 :     val rest'=[vE]@rest
246 :     val (vF, F)=mkMultiple(info,rest',addR,Sca)
247 :     in
248 :     (vF,E@code@F)
249 :     end
250 :     | sumI(a,sx,rest',code')=let
251 :     val mapp =insert(v, (sx+lb)) a
252 :     val (vE, E)=mkprodSumVec(mapp,t)
253 :     in
254 :     sumI(a,sx-1,[vE]@rest',E@code')
255 :     end
256 : cchiw 2612 in
257 : cchiw 2845 sumI(mapp, (ub-lb), [],[])
258 : cchiw 2612 end
259 : cchiw 2845 | sumDot _= raise Fail "Non-variable index in summation"
260 : cchiw 2612
261 :     end
262 :    
263 :    
264 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0