Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee_dev/src/compiler/low-il/helper.sml
ViewVC logotype

Annotation of /branches/charisee_dev/src/compiler/low-il/helper.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3606 - (view) (download)

1 : cchiw 3541 (*Helper functions
2 :     *)
3 :     structure HelperSet = struct
4 :     local
5 :    
6 :     structure DstIL = LowIL
7 :     structure DstTy = LowILTypes
8 :     structure DstOp = LowOps
9 :     structure Var = LowIL.Var
10 :     structure SrcIL = MidIL
11 :     structure SrcOp = MidOps
12 :     structure E = Ein
13 :     structure LowToS= LowToString
14 :     structure FL=FloatLit
15 :     structure IMap = IntRedBlackMap
16 :     in
17 :    
18 :     val testing=0
19 : cchiw 3606 val valnum = true
20 : cchiw 3541 val bV= ref 0
21 :     fun err str=raise Fail(str)
22 :     val realTy=DstTy.TensorTy []
23 :     val intTy=DstTy.intTy
24 :     fun iTos e1=Int.toString e1
25 :     fun iToss es=String.concat(List.map iTos es)
26 :     fun testp n= (case testing
27 :     of 0 => 0
28 :     | _ =>((print (String.concat n));1)
29 :     (*end case*))
30 :     fun incUse (LowIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
31 :     val empty = IMap.empty
32 :     fun lookup k d = IMap.find(d, k)
33 :     fun insert (k, v) d = ((String.concat["\n\t",Int.toString(k),"==>",Int.toString(v)]) ;IMap.insert(d, k, v))
34 :     fun insertP (k, v) d = ((String.concat["\n\t",Int.toString(k),"==>",Int.toString(v)]) ;IMap.insert(d, k, v))
35 :     fun find (v, mapp) = (case IMap.find(mapp, v)
36 :     of NONE => raise Fail(concat["Outside Bound(", Int.toString v, ")"])
37 :     | SOME s => s
38 :     (* end *))
39 :    
40 :     (*mapIndex:E.mu * dict-> int
41 :     * lookup
42 :     *)
43 :     fun mapIndex (e1, mapp)=(case e1
44 :     of E.V e => find(e,mapp)
45 :     | E.C(c,_) => c
46 :     (*end case*))
47 :    
48 :     (* *************************** DstIL.ASSGN **************************** *)
49 :     fun noplugin(opset,ty,lhs,rhs)=let
50 :     (*val _=List.map incUse [lhs]*)
51 :     val codeo=DstIL.ASSGN (lhs,rhs)
52 :     val (a,code)=(lhs,[codeo])
53 :     in
54 :     (opset,a,code)
55 :     end
56 :    
57 :    
58 :     fun plugin(opset,ty,lhs,rhs)=(case valnum
59 :     of false => noplugin(opset,ty,lhs,rhs)
60 :     | _ =>
61 :     let
62 :     val (opset,var) = lowSet.filter(opset,(lhs,rhs))
63 :     val _=List.map incUse [lhs] (*FIXME*)
64 :     val codeo=DstIL.ASSGN (lhs,rhs)
65 :     val _=(LowToS.toStringAll(ty,codeo))
66 :     val (a,code)=(case var
67 :     of SOME v=> (testp["\n Found(",DstIL.Var.toString(v),"):",LowToS.toStringAll(ty,codeo)]; (v,[]))
68 :     | NONE => (testp["\n Inserting:",LowToS.toStringAll(ty,codeo)];(lhs,[codeo]))
69 :    
70 :     (*end case*))
71 :     in
72 :     (opset,a,code)
73 :     end
74 :     (*end case *))
75 :    
76 :     (* *************************** DstIL.LIT **************************** *)
77 :    
78 :     (* mkINt:int->Var*code list*)
79 :     fun mkInt (opset,n)=let
80 :     val lhs=DstIL.Var.new("Int" ,intTy)
81 :     val rhs=DstIL.LIT(Literal.Int(IntInf.fromInt n))
82 :     in
83 :     plugin(opset,intTy,lhs,rhs)
84 :     end
85 :    
86 :     fun mkReal (opset,n)=let
87 :     val lhs=DstIL.Var.new("real" ,realTy)
88 :     val rhs=DstIL.LIT(Literal.Int(IntInf.fromInt n))
89 :     in
90 :     plugin(opset,realTy,lhs,rhs)
91 :     end
92 :    
93 :     (* *************************** DstIL.CONS **************************** *)
94 :     fun assgnCons(opset,pre,shape, args)=let
95 :     val ty=DstTy.TensorTy shape
96 :     val lhs = DstIL.Var.new("cons"^"_" ,ty)
97 :     val rhs = DstIL.CONS(ty ,args)
98 :     in
99 :     plugin(opset,ty,lhs,rhs)
100 :     end
101 :    
102 :     fun assgnConsV(opset,pre,ty, args)=let
103 :     val lhs = DstIL.Var.new("cons"^"_" ,ty)
104 :     val rhs = DstIL.CONS(ty ,args)
105 :     in plugin(opset,ty,lhs,rhs) end
106 :    
107 :    
108 :     (* *************************** DstIL.OP **************************** *)
109 :     fun assignOP(opset,opss,args,pre,ty)=let
110 :     val lhs=DstIL.Var.new(pre ,ty)
111 :     val rhs=DstIL.OP(opss,args)
112 :     in
113 :     plugin(opset,ty,lhs,rhs)
114 :     end
115 :    
116 :     fun mkSingle(opp,name,(opset,nU,code))=let
117 :     val (opset,vA,A)=assignOP(opset,opp,[nU],name,realTy)
118 :     in
119 :     (opset,vA,code@A)
120 :     end
121 :    
122 :     (* *************************** DstOp.IndexTensor **************************** *)
123 :     (*getTensorTy:E.params*E.tensor_id-> LowIL.Ty
124 :     * Integer, or Generic Tensor
125 :     *)
126 :     fun getTensorTy(params, id)=(case List.nth(params,id)
127 :     of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
128 :     | E.TEN(2,shape) =>DstTy.indexTy shape
129 :     | E.TEN(_,shape)=> DstTy.TensorTy shape
130 :     |_=> err"NONE Tensor Param"
131 :     (*end case*))
132 :    
133 :     (* indexTensor:dict*string*E.params*Var list*E.tensor_id*E.alpha
134 :     * ->Var*code list
135 :     * Index Tensor at specific indices to give a scalar result
136 :     *)
137 :     fun indexTensor(opset,_,(lhs,params,args,id, [] ,ty)) = (opset,List.nth(args,id),[])
138 :     | indexTensor(_,_,(lhs,params,args,id, [_,_,_],DstTy.TensorTy [_,_,_,_] )) = raise Fail "uneven"
139 :     | indexTensor(opset,mapp,(lhs,params,args,id,ix,ty))= let
140 :     val nU=List.nth(args,id)
141 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
142 : cchiw 3547 val ix'= ixx
143 : cchiw 3541 val argTy=getTensorTy(params,id)
144 :     val opp=DstOp.IndexTensor(id,ix',argTy)
145 :     val name=String.concat["Indx_",iToss ixx,"_"]
146 :     in
147 :     assignOP(opset,opp,[nU],name,ty)
148 :     end
149 :    
150 :    
151 :     (* *************************** DstOp._ Shortcuts **************************** *)
152 :    
153 :     (* Some shortcuts. Arguements are Low-IL variables already indexed/projected
154 :     * string*Var list ->Var*code list
155 :     *)
156 :     fun mkAddSca(opset,args)= assignOP(opset,DstOp.addSca,args,"addSca",realTy)
157 :     fun mkAddInt(opset,args)= assignOP(opset,DstOp.addSca,args,"addInt",intTy)
158 :     fun mkAddPtr(opset,args,ty)= assignOP(opset,DstOp.addSca,args,"addPtr",ty)
159 :     fun mkAddVec(opset,vecIX,args)=assignOP(opset,DstOp.addVec vecIX,args,"addV",DstTy.TensorTy([vecIX]))
160 :     fun mkSubSca(opset,args)= assignOP(opset,DstOp.subSca,args,"subSca",realTy)
161 :     fun mkProdSca(opset,args)=assignOP(opset,DstOp.prodSca,args,"prodSca",realTy)
162 :     fun mkProdInt(opset,args)=assignOP(opset,DstOp.prodSca,args,"prodInt",intTy)
163 :     fun mkProdVec(opset,vecIX,args)=assignOP(opset,DstOp.prodVec vecIX,args,"prodV",DstTy.TensorTy([vecIX]))
164 :     fun mkDivSca(opset,args)= assignOP(opset,DstOp.divSca,args,"divSca",realTy)
165 :     fun mkSumVec(opset,vecIX,args)= assignOP(opset,DstOp.sumVec vecIX,args,"sumVec",realTy)
166 :    
167 :    
168 :     (* *************************** DstOp. Other **************************** *)
169 :     fun mkDotVec(opset,vecIX,args)=let
170 :     val (opsetD,vD, D)=mkProdVec(opset,vecIX,args)
171 :     val (opsetE,vE, E)=mkSumVec(opsetD,vecIX,[vD])
172 :     in (opsetE,vE,D@E) end
173 :    
174 :     fun intToReal(setA,n)=let
175 :     val (setC,vC,C)=mkReal(setA,n)
176 :     val (setD,vD,D)=assignOP(setC,DstOp.IntToReal,[vC],"cast",realTy)
177 :     in (setD,vD,C@D)end
178 :    
179 :     fun mkPowInt((opset,nU,code),nn)= let
180 :     fun pow(1,setA)=(setA,nU,[])
181 :     | pow(2,setA)=let
182 :     val opp=DstOp.prodSca
183 :     val name=String.concat["_Pow2_"]
184 :     val (setB,vB,B)=assignOP(setA,opp,[nU,nU],name,intTy)
185 :     in
186 :     (setB,vB,B)
187 :     end
188 :     | pow(n,setA)=let
189 :     fun half m= let
190 :     val (setB,vB,B)= pow(m div 2,setA)
191 :     val opp=DstOp.prodSca
192 :     val name=String.concat["_Pow",Int.toString(m),"_"]
193 :     val (setC,vC,C)= assignOP(setB,opp,[vB,vB],name,intTy)
194 :     in (setC,vC,B@C) end
195 :     in if ((n mod 2) = 0)
196 :     then half n
197 :     else let
198 :     val (setC,vC,C)=half(n-1)
199 :     val opp=DstOp.prodSca
200 :     val name=String.concat["_Pow",Int.toString(n),"_"]
201 :     val (setD,vD,D)=assignOP(setC,opp,[nU,vC],name,intTy)
202 :     in
203 :     (setD,vD,C@D)
204 :     end
205 :     end
206 :    
207 :     val (setA,vA,A)=pow(nn,opset)
208 :     in
209 :     (setA,vA,code@A)
210 :     end
211 :    
212 :     fun mkOp1(E.PowInt n,e) = mkPowInt(e,n)
213 :     | mkOp1(t,e)=let
214 :     val opp=(case t
215 :     of E.Cosine => DstOp.Cosine
216 :     | E.ArcCosine => DstOp.ArcCosine
217 :     | E.Sine => DstOp.Sine
218 :     | E.ArcSine => DstOp.ArcSine
219 :     | E.Tangent => DstOp.Tangent
220 :     | E.ArcTangent => DstOp.ArcTangent
221 :     | E.Sqrt => DstOp.Sqrt
222 :     | E.Exp => DstOp.Exp
223 :     (*end case*))
224 :     in mkSingle(opp,"_op1_",e) end
225 :    
226 :     (*mkMultiple:string*Var list*LowOps.Op *ListIL.Ty -> Var*code list
227 :     *apply rator between each items on list1
228 :     *)
229 :     fun mkMultiple(opsetM,list1,rator,ty)=let
230 :     fun add(opset,[],_,_) = err"no element in mkMultiple"
231 :     | add(opset,[e1],_,_) = (opset,e1,[])
232 :     | add(opset,[e1,e2],code,_) = let
233 :     val (opsetA,vA,A)=assignOP(opset,rator,[e1,e2],"mult_2",ty)
234 :     in (opsetA,vA,code@A) end
235 :     | add(opset,e1::e2::es,code,count) = let
236 :     val (opsetA,vA,A)=assignOP(opset,rator,[e1,e2],String.concat["mult_",iTos count],ty)
237 :     in add(opsetA,vA::es,code@A,count-1)
238 :     end
239 :     in
240 :     add(opsetM,list1,[],List.length list1)
241 :     end
242 :    
243 :    
244 :     (* *************************** DstOp. Greek **************************** *)
245 :     (* deltaToInt:dict*E.mu*E.mu->int
246 :     * delta function
247 :     *)
248 :     fun deltaToInt(mapp,a,b)= let
249 :     val i=mapIndex(a,mapp)
250 :     val j=mapIndex(b,mapp)
251 :     in if(i=j) then 1 else 0 end
252 :    
253 :     fun evalDelta(opset,mapp,a,b)= intToReal(opset,deltaToInt(mapp,a,b))
254 :    
255 :     (*eval Epsilon-2d*)
256 :     fun evalEps2(opset,mapp,a,b)=let
257 :     val i=mapIndex(E.V a,mapp)
258 :     val j=mapIndex(E.V b,mapp)
259 :     in if(i=j) then intToReal(opset,0)
260 :     else
261 :     if(j>i) then intToReal(opset,1)
262 :     else intToReal(opset, ~1)
263 :     end
264 :    
265 :     (*eval Epsilon-3d*)
266 :     fun evalEps3(opset,mapp,a,b,c)=let
267 :     val i=mapIndex(E.V a,mapp)
268 :     val j=mapIndex(E.V b,mapp)
269 :     val k=mapIndex(E.V c,mapp)
270 :     in
271 :     if(i=j orelse j=k orelse i=k) then intToReal(opset, 0)
272 :     else if(j>i)
273 :     then if(j>k andalso k>i) then intToReal (opset, ~1) else intToReal(opset, 1)
274 :     else if(i>k andalso k>j) then intToReal(opset, 1) else intToReal(opset, ~1)
275 :     end
276 :     fun evalG(setG,mapp,b)=(case b
277 :     of E.Epsilon(i,j,k) => evalEps3(setG,mapp,i,j,k)
278 :     | E.Eps2(i,j) => evalEps2(setG,mapp,i,j)
279 :     | E.Delta(i,j) => evalDelta(setG,mapp,i,j)
280 :     (*end case*))
281 :     end
282 :    
283 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0