Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/synth/d2/test_eval.py
ViewVC logotype

View of /branches/ein16/synth/d2/test_eval.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4236 - (download) (as text) (annotate)
Wed Jul 20 03:02:00 2016 UTC (3 years, 2 months ago) by cchiw
File size: 28378 byte(s)
added generic cases for trace|det and added test cases
import sympy
from sympy import *
#symbols
x,y,z =symbols('x y z')
import sys
import re
import math
from obj_ty import *
from obj_apply import *
from obj_operator import *
from obj_field import *



# ***************************  unary operators ***************************
# binary operators
def fn_add(exp1,exp2):
    return exp1+exp2
def fn_subtract(exp1,exp2):
    return exp1-exp2
# scaling operator
def fn_multiplication(exp_s, t):
    exp_t = field.get_data(t)
    ityp1 = field.get_ty(t)
    shape1 = fty.get_shape(ityp1)
    #print "exp_s",exp_s,"exp_t",exp_t
    if(field.is_Scalar(t)):
        return  exp_s*  exp_t
    elif(field.is_Vector(t)):
        [n1] =  shape1 #vector
        rtn = []
        for i in range(n1):
            rtn.append(exp_s*exp_t[i])
        return rtn
    elif(field.is_Matrix(t)):
        [n1,n2] =  shape1
        rtn = []
        for i in range(n1):
            tmp = []
            for j in range(n2):
                tmp.append(exp_s*exp_t[i][j])
            rtn.append(tmp)
        return rtn
    elif(field.is_Ten3(t)):
        [n1,n2,n3] =  shape1
        rtn = []
        for i in range(n1):
            tmpI = []
            for j in range(n2):
                tmpJ = []
                for k in range(n3):
                    tmpJ.append(exp_s*exp_t[i][j][k])
                tmpI.append(tmpJ)
            rtn.append(tmpI)
        return rtn
    else:
        raise "unsupported scaling"

#scaling of a field
def fn_scaling(fld1, fld2):
    def get_sca():
        if(field.is_Scalar(fld1)):
            return (fld1, fld2)
        else:
            return (fld2, fld1)
    (s, t) = get_sca()
    exp_s = field.get_data(s)
    return fn_multiplication(exp_s, t)

#division of a field
def fn_division(t, s):
    if(field.is_Scalar(s)):
        exp_s = (1/field.get_data(s))
        return fn_multiplication(exp_s, t)
    else:
        raise Exception ("err second arg in division should be a scalar")


# sine  of field
def fn_negation(exp):
    return -1*exp

# negation  of field
def fn_negation(exp):
    return -1*exp


#evaluate cross product
def fn_cross(fld1, fld2):
    exp1 = field.get_data(fld1)
    ityp1 = field.get_ty(fld1)
    exp2 = field.get_data(fld2)
    ityp2 = field.get_ty(fld2)
    #print " exp1: ",exp1," exp2: ",exp2
    # vectors
    n1 = fty.get_vecLength(ityp1) #vector
    n2 = fty.get_vecLength(ityp2)
    if(n1==2):
        return (exp1[0]*exp2[1]) -(exp1[1]*exp2[0])
    elif(n1==3):
        x0= (exp1[1]*exp2[2]) -(exp1[2]*exp2[1])
        x1= (exp1[2]*exp2[0]) -(exp1[0]*exp2[2])
        x2= (exp1[0]*exp2[1]) -(exp1[1]*exp2[0])
        return [x0,x1,x2]
    else:
        raise "unsupported type for cross product"

#gradient of field
def fn_grad(exp, dim):
    if (dim==1):
        return diff(exp,x)
    elif (dim==2):
        return [diff(exp,x), diff(exp,y)]
    elif (dim==3):
        return [diff(exp,x), diff(exp,y), diff(exp,z)]
    else:
        raise "dimension not supported"

#evaluate divergence
def fn_divergence(fld):
    exp = field.get_data(fld)
    ityp = field.get_ty(fld)
    #print " exp1: ",exp1," exp2: ",exp2
    # vectors
    n1 = fty.get_vecLength(ityp) #vector
    if(n1==2):
        return diff(exp[0],x)+diff(exp[1],y)
    elif(n1==3):
        
        return diff(exp[0],x)+diff(exp[1],y)+diff(exp[2],z)
    else:
        raise "unsupported type for divergence"

#evaluate cross product
def fn_curl(fld):
    exp = field.get_data(fld)
    ityp = field.get_ty(fld)
    dim = field.get_dim(fld)
    n = fty.get_vecLength(ityp) #vector
    if(n!=dim):
        raise (" type not supported for curl")
    if(n==2):
       return diff(exp[1], x) - diff(exp[0], y)
    elif(n==3):
        x0= diff(exp[2],y) - diff(exp[1],z)
        x1= diff(exp[0],z) - diff(exp[2],x)
        x2= diff(exp[1],x) - diff(exp[0],y)
        return [x0,x1,x2]
    else:
        raise "unsupported type for cross product"

#evaluate jacob
def fn_jacob(fld):
    exp = field.get_data(fld)
    ityp = field.get_ty(fld)
    dim = field.get_dim(fld)
    #print " exp1: ",exp1," exp2: ",exp2
    # vectors
    n = fty.get_vecLength(ityp) #vector
    if(n!=dim):
       raise (" type not supported for jacob")
    if(n==2):
        return [[diff(exp[0],x), diff(exp[0],y)],
                [diff(exp[1],x), diff(exp[1],y)]]
    elif(n==3):
        return  [[diff(exp[0],x), diff(exp[0],y), diff(exp[0],z)],
                 [diff(exp[1],x), diff(exp[1],y), diff(exp[1],z)],
                 [diff(exp[2],x), diff(exp[2],y), diff(exp[2],z)]]
    else:
        raise "unsupported type for jacob"

#evaluate norm
def fn_norm(fld, dim):
    exp = field.get_data(fld)
    ityp = field.get_ty(fld)
    dim = field.get_dim(fld)
    #print " exp1: ",exp1," exp2: ",exp2
    # vectors
    def iter (es):
        sum = 0
        for i in es:
            sum+=i*i
        #print "\nsum",sum
        rtn  = (sum)**0.5
        #print "\nrtn",rtn
        return rtn
    if(field.is_Scalar(fld)):
        [] = fty.get_shape(ityp)
        return exp
    elif(field.is_Vector(fld)):
        [n] = fty.get_shape(ityp)
        rtn = []
        for i in range(n):
            rtn.append(exp[i])
        return iter(rtn)
    elif(field.is_Matrix(fld)):
        [n, m] = fty.get_shape(ityp)
        rtn = []
        for i in range(n):
            for j in range(m):
                rtn.append(exp[i][j])
        return iter(rtn)
    elif(field.is_Ten3(fld)):
        [n, m, p] = fty.get_shape(ityp)
        rtn = []
        for i in range(n):
            for j in range(m):
                for k in range(p):
                    rtn.append(exp[i][j][k])
        return iter(rtn)
    else:
        raise "unsupported type for norm"

#evaluate norm
def fn_normalize(fld, dim):
    exp = field.get_data(fld)
    ityp = field.get_ty(fld)
    dim = field.get_dim(fld)
    #print " exp1: ",exp1," exp2: ",exp2
    norm = fn_norm(fld, dim)
    if(field.is_Scalar(fld)):
        #print "scal",exp
        return exp
    elif(field.is_Vector(fld)):
        [n] = fty.get_shape(ityp)
        rtn = []
        for i in range(n):
            rtn.append(exp[i]/norm)
        #print "vec",rtn
        return rtn
    elif(field.is_Matrix(fld)):
        [n, m] = fty.get_shape(ityp)
        rtn = []
        for i in range(n):
            rtni = []
            for j in range(m):
                rtni.append(exp[i][j]/norm)
            rtn.append(rtni)
            #print "matrix:",rtn
        return rtn
    elif(field.is_Ten3(fld)):
        [n, m, p] = fty.get_shape(ityp)
        rtn = []
        for i in range(n):
            rtni = []
            for j in range(m):
                rtnj = []
                for k in range(p):
                    rtnj.append(exp[i][j][k]/norm)
                rtni.append( rtnj)
            rtn.append( rtni)
#print "ten3",rtn
        return rtn
    else:
        raise "unsupported type for norm"

#evaluate slice
def fn_slice(fld1):
    exp1 = field.get_data(fld1)
    ityp1 = field.get_ty(fld1)
    rtn=[]
    if(fty.is_Matrix(ityp1)):
        [n2,n3] = fty.get_shape(ityp1)
        for j in range(n3):
            rtn.append(exp1[j][0])
        return rtn
    else:
        raise "unsupported type for slice"

#evaluate trace
def fn_trace(fld):
    exp = field.get_data(fld)
    ityp = field.get_ty(fld)
    rtn=[]
    if(field.is_Matrix(fld)):
        [n, m] = fty.get_shape(ityp)
        if (n!=m):
            raise Exception("matrix is not identitical")
        rtn = exp[0][0]+exp[1][1]
        if(n==2):
            return rtn
        elif(n==3):
            return rtn+exp[2][2]
        else:
            raise "unsupported matrix field"
    else:
        raise "unsupported trace"

#evaluate transpose
def fn_transpose(fld):
    exp = field.get_data(fld)
    ityp = field.get_ty(fld)
    if(field.is_Matrix(fld)):
        [n, m] = fty.get_shape(ityp)
        rtn = []
        for i in range(n):
            rtni = []
            for j in range(m):
                rtni.append(exp[j][i])
            rtn.append(rtni)
        return rtn
    else:
        raise "unsupported transpose"

#evaluate det
def fn_det(fld):
    exp = field.get_data(fld)
    ityp = field.get_ty(fld)
    rtn=[]
    if(field.is_Matrix(fld)):
        [n, m] = fty.get_shape(ityp)
        if (n!=m):
            raise Exception("matrix is not identitical")
        a = exp[0][0]
        d = exp[1][1]
        c = exp[1][0]
        b = exp[0][1]
        if(n==2):
            return a*d-b*c
        elif(n==3):
            a = exp[0][0]
            b = exp[0][1]
            c = exp[0][2]
            d = exp[1][0]
            e = exp[1][1]
            f = exp[1][2]
            g = exp[2][0]
            h = exp[2][1]
            i = exp[2][2]
            return a*(e*i-f*h)-b*(d*i-f*g)+c*(d*h-e*g)
        else:
            raise "unsupported matrix field"
    else:
        raise "unsupported trace"


#evaluate outer product
def fn_outer(fld1, fld2):
    exp1 = field.get_data(fld1)
    ityp1 = field.get_ty(fld1)
    exp2 = field.get_data(fld2)
    ityp2 = field.get_ty(fld2)
    rtn=[]
    #print "exp1",exp1,"ityp1",ityp1.name,"-length",len(exp1)
    #print "exp2",exp2,"ityp2",ityp2.name,"-length",len(exp2)

    if(fty.is_Vector(ityp1)):
        n1= fty.get_vecLength(ityp1)
        if(fty.is_Vector(ityp2)):
            #both vectors
            n2= fty.get_vecLength(ityp2)
            for i in  range(n1):
                tmpI = []
                for j in range(n2):
                    tmpI.append(exp1[i]*exp2[j])
                rtn.append(tmpI)
            return rtn
        elif(fty.is_Matrix(ityp2)):
            [n2,n3] = fty.get_shape(ityp2)
            for i in  range(n1):
                tmpI = []
                for j in range(n2):
                    tmpJ = []
                    for k in range(n3):
                        tmpJ.append(exp1[i]*exp2[j][k])
                    tmpI.append(tmpJ)
                rtn.append(tmpI)
            return rtn
    elif(fty.is_Matrix(ityp1)):
        [n1,n2] = fty.get_shape(ityp1)
        if(fty.is_Vector(ityp2)):
            n3= fty.get_vecLength(ityp2)
            for i in  range(n1):
                tmpI = []
                for j in range(n2):
                    tmpJ = []
                    for k in range(n3):
                        tmpJ.append(exp1[i][j]*exp2[k])
                    tmpI.append(tmpJ)
                rtn.append(tmpI)
            return rtn
        elif(fty.is_Matrix(ityp2)):
            [n3,n4] = fty.get_shape(ityp2)
            for i in  range(n1):
                tmpI = []
                for j in range(n2):
                    tmpJ = []
                    for k in range(n3):
                        tmpK = []
                        for l in range(n4):
                            tmpK.append(exp1[i][j]*exp2[k][l])
                        tmpJ.append(tmpK)
                    tmpI.append(tmpJ)
                rtn.append(tmpI)
            return rtn
        else:
            raise "outer product is not supported"
    else:
        raise "outer product is not supported"
#evaluate inner product
def fn_inner(fld1, fld2):
    exp1 = field.get_data(fld1)
    ityp1 = field.get_ty(fld1)
    exp2 = field.get_data(fld2)
    ityp2 = field.get_ty(fld2)
    #print " exp1: ",exp1," exp2: ",exp2
    # vectors
    if(fty.is_Vector(ityp1)):
        n1 = fty.get_vecLength(ityp1) #vector
        if(fty.is_Vector(ityp2)):
            #length of vetors
            rtn=0
            n2 = fty.get_vecLength(ityp2)
            for s in  range(n1):
                curr = exp1[s]*exp2[s]
                #print (" exp1[s]: ",exp1[s]," exp2[s]: ",exp2[s],"cur",curr)
                rtn += curr
            return rtn
        elif(fty.is_Matrix(ityp2)):
            [n2] = fty.drop_last(ityp2)  #matrix
            rtn=[]
            for i in  range(n2):
                sumrtn=0
                for s in  range(n1):
                    curr = exp1[s]*exp2[s][i]
                    sumrtn += curr
                rtn.append(sumrtn)
            return rtn
        elif(fty.is_Ten3(ityp2)):
            [n2,n3] = fty.drop_last(ityp2)
            rtn = []
            for i in  range(n2):
                tmpJ = []
                for j in  range(n3):
                    sumrtn=0
                    for s in  range(n1):
                        curr = exp1[s]*exp2[s][i][j]
                        sumrtn += curr
                    tmpJ.append(sumrtn)
                rtn.append(tmpJ)
            return rtn
        else:
            raise "inner product is not supported"
    elif(fty.is_Matrix(ityp1)):
        n2 = fty.get_first_ix(ityp1)  #matrix
        if(fty.is_Vector(ityp2)):
            ns = fty.get_vecLength(ityp2) #vector
            rtn=[]
            for i in  range(n2):
                sumrtn=0
                for s in  range(ns):
                    curr = exp1[i][s]*exp2[s]
                    sumrtn += curr
                rtn.append(sumrtn)
            return rtn
        else:
            raise "inner product is not supported"
    elif(fty.is_Ten3(ityp1)):
        [n1,n2] = fty.drop_first(ityp1)
        if(fty.is_Vector(ityp2)):
            ns = fty.get_vecLength(ityp2)
            rtn=[]
            for i in  range(n1):
                tmpI=[]
                for j in  range(n2):
                    sumrtn=0
                    for s in  range(ns):
                        curr = exp1[i][j][s]*exp2[s]
                        sumrtn += curr
                    tmpI.append(sumrtn)
                rtn.append(tmpI)
            return rtn
        else:
            raise "inner product is not supported"
    else:
        raise "inner product is not supported"


# ***************************  generic apply operators ***************************
#unary operator on a vector
def applyToVector(vec, unary):
    rtn = []
    for v in vec:
        rtn.append(unary(v))
    return rtn
#binary operator on a vector
def applyToVectors(vecA, vecB,  binary):
    rtn = []
    for (a,b) in zip(vecA,vecB):
        x= binary(a,b)
        rtn.append(x)
    return rtn

def applyToM(vec, unary):
    rtn = []
    for i in vec:
        tmpI = []
        for v in i:
            tmpI.append(unary(v))
        rtn.append(tmpI)
    return rtn

def applyToMs(vecA,vecB, unary):
    rtn = []
    for (a,b) in zip(vecA,vecB):
        tmpI = []
        for (u,v) in zip(a,b):
            tmpI.append(unary(u, v))
        rtn.append(tmpI)
    return rtn


def applyToT3(vec, unary):
    rtn = []
    for i in vec:
        tmpI = []
        for j in i:
            tmpJ = []
            for v in j:
                tmpJ.append(unary(v))
            tmpI.append(tmpJ)
        rtn.append(tmpI)
    return rtn

# ***************************  apply to scalars or vectors ***************************
#apply operators to expression
# return output types and expression
# unary operator
# exp: scalar types

def applyToExp_U_S(fn_name, fld):
    exp = field.get_data(fld)
    dim = field.get_dim(fld)
    #print fn_name
    if(op_probe==fn_name): #probing
        return  exp
    elif(op_negation==fn_name): #negation
        return fn_negation(exp)
    elif(op_gradient==fn_name): #gradient
        return fn_grad(exp, dim)
    else:
        raise Exception("unsupported unary operator on scalar field:"+ fn_name.name)

# unary operator
# exp: vector  types
def applyToExp_U_V(fn_name, fld):
    exp = field.get_data(fld)
    if(op_probe==fn_name): #probing
        return exp
    elif(op_negation==fn_name): #negation
        return applyToVector(exp, fn_negation)
    elif(op_divergence==fn_name):
        return fn_divergence(fld)
    elif(op_curl==fn_name):
        return fn_curl(fld)
    elif(op_jacob==fn_name): #jacob
        return fn_jacob(fld)
    else:
        raise Exception("unsupported unary operator:"+ fn_name.name)

def applyToExp_U_M(fn_name, fld):
    exp = field.get_data(fld)
    if(op_probe==fn_name): #probing
        return exp
    elif(op_negation==fn_name): #negation
        return applyToM(exp, fn_negation)
    elif(op_jacob==fn_name): #jacob
        return fn_jacob(fld)
    elif(op_slice==fn_name):
        return fn_slice(fld)
    elif(op_trace == fn_name):
        return fn_trace(fld)
    elif(op_transpose==fn_name):
        return fn_transpose(fld)
    elif(op_det==fn_name):
        return fn_det(fld)
    else:
        raise Exception("unsupported unary operator:"+ fn_name.name)

def applyToExp_U_T3(fn_name, fld):
    exp = field.get_data(fld)
    if(op_probe==fn_name): #probing
        return exp
    elif(op_negation==fn_name): #negation

        return applyToT3(exp, fn_negation)
    elif(op_jacob==fn_name): #jacob
        return fn_jacob(fld)
    else:
        raise Exception("unsupported unary operator:"+ fn_name.name)

# binary operator
# exp: scalar types
def applyToExp_B_S(e):
    fn_name=e.opr
    (fld1,fld2) =  apply.get_binary(e)
    exp1 = field.get_data(fld1)
    exp2 = field.get_data(fld2)
    #print fn_name
    if(op_add==fn_name):#addition
        return fn_add(exp1,exp2)
    elif(op_subtract==fn_name):#subtract
        return fn_subtract(exp1,exp2)
    elif(op_scale==fn_name): #scaling
        return fn_scaling(fld1,fld2)
    elif(op_division==fn_name): #division
        return fn_division(fld1,fld2)
    else:
        raise Exception("unsupported binary operator on scalar fields:"+ fn_name.name)

# binary, args do not need to have the same shape
def applyToExp_B_uneven(e):
    fn_name=e.opr
    (fld1,fld2) =  apply.get_binary(e)
    exp1 = field.get_data(fld1)
    exp2 = field.get_data(fld2)
    if(op_outer==fn_name):
        return fn_outer(fld1, fld2)
    elif(op_inner==fn_name):
        return fn_inner(fld1, fld2)
    elif(op_scale==fn_name): #scaling
        return fn_scaling(fld1,fld2)
    else:
        raise Exception("unsupported unary operator:",op_name)


# binary operator
# args have the same shape
def applyToExp_B_V(e):
    fn_name=e.opr
    (fld1,fld2) =  apply.get_binary(e)
    exp1 = field.get_data(fld1)
    exp2 = field.get_data(fld2)
    if(op_add==fn_name):#addition
        return applyToVectors(exp1, exp2,  fn_add)
    elif(op_subtract==fn_name):#subtract
        return  applyToVectors(exp1, exp2, fn_subtract)
    elif(op_cross==fn_name):
        return fn_cross(fld1, fld2)
    else:
       return applyToExp_B_uneven(e)

def applyToExp_B_M(e):
    fn_name=e.opr
    (fld1,fld2) =  apply.get_binary(e)
    exp1 = field.get_data(fld1)
    exp2 = field.get_data(fld2)
    if(op_add==fn_name):#addition
        return applyToMs(exp1, exp2,  fn_add)
    elif(op_subtract==fn_name):#subtract
        return  applyToMs(exp1, exp2, fn_subtract)
    else:
        return applyToExp_B_uneven(e)


# ***************************  unary/binary operators ***************************
def unary(e):
    #apply.toStr(e)
    fld =apply.get_unary(e)
    fn_name=e.opr
    exp = field.get_data(fld)
    dim = field.get_dim(fld)
    if(op_norm==fn_name):#norm
        return fn_norm(fld, dim)
    if(op_normalize==fn_name):#normalize
        x= fn_normalize(fld, dim)
        return x
    elif (field.is_Scalar(fld)): # input is a scalar field
        return applyToExp_U_S(fn_name, fld)
    elif(field.is_Vector(fld)): # input is a vector field
        return applyToExp_U_V(fn_name, fld)
    elif(field.is_Matrix(fld)): # input is a vector field
        return applyToExp_U_M(fn_name, fld)
    else:
        return applyToExp_U_T3(fn_name, fld)

def binary(e):
    (f, g) =apply.get_binary(e)
    fn_name = e.opr
    # type is checked elsewhere or does not matter
    if(op_division==fn_name): #division
        return fn_division(f, g)
    elif (field.is_Scalar(f) and field.is_Scalar(g)): # input is a scalar field
        return applyToExp_B_S(e)
    elif (field.is_Vector(f)):# input is a vector field
        if(field.is_Vector(g)):
            return applyToExp_B_V(e)
        else: # input is a vector field, _
            return applyToExp_B_V(e)
    elif (field.is_Matrix(f)):# input is a matrix field
        if(field.is_Matrix(g)):
            return applyToExp_B_M(e)
        else:
            return  applyToExp_B_V(e)
    else:
         return applyToExp_B_V(e)

def applyUnaryOnce(oexp_inner,app_inner,app_outer):
    #print "applyUnaryOnce"
    #apply.toStr(app_inner)
    oty_inner = apply.get_oty(app_inner)
    oty_outer = apply.get_oty(app_outer)
    opr_outer = app_outer.opr
    #print "oexp_inner",oexp_inner,"opr_outer",opr_outer.name
    lhs_tmp = field(true, "tmp", oty_inner, "", oexp_inner, "")
    app_tmp = apply("tmp", opr_outer, lhs_tmp, None, oty_outer, true, true)
    oexp_tmp =unary(app_tmp)
    #print " oexp_tmp", oexp_tmp
    return (oty_outer, oexp_tmp)

def applyBinaryOnce(oexp_inner,app_inner,app_outer,rhs):
    oty_inner = apply.get_oty(app_inner)
    oty_outer = apply.get_oty(app_outer)
    opr_outer = app_outer.opr
    
    lhs_tmp = field(true, "tmp", oty_inner, "", oexp_inner, "")
    
    app_tmp = apply("tmp", opr_outer, lhs_tmp, rhs, oty_outer, true,true)
    oexp_tmp =binary(app_tmp)
    return (oty_outer, oexp_tmp)


# operators with scalar field and vector field
def sort(e):
    #apply.toStr(e)
    arity = apply.get_arity(e)
    if(e.isrootlhs): # is root
        #print "sort is a root"
        oty = apply.get_oty(e)
        if (arity ==1):
            return (oty, unary(e))
        elif (arity ==2): # binary operator
            return (oty, binary(e))
        else:
            raise Exception ("arity is not supported: "+str(arity_outer))
    else:
        app_outer = e
        arity_outer = arity
        #print "app_outer",app_outer.opr.name
        if (arity_outer==1):  #assuming both arity
            app_inner = apply.get_unary(app_outer)
            #print "outer(1) app_inner:",app_inner.opr.name
            arity_inner=  app_inner.opr.arity
            if (arity_inner==1):
                oexp_inner = unary(app_inner)

                (oty_outer, oexp_tmp) =  applyUnaryOnce(oexp_inner ,app_inner, app_outer)

                return (oty_outer, oexp_tmp)
            elif(arity_inner==2):
                oexp_inner = binary(app_inner)
                (oty_outer, oexp_tmp) =  applyUnaryOnce(oexp_inner ,app_inner, app_outer)
  
                return (oty_outer, oexp_tmp)
            else:
                raise Exception ("arity is not supported: "+str(arity_outer))
        elif (arity_outer==2):
            (app_inner, G) = apply.get_binary(app_outer)
            arity_inner=  app_inner.opr.arity
            #print "outer(2) app_inner",app_inner.opr.name
            if (arity_inner==1):
                oexp_inner = unary(app_inner)
                rhs = G
                (oty_outer, oexp_tmp) =  applyBinaryOnce(oexp_inner, app_inner, app_outer, rhs)
                return (oty_outer, oexp_tmp)
            elif(arity_inner==2):
                oexp_inner = binary(app_inner)
                #print "applying binary frst time", oexp_inner
                rhs = G
                (oty_outer, oexp_tmp) =  applyBinaryOnce(oexp_inner, app_inner, app_outer, rhs)
                #print "applying binary second time",  oexp_tmp
                return (oty_outer, oexp_tmp)
            else:
                raise Exception ("arity is not supported: "+str(arity_outer))
        else:
            raise Exception ("arity is not supported: "+str(arity_outer))

# ***************************  evaluate at positions ***************************
#evaluate scalar field exp
def eval_d1(pos0, exp):
    #print "eval vec d1"
    #print "exp:",exp
    #print "pos0",pos0
    #evaluate field defined by coefficients at position
    exp = exp.subs(x,pos0)
    #print "exp",exp
    return exp

def eval_d2(pos0, pos1, exp):
    #print "exp:",exp
    #evaluate field defined by coefficients at position
    exp = exp.subs(x,pos0)
    exp = exp.subs(y,pos1)
    return exp

def eval_d3(pos0, pos1, pos2, exp):
    #evaluate field defined by coefficients at position
    exp = exp.subs(x,pos0)
    exp = exp.subs(y,pos1)
    exp = exp.subs(z,pos2)
    return exp

#evaluate vector field [exp]
def eval_vec_d1(pos0, vec):
    #print "eval vec d1"
    rtn = []
    for v in vec:
        rtn.append(eval_d1(pos0, v))
    return rtn

#evaluate vector field [exp,exp]
def eval_vec_d2(pos0, pos1, vec):
    #print "eval_vec_d2 vec:",vec
    rtn = []
    for v in vec:
        rtn.append(eval_d2(pos0, pos1, v))
    return rtn

def eval_ten3_d1(pos0,ten3):
    rtn = []
    for i in ten3:
        for j in i:
            for v in j:
                rtn.append(eval_d1(pos0, v))
    return rtn



#evaluate vector field [exp,exp]
def eval_mat_d1(pos0, mat):
    #print "eval_vec_d2 vec:",vec
    rtn = []
    for i in mat:
        for v in i:
            rtn.append(eval_d1(pos0, v))
    return rtn

#evaluate vector field [exp,exp]
def eval_mat_d2(pos0, pos1, mat):
    #print "eval_vec_d2 vec:",vec
    rtn = []
    #print "eval_mat_d2 mat",mat
    for i in mat:
        for v in i:
            rtn.append(eval_d2(pos0, pos1, v))
    return rtn

def eval_ten3_d2(pos0, pos1, ten3):
    #print "eval_vec_d2 vec:",vec
    rtn = []
    for i in ten3:
        for j in i:
            for v in j:
                rtn.append(eval_d2(pos0, pos1, v))
    return rtn



#evaluate vector field [exp,exp]
def eval_vec_d3(pos0, pos1, pos2, vec):
    rtn = []
    for v in vec:
        rtn.append(eval_d3(pos0, pos1, pos2, v))
    return rtn


#evaluate vector field [exp,exp]
def eval_mat_d3(pos0, pos1, pos2, mat):
    #print "eval_vec_d2 vec:",vec
    rtn = []
    for i in mat:
        for v in i:
            rtn.append(eval_d3(pos0, pos1, pos2, v))
    return rtn

def eval_ten3_d3(pos0, pos1, pos2,ten3):
    rtn = []
    for i in ten3:
        for j in i:
            for v in j:
                rtn.append(eval_d3(pos0, pos1, pos2, v))
    return rtn



def iter_d1(k, pos, exp):
    corr = []
    for x in pos:
        val = k(x, exp)
        corr.append(val)
    return corr

def iter_d2(k, pos, exp):
    corr = []
    #print "iter expr:", exp
    #print "pos", pos
    for p in pos:
        #print "p", p
        x=p[0]
        y=p[1]
        val = k(x,y,exp)
        corr.append(val)
    return corr

def iter_d3(k, pos, exp):
    corr = []
    #print "iter exp:", exp
    for p in pos:
        x=p[0]
        y=p[1]
        z=p[2]
        val = k(x,y,z, exp)
        #print "pos: ",x,y,z, " val:", val
        corr.append(val)
    return corr

def probeField(otyp1,pos, ortn):
    dim = fty.get_dim(otyp1)
    #print "output type"+otyp1.name
    if (dim==1):
        def get_k():
            if (fty.is_ScalarField(otyp1)): # output is a scalar field
                #print "s_d1"
                return eval_d1
            elif (fty.is_VectorField(otyp1)):
                #print "v_d1"
                return eval_vec_d1
            elif (fty.is_Matrix(otyp1)):
                return eval_mat_d1
            elif(fty.is_Ten3(otyp1)):
                return eval_ten3_d1
            else:
                raise "error"+otyp1.name
        return iter_d1(get_k(), pos, ortn)
    elif (dim==2):
        def get_k():
            if (fty.is_ScalarField(otyp1)): # output is a scalar field
                return eval_d2
            elif (fty.is_VectorField(otyp1)):
                return eval_vec_d2
            elif (fty.is_Matrix(otyp1)):
                return eval_mat_d2
            elif(fty.is_Ten3(otyp1)):
                return eval_ten3_d2
            else:
                raise "error"+otyp1.name
        return iter_d2(get_k(), pos, ortn)
    elif (dim==3):
        def get_k():
            if (fty.is_ScalarField(otyp1)): # output is a scalar field
                return eval_d3
            elif (fty.is_VectorField(otyp1)):
                return eval_vec_d3
            elif (fty.is_Matrix(otyp1)):
                return eval_mat_d3
            elif(fty.is_Ten3(otyp1)):
                return eval_ten3_d3
            else:
                raise "error"+otyp1.name
        return iter_d3(get_k(), pos, ortn)
    else:
        raise "unsupported dimension"

# ***************************  main  ***************************

# operators with scalar field and vector field
def eval(app, pos):
    #print "evalname",app.name
    #apply.toStr(app)
    (otyp1, ortn) = sort(app) #apply operations to expressions
    #print "ortn",ortn
    rtn = probeField(otyp1, pos, ortn) #evaluate expression at positions
    #print "rtn", rtn
    return rtn

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0