Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/helper.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/helper.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2867 - (view) (download)

1 : cchiw 2859 (*Helper functions
2 :     *)
3 :     structure Helper = struct
4 :     local
5 :    
6 :     structure DstIL = LowIL
7 :     structure DstTy = LowILTypes
8 :     structure DstOp = LowOps
9 :     structure Var = LowIL.Var
10 :     structure SrcIL = MidIL
11 :     structure SrcOp = MidOps
12 :     structure E = Ein
13 :     structure LowToS= LowToString
14 :    
15 :     in
16 :    
17 :     val testing=0
18 :     val bV= ref 0
19 :     fun err str=raise Fail(str)
20 :     val realTy=DstTy.TensorTy []
21 :     val intTy=DstTy.intTy
22 :     fun iTos e1=Int.toString e1
23 :     fun iToss es=String.concat(List.map iTos es)
24 :     fun lookup k d = d k
25 :     fun testp n= (case testing
26 :     of 0=> 0
27 :     | _ =>((print (String.concat n));1)
28 :     (*end case*))
29 :    
30 :     fun insert (key, value) d =fn s =>
31 :     if s = key then SOME value
32 :     else d s
33 :    
34 :     fun find(v, mapp)=(case (lookup v mapp)
35 :     of NONE=> raise Fail ("Outside Bound("^Int.toString(v)^")")
36 :     |SOME s => s
37 :     (*end case*))
38 :    
39 :     (*mapIndex:E.mu * dict-> int
40 :     * lookup
41 :     *)
42 :     fun mapIndex(e1,mapp)=(case e1
43 :     of E.V e => find(e,mapp)
44 :     | E.C c=> c
45 :     (*end case*))
46 :    
47 :    
48 :     (* mkReal:int->Var*LowIL.ASSGN list
49 :     *)
50 :     fun mkReal n=let
51 :     val a=DstIL.Var.new("Real" ,realTy)
52 :     val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
53 :     val _ =testp[LowToS.toStringAll(realTy,code)]
54 :     in
55 :     (a,[code])
56 :     end
57 :    
58 :     (* mkINt:int->Var*LowIL.ASSGN list
59 :     *)
60 :     fun mkInt n=let
61 :     val a=DstIL.Var.new("Int" ,intTy)
62 :     val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int(IntInf.fromInt n)))
63 :     val _ =testp[LowToS.toStringAll(intTy,code)]
64 :     in
65 :     (a,[code])
66 :     end
67 :    
68 :     (*assgnCons int list * Var list->Var*LowIL.ASSGN list
69 :     * cons elements on list
70 :     *)
71 :     fun assgnCons(pre,shape, args)=let
72 :     val ty=DstTy.TensorTy shape
73 :     val a=DstIL.Var.new("cons"^"_"^pre^"_" ,ty)
74 :     val code=DstIL.ASSGN (a,DstIL.CONS(ty ,args))
75 :     val _ =testp[LowToS.toStringAll(ty,code)]
76 :     in
77 :     (a, [code])
78 :     end
79 :    
80 :     (*LowOps.Op.op * var list*string*LowIL.Ty
81 :     * -> Var*LowIL.ASSGN list
82 :     * Make lowIL assignment
83 :     *)
84 :     fun assgn(opss,args,pre,ty)=let
85 :     val a=DstIL.Var.new(pre ,ty)
86 :     val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
87 :     val _ =testp[LowToS.toStringAll(ty,code)]
88 :     in
89 :     (a,[code])
90 :     end
91 :    
92 :     (*getTensorTy:E.params*E.tensor_id-> LowIL.Ty
93 :     * Integer, or Generic Tensor
94 :     *)
95 :     fun getTensorTy(params, id)=(case List.nth(params,id)
96 :     of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
97 :     | E.TEN(_,shape)=> DstTy.TensorTy shape
98 :     |_=> err"NONE Tensor Param"
99 :     (*end case*))
100 :    
101 :     (* indexTensor:dict*string*E.params*Var list*E.tensor_id*E.alpha
102 :     * ->Var*LowIL.ASSGN list
103 :     * Index Tensor at specific indices to give a scalar result
104 :     *)
105 :     fun indexTensor(_,(_,_,args,id, [],ty)) = (List.nth(args,id),[])
106 :     | indexTensor(mapp,(lhs,params,args,id,ix,ty))= let
107 :     val nU=List.nth(args,id)
108 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
109 :     val ix'=DstTy.indexTy ixx
110 :     val argTy=getTensorTy(params,id)
111 :     val opp=DstOp.IndexTensor(id,ix',argTy)
112 :     val name=String.concat[Var.name nU,"_",iToss ixx,"_"]
113 :     in
114 :     assgn(opp,[nU],name,ty)
115 :     end
116 :    
117 :     (*projTensor:dict*(string* E.params*Var list*int*E.tensor_id*E.alpha)->Var*LowIL.ASSGN list
118 :     * projects tensor to a vector
119 :     *just used by EintoVecOps but made sense to keep it here
120 :     *)
121 :     fun projTensor(_,(_,_,args,_,id,[]))= (List.nth(args,id),[])
122 :     | projTensor(mapp,(lhs,params,args,vecIX,id,ix))= let
123 :     val nU=List.nth(args,id)
124 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
125 :     val ix'=DstTy.indexTy ixx
126 :     val argTy= getTensorTy(params,id)
127 :     val vecTy=DstTy.TensorTy [vecIX]
128 :     val opp=DstOp.ProjectTensor(id,vecIX,ix',argTy)
129 :     val name=String.concat[Var.name nU,"V_",iToss ixx,"_"]
130 :     in
131 :     assgn(opp,[nU],name,vecTy)
132 :     end
133 :    
134 : cchiw 2867
135 :     fun mkSqrt(nU,code)= let
136 :     val opp=DstOp.Sqrt
137 :     val name=String.concat[Var.name nU,"_Sqrt_"]
138 :     val (vA,A)=assgn(opp,[nU],name,realTy)
139 :     in
140 :     (vA,code@A)
141 :     end
142 :    
143 :    
144 :    
145 :    
146 : cchiw 2859 (* Some shortcuts. Arguements are Low-IL variables already indexed/projected
147 :     * string*Var list ->Var*LowIL.ASSGN list
148 :     *)
149 :     fun mkAddSca(lhs,args)= assgn(DstOp.addSca,args,lhs^"addSca",realTy)
150 :     fun mkAddInt(lhs,args)= assgn(DstOp.addSca,args,lhs^"addInt",intTy)
151 :     fun mkAddPtr(lhs,args,ty)= assgn(DstOp.addSca,args,lhs^"addPtr",ty)
152 :     fun mkAddVec(lhs,vecIX,args)=assgn(DstOp.addVec vecIX,args,lhs^"addV",DstTy.TensorTy([vecIX]))
153 :     fun mkSubSca(lhs,args)= assgn(DstOp.subSca,args,lhs^"subSca",realTy)
154 :     fun mkProdSca(lhs,args)=assgn(DstOp.prodSca,args,lhs^"prodSca",realTy)
155 :     fun mkProdInt(lhs,args)=assgn(DstOp.prodSca,args,lhs^"prodInt",intTy)
156 :     fun mkProdVec(lhs,vecIX,args)=assgn(DstOp.prodVec vecIX,args,lhs^"prodV",DstTy.TensorTy([vecIX]))
157 :     fun mkDivSca(lhs,args)= assgn(DstOp.divSca,args,lhs^"divSca",realTy)
158 :     fun mkSumVec(lhs,vecIX,args)= assgn(DstOp.sumVec vecIX,args,lhs^"sumVec",realTy)
159 :     fun mkDotVec(lhs,vecIX,args)=let
160 :     val (vD, D)=mkProdVec(lhs,vecIX,args)
161 :     val (vE, E)=mkSumVec(lhs,vecIX,[vD])
162 :     in (vE,D@E) end
163 :    
164 :    
165 :     (*mkMultiple:string*Var list*LowOps.Op *ListIL.Ty -> Var*LowIL.ASSGN list
166 :     *apply rator between each items on list1
167 :     *)
168 :     fun mkMultiple(lhs,list1,rator,ty)=let
169 :     fun add([],_,_) = err"no element in mkMultiple"
170 :     | add([e1],_,_) = (e1,[])
171 :     | add([e1,e2],code,_) = let
172 :     val (vA,A)=assgn(rator,[e1,e2],lhs^"_2",ty)
173 :     in (vA,code@A) end
174 :     | add(e1::e2::es,code,count) = let
175 :     val (vA,A)=assgn(rator,[e1,e2],String.concat[lhs,"_",iTos count],ty)
176 :     in add(vA::es,code@A,count-1)
177 :     end
178 :     in
179 :     add(list1,[],List.length list1)
180 :     end
181 :    
182 :     (* deltaToInt:dict*E.mu*E.mu->int
183 :     * delta function
184 :     *)
185 :     fun deltaToInt(mapp,a,b)= let
186 :     val i=mapIndex(a,mapp)
187 :     val j=mapIndex(b,mapp)
188 :     in
189 :     if(i=j) then 1 else 0
190 :     end
191 :    
192 :     fun evalDelta(mapp,a,b)= mkReal(deltaToInt(mapp,a,b))
193 :    
194 :     (*eval Epsilon-2d*)
195 :     fun evalEps2(mapp,a,b)=let
196 :     val i=mapIndex(E.V a,mapp)
197 :     val j=mapIndex(E.V b,mapp)
198 :     in
199 :     if(i=j) then mkReal 0
200 :     else
201 :     if(j>i) then mkReal 1
202 :     else mkReal ~1
203 :     end
204 :    
205 :     (*eval Epsilon-3d*)
206 :     fun evalEps3(mapp,a,b,c)=let
207 :     val i=mapIndex(E.V a,mapp)
208 :     val j=mapIndex(E.V b,mapp)
209 :     val k=mapIndex(E.V c,mapp)
210 :     in
211 :     if(i=j orelse j=k orelse i=k) then mkReal 0
212 :     else
213 :     if(j>i) then
214 :     if(j>k andalso k>i) then mkReal ~1 else mkReal 1
215 :     else if(i>k andalso k>j) then mkReal 1 else mkReal ~1
216 :    
217 :     end
218 :    
219 :     end
220 :    
221 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0