Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/ein16/synth/d2/obj_ty.py
ViewVC logotype

Annotation of /branches/ein16/synth/d2/obj_ty.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4411 - (view) (download) (as text)

1 : cchiw 3915 # -*- coding: utf-8 -*-
2 :    
3 :     from __future__ import unicode_literals
4 :    
5 : cchiw 4308 nonefield_k= -1
6 :     nonefield_dim = 0
7 :    
8 : cchiw 3939 #tensor types
9 :     class tty:
10 :     def __init__(self, id, name, shape):
11 :     self.id=id
12 :     self.name=name
13 :     self.shape=shape
14 :     def toStr(self):
15 :     return ("tensor"+str(self.name))
16 :     def isEq_id(a,b):
17 :     return (a.id==b.id)
18 :    
19 :     # field types
20 : cchiw 3915 class fty:
21 : cchiw 3939 def __init__(self, id, name, dim, shape, tensorType,k):
22 : cchiw 3915 self.id=id
23 :     self.name=name
24 :     self.dim=dim
25 :     self.shape=shape
26 :     self.tensorType=tensorType # if we probed a field with this field type
27 : cchiw 3939 self.k=k
28 : cchiw 3915 def toStr(self):
29 : cchiw 4308 if(self.dim==nonefield_dim):
30 : cchiw 3939 return "tensor "
31 :     else:
32 :     return ("field #"+str(self.k)+"("+str(self.dim)+")"+str(self.shape))
33 :     def get_tensorType(self):
34 : cchiw 3915 return self.tensorType
35 : cchiw 3939 def get_dim(self):
36 :     return self.dim
37 :     # get vector length
38 :     def get_shape(ty0):
39 :     return ty0.shape
40 : cchiw 3915
41 : cchiw 3939 def get_vecLength(ty0):
42 :     shape = ty0.shape
43 :     if (len(shape)==1):
44 :     return shape[0] # vector length
45 :     else:
46 :     raise "unsupported get_vecLength types"
47 :     def is_Field(self):
48 : cchiw 4308 return (not (self.dim==nonefield_dim))
49 : cchiw 4321 def is_Tensor(self):
50 :     return (self.dim==nonefield_dim)
51 : cchiw 3939 def is_ScalarField(self):
52 :     return ((len(self.shape)==0) and fty.is_Field(self))
53 :     def is_VectorField(self):
54 :     return ((len(self.shape)==1) and fty.is_Field(self))
55 : cchiw 4230 def is_MatrixField(self):
56 :     return ((len(self.shape)==2) and fty.is_Field(self))
57 : cchiw 3939 def is_Scalar(self):
58 :     return (len(self.shape)==0)
59 :     def is_Vector(self):
60 :     return (len(self.shape)==1)
61 :     def is_Matrix(self):
62 :     return (len(self.shape)==2)
63 :     def is_Ten3(self):
64 :     return (len(self.shape)==3)
65 :     def get_first_ix(self):
66 :     return self.shape[0]
67 :     def get_last_ix(self):
68 :     return self.shape[len(self.shape)-1]
69 :     def drop_first(self):
70 :     rtn = []
71 :     for i in range(len(self.shape)-1):
72 :     rtn.append(self.shape[i])
73 :     return rtn
74 :     def drop_last(self):
75 :     rtn = []
76 :     for i in range(len(self.shape)-1):
77 :     rtn.append(self.shape[i+1])
78 :     return rtn
79 : cchiw 3915
80 :     #compares finfo with fty constant
81 :     def isEq_id(a,b):
82 : cchiw 3939 return (a.id==b.id)
83 : cchiw 3915 #string for diderot program
84 : cchiw 3939 def toDiderot(self):
85 :     if(self.dim==0):
86 :     return "tensor "+str(self.shape)
87 :     else:
88 :     return "field#"+str(self.k)+"("+str(self.dim)+")"+str(self.shape)
89 :     #creates ty object
90 :     def convertTy(const,k):
91 :     return fty(const.id,const.name, const.dim, const.shape, const.tensorType, k)
92 : cchiw 3915
93 : cchiw 3939
94 : cchiw 3915 # ------------------------------ type name to other properties ------------------------------
95 :     # shorthand used to refer to different types
96 :     # the helper functions match shorthand to other properties and creates ty object
97 :    
98 : cchiw 4210 def tyToStr(s):
99 :     n=len(s)
100 :     if(n==0):
101 : cchiw 4230 return "sc"
102 : cchiw 4210 elif(n==1):
103 :     [v] = s
104 :     return "v"+str(v)
105 :     elif(n==2):
106 :     [v,m] = s
107 :     return "m"+str(v)+"x"+str(m)
108 :     elif(n==3):
109 :     [v,m,l] = s
110 :     return "t"+str(v)+"x"+str(m)+"x"+str(l)
111 :     return "t"
112 :    
113 :     def mkTensor(id, shape):
114 :     #print "id",str(id)
115 : cchiw 4329 name = "t_"+tyToStr(shape)
116 : cchiw 4210 return tty(id, name, shape)
117 :    
118 :     ty_noneT = mkTensor(id,[])
119 : cchiw 4308 # distinctive features of lifted tensors or NoneFields
120 :     # are dim=0 and k=-1
121 :     def mkNoneField(id, outputtensor):
122 : cchiw 4210 #print "id",str(id)
123 : cchiw 4308 shape = outputtensor.shape
124 : cchiw 4329 name = "T_"+tyToStr(shape)
125 : cchiw 4308 return fty(id, name,nonefield_dim, shape, outputtensor, nonefield_k)
126 : cchiw 4210
127 : cchiw 4230 k_init=2#null k
128 :    
129 : cchiw 4210 #fields: #id,name, dim, shape in string form,probe field type returns tensor type
130 :     def mkField(id, dim, outputtensor):
131 :     #print "id",str(id)
132 :     name = "F_"+tyToStr(outputtensor.shape)+"_d"+str(dim)
133 :     return fty(id, name, dim, outputtensor.shape, outputtensor, k_init)
134 :    
135 :     def get_Tshape(ty):
136 :     return ty.shape
137 :    
138 : cchiw 3915 #tensors
139 : cchiw 4158 id=0
140 : cchiw 4247 ty_scalarT = mkTensor(id,[])
141 :     ty_vec1T = mkTensor(id+1,[1])
142 :     ty_vec2T = mkTensor(id+2,[2])
143 :     ty_vec3T = mkTensor(id+3,[3])
144 :     ty_mat1x1T = mkTensor(id+4,[1,1])
145 :     ty_mat2x2T = mkTensor(id+5,[2,2])
146 :     ty_mat3x3T = mkTensor(id+6,[3,3])
147 :     ty_mat2x3T = mkTensor(id+7,[2,3])
148 :     ty_mat3x2T = mkTensor(id+8,[3,2])
149 :     ty_ten2x2x2T = mkTensor(id+9,[2,2,2])
150 :     ty_ten3x3x3T = mkTensor(id+10,[3,3,3])
151 : cchiw 3915
152 : cchiw 4343 ty_ten2x3x2T = mkTensor(id+11,[2,3,2])
153 :     ty_ten3x2x2T = mkTensor(id+12,[3,2,2])
154 :     ty_ten3x3x2T = mkTensor(id+13,[3,3,2])
155 :    
156 :     ty_ten2x3x3T = mkTensor(id+14,[2,3,3])
157 :     ty_ten3x2x3T = mkTensor(id+15,[3,2,3])
158 :     ty_ten2x2x3T = mkTensor(id+16,[2,2,3])
159 :    
160 :    
161 : cchiw 4158 id=0
162 : cchiw 3939 #lift tensor to field level
163 : cchiw 4308 ty_scalarFT = mkNoneField(id, ty_scalarT)
164 :     ty_vec1FT = mkNoneField(id+1, ty_vec1T)
165 :     ty_vec2FT = mkNoneField(id+2, ty_vec2T)
166 :     ty_vec3FT = mkNoneField(id+3, ty_vec3T)
167 :     ty_mat2x2FT = mkNoneField(id+4, ty_mat2x2T)
168 :     ty_mat3x3FT =mkNoneField(id+5, ty_mat3x3T)
169 :     ty_mat2x3FT = mkNoneField(id+6, ty_mat2x3T)
170 :     ty_mat3x2FT = mkNoneField(id+7, ty_mat3x2T)
171 :     ty_ten2x2x2FT = mkNoneField(id+8, ty_ten2x2x2T)
172 :     ty_ten3x3x3FT = mkNoneField(id+9, ty_ten3x3x3T)
173 : cchiw 3915
174 : cchiw 4210 #dimension 1
175 : cchiw 4158 id=10
176 :     dim = 1
177 : cchiw 4252 ty_scalarF_d1 = mkField(id, dim, ty_scalarT)
178 :     ty_vec1F_d1 = mkField(id+1, dim, ty_vec1T)
179 :     ty_vec2F_d1 = mkField(id+2, dim, ty_vec2T)
180 :     ty_vec3F_d1 = mkField(id+3, dim, ty_vec3T)
181 :     ty_mat2x2F_d1 = mkField(id+4, dim, ty_mat2x2T)
182 :     ty_mat3x3F_d1 = mkField(id+5, dim, ty_mat3x3T)
183 :     ty_mat2x3F_d1 = mkField(id+6, dim, ty_mat2x3T)
184 :     ty_mat3x2F_d1 = mkField(id+7, dim, ty_mat3x2T)
185 :     ty_ten2x2x2F_d1 = mkField(id+8, dim, ty_ten2x2x2T)
186 :     ty_ten3x3x3F_d1 = mkField(id+9, dim, ty_ten3x3x3T)
187 : cchiw 4158 #dimension 2
188 :     id=20
189 : cchiw 4210 dim=2
190 : cchiw 4247 ty_scalarF_d2 = mkField(id, dim, ty_scalarT)
191 :     ty_vec2F_d2 = mkField(id+1, dim, ty_vec2T)
192 :     ty_vec3F_d2 = mkField(id+2, dim, ty_vec3T)
193 :     ty_mat2x2F_d2 = mkField(id+3, dim, ty_mat2x2T)
194 :     ty_mat3x3F_d2 = mkField(id+4, dim, ty_mat3x3T)
195 :     ty_mat2x3F_d2 = mkField(id+5, dim, ty_mat2x3T)
196 :     ty_mat3x2F_d2 = mkField(id+6, dim, ty_mat3x2T)
197 : cchiw 4210 ty_ten2x2x2F_d2 = mkField(id+7, dim, ty_ten2x2x2T)
198 :     ty_ten3x3x3F_d2 = mkField(id+8, dim, ty_ten3x3x3T)
199 : cchiw 4158 #dimension 3
200 :     id=30
201 : cchiw 4210 dim=3
202 : cchiw 4247 ty_scalarF_d3 = mkField(id, dim, ty_scalarT)
203 :     ty_vec2F_d3 = mkField(id+1, dim, ty_vec2T)
204 :     ty_vec3F_d3 = mkField(id+2, dim, ty_vec3T)
205 :     ty_mat2x2F_d3 = mkField(id+3, dim, ty_mat2x2T)
206 :     ty_mat3x3F_d3 = mkField(id+4, dim, ty_mat3x3T)
207 :     ty_mat2x3F_d3 = mkField(id+5, dim, ty_mat2x3T)
208 :     ty_mat3x2F_d3 = mkField(id+6, dim, ty_mat3x2T)
209 : cchiw 4210 ty_ten2x2x2F_d3 = mkField(id+7, dim, ty_ten2x2x2T)
210 :     ty_ten3x3x3F_d3 = mkField(id+8, dim, ty_ten3x3x3T)
211 : cchiw 4158
212 : cchiw 4236 #list of types
213 :     #fields that we can create data
214 : cchiw 4243 #ty_vec2F_d1,
215 : cchiw 4385 l_all_F= [ty_scalarF_d1, ty_scalarF_d2, ty_vec2F_d2, ty_vec3F_d2, ty_mat2x2F_d2, ty_mat3x3F_d2, ty_scalarF_d3, ty_vec2F_d3, ty_vec3F_d3,ty_mat2x2F_d3, ty_mat3x3F_d3]
216 : cchiw 4308 l_all_T = [ty_scalarFT, ty_vec2FT, ty_vec3FT, ty_mat2x2FT, ty_mat3x3FT, ty_mat2x3FT , ty_mat3x2FT, ty_ten2x2x2FT, ty_ten3x3x3FT]
217 : cchiw 4397 l_all = l_all_T+l_all_F
218 : cchiw 4158
219 : cchiw 4236
220 : cchiw 3939 # check equal dim
221 :     def check_dim(fld,b):
222 :     if(fty.is_Field(b)):
223 : cchiw 4158 return (fld.dim==b.dim)
224 : cchiw 4308 return true
225 : cchiw 3939
226 : cchiw 4308 def get_scaF(es):
227 : cchiw 3998 rtn = []
228 :     # binary operator
229 : cchiw 4308 for f in es:
230 : cchiw 3998 if(fty.is_Scalar(f)):
231 :     rtn.append(f)
232 :     return rtn
233 :    
234 : cchiw 4158 def get_scaFnotd1():
235 :     rtn = []
236 :     # binary operator
237 :     for f in l_all_F:
238 :     dim = f.dim
239 :     if(fty.is_Scalar(f) and (dim!=1)):
240 :     rtn.append(f)
241 :     return rtn
242 :    
243 :    
244 : cchiw 3939 #list of vector fields
245 : cchiw 4308 def get_vecF(es):
246 : cchiw 3939 rtn = []
247 :     # binary operator
248 : cchiw 4308 for f in es:
249 : cchiw 3939 if(fty.is_Vector(f)):
250 :     rtn.append(f)
251 :     return rtn
252 : cchiw 3998
253 : cchiw 4248 #list of vector fields (d)=n.
254 : cchiw 4308 def get_vecF_samedim(es):
255 : cchiw 4248 rtn = []
256 :     # binary operator
257 : cchiw 4308 for f in es:
258 : cchiw 4248 if(fty.is_Vector(f)):
259 :     [n] =f.shape
260 :     if(f.dim== n):
261 :     rtn.append(f)
262 :     return rtn
263 :    
264 : cchiw 4411 #list of vector fields (d)=n.
265 :     def get_vecF_matF(es):
266 :     rtn = []
267 :     # binary operator
268 :     for f in es:
269 :     if(fty.is_Vector(f) or fty.is_Matrix(f)):
270 :     rtn.append(f)
271 :     return rtn
272 : cchiw 4248
273 : cchiw 4411
274 : cchiw 3998 #list of matrix fields
275 : cchiw 4308 def get_matF(es):
276 : cchiw 3939 rtn = []
277 :     # binary operator
278 : cchiw 4308 for f in es:
279 : cchiw 3998 if(fty.is_Matrix(f)):
280 : cchiw 3939 rtn.append(f)
281 : cchiw 4411 #print "ty", f.name
282 : cchiw 3939 return rtn
283 :    
284 : cchiw 4308
285 : cchiw 4243 #list of matrix fields
286 : cchiw 4308 def get_mat_symmal(es):
287 : cchiw 4243 rtn = []
288 :     # binary operator
289 : cchiw 4308 for f in es:
290 :     if(fty.is_Matrix(f)):
291 :     [n, m] = f.shape
292 :     if(n==m):
293 :     rtn.append(f)
294 :     return rtn
295 :    
296 :    
297 :     #list of matrix fields
298 :     def get_Ten3(es):
299 :     rtn = []
300 :     # binary operator
301 :     for f in es:
302 : cchiw 4243 if(fty.is_Ten3(f)):
303 :     rtn.append(f)
304 :     return rtn
305 : cchiw 3939
306 :     #binary operator so two args
307 :     def find_field(ty1, ty2):
308 :     dim1=ty1.dim
309 :     dim2=ty2.dim
310 :     if (dim1==0): # tensors
311 : cchiw 3946 return (True , ty2)
312 : cchiw 3939 elif(dim2==0):# tensors
313 : cchiw 3946 return (True , ty1)
314 : cchiw 3939 elif(dim1==dim2):
315 : cchiw 3946 return (True , ty1)
316 : cchiw 3939 else :
317 : cchiw 3946 return (False, None)
318 : cchiw 3939
319 :     #shape to type
320 : cchiw 4236 #shape to type
321 :     def shapeToTyhelper(shapeout, dim):
322 : cchiw 4308 if (dim==nonefield_dim):
323 : cchiw 3939 if (shapeout==[]):
324 : cchiw 4308 return (True, ty_scalarFT)
325 :     elif (shapeout==[1]):
326 :     return (True, ty_vec1FT)
327 :     elif (shapeout==[2]):
328 :     return (True, ty_vec2FT)
329 :     elif (shapeout==[3]):
330 :     return (True, ty_vec3FT)
331 :     elif(shapeout==[1,1]):
332 :     return (True, ty_mat1x1FT)
333 :     elif(shapeout==[2,2]):
334 :     return (True, ty_mat2x2FT)
335 :     elif(shapeout==[3,3]):
336 :     return (True, ty_mat3x3FT)
337 :     elif(shapeout==[2,3]):
338 :     return (True, ty_mat2x3FT)
339 :     elif(shapeout==[3,2]):
340 :     return (True, ty_mat3x2FT)
341 :     elif (shapeout==[2,2,2]):
342 :     return (True, ty_ten2x2x2FT)
343 :     elif(shapeout==[3,3, 3]):
344 :     return (True, ty_ten3x3x3FT)
345 :     else:
346 :     #print "shapeout",shapeout,"dim", dim
347 :     return (False, ("unsupported shapeout dim-1 "+ str(shapeout)))
348 :     elif (dim==1):
349 :     if (shapeout==[]):
350 : cchiw 4236 return (True, ty_scalarF_d1)
351 : cchiw 4158 elif (shapeout==[1]):
352 : cchiw 4236 return (True, ty_vec1F_d1)
353 : cchiw 4158 elif (shapeout==[2]):
354 : cchiw 4236 return (True, ty_vec2F_d1)
355 : cchiw 4158 elif (shapeout==[3]):
356 : cchiw 4236 return (True, ty_vec3F_d1)
357 : cchiw 4158 elif(shapeout==[1,1]):
358 : cchiw 4236 return (True, ty_mat1x1F_d1)
359 : cchiw 4158 elif(shapeout==[2,2]):
360 : cchiw 4236 return (True, ty_mat2x2F_d1)
361 : cchiw 4158 elif(shapeout==[3,3]):
362 : cchiw 4236 return (True, ty_mat3x3F_d1)
363 : cchiw 4252 elif(shapeout==[2,3]):
364 :     return (True, ty_mat2x3F_d1)
365 :     elif(shapeout==[3,2]):
366 :     return (True, ty_mat3x2F_d1)
367 : cchiw 4158 elif (shapeout==[2,2,2]):
368 : cchiw 4236 return (True, ty_ten2x2x2F_d1)
369 : cchiw 4158 elif(shapeout==[3,3, 3]):
370 : cchiw 4236 return (True, ty_ten3x3x3F_d1)
371 : cchiw 4158 else:
372 :     #print "shapeout",shapeout,"dim", dim
373 : cchiw 4236 return (False, ("unsupported shapeout dim-1 "+ str(shapeout)))
374 : cchiw 4158 elif (dim==2):
375 :     if (shapeout==[]):
376 : cchiw 4236 return (True, ty_scalarF_d2)
377 : cchiw 3939 elif (shapeout==[2]):
378 : cchiw 4236 return (True, ty_vec2F_d2)
379 : cchiw 3939 elif(shapeout==[3]):
380 : cchiw 4236 return (True, ty_vec3F_d2)
381 : cchiw 3939 elif (shapeout==[2,2]):
382 : cchiw 4236 return (True, ty_mat2x2F_d2)
383 : cchiw 3939 elif (shapeout==[3,3]):
384 : cchiw 4236 return (True, ty_mat3x3F_d2)
385 : cchiw 4247 elif (shapeout==[2,3]):
386 :     return (True, ty_mat2x3F_d2)
387 :     elif (shapeout==[3,2]):
388 :     return (True, ty_mat3x2F_d2)
389 : cchiw 3939 elif (shapeout==[2,2,2]):
390 : cchiw 4236 return (True, ty_ten2x2x2F_d2)
391 : cchiw 3939 elif(shapeout==[3,3, 3]):
392 : cchiw 4236 return (True, ty_ten3x3x3F_d2)
393 : cchiw 3939 else:
394 : cchiw 3946 #print "shapeout",shapeout,"dim", dim
395 : cchiw 4236 return(False, "unsupported shapeout dim-2 "+str(shapeout))
396 : cchiw 3939 elif (dim==3):
397 :     if (shapeout==[]):
398 : cchiw 4236 return (True, ty_scalarF_d3)
399 : cchiw 3939 elif(shapeout==[3]):
400 : cchiw 4236 return (True, ty_vec3F_d3)
401 : cchiw 3939 elif(shapeout==[3,3]):
402 : cchiw 4236 return (True, ty_mat3x3F_d3)
403 : cchiw 3939 elif(shapeout==[3,3, 3]):
404 : cchiw 4236 return(True, ty_ten3x3x3F_d3)
405 : cchiw 3939 elif(shapeout==[2]):
406 : cchiw 4236 return (True, ty_vec2F_d3)
407 : cchiw 3939 elif(shapeout==[2,2]):
408 : cchiw 4236 return (True, ty_mat2x2F_d3)
409 : cchiw 4247 elif(shapeout==[2,3]):
410 :     return (True, ty_mat2x3F_d3)
411 :     elif(shapeout==[3,2]):
412 :     return (True, ty_mat3x2F_d3)
413 : cchiw 3939 elif(shapeout==[2,2,2]):
414 : cchiw 4236 return (True, ty_ten2x2x2F_d3)
415 : cchiw 3939 else:
416 : cchiw 4236 return (False, "unsupported shapeout dim-3"+str(shapeout))
417 : cchiw 3915 else:
418 : cchiw 4236 return (False, "unsupported dim")
419 : cchiw 3915
420 : cchiw 4236 def shapeToTy(shapeout, dim):
421 : cchiw 4261 (tf, shape) = shapeToTyhelper(shapeout, dim)
422 :     if(tf):
423 :     return shape
424 :     else:
425 :     raise Exception ("shapeout",shapeout, "dim", dim, "rtn:",shape)
426 : cchiw 3939 #concat two types to form a new type
427 :     def concatTys(ty1, ty2):
428 :     if (fty.is_Vector(ty1)):
429 :     n1 = fty.get_vecLength(ty1)
430 :     if (fty.is_Vector(ty2)):
431 :     n2 = fty.get_vecLength(ty2)
432 :     fldty = find_field(ty1, ty2)
433 :     k = fldty.k
434 :     if (n1==2):
435 :     if(n2==2):
436 :     return fty.convertTy(ty_mat2x2F_d2, k)
437 :     else:
438 :     raise "unsupported concat types"
439 :     elif (n1==3):
440 :     if (n2==3):
441 :     return fty.convertTy(ty_mat3x3F_d3, k)
442 :     else:
443 :     raise "unsupported concat types"
444 :     # elif (fty.is_Matrix(ty2)):
445 :     else:
446 :     raise "unsupported concat types"
447 : cchiw 3915 else:
448 : cchiw 3939 raise "unsupported concat types"
449 : cchiw 3915
450 : cchiw 3939 #reduce shape of fields
451 :     def reduceIndex(ty0):
452 :     # keep current k value
453 :     k=ty0.k
454 :     if (fty.isEq_id(ty0, ty_vec2F_d2)):
455 :     return fty.convertTy(ty_scalarF_d2,k)
456 :     elif (fty.isEq_id(ty0, ty_vec3F_d3)):
457 :     return fty.convertTy(ty_scalarF_d3,k)
458 : cchiw 3915 else:
459 : cchiw 3939 raise Exception ("unsupported field shape:"+ty0.name)
460 : cchiw 3915

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0