Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/synth/d2/obj_ty.py
ViewVC logotype

View of /branches/ein16/synth/d2/obj_ty.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3946 - (download) (as text) (annotate)
Sat Jun 11 00:39:19 2016 UTC (2 years, 9 months ago) by cchiw
File size: 8362 byte(s)
 added outer prod
# -*- coding: utf-8 -*-

from __future__ import unicode_literals

#tensor types
class tty:
    def __init__(self, id, name, shape):
        self.id=id
        self.name=name
        self.shape=shape
    def toStr(self):
        return ("tensor"+str(self.name))
    def isEq_id(a,b):
        return (a.id==b.id)

# field types
class fty:
    def __init__(self, id, name, dim, shape, tensorType,k):
        self.id=id
        self.name=name
        self.dim=dim
        self.shape=shape
        self.tensorType=tensorType # if we probed a field with this field type
        self.k=k
    def toStr(self):
        if(self.dim==0):
            return "tensor "
        else:
            return ("field #"+str(self.k)+"("+str(self.dim)+")"+str(self.shape))
    def get_tensorType(self):
        return self.tensorType
    def get_dim(self):
        return self.dim
    # get vector length
    def get_shape(ty0):
        return  ty0.shape

    def get_vecLength(ty0):
        shape = ty0.shape
        if  (len(shape)==1):
            return shape[0] # vector length
        else:
            raise "unsupported get_vecLength types"
    def is_Field(self):
        return  (self.dim>0)
    def is_ScalarField(self):
        return  ((len(self.shape)==0) and fty.is_Field(self))
    def is_VectorField(self):
        return  ((len(self.shape)==1) and fty.is_Field(self))
    def is_Scalar(self):
        return  (len(self.shape)==0)
    def is_Vector(self):
        return  (len(self.shape)==1)
    def is_Matrix(self):
        return  (len(self.shape)==2)
    def is_Ten3(self):
        return  (len(self.shape)==3)
    def get_first_ix(self):
        return  self.shape[0]
    def get_last_ix(self):
        return  self.shape[len(self.shape)-1]
    def drop_first(self):
        rtn = []
        for i in range(len(self.shape)-1):
            rtn.append(self.shape[i])
        return rtn
    def drop_last(self):
        rtn = []
        for i in range(len(self.shape)-1):
            rtn.append(self.shape[i+1])
        return rtn

    #compares finfo with fty constant
    def isEq_id(a,b):
        return (a.id==b.id)
    #string for diderot program
    def toDiderot(self):
        if(self.dim==0):
            return "tensor "+str(self.shape)
        else:
            return "field#"+str(self.k)+"("+str(self.dim)+")"+str(self.shape)
    #creates ty object
    def convertTy(const,k):
        return  fty(const.id,const.name, const.dim, const.shape, const.tensorType, k)


# ------------------------------ type name to other properties ------------------------------
# shorthand used to refer to different types
# the helper functions match shorthand to other properties and creates ty object

#tensors
ty_noneT = tty(0,"none",[])
ty_scalarT = tty(1,"scalartensor",[])
ty_vec2T = tty(2,"vec2tensor",[2])
ty_vec3T = tty(3,"vec3tensor",[3])
ty_mat2x2T = tty(4,"mat2x2tensor",[2,2])
ty_mat3x3T = tty(5,"mat3x3tensor",[3,3])
ty_ten2x2x2T = tty(11,"ten2x2x2tensor",[2,2,2])
ty_ten3x3x3T = tty(12,"ten3x3x3tensor",[3,3,3])


#lift tensor to field level
ty_scalarFT = fty(0,"scalartensor",0, [],ty_noneT,-1)
ty_vec2FT = fty(1,"vec2tensor",0, [2],ty_noneT,-1)
ty_vec3FT = fty(2,"vec3tensor",0,[3],ty_noneT,-1)
ty_mat2x2FT = fty(3,"mat2x2tensor",0,[2,2],ty_noneT,-1)
ty_mat3x3FT = fty(4,"mat3x3tensor",0,[3,3],ty_noneT,-1)
ty_mat2x2x2FT = fty(5,"mat2x2x2tensor",0,[2,2,2],ty_noneT,-1)
ty_mat3x3x3FT = fty(6,"mat3x3tx3ensor",0,[3,3,3],ty_noneT,-1)

def get_Tshape(ty):
    return ty.shape

#fields:  #id,name, dim, shape in string form,probe field type returns tensor type
k_init=0#null k
ty_scalarF_d2 = fty(7,"scalarfield_d2", 2, [], ty_scalarT,k_init)
ty_vec2F_d2 = fty(8,"vec2field_d2", 2, [2], ty_vec2T,k_init)
ty_mat2x2F_d2 = fty(9,"mat2x2field_d2", 2, [2,2], ty_mat2x2T,k_init)
ty_ten2x2x2F_d2 = fty(10,"mat2x2x2field_d2", 2, [2,2,2], ty_ten2x2x2T,k_init)
ty_scalarF_d3 = fty(11,"scalarfield_d3", 3, [], ty_scalarT,k_init)
ty_vec3F_d3 = fty(12,"vec3field_d3", 3, [3], ty_vec3T,k_init)
ty_mat3x3F_d3 = fty(13,"mat3x3field_d3", 3, [3,3], ty_mat3x3T,k_init)
ty_ten3x3x3F_d3 = fty(14,"mat3x3x3field_d3", 3, [3,3,3], ty_ten3x3x3T,k_init)

ty_vec3F_d2 = fty(15,"vec3field_d2", 2, [3], ty_vec3T,k_init)
ty_mat3x3F_d2 = fty(16,"mat3x3field_d2", 2, [3,3], ty_mat3x3T,k_init)
ty_ten3x3x3F_d2 = fty(17,"ten3x3x3field_d2", 2, [3,3,3], ty_ten3x3x3T,k_init)

ty_vec2F_d3 = fty(18,"vec2field_d3", 3, [2], ty_vec2T,k_init)
ty_mat2x2F_d3 = fty(19,"mat2x2field_d3", 3, [2,2], ty_mat2x2T,k_init)
ty_ten2x2x2F_d3 = fty(20,"ten2x2x2field_d3", 3, [2,2,2], ty_ten2x2x2T,k_init)

# check equal dim
def check_dim(fld,b):
    if(fty.is_Field(b)):
        dim2=b.dim
        return (fld.dim==dim2)

#list of vector fields
def get_vecF():
    rtn = []
    # binary operator
    for f in l_all_F:
        if(fty.is_Vector(f)):
            rtn.append(f)
    return rtn
def get_scaF():
    rtn = []
    # binary operator
    for f in l_all_F:
        if(fty.is_Scalar(f)):
            rtn.append(f)
    return rtn


# types for multiplication
def get_mul():
    rtn = []
    # binary operator
    for sf in l_all_F:
        for a in l_all:
            # check equal dim
            if(fty.is_Field(a) and (not check_dim(sf, a))):
                continue
            # one arg needs to be a scalar
            elif (fty.is_Scalar(sf) or fty.is_Scalar(a)):
                #print ("["+sf.name+","+a.name+"]")
                rtn.append([sf,a])
    return rtn

#list of types
#fields that we can create data
l_all_F= [ty_scalarF_d2, ty_vec2F_d2, ty_scalarF_d3, ty_vec3F_d3]
l_all_FT = [ty_scalarFT, ty_vec2FT, ty_vec3FT, ty_mat2x2FT, ty_mat3x3FT, ty_mat2x2x2FT, ty_mat3x3x3FT]
l_all= l_all_F + l_all_FT
# list of fields by type
vectorFlds = get_vecF()
scalarFlds = get_scaF()


#binary operator so two args
def find_field(ty1, ty2):
    dim1=ty1.dim
    dim2=ty2.dim
    if (dim1==0): # tensors
        return (True , ty2)
    elif(dim2==0):# tensors
        return  (True , ty1)
    elif(dim1==dim2):
        return  (True , ty1)
    else :
        return (False, None)

#shape to type
def shapeToTy(shapeout, dim):
    if (dim==2):
        if (shapeout==[]):
            return ty_scalarF_d2
        elif (shapeout==[2]):
            return ty_vec2F_d2
        elif(shapeout==[3]):
            return ty_vec3F_d2
        elif (shapeout==[2,2]):
            return ty_mat2x2F_d2
        elif (shapeout==[3,3]):
            return ty_mat3x3F_d2
        elif (shapeout==[2,2,2]):
            return ty_ten2x2x2F_d2
        elif(shapeout==[3,3, 3]):
            return ty_ten3x3x3F_d2
        else:
            #print "shapeout",shapeout,"dim", dim
            raise Exception ("unsupported shapeout", str(shapeout))
    elif (dim==3):
        if (shapeout==[]):
            return ty_scalarF_d3
        elif(shapeout==[3]):
            return ty_vec3F_d3
        elif(shapeout==[3,3]):
            return ty_mat3x3F_d3

        elif(shapeout==[3,3, 3]):
            return ty_ten3x3x3F_d3
        elif(shapeout==[2]):
            return ty_vec2F_d3
        elif(shapeout==[2,2]):
            return ty_mat2x2F_d3
        elif(shapeout==[2,2,2]):
            return ty_ten2x2x2F_d3
        else:
            raise "unsupported shapeout"
    else:
        raise "unsupported dim"

#concat two types to form a new type
def concatTys(ty1, ty2):
    if (fty.is_Vector(ty1)):
        n1 = fty.get_vecLength(ty1)
        if (fty.is_Vector(ty2)):
            n2 = fty.get_vecLength(ty2)
            fldty = find_field(ty1, ty2)
            k =  fldty.k
            if (n1==2):
                if(n2==2):
                    return fty.convertTy(ty_mat2x2F_d2, k)
                else:
                    raise "unsupported concat types"
            elif (n1==3):
                if (n2==3):
                    return fty.convertTy(ty_mat3x3F_d3, k)
                else:
                    raise "unsupported concat types"
    #    elif (fty.is_Matrix(ty2)):
        else:
            raise "unsupported concat types"
    else:
        raise "unsupported concat types"

#reduce shape of fields
def reduceIndex(ty0):
    # keep current k value
    k=ty0.k
    if (fty.isEq_id(ty0, ty_vec2F_d2)):
        return fty.convertTy(ty_scalarF_d2,k)
    elif (fty.isEq_id(ty0, ty_vec3F_d3)):
        return fty.convertTy(ty_scalarF_d3,k)
    else:
        raise Exception ("unsupported field shape:"+ty0.name)


root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0