Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/helper.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/helper.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2870 - (view) (download)

1 : cchiw 2859 (*Helper functions
2 :     *)
3 :     structure Helper = struct
4 :     local
5 :    
6 :     structure DstIL = LowIL
7 :     structure DstTy = LowILTypes
8 :     structure DstOp = LowOps
9 :     structure Var = LowIL.Var
10 :     structure SrcIL = MidIL
11 :     structure SrcOp = MidOps
12 :     structure E = Ein
13 :     structure LowToS= LowToString
14 : cchiw 2870 structure FL=FloatLit
15 : cchiw 2859
16 :     in
17 :    
18 :     val testing=0
19 :     val bV= ref 0
20 :     fun err str=raise Fail(str)
21 :     val realTy=DstTy.TensorTy []
22 :     val intTy=DstTy.intTy
23 :     fun iTos e1=Int.toString e1
24 :     fun iToss es=String.concat(List.map iTos es)
25 :     fun lookup k d = d k
26 :     fun testp n= (case testing
27 :     of 0=> 0
28 :     | _ =>((print (String.concat n));1)
29 :     (*end case*))
30 :    
31 :     fun insert (key, value) d =fn s =>
32 :     if s = key then SOME value
33 :     else d s
34 :    
35 :     fun find(v, mapp)=(case (lookup v mapp)
36 :     of NONE=> raise Fail ("Outside Bound("^Int.toString(v)^")")
37 :     |SOME s => s
38 :     (*end case*))
39 :    
40 :     (*mapIndex:E.mu * dict-> int
41 :     * lookup
42 :     *)
43 :     fun mapIndex(e1,mapp)=(case e1
44 :     of E.V e => find(e,mapp)
45 :     | E.C c=> c
46 :     (*end case*))
47 :    
48 :    
49 : cchiw 2870 (* intToReal:int->Var*LowIL.ASSGN list
50 : cchiw 2859 *)
51 : cchiw 2870 fun intToReal n=let
52 :     val a=DstIL.Var.new("real" ,realTy)
53 :     val b=DstIL.Var.new("cast" ,realTy)
54 :     val code=[DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n))),
55 :     DstIL.ASSGN (b,DstIL.OP(DstOp.IntToReal,[a]))]
56 : cchiw 2859 in
57 : cchiw 2870 (b,code)
58 : cchiw 2859 end
59 :    
60 : cchiw 2870
61 :    
62 : cchiw 2859 (* mkINt:int->Var*LowIL.ASSGN list
63 :     *)
64 :     fun mkInt n=let
65 :     val a=DstIL.Var.new("Int" ,intTy)
66 :     val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
67 :     val _ =testp[LowToS.toStringAll(intTy,code)]
68 :     in
69 :     (a,[code])
70 :     end
71 :    
72 :     (*assgnCons int list * Var list->Var*LowIL.ASSGN list
73 :     * cons elements on list
74 :     *)
75 :     fun assgnCons(pre,shape, args)=let
76 :     val ty=DstTy.TensorTy shape
77 : cchiw 2870 val a=DstIL.Var.new("cons"^"_" ,ty)
78 : cchiw 2859 val code=DstIL.ASSGN (a,DstIL.CONS(ty ,args))
79 :     val _ =testp[LowToS.toStringAll(ty,code)]
80 :     in
81 :     (a, [code])
82 :     end
83 :    
84 :     (*LowOps.Op.op * var list*string*LowIL.Ty
85 :     * -> Var*LowIL.ASSGN list
86 :     * Make lowIL assignment
87 :     *)
88 :     fun assgn(opss,args,pre,ty)=let
89 :     val a=DstIL.Var.new(pre ,ty)
90 :     val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
91 :     val _ =testp[LowToS.toStringAll(ty,code)]
92 :     in
93 :     (a,[code])
94 :     end
95 :    
96 :     (*getTensorTy:E.params*E.tensor_id-> LowIL.Ty
97 :     * Integer, or Generic Tensor
98 :     *)
99 :     fun getTensorTy(params, id)=(case List.nth(params,id)
100 :     of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
101 :     | E.TEN(_,shape)=> DstTy.TensorTy shape
102 :     |_=> err"NONE Tensor Param"
103 :     (*end case*))
104 :    
105 :     (* indexTensor:dict*string*E.params*Var list*E.tensor_id*E.alpha
106 :     * ->Var*LowIL.ASSGN list
107 :     * Index Tensor at specific indices to give a scalar result
108 :     *)
109 :     fun indexTensor(_,(_,_,args,id, [],ty)) = (List.nth(args,id),[])
110 :     | indexTensor(mapp,(lhs,params,args,id,ix,ty))= let
111 :     val nU=List.nth(args,id)
112 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
113 :     val ix'=DstTy.indexTy ixx
114 :     val argTy=getTensorTy(params,id)
115 :     val opp=DstOp.IndexTensor(id,ix',argTy)
116 : cchiw 2870 val name=String.concat["Indx_",iToss ixx,"_"]
117 : cchiw 2859 in
118 :     assgn(opp,[nU],name,ty)
119 :     end
120 :    
121 :     (*projTensor:dict*(string* E.params*Var list*int*E.tensor_id*E.alpha)->Var*LowIL.ASSGN list
122 :     * projects tensor to a vector
123 :     *just used by EintoVecOps but made sense to keep it here
124 :     *)
125 :     fun projTensor(_,(_,_,args,_,id,[]))= (List.nth(args,id),[])
126 :     | projTensor(mapp,(lhs,params,args,vecIX,id,ix))= let
127 :     val nU=List.nth(args,id)
128 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
129 :     val ix'=DstTy.indexTy ixx
130 :     val argTy= getTensorTy(params,id)
131 :     val vecTy=DstTy.TensorTy [vecIX]
132 :     val opp=DstOp.ProjectTensor(id,vecIX,ix',argTy)
133 : cchiw 2870 val name=String.concat["Proj_",iToss ixx,"_"]
134 : cchiw 2859 in
135 :     assgn(opp,[nU],name,vecTy)
136 :     end
137 :    
138 : cchiw 2867
139 :     fun mkSqrt(nU,code)= let
140 :     val opp=DstOp.Sqrt
141 : cchiw 2870 val name=String.concat["_Sqrt_"]
142 : cchiw 2867 val (vA,A)=assgn(opp,[nU],name,realTy)
143 :     in
144 :     (vA,code@A)
145 :     end
146 :    
147 :    
148 : cchiw 2870 fun mkPowInt((nU,code),n)= let
149 :     val opp=DstOp.powInt
150 :     val name=String.concat["_PowInt_"]
151 :     val (r,rcode)=mkInt n
152 :     val (vA,A)=assgn(opp,[nU,r],name,realTy)
153 :     in
154 :     (vA,code@rcode@A)
155 :     end
156 : cchiw 2867
157 : cchiw 2870 fun mkPowRat((nU,code),rat)= let
158 :     val opp=DstOp.powRat(DstTy.R rat)
159 :     val name=String.concat["_PowRat_"]
160 :    
161 :     val (vA,A)=assgn(opp,[nU],name,realTy)
162 :     in
163 :     (vA,code@A)
164 :     end
165 :    
166 :    
167 :    
168 :    
169 :    
170 :    
171 : cchiw 2859 (* Some shortcuts. Arguements are Low-IL variables already indexed/projected
172 :     * string*Var list ->Var*LowIL.ASSGN list
173 :     *)
174 : cchiw 2870 fun mkAddSca(lhs,args)= assgn(DstOp.addSca,args,"addSca",realTy)
175 :     fun mkAddInt(lhs,args)= assgn(DstOp.addSca,args,"addInt",intTy)
176 :     fun mkAddPtr(lhs,args,ty)= assgn(DstOp.addSca,args,"addPtr",ty)
177 :     fun mkAddVec(lhs,vecIX,args)=assgn(DstOp.addVec vecIX,args,"addV",DstTy.TensorTy([vecIX]))
178 :     fun mkSubSca(lhs,args)= assgn(DstOp.subSca,args,"subSca",realTy)
179 :     fun mkProdSca(lhs,args)=assgn(DstOp.prodSca,args,"prodSca",realTy)
180 :     fun mkProdInt(lhs,args)=assgn(DstOp.prodSca,args,"prodInt",intTy)
181 :     fun mkProdVec(lhs,vecIX,args)=assgn(DstOp.prodVec vecIX,args,"prodV",DstTy.TensorTy([vecIX]))
182 :     fun mkDivSca(lhs,args)= assgn(DstOp.divSca,args,"divSca",realTy)
183 :     fun mkSumVec(lhs,vecIX,args)= assgn(DstOp.sumVec vecIX,args,"sumVec",realTy)
184 : cchiw 2859 fun mkDotVec(lhs,vecIX,args)=let
185 : cchiw 2870 val (vD, D)=mkProdVec("",vecIX,args)
186 :     val (vE, E)=mkSumVec("",vecIX,[vD])
187 : cchiw 2859 in (vE,D@E) end
188 :    
189 :    
190 :     (*mkMultiple:string*Var list*LowOps.Op *ListIL.Ty -> Var*LowIL.ASSGN list
191 :     *apply rator between each items on list1
192 :     *)
193 :     fun mkMultiple(lhs,list1,rator,ty)=let
194 :     fun add([],_,_) = err"no element in mkMultiple"
195 :     | add([e1],_,_) = (e1,[])
196 :     | add([e1,e2],code,_) = let
197 : cchiw 2870 val (vA,A)=assgn(rator,[e1,e2],"mult_2",ty)
198 : cchiw 2859 in (vA,code@A) end
199 :     | add(e1::e2::es,code,count) = let
200 : cchiw 2870 val (vA,A)=assgn(rator,[e1,e2],String.concat["mult_",iTos count],ty)
201 : cchiw 2859 in add(vA::es,code@A,count-1)
202 :     end
203 :     in
204 :     add(list1,[],List.length list1)
205 :     end
206 :    
207 :     (* deltaToInt:dict*E.mu*E.mu->int
208 :     * delta function
209 :     *)
210 :     fun deltaToInt(mapp,a,b)= let
211 :     val i=mapIndex(a,mapp)
212 :     val j=mapIndex(b,mapp)
213 :     in
214 :     if(i=j) then 1 else 0
215 :     end
216 :    
217 : cchiw 2870 fun evalDelta(mapp,a,b)= intToReal(deltaToInt(mapp,a,b))
218 : cchiw 2859
219 :     (*eval Epsilon-2d*)
220 :     fun evalEps2(mapp,a,b)=let
221 :     val i=mapIndex(E.V a,mapp)
222 :     val j=mapIndex(E.V b,mapp)
223 :     in
224 : cchiw 2870 if(i=j) then intToReal 0
225 : cchiw 2859 else
226 : cchiw 2870 if(j>i) then intToReal 1
227 :     else intToReal ~1
228 : cchiw 2859 end
229 :    
230 :     (*eval Epsilon-3d*)
231 :     fun evalEps3(mapp,a,b,c)=let
232 :     val i=mapIndex(E.V a,mapp)
233 :     val j=mapIndex(E.V b,mapp)
234 :     val k=mapIndex(E.V c,mapp)
235 :     in
236 : cchiw 2870 if(i=j orelse j=k orelse i=k) then intToReal 0
237 : cchiw 2859 else
238 :     if(j>i) then
239 : cchiw 2870 if(j>k andalso k>i) then intToReal ~1 else intToReal 1
240 :     else if(i>k andalso k>j) then intToReal 1 else intToReal ~1
241 : cchiw 2859
242 :     end
243 :    
244 :     end
245 :    
246 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0