Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/ein16/synth/d2/obj_ty.py
ViewVC logotype

Annotation of /branches/ein16/synth/d2/obj_ty.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3946 - (view) (download) (as text)

1 : cchiw 3915 # -*- coding: utf-8 -*-
2 :    
3 :     from __future__ import unicode_literals
4 :    
5 : cchiw 3939 #tensor types
6 :     class tty:
7 :     def __init__(self, id, name, shape):
8 :     self.id=id
9 :     self.name=name
10 :     self.shape=shape
11 :     def toStr(self):
12 :     return ("tensor"+str(self.name))
13 :     def isEq_id(a,b):
14 :     return (a.id==b.id)
15 :    
16 :     # field types
17 : cchiw 3915 class fty:
18 : cchiw 3939 def __init__(self, id, name, dim, shape, tensorType,k):
19 : cchiw 3915 self.id=id
20 :     self.name=name
21 :     self.dim=dim
22 :     self.shape=shape
23 :     self.tensorType=tensorType # if we probed a field with this field type
24 : cchiw 3939 self.k=k
25 : cchiw 3915 def toStr(self):
26 : cchiw 3939 if(self.dim==0):
27 :     return "tensor "
28 :     else:
29 :     return ("field #"+str(self.k)+"("+str(self.dim)+")"+str(self.shape))
30 :     def get_tensorType(self):
31 : cchiw 3915 return self.tensorType
32 : cchiw 3939 def get_dim(self):
33 :     return self.dim
34 :     # get vector length
35 :     def get_shape(ty0):
36 :     return ty0.shape
37 : cchiw 3915
38 : cchiw 3939 def get_vecLength(ty0):
39 :     shape = ty0.shape
40 :     if (len(shape)==1):
41 :     return shape[0] # vector length
42 :     else:
43 :     raise "unsupported get_vecLength types"
44 :     def is_Field(self):
45 :     return (self.dim>0)
46 :     def is_ScalarField(self):
47 :     return ((len(self.shape)==0) and fty.is_Field(self))
48 :     def is_VectorField(self):
49 :     return ((len(self.shape)==1) and fty.is_Field(self))
50 :     def is_Scalar(self):
51 :     return (len(self.shape)==0)
52 :     def is_Vector(self):
53 :     return (len(self.shape)==1)
54 :     def is_Matrix(self):
55 :     return (len(self.shape)==2)
56 :     def is_Ten3(self):
57 :     return (len(self.shape)==3)
58 :     def get_first_ix(self):
59 :     return self.shape[0]
60 :     def get_last_ix(self):
61 :     return self.shape[len(self.shape)-1]
62 :     def drop_first(self):
63 :     rtn = []
64 :     for i in range(len(self.shape)-1):
65 :     rtn.append(self.shape[i])
66 :     return rtn
67 :     def drop_last(self):
68 :     rtn = []
69 :     for i in range(len(self.shape)-1):
70 :     rtn.append(self.shape[i+1])
71 :     return rtn
72 : cchiw 3915
73 :     #compares finfo with fty constant
74 :     def isEq_id(a,b):
75 : cchiw 3939 return (a.id==b.id)
76 : cchiw 3915 #string for diderot program
77 : cchiw 3939 def toDiderot(self):
78 :     if(self.dim==0):
79 :     return "tensor "+str(self.shape)
80 :     else:
81 :     return "field#"+str(self.k)+"("+str(self.dim)+")"+str(self.shape)
82 :     #creates ty object
83 :     def convertTy(const,k):
84 :     return fty(const.id,const.name, const.dim, const.shape, const.tensorType, k)
85 : cchiw 3915
86 : cchiw 3939
87 : cchiw 3915 # ------------------------------ type name to other properties ------------------------------
88 :     # shorthand used to refer to different types
89 :     # the helper functions match shorthand to other properties and creates ty object
90 :    
91 :     #tensors
92 : cchiw 3939 ty_noneT = tty(0,"none",[])
93 :     ty_scalarT = tty(1,"scalartensor",[])
94 :     ty_vec2T = tty(2,"vec2tensor",[2])
95 :     ty_vec3T = tty(3,"vec3tensor",[3])
96 :     ty_mat2x2T = tty(4,"mat2x2tensor",[2,2])
97 :     ty_mat3x3T = tty(5,"mat3x3tensor",[3,3])
98 :     ty_ten2x2x2T = tty(11,"ten2x2x2tensor",[2,2,2])
99 :     ty_ten3x3x3T = tty(12,"ten3x3x3tensor",[3,3,3])
100 : cchiw 3915
101 :    
102 : cchiw 3939 #lift tensor to field level
103 :     ty_scalarFT = fty(0,"scalartensor",0, [],ty_noneT,-1)
104 :     ty_vec2FT = fty(1,"vec2tensor",0, [2],ty_noneT,-1)
105 :     ty_vec3FT = fty(2,"vec3tensor",0,[3],ty_noneT,-1)
106 :     ty_mat2x2FT = fty(3,"mat2x2tensor",0,[2,2],ty_noneT,-1)
107 :     ty_mat3x3FT = fty(4,"mat3x3tensor",0,[3,3],ty_noneT,-1)
108 :     ty_mat2x2x2FT = fty(5,"mat2x2x2tensor",0,[2,2,2],ty_noneT,-1)
109 :     ty_mat3x3x3FT = fty(6,"mat3x3tx3ensor",0,[3,3,3],ty_noneT,-1)
110 : cchiw 3915
111 : cchiw 3939 def get_Tshape(ty):
112 :     return ty.shape
113 : cchiw 3915
114 : cchiw 3939 #fields: #id,name, dim, shape in string form,probe field type returns tensor type
115 :     k_init=0#null k
116 :     ty_scalarF_d2 = fty(7,"scalarfield_d2", 2, [], ty_scalarT,k_init)
117 :     ty_vec2F_d2 = fty(8,"vec2field_d2", 2, [2], ty_vec2T,k_init)
118 :     ty_mat2x2F_d2 = fty(9,"mat2x2field_d2", 2, [2,2], ty_mat2x2T,k_init)
119 :     ty_ten2x2x2F_d2 = fty(10,"mat2x2x2field_d2", 2, [2,2,2], ty_ten2x2x2T,k_init)
120 :     ty_scalarF_d3 = fty(11,"scalarfield_d3", 3, [], ty_scalarT,k_init)
121 :     ty_vec3F_d3 = fty(12,"vec3field_d3", 3, [3], ty_vec3T,k_init)
122 :     ty_mat3x3F_d3 = fty(13,"mat3x3field_d3", 3, [3,3], ty_mat3x3T,k_init)
123 :     ty_ten3x3x3F_d3 = fty(14,"mat3x3x3field_d3", 3, [3,3,3], ty_ten3x3x3T,k_init)
124 : cchiw 3915
125 : cchiw 3939 ty_vec3F_d2 = fty(15,"vec3field_d2", 2, [3], ty_vec3T,k_init)
126 :     ty_mat3x3F_d2 = fty(16,"mat3x3field_d2", 2, [3,3], ty_mat3x3T,k_init)
127 :     ty_ten3x3x3F_d2 = fty(17,"ten3x3x3field_d2", 2, [3,3,3], ty_ten3x3x3T,k_init)
128 : cchiw 3915
129 : cchiw 3939 ty_vec2F_d3 = fty(18,"vec2field_d3", 3, [2], ty_vec2T,k_init)
130 :     ty_mat2x2F_d3 = fty(19,"mat2x2field_d3", 3, [2,2], ty_mat2x2T,k_init)
131 :     ty_ten2x2x2F_d3 = fty(20,"ten2x2x2field_d3", 3, [2,2,2], ty_ten2x2x2T,k_init)
132 :    
133 :     # check equal dim
134 :     def check_dim(fld,b):
135 :     if(fty.is_Field(b)):
136 :     dim2=b.dim
137 :     return (fld.dim==dim2)
138 :    
139 :     #list of vector fields
140 :     def get_vecF():
141 :     rtn = []
142 :     # binary operator
143 :     for f in l_all_F:
144 :     if(fty.is_Vector(f)):
145 :     rtn.append(f)
146 :     return rtn
147 :     def get_scaF():
148 :     rtn = []
149 :     # binary operator
150 :     for f in l_all_F:
151 :     if(fty.is_Scalar(f)):
152 :     rtn.append(f)
153 :     return rtn
154 :    
155 :    
156 :     # types for multiplication
157 :     def get_mul():
158 :     rtn = []
159 :     # binary operator
160 :     for sf in l_all_F:
161 :     for a in l_all:
162 :     # check equal dim
163 :     if(fty.is_Field(a) and (not check_dim(sf, a))):
164 :     continue
165 :     # one arg needs to be a scalar
166 :     elif (fty.is_Scalar(sf) or fty.is_Scalar(a)):
167 :     #print ("["+sf.name+","+a.name+"]")
168 :     rtn.append([sf,a])
169 :     return rtn
170 :    
171 :     #list of types
172 :     #fields that we can create data
173 :     l_all_F= [ty_scalarF_d2, ty_vec2F_d2, ty_scalarF_d3, ty_vec3F_d3]
174 :     l_all_FT = [ty_scalarFT, ty_vec2FT, ty_vec3FT, ty_mat2x2FT, ty_mat3x3FT, ty_mat2x2x2FT, ty_mat3x3x3FT]
175 :     l_all= l_all_F + l_all_FT
176 :     # list of fields by type
177 :     vectorFlds = get_vecF()
178 :     scalarFlds = get_scaF()
179 :    
180 :    
181 :     #binary operator so two args
182 :     def find_field(ty1, ty2):
183 :     dim1=ty1.dim
184 :     dim2=ty2.dim
185 :     if (dim1==0): # tensors
186 : cchiw 3946 return (True , ty2)
187 : cchiw 3939 elif(dim2==0):# tensors
188 : cchiw 3946 return (True , ty1)
189 : cchiw 3939 elif(dim1==dim2):
190 : cchiw 3946 return (True , ty1)
191 : cchiw 3939 else :
192 : cchiw 3946 return (False, None)
193 : cchiw 3939
194 :     #shape to type
195 :     def shapeToTy(shapeout, dim):
196 :     if (dim==2):
197 :     if (shapeout==[]):
198 :     return ty_scalarF_d2
199 :     elif (shapeout==[2]):
200 :     return ty_vec2F_d2
201 :     elif(shapeout==[3]):
202 :     return ty_vec3F_d2
203 :     elif (shapeout==[2,2]):
204 :     return ty_mat2x2F_d2
205 :     elif (shapeout==[3,3]):
206 :     return ty_mat3x3F_d2
207 :     elif (shapeout==[2,2,2]):
208 :     return ty_ten2x2x2F_d2
209 :     elif(shapeout==[3,3, 3]):
210 :     return ty_ten3x3x3F_d2
211 :     else:
212 : cchiw 3946 #print "shapeout",shapeout,"dim", dim
213 :     raise Exception ("unsupported shapeout", str(shapeout))
214 : cchiw 3939 elif (dim==3):
215 :     if (shapeout==[]):
216 :     return ty_scalarF_d3
217 :     elif(shapeout==[3]):
218 :     return ty_vec3F_d3
219 :     elif(shapeout==[3,3]):
220 :     return ty_mat3x3F_d3
221 : cchiw 3946
222 : cchiw 3939 elif(shapeout==[3,3, 3]):
223 :     return ty_ten3x3x3F_d3
224 :     elif(shapeout==[2]):
225 :     return ty_vec2F_d3
226 :     elif(shapeout==[2,2]):
227 :     return ty_mat2x2F_d3
228 :     elif(shapeout==[2,2,2]):
229 :     return ty_ten2x2x2F_d3
230 :     else:
231 :     raise "unsupported shapeout"
232 : cchiw 3915 else:
233 : cchiw 3939 raise "unsupported dim"
234 : cchiw 3915
235 : cchiw 3939 #concat two types to form a new type
236 :     def concatTys(ty1, ty2):
237 :     if (fty.is_Vector(ty1)):
238 :     n1 = fty.get_vecLength(ty1)
239 :     if (fty.is_Vector(ty2)):
240 :     n2 = fty.get_vecLength(ty2)
241 :     fldty = find_field(ty1, ty2)
242 :     k = fldty.k
243 :     if (n1==2):
244 :     if(n2==2):
245 :     return fty.convertTy(ty_mat2x2F_d2, k)
246 :     else:
247 :     raise "unsupported concat types"
248 :     elif (n1==3):
249 :     if (n2==3):
250 :     return fty.convertTy(ty_mat3x3F_d3, k)
251 :     else:
252 :     raise "unsupported concat types"
253 :     # elif (fty.is_Matrix(ty2)):
254 :     else:
255 :     raise "unsupported concat types"
256 : cchiw 3915 else:
257 : cchiw 3939 raise "unsupported concat types"
258 : cchiw 3915
259 : cchiw 3939 #reduce shape of fields
260 :     def reduceIndex(ty0):
261 :     # keep current k value
262 :     k=ty0.k
263 :     if (fty.isEq_id(ty0, ty_vec2F_d2)):
264 :     return fty.convertTy(ty_scalarF_d2,k)
265 :     elif (fty.isEq_id(ty0, ty_vec3F_d3)):
266 :     return fty.convertTy(ty_scalarF_d3,k)
267 : cchiw 3915 else:
268 : cchiw 3939 raise Exception ("unsupported field shape:"+ty0.name)
269 : cchiw 3915

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0