Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/step3.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/step3.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2628 - (view) (download)

1 : cchiw 2615 (*Helper function gen-ein*)
2 : cchiw 2612 structure step3 = struct
3 :     local
4 :    
5 :     structure DstIL = LowIL
6 :     structure DstTy = LowILTypes
7 :     structure DstOp = LowOps
8 :     structure Var = LowIL.Var
9 :     structure SrcIL = MidIL
10 :     structure SrcOp = MidOps
11 :     structure E = Ein
12 :     structure tS= toStringEin
13 : cchiw 2627
14 : cchiw 2615
15 : cchiw 2612 in
16 :    
17 : cchiw 2615 val testing=1
18 : cchiw 2612 val bV= ref 0
19 :     fun err str=raise Fail(str)
20 : cchiw 2620 val Sca=DstTy.TensorTy []
21 :     fun lookup k d = d k
22 :     fun q e1=Int.toString e1
23 : cchiw 2612
24 :     fun insert (key, value) d =fn s =>
25 :     if s = key then (print(String.concat[Int.toString(key),"=>",Int.toString(value)]);SOME value)
26 :     else d s
27 :    
28 :    
29 :     (*Get kernel and Image bindings*)
30 :     fun getKernel x = (case SrcIL.Var.binding x
31 :     of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _),_))=> h
32 :     | vb => (err (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
33 :     (* end case *))
34 :    
35 : cchiw 2620
36 : cchiw 2612 fun getImage x = (case SrcIL.Var.binding x
37 :     of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage(img),_))=> img
38 :     | vb => (err (String.concat["\n -- Not an image, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
39 :     (* end case *))
40 :    
41 :     (*Make assignment*)
42 :     fun aaV(opss,args,pre,ty)=let
43 :     val a=DstIL.Var.new(pre ,ty)
44 :     val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
45 :     val _ =(case testing
46 :     of 0=> 1
47 : cchiw 2627 | _ => (print(tS.toStringAll(ty,code)); 1)
48 : cchiw 2612 (* end case *))
49 :     in
50 :     (a,[code])
51 :     end
52 :    
53 : cchiw 2628 (*
54 : cchiw 2620 fun mkInt n= aaV(DstOp.C n,[],"Int",Sca)
55 : cchiw 2628 *)
56 : cchiw 2624 fun mkInt n=let
57 :     val a=DstIL.Var.new("Int" ,Sca)
58 :     val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int n))
59 :     val _ =(case testing
60 :     of 0=> 1
61 :     | _ => (print(tS.toStringAll(Sca,code));1)
62 :     (*end case*))
63 :     in
64 :     (a,[code])
65 :     end
66 : cchiw 2620
67 : cchiw 2628 fun mkInt2 n=let
68 :     val a=DstIL.Var.new("Int" ,Sca)
69 :     val code=DstIL.ASSGN (a,DstIL.LIT( n))
70 :     val _ =(case testing
71 :     of 0=> 1
72 :     | _ => (print(tS.toStringAll(Sca,code));1)
73 :     (*end case*))
74 :     in
75 :     (a,[code])
76 :     end
77 : cchiw 2620
78 : cchiw 2624
79 : cchiw 2628
80 :    
81 :    
82 : cchiw 2612 (*mk Multiple, Add Ids on list1*)
83 :     fun mkMultiple(list1,rator,ty)=let
84 :     fun add(e,code)=(case e
85 :     of [] => err"no element in mkMultiple"
86 :     | [e1] => (e1,[])
87 :     | [e1,e2] => let
88 :     val (vA,A)=aaV(rator,[e1,e2],"MO",ty)
89 :     in (vA,code@A) end
90 :     | (e1::e2::es) => let
91 :     val (vA,A)=aaV(rator,[e1,e2],"MO",ty)
92 :     in add(vA::es,code@A)
93 :     end
94 :     (*end case*))
95 :     in
96 :     add(list1,[])
97 :     end
98 :    
99 :    
100 :     fun mapIndex(e1,mapp)=(case e1
101 :     of E.V e => (case (lookup e mapp)
102 :     of NONE=> err("Outside Bound:"^Int.toString(e))
103 :     |SOME s => s)
104 :     | E.C c=> c
105 :     (*end case*))
106 :    
107 :    
108 :     (*Integer, or Generic Tensor*)
109 :     fun getTensorTy(params, id)=(case List.nth(params,id)
110 :     of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
111 :     | E.TEN(_,shape)=> DstTy.TensorTy shape
112 :     |_=> err"NONE Tensor Param")
113 :    
114 :    
115 : cchiw 2628 (*Just added Index options*)
116 :     fun mkSca(mapp,(id, [],(params,args)))=let
117 :     val nU=List.nth(args,id)
118 :     in
119 :     (nU,[])
120 :     end
121 :     | mkSca(mapp,(id, [ix],(params,args)))=let
122 : cchiw 2612 val nU=List.nth(args,id)
123 : cchiw 2628 val ix'=mapIndex(ix,mapp)
124 :     val argTy=getTensorTy(params,id)
125 :     in
126 :     aaV(DstOp.Index(argTy, ix'),[nU],"S-Index-"^Int.toString(id),Sca)
127 :     end
128 :    
129 :     | mkSca(mapp,(id,ix,(params,args)))= let
130 :     val nU=List.nth(args,id)
131 : cchiw 2612 val ix'=DstTy.indexTy(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
132 :     val argTy=getTensorTy(params,id)
133 :     in
134 :     aaV(DstOp.S(id,ix',argTy),[nU],"S"^Int.toString(id),Sca)
135 :     end
136 :    
137 :     (*eval Epsilon*)
138 :     fun evalEps(mapp,a,b,c)=let
139 :     val i=mapIndex(E.V a,mapp)
140 :     val j=mapIndex(E.V b,mapp)
141 :     val k=mapIndex(E.V c,mapp)
142 :     in
143 :     if(i=j orelse j=k orelse i=k) then mkInt 0
144 :     else
145 :     if(j>i) then
146 :     if(j>k andalso k>i) then mkInt ~1 else mkInt 1
147 :     else if(i>k andalso k>j) then mkInt 1 else mkInt ~1
148 :    
149 :     end
150 :    
151 :     (*eval Delta*)
152 :     fun evalDelta2(mapp,a,b)= let
153 :     val i=mapIndex(a,mapp)
154 :     val j=mapIndex(b,mapp)
155 :     in
156 :     if(i=j) then mkInt 1 else mkInt 0
157 :     end
158 :    
159 :    
160 :     fun evalDels(mapp,dels)=let
161 :     fun m(a,b)=if(a=b) then 1 else 0
162 :     fun ij(i,j)=(case (i,j)
163 :     of (E.V a, E.V b)=>m(mapIndex(i,mapp),mapIndex(j,mapp))
164 :     | (E.C a, E.V b)=>m(a,mapIndex(j,mapp))
165 :     | (E.V a, E.C b)=>m(mapIndex(i,mapp),b)
166 :     | (E.C a, E.C b)=>m(i,j)
167 :     (*end case*))
168 :     val dels'=List.map ij dels
169 :     in
170 :     List.foldl(fn(x,y)=>x+y) 0 dels'
171 :     end
172 :    
173 :    
174 :     (*--------------------Vectorization Helper Functions--------------------*)
175 :     (*val nextfnArgs=(body,params,args,origargs)*)
176 :    
177 :     fun mkVec(mapp,(id,ix,vecIX,(params,args)))= let
178 :     val nU=List.nth(args,id)
179 :     val ix'=DstTy.indexTy(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
180 :     val argTy= getTensorTy(params,id)
181 :     in
182 :     aaV(DstOp.V(id,vecIX, ix',argTy),[nU],"V"^Int.toString(id),DstTy.TensorTy [vecIX])
183 :     end
184 :    
185 :    
186 :     (*product of -1 and 1 projection*)
187 :     fun mkNegV(mapp,((vA,id,ix),vecIX,info))=let
188 :     val (vB, B)= mkVec(mapp,(id,ix,vecIX,info))
189 :     val (vD, D)=aaV(DstOp.prodScaV vecIX,[vA, vB],"prodScaV",DstTy.TensorTy [vecIX])
190 :     in
191 :     (vD,B@D)
192 :     end
193 :    
194 :     (* Vector Subtraction*)
195 :     fun mksubVec(mapp,(id1,ix1,id2,ix2,vecIX,info))= let
196 :     val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
197 :     val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
198 :     val (vD,D)= aaV(DstOp.subVec vecIX ,[vA, vB],"subVec",DstTy.TensorTy [vecIX])
199 :     in
200 :     (vD, A@B@D)
201 :     end
202 :    
203 :     (*Vector Addition *)
204 :     fun handleAddVec(mapp,(es,vecIX,info))=let
205 :     fun add([],rest,code)=(rest,code)
206 :     | add((id1,ix1)::es,rest,code)=let
207 :     val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
208 :     in add(es,rest@[vA],code@A)
209 :     end
210 :     val (rest,code)=add(es,[],[])
211 :     val (vA,A)=mkMultiple(rest,DstOp.addVec(vecIX),DstTy.TensorTy([vecIX]))
212 :     in
213 :     (vA,code@A)
214 :     end
215 :    
216 :     (*Vector Scaling*)
217 :     fun mkprodScaV(mapp,(id1,ix1,id2,ix2,vecIX,info))=let
218 :     val (vA,A)= mkSca(mapp,(id1,ix2,info))
219 :     val (vB,B)= mkVec(mapp,(id2,ix2,vecIX,info))
220 :     val (vD,D)= aaV(DstOp.prodScaV(vecIX),[vA, vB],"prodScaV",DstTy.TensorTy([vecIX]))
221 :     in
222 :     (vD,A@B@D)
223 :     end
224 :    
225 :     (*Vector Product*)
226 :     fun mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info))= let
227 :     val (vA,A)= mkVec(mapp,(id1,ix1,vecIX,info))
228 :     val (vB, B)= mkVec(mapp,(id2,ix2,vecIX,info))
229 :     val (vD, D)=aaV(DstOp.prodVec(vecIX),[vA, vB],"prodV",DstTy.TensorTy([vecIX]))
230 :     in
231 :     (vD, A@B@D)
232 :     end
233 :    
234 :     (*Sum of Vector Product*)
235 :     fun mkprodSumVec(mapp,(id1,ix1,id2,ix2,vecIX, info))= let
236 :     val (vD,D)=mkprodVec(mapp,(id1,ix1,id2,ix2,vecIX, info))
237 :     val (vE, E)=aaV(DstOp.sumVec vecIX,[vD],"sumVec",DstTy.realTy)
238 :     in
239 :     (vE, D @E)
240 :     end
241 :    
242 :     (*Dot Product like summation *)
243 :     fun sumDot(mapp, ((E.V v,lb,ub),t))=let
244 :    
245 :     fun sumI(a,0,rest,code)=let
246 :     val mapp =insert(v, 0) a
247 :     val (vE, E)=mkprodSumVec(mapp,t)
248 :     val rest'=[vE]@rest
249 :     val (vF, F)=mkMultiple(rest',DstOp.addSca,Sca)
250 :     in
251 :     (vF,E@code@F)
252 :     end
253 :     | sumI(a,sx,rest',code')=let
254 :     val mapp =insert(v, (sx+lb)) a
255 :     val (vE, E)=mkprodSumVec(mapp,t)
256 :     in
257 :     sumI(a,sx-1,[vE]@rest',E@code')
258 :     end
259 :     in
260 :     sumI(mapp, (ub-lb), [],[])
261 :     end
262 : cchiw 2615 | sumDot _= raise Fail "Non-variable index in summation"
263 : cchiw 2612
264 :     end
265 :    
266 :    
267 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0