Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/helper.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/helper.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3263 - (view) (download)

1 : cchiw 2859 (*Helper functions
2 :     *)
3 :     structure Helper = struct
4 :     local
5 :    
6 :     structure DstIL = LowIL
7 :     structure DstTy = LowILTypes
8 :     structure DstOp = LowOps
9 :     structure Var = LowIL.Var
10 :     structure SrcIL = MidIL
11 :     structure SrcOp = MidOps
12 :     structure E = Ein
13 :     structure LowToS= LowToString
14 : cchiw 2870 structure FL=FloatLit
15 : jhr 2924 structure IMap = IntRedBlackMap
16 : cchiw 2859 in
17 :    
18 : cchiw 3196 val testing=0
19 : cchiw 2859 val bV= ref 0
20 :     fun err str=raise Fail(str)
21 :     val realTy=DstTy.TensorTy []
22 :     val intTy=DstTy.intTy
23 :     fun iTos e1=Int.toString e1
24 :     fun iToss es=String.concat(List.map iTos es)
25 :     fun testp n= (case testing
26 : jhr 2924 of 0 => 0
27 : cchiw 2859 | _ =>((print (String.concat n));1)
28 :     (*end case*))
29 :    
30 : jhr 2924 val empty = IMap.empty
31 :     fun lookup k d = IMap.find(d, k)
32 :     fun insert (k, v) d = IMap.insert(d, k, v)
33 :     fun find (v, mapp) = (case IMap.find(mapp, v)
34 :     of NONE => raise Fail(concat["Outside Bound(", Int.toString v, ")"])
35 :     | SOME s => s
36 :     (* end *))
37 : cchiw 2859
38 :     (*mapIndex:E.mu * dict-> int
39 :     * lookup
40 :     *)
41 : jhr 2924 fun mapIndex (e1, mapp)=(case e1
42 : cchiw 2859 of E.V e => find(e,mapp)
43 : jhr 2924 | E.C c => c
44 : cchiw 2859 (*end case*))
45 :    
46 :    
47 : cchiw 2870 (* intToReal:int->Var*LowIL.ASSGN list
48 : cchiw 2859 *)
49 : cchiw 2870 fun intToReal n=let
50 :     val a=DstIL.Var.new("real" ,realTy)
51 :     val b=DstIL.Var.new("cast" ,realTy)
52 :     val code=[DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n))),
53 :     DstIL.ASSGN (b,DstIL.OP(DstOp.IntToReal,[a]))]
54 : cchiw 2859 in
55 : cchiw 2870 (b,code)
56 : cchiw 2859 end
57 :    
58 :     (* mkINt:int->Var*LowIL.ASSGN list
59 :     *)
60 :     fun mkInt n=let
61 :     val a=DstIL.Var.new("Int" ,intTy)
62 :     val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
63 :     val _ =testp[LowToS.toStringAll(intTy,code)]
64 :     in
65 :     (a,[code])
66 :     end
67 :    
68 :     (*assgnCons int list * Var list->Var*LowIL.ASSGN list
69 :     * cons elements on list
70 :     *)
71 :     fun assgnCons(pre,shape, args)=let
72 :     val ty=DstTy.TensorTy shape
73 : cchiw 2870 val a=DstIL.Var.new("cons"^"_" ,ty)
74 : cchiw 2859 val code=DstIL.ASSGN (a,DstIL.CONS(ty ,args))
75 :     val _ =testp[LowToS.toStringAll(ty,code)]
76 :     in
77 :     (a, [code])
78 :     end
79 :    
80 :     (*LowOps.Op.op * var list*string*LowIL.Ty
81 :     * -> Var*LowIL.ASSGN list
82 :     * Make lowIL assignment
83 :     *)
84 :     fun assgn(opss,args,pre,ty)=let
85 :     val a=DstIL.Var.new(pre ,ty)
86 :     val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
87 : cchiw 3263 val _ =print(String.concat[LowToS.toStringAll(ty,code)])
88 : cchiw 3261
89 : cchiw 2859 in
90 :     (a,[code])
91 :     end
92 :    
93 :     (*getTensorTy:E.params*E.tensor_id-> LowIL.Ty
94 :     * Integer, or Generic Tensor
95 :     *)
96 :     fun getTensorTy(params, id)=(case List.nth(params,id)
97 :     of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
98 :     | E.TEN(_,shape)=> DstTy.TensorTy shape
99 :     |_=> err"NONE Tensor Param"
100 :     (*end case*))
101 :    
102 :     (* indexTensor:dict*string*E.params*Var list*E.tensor_id*E.alpha
103 :     * ->Var*LowIL.ASSGN list
104 :     * Index Tensor at specific indices to give a scalar result
105 :     *)
106 : cchiw 3261 fun indexTensor(_,(_,_,args,id, [],ty)) = (List.nth(args,id),[])
107 :     | indexTensor(_,(_,_,_,_, [_,_,_],DstTy.TensorTy [_,_,_,_] )) = raise Fail "uneven"
108 : cchiw 2859 | indexTensor(mapp,(lhs,params,args,id,ix,ty))= let
109 :     val nU=List.nth(args,id)
110 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
111 :     val ix'=DstTy.indexTy ixx
112 :     val argTy=getTensorTy(params,id)
113 :     val opp=DstOp.IndexTensor(id,ix',argTy)
114 : cchiw 2870 val name=String.concat["Indx_",iToss ixx,"_"]
115 : cchiw 2859 in
116 :     assgn(opp,[nU],name,ty)
117 :     end
118 :    
119 :     (*projTensor:dict*(string* E.params*Var list*int*E.tensor_id*E.alpha)->Var*LowIL.ASSGN list
120 :     * projects tensor to a vector
121 :     *just used by EintoVecOps but made sense to keep it here
122 :     *)
123 : cchiw 3196 fun (* projTensor(_,(_,_,args,3,id,[]))=let
124 :     val opp=DstOp.LdVec 3
125 :     val nU=List.nth(args,id)
126 :     val name=String.concat["LdVec_"]
127 :     val vecTy=DstTy.TensorTy [3]
128 :     in
129 :     assgn(opp,[nU],name,vecTy)
130 :     end
131 :     |*) projTensor(_,(_,_,args,_,id,[]))= (List.nth(args,id),[])
132 : cchiw 2859 | projTensor(mapp,(lhs,params,args,vecIX,id,ix))= let
133 :     val nU=List.nth(args,id)
134 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
135 :     val ix'=DstTy.indexTy ixx
136 :     val argTy= getTensorTy(params,id)
137 :     val vecTy=DstTy.TensorTy [vecIX]
138 : cchiw 3195 val opp=DstOp.ProjectLast(id,vecIX,ix',argTy)
139 :     val name=String.concat["ProjLast_",iToss ixx,"_"]
140 : cchiw 2859 in
141 :     assgn(opp,[nU],name,vecTy)
142 :     end
143 :    
144 : cchiw 3195
145 :     fun projFirst(mapp,(lhs,params,args,vecIX,id,ix))= let
146 :     val nU=List.nth(args,id)
147 :     val argTy= getTensorTy(params,id)
148 :     val vecTy=DstTy.TensorTy [vecIX]
149 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
150 :     fun f cnt = let
151 :     val ix'=DstTy.indexTy ([cnt]@ixx)
152 :     val opp=DstOp.IndexTensor(id,ix',argTy)
153 :     val name=String.concat["IndexTensor_",iToss ixx,"_"]
154 :     in
155 :     assgn(opp,[nU],name,realTy)
156 :     end
157 :     val ops=List.tabulate( vecIX, fn e=> f e)
158 :     fun iter ([],vCs,Cs)=(vCs,Cs)
159 :     | iter((vB,B)::es,vCs,Cs)= iter(es,vCs@[vB],Cs@B)
160 :    
161 :     val (vCs,Cs)=iter(ops,[],[])
162 :     val (vD,D)=assgnCons("projFirstCons", [vecIX], vCs)
163 :     in
164 :     (vD, Cs@D)
165 :     end
166 :    
167 :    
168 :    
169 :     fun projFirst2(mapp,(lhs,params,args,vecIX,id,ix))= let
170 :     val nU=List.nth(args,id)
171 :     val argTy= getTensorTy(params,id)
172 :     val vecTy=DstTy.TensorTy [vecIX]
173 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
174 :     val ix'=DstTy.indexTy ixx
175 :     val opp=DstOp.ProjectFirst(id,vecIX,ix',argTy)
176 :     val name=String.concat["ProjFirst_",iToss ixx,"_"]
177 :     in assgn(opp,[nU],name,vecTy)
178 :     end
179 :    
180 :    
181 :    
182 : cchiw 2867 fun mkSqrt(nU,code)= let
183 :     val opp=DstOp.Sqrt
184 : cchiw 2870 val name=String.concat["_Sqrt_"]
185 : cchiw 2867 val (vA,A)=assgn(opp,[nU],name,realTy)
186 :     in
187 :     (vA,code@A)
188 :     end
189 :    
190 : cchiw 3138 fun mkCosine(nU,code)= let
191 :     val opp=DstOp.Cosine
192 :     val name=String.concat["_Cosine_"]
193 :     val (vA,A)=assgn(opp,[nU],name,realTy)
194 :     in
195 :     (vA,code@A)
196 :     end
197 : cchiw 2867
198 : cchiw 3138 fun mkArcCosine(nU,code)= let
199 :     val opp=DstOp.ArcCosine
200 :     val name=String.concat["_ArcCosine_"]
201 :     val (vA,A)=assgn(opp,[nU],name,realTy)
202 :     in
203 :     (vA,code@A)
204 :     end
205 :    
206 :     fun mkSine(nU,code)= let
207 :     val opp=DstOp.Sine
208 :     val name=String.concat["_Sine_"]
209 :     val (vA,A)=assgn(opp,[nU],name,realTy)
210 :     in
211 :     (vA,code@A)
212 :     end
213 :    
214 :     fun mkArcSine(nU,code)= let
215 :     val opp=DstOp.ArcSine
216 :     val name=String.concat["_ArcSine_"]
217 :     val (vA,A)=assgn(opp,[nU],name,realTy)
218 :     in
219 :     (vA,code@A)
220 :     end
221 :    
222 :    
223 : cchiw 2870 fun mkPowInt((nU,code),n)= let
224 :     val opp=DstOp.powInt
225 :     val name=String.concat["_PowInt_"]
226 :     val (r,rcode)=mkInt n
227 :     val (vA,A)=assgn(opp,[nU,r],name,realTy)
228 :     in
229 :     (vA,code@rcode@A)
230 :     end
231 : cchiw 2867
232 : cchiw 2870 fun mkPowRat((nU,code),rat)= let
233 :     val opp=DstOp.powRat(DstTy.R rat)
234 :     val name=String.concat["_PowRat_"]
235 :    
236 :     val (vA,A)=assgn(opp,[nU],name,realTy)
237 :     in
238 :     (vA,code@A)
239 :     end
240 :    
241 :    
242 :    
243 :    
244 :    
245 :    
246 : cchiw 2859 (* Some shortcuts. Arguements are Low-IL variables already indexed/projected
247 :     * string*Var list ->Var*LowIL.ASSGN list
248 :     *)
249 : cchiw 2870 fun mkAddSca(lhs,args)= assgn(DstOp.addSca,args,"addSca",realTy)
250 :     fun mkAddInt(lhs,args)= assgn(DstOp.addSca,args,"addInt",intTy)
251 :     fun mkAddPtr(lhs,args,ty)= assgn(DstOp.addSca,args,"addPtr",ty)
252 :     fun mkAddVec(lhs,vecIX,args)=assgn(DstOp.addVec vecIX,args,"addV",DstTy.TensorTy([vecIX]))
253 :     fun mkSubSca(lhs,args)= assgn(DstOp.subSca,args,"subSca",realTy)
254 :     fun mkProdSca(lhs,args)=assgn(DstOp.prodSca,args,"prodSca",realTy)
255 :     fun mkProdInt(lhs,args)=assgn(DstOp.prodSca,args,"prodInt",intTy)
256 :     fun mkProdVec(lhs,vecIX,args)=assgn(DstOp.prodVec vecIX,args,"prodV",DstTy.TensorTy([vecIX]))
257 :     fun mkDivSca(lhs,args)= assgn(DstOp.divSca,args,"divSca",realTy)
258 :     fun mkSumVec(lhs,vecIX,args)= assgn(DstOp.sumVec vecIX,args,"sumVec",realTy)
259 : cchiw 2859 fun mkDotVec(lhs,vecIX,args)=let
260 : cchiw 2870 val (vD, D)=mkProdVec("",vecIX,args)
261 :     val (vE, E)=mkSumVec("",vecIX,[vD])
262 : cchiw 2859 in (vE,D@E) end
263 :    
264 :    
265 :     (*mkMultiple:string*Var list*LowOps.Op *ListIL.Ty -> Var*LowIL.ASSGN list
266 :     *apply rator between each items on list1
267 :     *)
268 :     fun mkMultiple(lhs,list1,rator,ty)=let
269 :     fun add([],_,_) = err"no element in mkMultiple"
270 :     | add([e1],_,_) = (e1,[])
271 :     | add([e1,e2],code,_) = let
272 : cchiw 2870 val (vA,A)=assgn(rator,[e1,e2],"mult_2",ty)
273 : cchiw 2859 in (vA,code@A) end
274 :     | add(e1::e2::es,code,count) = let
275 : cchiw 2870 val (vA,A)=assgn(rator,[e1,e2],String.concat["mult_",iTos count],ty)
276 : cchiw 2859 in add(vA::es,code@A,count-1)
277 :     end
278 :     in
279 :     add(list1,[],List.length list1)
280 :     end
281 :    
282 :     (* deltaToInt:dict*E.mu*E.mu->int
283 :     * delta function
284 :     *)
285 :     fun deltaToInt(mapp,a,b)= let
286 :     val i=mapIndex(a,mapp)
287 :     val j=mapIndex(b,mapp)
288 :     in
289 :     if(i=j) then 1 else 0
290 :     end
291 :    
292 : cchiw 2870 fun evalDelta(mapp,a,b)= intToReal(deltaToInt(mapp,a,b))
293 : cchiw 2859
294 :     (*eval Epsilon-2d*)
295 :     fun evalEps2(mapp,a,b)=let
296 :     val i=mapIndex(E.V a,mapp)
297 :     val j=mapIndex(E.V b,mapp)
298 :     in
299 : cchiw 2870 if(i=j) then intToReal 0
300 : cchiw 2859 else
301 : cchiw 2870 if(j>i) then intToReal 1
302 :     else intToReal ~1
303 : cchiw 2859 end
304 :    
305 :     (*eval Epsilon-3d*)
306 :     fun evalEps3(mapp,a,b,c)=let
307 :     val i=mapIndex(E.V a,mapp)
308 :     val j=mapIndex(E.V b,mapp)
309 :     val k=mapIndex(E.V c,mapp)
310 :     in
311 : cchiw 2870 if(i=j orelse j=k orelse i=k) then intToReal 0
312 : cchiw 2859 else
313 :     if(j>i) then
314 : cchiw 2870 if(j>k andalso k>i) then intToReal ~1 else intToReal 1
315 :     else if(i>k andalso k>j) then intToReal 1 else intToReal ~1
316 : cchiw 2859
317 :     end
318 :    
319 :     end
320 :    
321 : jhr 2924 end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0