Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/synth/d2/test_writeDiderot.py
ViewVC logotype

View of /branches/ein16/synth/d2/test_writeDiderot.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4243 - (download) (as text) (annotate)
Thu Jul 21 17:54:34 2016 UTC (3 years, 2 months ago) by cchiw
File size: 11889 byte(s)
added modulate
# -*- coding: utf-8 -*-

from __future__ import unicode_literals

import codecs
import sys
import os
import re


from obj_ex import *
from obj_apply import *
from obj_ty import *
from obj_operator import *
from obj_field import *

template="template/foo.ddro"     # template

#strings in diderot template
foo_in="foo_in"
foo_outTen="foo_outTen"
foo_op ="foo_op"
foo_probe ="foo_probe"
foo_length="foo_length"
#otherwise variables in diderot program
foo_out="out"

def fieldName(i):
    return "F"+str(i)
opfieldname1="G"

# type of field
def fieldShape(f, fldty):
    foo = fty.toDiderot(fldty)
    f.write(foo.encode('utf8'))

#field input line
#f: file to write to
#k:continuity
#itypes: types for input field
#inputlist: name for input data
def inShape(f,appC):
    #app = apply.get_root_app(appC)
    i=0
    #exps =  apply.get_exps(app)
    exps = apply.get_all_Fields(appC)
    for exp in exps:
        #print "current fld",field.toStr(exp)
        dim = field.get_dim(exp)
        #print ("exp:", field.toStr(exp))
        fieldShape(f, exp.fldty)
        if (field.get_isField(exp)):
            foo= fieldName(i)+" = "+exp.krn+u'⊛'+"  image(\""+exp.inputfile+".nrrd\");\n"
            f.write(foo.encode('utf8'))
        else: #tensor type
            foo= fieldName(i)+" = "+str(field.get_data(exp))+";\n"
            f.write(foo.encode('utf8'))
        i+=1


#instaniate output tensor
def outLine(f, type):
    otype = fty.get_tensorType(type)
    foo="\toutput "
    print " otype", otype
    print " otype", otype.name,otype.id," ty_scalarT",ty_scalarT.name,"id",ty_scalarT.id
    if (ty_scalarT==otype):
        print "here"
        foo+="tensor []"+foo_out+" = 0.0"
    elif(ty_vec1T==otype):
        foo+= "tensor [1] "+foo_out+" = [0.0]"
    elif(ty_vec2T==otype):
        foo+= "tensor [2] "+foo_out+" = [0.0, 0.0]"
    elif(ty_vec3T==otype):
        foo+= "tensor [3] "+foo_out+" = [0.0, 0.0, 0.0]"
    elif(ty_mat2x2T==otype):
        foo+= "tensor [2,2] "+foo_out+" = [[0.0, 0.0],[0.0, 0.0]]"
    elif(ty_mat3x3T==otype):
        foo+= "tensor [3,3] "+foo_out+" = [[0.0, 0.0, 0.0],[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]]"
    elif(ty_ten2x2x2T==otype):
        foo+= "tensor [2,2,2] "+foo_out+" = [[[0.0, 0.0],[0.0, 0.0]],[[0.0, 0.0],[0.0, 0.0]]]"
    elif(ty_ten3x3x3T==otype):
        foo+= "tensor [3,3,3] "+foo_out+" = [[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]],[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]],[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]]]"
    else:
        raise Exception("unsupported input type",otype.name)
    foo+= ";\n"
    f.write(foo.encode('utf8'))

#write operation between fields
def gotop1(f,app):
    op1 = app.opr
    arity = app.opr.arity
    itypes = apply.get_types(app)
    foo = opfieldname1+" = "
    if (arity==1):
        (ft, typ2) = applyUnaryOp(op1,itypes)
        if(ft==false):
            raise "unsupported application of unary operator"
        fieldShape(f, typ2)
        if(op1.placement == place_left):
            foo += (op1.symb)+"("+fieldName(0)+")"
        elif(op1.placement == place_split):
            foo += (op1.symb)+"("+fieldName(0)+")"+(op1.symb)
        elif(op1.placement == place_right):
            foo += "(("+fieldName(0)+")"+(op1.symb)+")"
        else:
            raise Exception ("unhandled placement")
    elif(arity==2):
        (ft,typ2) = applyBinaryOp(op1,itypes)
        if(ft==false):
            raise "unsupported application of binary operator"
        fieldShape(f, typ2)
        if(op1.placement == place_left):
            foo += (op1.symb)+"("+fieldName(0)+","+fieldName(1)+")"
        elif(op1.placement == place_split):
            foo += "("+fieldName(0)+(op1.symb)+fieldName(1)+")"
    else:
        raise Exception("unsupported arity")
    foo += ";\n"
    f.write(foo.encode('utf8'))

# print unary operator
def prntUnary(opr, e):
    if(opr.placement == place_split):
        k = (opr.symb)+"("+e+")"+(opr.symb)
        return k
    elif(opr.placement == place_left):
        k = (opr.symb)+"("+e+")"
        return k
    elif(opr.placement == place_right):
        k = "("+e+")"+(opr.symb)
        return k
    else:
        raise Exception ("unsupported placement")

#write operation between fields
def gotop2(f,app_outer):
    opr_outer=app_outer.opr
    app_inner=apply.get_unary(app_outer)
    opr_inner=app_inner.opr
    arity_inner= opr_inner.arity
    arity_outer= opr_outer.arity
    s_inner=apply.toStr(app_inner,0)
    s_outer= apply.toStr(app_outer,0)
    #print s_outer
    itypes_inner = apply.get_types(app_inner)
    foo = opfieldname1+" = "
    #checks inside arity
    #assumes outside arity is one
    if (arity_inner==1):
        print "A"
        (_, typ_inner) = applyUnaryOp(opr_inner,itypes_inner)
        if(arity_outer==1):
            print"C"
            k = prntUnary(opr_inner, fieldName(0)) # inner placement
            (_, typ_outer) = applyUnaryOp(opr_outer,[typ_inner])
            if(opr_outer.placement==place_right):
                # multiple lines
                print "F"
                fieldShape(f, typ_inner)
                foo0 = fieldName(2)+" = "+k+";\n"
                f.write(foo0.encode('utf8'))
                fieldShape(f, typ_outer)
                foo += fieldName(2)+(opr_outer.symb)
            else:
                print "E"
                #single line
                fieldShape(f, typ_outer)
                foo += prntUnary(opr_outer, k)
        elif(arity_outer==2):
            print "d"
            #assumes second arg is a field
            (f_outer,g_outer) = apply.get_binary(app_outer)
            g_ty = g_outer.fldty
            (_, typ_outer) = applyBinaryOp(opr_outer,[typ_inner,  g_ty])
            fieldShape(f, typ_outer)
            foo += fieldName(0)+(opr_outer.symb)+prntUnary(opr_inner, fieldName(1))
        else:
            raise Exception("unsupported arity")
    elif(arity_inner==2):

        (_, typ_inner)= applyBinaryOp(opr_inner,itypes_inner)
        if(arity_outer==1):

            (_, typ_outer) = applyUnaryOp(opr_outer,[typ_inner])
            if(opr_outer.placement==place_right):
                # multiple lines

                fieldShape(f, typ_inner)
                foo0 = fieldName(2)+" = ("+fieldName(0)+(opr_inner.symb)+fieldName(1)+");\n"
                f.write(foo0.encode('utf8'))
                fieldShape(f, typ_outer)
                foo += fieldName(2)+(opr_outer.symb)
            elif(opr_outer.placement==place_left):
                # multiple lines

                fieldShape(f, typ_inner)
                foo0 = fieldName(2)+" ="+(opr_inner.symb)+" ("+fieldName(0)+","+fieldName(1)+");\n"
                f.write(foo0.encode('utf8'))
                fieldShape(f, typ_outer)
                foo += fieldName(2)+(opr_outer.symb)
        
            else:
                print "H"
                # one line
                fieldShape(f, typ_outer)
                tmp = fieldName(0)+(opr_inner.symb)+fieldName(1)
                foo+= prntUnary(opr_outer, tmp)

        elif(arity_outer==2):
            print "Z"
            (f_outer,g_outer) = apply.get_binary(app_outer)
            #print "f-outer", f_outer
            #print "g-outer", g_outer
            g_ty = g_outer.fldty
            (_, typ_outer) = applyBinaryOp(opr_outer,[typ_inner,  g_ty])
            fieldShape(f, typ_outer)
            foo += "("+fieldName(0)+(opr_inner.symb)+fieldName(1)+")"+(opr_outer.symb)+fieldName(2)
        else:
            raise Exception("unsupported arity")
    else:
        raise Exception("unsupported arity")
    foo += ";\n"
    f.write(foo.encode('utf8'))

def replaceOp(f,app):
    # one or two?
    if(app.isrootlhs):
        return gotop1(f,app)
    else: #twice embedded
        return gotop2(f,app)

def indexG(f, pos, oty):
    i=0
    foo=""
    #pos.insert(0,pos[0])
    p=str(pos[0])
    foo += "\t\tif(i=="+str(i)+"){\n"
    foo += "\t\t\t"+foo_out+" = "+opfieldname1+"("+p+");\n"
    foo += "\t\t}\n"
    i=i+1
    for p1 in pos:
        p=str(p1)
        foo += "\t\telse if(i=="+str(i)+"){\n"
        foo += "\t\t\t"+foo_out+" = "+opfieldname1+"("+p+");\n"
        if(fty.is_Matrix(oty)):
            shape = fty.get_shape(oty)
                #   if(shape==[2,2]):
                # foo += "\t\t\t"+foo_out+" = [["+foo_out+"[0,0],"+foo_out+"[0,1]],["+foo_out+"[1,0],"+foo_out+"[1,1]]];\n"
                #if(shape==[3,3]):
                #foo += "\t\t\t"+foo_out+" = [["+foo_out+"[0,0],"+foo_out+"[0,1],"+foo_out+"[1,1]],["+foo_out+"[1,0],"+foo_out+"[1,1],"+foo_out+"[1,2]], ["+foo_out+"[2,0],"+foo_out+"[2,1],"+foo_out+"[2,1]]  ];\n"
        foo += "\t\t}\n"
        i=i+1
    f.write(foo.encode('utf8'))

def setLength(f, n):
    foo="int length ="+str(n)+";"
    f.write(foo.encode('utf8'))

#itype: shape of fields
#otype: output tensor
#op1: unary operation involved
def readDiderot(p_out,app,pos):
    #read diderot template
    ftemplate = open(template, 'r')
    ftemplate.readline()
    #write diderot program
    f = open(p_out+".diderot", 'w+')
    for line in ftemplate:
        # is it initial field line?
        a0 = re.search(foo_in, line)
        if a0:
            #replace field input line

            inShape(f,app)
            continue
        # is it output tensor line?
        b0 = re.search(foo_outTen, line)
        if b0:
            outLine(f, app.oty)
            continue
        # operation on field
        c0 = re.search(foo_op,line)
        if c0:
            replaceOp(f, app)
            continue
        # index field at position
        d0 = re.search(foo_probe,line)
        if d0:
            indexG(f, pos, app.oty)
            continue
        # length number of positions
        e0=re.search(foo_length, line)
        if e0:
            setLength(f,len(pos))
            continue
        # nothing is being replaced
        else:
            f.write(line)

    ftemplate.close()
    f.close()

def runDiderot(p_out, app, pos, output, runtimepath, isVis):
    shape = app.oty.shape
    # print "shape",shape
    product = 1
    for x in shape:
        product *= x
    #  print "product", product
    if(isVis):
        m2 = len(pos)+1
        # print "m2",m2
        w_shape=" -s "+str(product)+" "+str(m2)
        #print " w_shape=",  w_shape
        os.system("./"+p_out+" -o tmp.nrrd")
        os.system("unu reshape -i tmp.nrrd "+w_shape+" | unu save -f text -o "+p_out+".txt")
    else:
        os.system("./"+p_out)

def writeDiderot(p_out, app, pos, output, runtimepath, isVis):
    readDiderot(p_out,app,pos)
    # compile and run diderot program
    # print runtimepath
    os.system(" echo \"pout "+p_out+"\"")
    os.system(" echo \"create diderot program "+p_out+".diderot \"")
    os.system("cp "+p_out+".diderot "+output+".diderot")
    name ="tmp.nrrd"
    os.system(" rm "+name)
    if(os.path.exists(name)):
        print "tmp.nrrd:exists-pre-star",name
    name = p_out+".txt"
    os.system(" rm "+name)
    if(os.path.exists(name)):
        print "p_out:exists-pre-star",name
    name = output+".txt"
    os.system(" rm "+name)
    if(os.path.exists(name)):
        print "output:exists-pre-star",name
    name =p_out
    os.system(" rm "+name)
    if(os.path.exists(name)):
        print "./p_out:exists-pre-star",name
    


    os.system(runtimepath+" "+p_out+".diderot")
    runDiderot(p_out, app, pos, output, runtimepath,isVis)
    name = p_out+".txt"
    if(os.path.exists(name)):
        print "p_out:exists-post-star",name
        os.system("cp "+p_out+".txt "+output+".txt")
    name = output+".txt"
    if(os.path.exists(name)):
        print "output:exists-post-star",name

    os.system("cp "+p_out+".diderot "+output+".diderot")
    os.system("cp "+p_out+".c "+output+".c")
    os.system(" echo \"copied diderot program to "+output+".diderot \"")
    os.system(" echo \"copied o-out "+output+".txt \"")
# os.system("rm "+p_out+"*")

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0