Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/synth/d2/test_writeDiderot.py
ViewVC logotype

View of /branches/ein16/synth/d2/test_writeDiderot.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4456 - (download) (as text) (annotate)
Thu Aug 25 04:03:27 2016 UTC (2 years, 11 months ago) by cchiw
File size: 14930 byte(s)
added missing cubic
# -*- coding: utf-8 -*-

from __future__ import unicode_literals

import codecs
import sys
import os
import re


from obj_ex import *
from obj_apply import *
from obj_ty import *
from obj_operator import *
from obj_field import *

template="template/foo.ddro"     # template

#strings in diderot template
foo_in="foo_in"
foo_outTen="foo_outTen"
foo_op ="foo_op"
foo_probe ="foo_probe"
foo_length="foo_length"
#otherwise variables in diderot program
foo_out="out"
foo_pos="pos"

def fieldName(i):
    return "F"+str(i)
opfieldname1="G"

# type of field
def fieldShape(f, fldty):
    #print "fldty: ",fldty
    foo = fty.toDiderot(fldty)
    f.write(foo.encode('utf8'))

#field input line
#f: file to write to
#k:continuity
#itypes: types for input field
#inputlist: name for input data
def inShape(f,appC):
    #app = apply.get_root_app(appC)
    i=0
    #exps =  apply.get_exps(app)
    exps = apply.get_all_Fields(appC)
    for exp in exps:
        #print "current fld",field.toStr(exp)
        dim = field.get_dim(exp)
        #print ("exp:", field.toStr(exp))
        fieldShape(f, exp.fldty)
        if (field.get_isField(exp)):
            foo= fieldName(i)+" = "+exp.krn+u'⊛'+"  image(\""+exp.inputfile+".nrrd\");\n"
            f.write(foo.encode('utf8'))
        else: #tensor type
            foo= fieldName(i)+" = "+str(field.get_data(exp))+";\n"
            f.write(foo.encode('utf8'))
        i+=1


#instaniate output tensor
def outLineTF(type,s):
    otype = fty.get_tensorType(type)
    v2 =  " ["+s+","+s+"]"
    v3 = " ["+s+","+s+","+s+"]"
    v4 = " ["+s+","+s+","+s+","+s+"]"
    m22 = "["+v2+","+v2+"]"
    m33 = "["+v3+","+v3+","+v3+"]"
    m44 = "["+v4+","+v4+","+v4+","+v4+"]"
    m23 = "["+v3+","+v3+"]"
    m32 = "["+v2+","+v2+","+v2+"]"
    m222 = "["+m22+","+m22+"]"
    m333 = "["+m33+","+m33+","+m33+"]"
    if (ty_scalarT==otype):
        return s
    elif(ty_vec1T==otype):
        return " ["+s+"]"
    elif(ty_vec2T==otype):
        return v2
    elif(ty_vec3T==otype):
        return v3
    elif(ty_vec4T==otype):
        return v4
    elif(ty_mat2x2T==otype):
        return  m22
    elif(ty_mat3x3T==otype):
        return m33
    elif(ty_mat4x4T==otype):
        return m44
    elif(ty_mat2x3T==otype):
        return m23
    elif(ty_mat3x2T==otype):
        return m32
    elif(ty_ten2x2x2T==otype):
        return  m222
    elif(ty_ten3x3x3T==otype):
        return m333
    else:
        raise Exception("unsupported input type",otype.name)


#instaniate output tensor
def outLineF(f, type):
    otype = fty.get_tensorType(type)
    foo="\toutput "
    #print " otype", otype
    #print " otype", otype.name,otype.id," ty_scalarT",ty_scalarT.name,"id",ty_scalarT.id
    foo+="tensor "+str(otype.shape)+" "+foo_out+" = "+outLineTF(type,"0.0")+";\n\t"
    f.write(foo.encode('utf8'))




# print unary operator
def prntUnary(opr, e):
    if(opr.placement == place_split):
        (symb_lhs, symb_rhs)= opr.symb
        k = symb_lhs+"("+e+")"+ symb_rhs
        return k
    elif(opr.placement == place_left):
        k = (opr.symb)+"("+e+")"
        return k
    elif(opr.placement == place_right):
        k = "("+e+")"+(opr.symb)
        return k
    else:
        raise Exception ("unsupported placement")

def prntBinary(opr, e1, e2):
    #print "prntBinary",opr.name
    #print "symb,",opr.symb
    if(opr.placement == place_left):
        return  (opr.symb)+"(("+e1+"),("+e2+"))"
    elif(opr.placement == place_middle):
        return "("+e1+")"+(opr.symb)+"("+e2+")"
    elif(opr.placement == place_right):
        return  "(("+e1+"),("+e2+"))"+(opr.symb)
    else:
        raise Exception ("unhandled placement")

# writing to a line
def write_shape(pre, f, typ, lhs, rhs):
    f.write(pre.encode('utf8'))
    # type of resulting field
    fieldShape(f, typ)
    # set expression equal to output field
    foo = lhs+" = "+rhs+";\n"
    # Write to file
    f.write(foo.encode('utf8'))


#write operation between fields
#get output var name-lhs
def gotop1(f,app, pre, lhs):
    op1 = app.opr
    arity = op1.arity
    oty = app.oty
    # names of lhs variables
    f0 = fieldName(0)
    f1 = fieldName(1)
    def rtn_rhs():
        if (arity==1):
            return prntUnary(op1, f0)
        elif(arity==2):
            return prntBinary(op1, f0, f1)
        else:
            raise Exception("unsupported arity")
    rhs = rtn_rhs()
    write_shape(pre, f, oty, lhs, rhs)
    return


#write operation between fields
def gotop2(f, app_outer, pre, lhs):
    opr_outer=app_outer.opr
    app_inner=apply.get_unary(app_outer)
    opr_inner=app_inner.opr
    # arity of each app
    arity_inner= opr_inner.arity
    arity_outer= opr_outer.arity
    # type of output of each app
    typ_outer = app_outer.oty
    typ_inner = app_inner.oty
    
    # names of lhs variables
    f0 = fieldName(0)
    f1 = fieldName(1)
    f2 = fieldName(2)
    #s_inner = apply.toStr(app_inner,0)
    #s_outer = apply.toStr(app_outer,0)

    # writing to a line
    def write_inner(e):
        write_shape("\t", f, typ_inner, f2, e)
    def write_lastline(e):
        write_shape(pre, f, typ_outer, lhs, e)

    if (arity_inner==1):
        # inner placement
        e1 = prntUnary(opr_inner, f0)
        if(arity_outer==1):
            if(opr_outer.placement==place_right):
                # multiple lines
                write_inner(e1)
                line2 = f2+(opr_outer.symb)
                write_lastline(line2)
                return
            else:
                #single line
                line1= prntUnary(opr_outer, e1)
                write_lastline(line1)
                return
        elif(arity_outer==2):
            #assumes second arg is a field
            line1 = prntBinary(opr_outer, e1, f1)
            write_lastline(line1)
            return
        else:
            raise Exception("unsupported arity")
    elif(arity_inner==2):
        e1 = prntBinary(opr_inner, f0, f1)
        if(arity_outer==1):
            write_inner(e1)
            line2 = prntUnary(opr_outer, f2)
            write_lastline(line2)
            return
        elif(arity_outer==2):
            line1 = prntBinary(opr_outer, e1, f2)
            write_lastline(line1)
            return
        else:
            raise Exception("unsupported arity")
    else:
        raise Exception("unsupported arity")


                          

def replaceOp(f,app):
    if (fty.is_Field(app.oty)):
        #field type
        # one or two operators?
        if(app.isrootlhs):
            #print "going to 1"
            return gotop1(f,app, "",opfieldname1)
        else: #twice embedded
            #print "goint to 2"
            return gotop2(f,app, "", opfieldname1)
    else:
        return


# probes field at variable position
def isProbe(exp, fld):
    if(fty.is_Field(fld)):
        return "("+exp+")(pos)"
    else:
        return "("+exp+")"

# get restraint on argument to operators
# i.e. sqrt(x), so x must be positive
def getCond(app, set):
    oty = app.oty
    app_outer = app
    opr_outer=app_outer.opr
    app_inner=apply.get_unary(app_outer)
    opr_inner=app_inner.opr
    exp0 = "F0"
    exp1 = "F1"
    exp2 = "F2"
    foo = ""



    def limit(e, opr):
        if(opr.limit==limit_small):
            return "|("+e+")|>0.01"
        elif(opr.limit==limit_det):
            return "| det("+e+") |>0.01"
        elif(opr.limit ==limit_trig):
            return "| 0.1*("+e+")|<=1.0"
        elif(opr.limit == limit_nonzero):
            return "|("+ e + ")| > 0.0 "
        else:
            raise Exception(opr.name,"unknown limit:",opr.limit)
    def ifelse(cond):
        k= "\n\tif("+cond+"){\n\t"+set+"\t}\n\t"
        k+="\n\telse{\n\t\t"+foo_out+" = "+outLineTF(oty, "7.2")+";\n\t}"
        return k
    if((opr_inner.limit== None) and (opr_outer.limit== None)):
        # there is no limit
        foo += set
    elif(not (opr_inner.limit== None)):
        # there is a limit on inner operator
        def get_args():
            if(opr_inner.limit==limit_small):
                #if(opr_inner.id==op_division.id):
                return isProbe(exp1,app_inner.rhs.fldty)
            else:
                return isProbe(exp0,app_inner.oty)
        texp = get_args()
        cond = limit(texp, opr_inner)
        foo+= ifelse(cond)
    elif(not (opr_outer.limit== None)):
        # check outer operator
        def get_exp1(opr):
            if(opr.arity==2):
                return exp2
            else:
                return exp1
        def get_exp2(opr):
            if(opr.arity==2):
                return prntBinary(opr, exp0, exp1)
            else:
                return prntUnary(opr, exp0)
        def get_args():
            if(opr_outer.limit==limit_small):# op_division.id):
                pexp = get_exp1(opr_inner)
                return isProbe(pexp, app_outer.rhs.fldty)
            else:
                pexp = get_exp2(opr_inner)
                return isProbe(pexp, app_inner.oty)
        texp = get_args()
        cond = limit(texp, opr_outer)
        foo+= ifelse(cond)
    else:
        foo = set
    return foo

##ff: field that is being probed or tensor variable inside if statement
def check_conditional(f, ff, app):
    # probes field at variable position
    oty = app.oty
    set =  "\t"+foo_out+" = "+isProbe(ff,oty)+";\n"
    foo = ""
    if(app.isrootlhs):
        foo = set
    else: #twice embedded
        # there might be a conditional restraint
        foo= getCond(app, set)
    f.write(foo.encode('utf8'))
    return


# set positions variables
# index field at position
def index_field_at_positions(f, pos, app):
    oty = app.oty
    i=0
    foo="\t\t"
    dim=oty.dim
    if(dim==1):
        foo+="real  "+foo_pos+"=0;\n"
    elif(dim==2):
        foo+="tensor [2] "+foo_pos+"=[0,0];\n"
    elif(dim==3):
        foo+="tensor [3] "+foo_pos+"=[0,0,0];\n"
    # does first position
    #pos.insert(0,pos[0])
    p=str(pos[0])
    foo += "\t\tif(i=="+str(i)+"){\n"
    # just sets poitions
    foo += "\t\t\t"+foo_pos+" = "+"("+p+");\n"
    # probes field at position
    # foo += "\t\t\t"+foo_out+" = "+opfieldname1+"("+p+");\n"
    foo += "\t\t}\n"
    i=i+1
    for p1 in pos:
        p=str(p1)
        foo += "\t\telse if(i=="+str(i)+"){\n"
        # just sets poitions
        foo += "\t\t\t"+foo_pos+" = "+"("+p+");\n"
        # probes field at current position
        #foo += "\t\t\t"+foo_out+" = "+opfieldname1+"("+p+");\n"
        foo += "\t\t}\n"
        i=i+1
    f.write(foo.encode('utf8'))


#witten inside update method
def update_method(f, pos, app):
    oty = app.oty
    if(fty.is_Field(oty)):
        # index field at random positions
        index_field_at_positions(f, pos, app)
        check_conditional(f, opfieldname1, app)
    else:
        # get conditional for tensor argument
        check_conditional(f,  foo_out, app)


def outLine(f, app):
    type  = app.oty
    if (fty.is_Field(type)):
        outLineF(f, type)
    else:
        if(app.isrootlhs):
            gotop1(f,app, "\toutput ", foo_out)
        else: #twice embedded
            gotop2(f,app, "\toutput ", foo_out)
    return




def setLength(f, n):
    foo="int length ="+str(n)+";"
    f.write(foo.encode('utf8'))

#itype: shape of fields
#otype: output tensor
#op1: unary operation involved
def readDiderot(p_out,app,pos):
    #read diderot template
    ftemplate = open(template, 'r')
    ftemplate.readline()
    #write diderot program
    f = open(p_out+".diderot", 'w+')
    for line in ftemplate:
        # is it initial field line?
        a0 = re.search(foo_in, line)
        if a0:
            #replace field input line
            #print "inshape"
            inShape(f,app)
            continue
        # is it output tensor line?
        b0 = re.search(foo_outTen, line)
        if b0:
            #print "outline"
            outLine(f, app)
            continue
        # operation on field
        c0 = re.search(foo_op,line)
        if c0:
            #print "replace op"
            replaceOp(f, app)
            continue
        # index field at position
        d0 = re.search(foo_probe,line)
        if d0:
            #print "update_method"
            update_method(f, pos, app)
            continue
        # length number of positions
        e0=re.search(foo_length, line)
        if e0:
            #print "Set length"
            setLength(f,len(pos))
            continue
        # nothing is being replaced
        else:
            f.write(line)

    ftemplate.close()
    f.close()

def runDiderot(p_out, app, pos, output, runtimepath, isVis):
    shape = app.oty.shape
    # print "shape",shape
    product = 1
    for x in shape:
        product *= x
    #  print "product", product
    if(isVis):
        #print "is vis"
        m2 = len(pos)+1
        # print "m2",m2
        w_shape=" -s "+str(product)+" "+str(m2)
        #print " w_shape=",  w_shape
        os.system("./"+p_out+" -o tmp.nrrd")
        os.system("unu reshape -i tmp.nrrd "+w_shape+" | unu save -f text -o "+p_out+".txt")
    else:
        #print "not is vis"
        t= p_out+".txt"
        if(os.path.exists(t)):
            print "found txt pre run"
        if(os.path.exists(p_out)):
            print "does ./ exsist?"
        k="./"+p_out
        #print "attempting"+k
        os.system(k)
        if(os.path.exists(t)):
            print "found txt post run"



def writeDiderot(p_out, app, pos, output, runtimepath, isVis):
    print "Read diderot-2"
    print apply.oprToStr(app)
    print "pos",pos
    print "app", app.name
    print "app", app.oty
    readDiderot(p_out,app,pos)
    # compile and run diderot program
    print runtimepath
    os.system(" echo \"pout "+p_out+"\"")
    os.system(" echo \"create diderot program "+p_out+".diderot \"")
    os.system("cp "+p_out+".diderot "+output+".diderot")
    name ="tmp.nrrd"
    os.system(" rm "+name)
    if(os.path.exists(name)):
        print "tmp.nrrd:exists-pre-star",name
    name = p_out+".txt"
    os.system(" rm "+name)
    if(os.path.exists(name)):
        print "p_out:exists-pre-star",name
    name = output+".txt"
    os.system(" rm "+name)
    if(os.path.exists(name)):
        print "output:exists-pre-star",name
    name =p_out
    os.system(" rm "+name)
    if(os.path.exists(name)):
        print "./p_out:exists-pre-star",name
    os.system(runtimepath+" "+p_out+".diderot")
    print "run diderot"
    runDiderot(p_out, app, pos, output, runtimepath,isVis)
    print "post run diderot"
    name = p_out+".txt"
    if(os.path.exists(name)):
        print "p_out:exists-post-star",name
        os.system("cp "+p_out+".txt "+output+".txt")
    name = output+".txt"
    if(os.path.exists(name)):
        print "output:exists-post-star",name
    os.system("cp "+p_out+".diderot "+output+".diderot")
    os.system("cp "+p_out+".c "+output+".c")
    os.system(" echo \"copied diderot program to "+output+".diderot \"")

    os.system("rm "+p_out+"*")

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0