Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/synth/d2/test_writeDiderot.py
ViewVC logotype

View of /branches/ein16/synth/d2/test_writeDiderot.py

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3946 - (download) (as text) (annotate)
Sat Jun 11 00:39:19 2016 UTC (2 years, 11 months ago) by cchiw
File size: 8066 byte(s)
 added outer prod
# -*- coding: utf-8 -*-

from __future__ import unicode_literals

import codecs
import sys
import os
import re


from obj_ex import *
from obj_apply import *
from obj_ty import *
from obj_operator import *
from obj_field import *

template="template/foo.ddro"     # template

#strings in diderot template
foo_in="foo_in"
foo_outTen="foo_outTen"
foo_op ="foo_op"
foo_probe ="foo_probe"
foo_length="foo_length"
#otherwise variables in diderot program
foo_out="out"

def fieldName(i):
    return "F"+str(i)
opfieldname1="G"

# type of field
def fieldShape(f, fldty):
    foo = fty.toDiderot(fldty)
    f.write(foo.encode('utf8'))

#field input line
#f: file to write to
#k:continuity
#itypes: types for input field
#inputlist: name for input data
def inShape(f,appC):
    #app = apply.get_root_app(appC)
    i=0
    #exps =  apply.get_exps(app)
    exps = apply.get_all_Fields(appC)
    for exp in exps:
        #print "current fld",field.toStr(exp)
        dim = field.get_dim(exp)
        #print ("exp:", field.toStr(exp))
        fieldShape(f, exp.fldty)
        if (field.get_isField(exp)):
            foo= fieldName(i)+" = "+exp.krn+u'⊛'+"  image(\""+exp.inputfile+".nrrd\");\n"
            f.write(foo.encode('utf8'))
        else: #tensor type
            foo= fieldName(i)+" = "+str(field.get_data(exp))+";\n"
            f.write(foo.encode('utf8'))
        i+=1


#instaniate output tensor
def outLine(f, type):
    otype = fty.get_tensorType(type)
    foo="\toutput "
    if (ty_scalarT==otype):
        foo+="tensor []"+foo_out+" = 0.0"
    elif(ty_vec2T==otype):
        foo+= "tensor [2] "+foo_out+" = [0.0, 0.0]"
    elif(ty_vec3T==otype):
        foo+= "tensor [3] "+foo_out+" = [0.0, 0.0, 0.0]"
    elif(ty_mat2x2T==otype):
        foo+= "tensor [2,2] "+foo_out+" = [[0.0, 0.0],[0.0, 0.0]]"
    elif(ty_mat3x3T==otype):
        foo+= "tensor [3,3] "+foo_out+" = [[0.0, 0.0, 0.0],[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]]"
    elif(ty_ten2x2x2T==otype):
        foo+= "tensor [2,2,2] "+foo_out+" = [[[0.0, 0.0],[0.0, 0.0]],[[0.0, 0.0],[0.0, 0.0]]]"
    elif(ty_ten3x3x3T==otype):
        foo+= "tensor [3,3,3] "+foo_out+" = [[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]],[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]],[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]]]"
    
    else:
        raise Exception("unsupported input type",otype)
    foo+= ";\n"
    f.write(foo.encode('utf8'))

#write operation between fields
def gotop1(f,app):
    op1=app.opr
    arity=app.opr.arity
    itypes=apply.get_types(app)
    foo = opfieldname1+" = "
    if (arity==1):
        (ft, typ2) = applyUnaryOp(op1,itypes)
        if(ft==false):
            raise "unsupported application of unary operator"
        fieldShape(f, typ2)
        foo += (op1.symb)+fieldName(0)
    elif(arity==2):
        (ft,typ2) = applyBinaryOp(op1,itypes)
        if(ft==false):
            raise "unsupported application of binary operator"
        fieldShape(f, typ2)
        foo += fieldName(0)+(op1.symb)+fieldName(1)
    else:
        raise Exception("unsupported arity")
    foo += ";\n"
    f.write(foo.encode('utf8'))

#write operation between fields
def gotop2(f,app_outer):
    opr_outer=app_outer.opr
    app_inner=apply.get_unary(app_outer)
    opr_inner=app_inner.opr
    arity_inner= opr_inner.arity
    arity_outer= opr_outer.arity
    s_inner=apply.toStr(app_inner,0)
    s_outer= apply.toStr(app_outer,0)
    #print s_outer

    itypes_inner = apply.get_types(app_inner)
    foo = opfieldname1+" = "
    #checks inside arity
    #assumes outside arity is one
    if (arity_inner==1):
        (_, typ_inner) = applyUnaryOp(opr_inner,itypes_inner)
        if(arity_outer==1):
            (_, typ_outer) = applyUnaryOp(opr_outer,[typ_inner])
            fieldShape(f, typ_outer)
            foo += (opr_outer.symb)+"("+(opr_inner.symb)+fieldName(0)+")"
        elif(arity_outer==2):
            #assumes second arg is a field
            (f_outer,g_outer) = apply.get_binary(app_outer)
            #print "f-outer", f_outer
            #print "g-outer", g_outer
            g_ty = g_outer.fldty
            (_, typ_outer) = applyBinaryOp(opr_outer,[typ_inner,  g_ty])
            fieldShape(f, typ_outer)
            foo += fieldName(0)+(opr_outer.symb)+"("+(opr_inner.symb)+fieldName(1)+")"
        else:
            raise Exception("unsupported arity")
    elif(arity_inner==2):
        (_, typ_inner)= applyBinaryOp(opr_inner,itypes_inner)
        if(arity_outer==1):
            (_, typ_outer) = applyUnaryOp(opr_outer,[typ_inner])
            fieldShape(f, typ_outer)
            foo += (opr_outer.symb)+"("+fieldName(0)+(opr_inner.symb)+fieldName(1)+")"
        elif(arity_outer==2):
            (f_outer,g_outer) = apply.get_binary(app_outer)
            #print "f-outer", f_outer
            #print "g-outer", g_outer
            g_ty = g_outer.fldty
            (_, typ_outer) = applyBinaryOp(opr_outer,[typ_inner,  g_ty])
            fieldShape(f, typ_outer)
            foo += "("+fieldName(0)+(opr_inner.symb)+fieldName(1)+")"+(opr_outer.symb)+fieldName(2)
        else:
            raise Exception("unsupported arity")
    else:
        raise Exception("unsupported arity")
    foo += ";\n"
    f.write(foo.encode('utf8'))

def replaceOp(f,app):
    # one or two?
    if(app.isrootlhs):
        return gotop1(f,app)
    else: #twice embedded
        return gotop2(f,app)

def indexG(f, pos):
    i=0
    foo=""
    #pos.insert(0,pos[0])
    p=str(pos[0])
    foo += "\t\tif(i=="+str(i)+"){\n"
    foo += "\t\t\t"+foo_out+" = "+opfieldname1+"("+p+");\n"
    foo += "\t\t}\n"
    i=i+1
    for p1 in pos:
        p=str(p1)
        foo += "\t\telse if(i=="+str(i)+"){\n"
        foo += "\t\t\t"+foo_out+" = "+opfieldname1+"("+p+");\n"
        foo += "\t\t}\n"
        i=i+1
    f.write(foo.encode('utf8'))

def setLength(f, n):
    foo="int length ="+str(n)+";"
    f.write(foo.encode('utf8'))

#itype: shape of fields
#otype: output tensor
#op1: unary operation involved
def readDiderot(p_out,app,pos):
    #read diderot template
    ftemplate = open(template, 'r')
    ftemplate.readline()
    #write diderot program
    f = open(p_out+".diderot", 'w+')
    for line in ftemplate:
        # is it initial field line?
        a0 = re.search(foo_in, line)
        if a0:
            #replace field input line

            inShape(f,app)
            continue
        # is it output tensor line?
        b0 = re.search(foo_outTen, line)
        if b0:
            outLine(f, app.oty)
            continue
        # operation on field
        c0 = re.search(foo_op,line)
        if c0:
            replaceOp(f, app)
            continue
        # index field at position
        d0 = re.search(foo_probe,line)
        if d0:
            indexG(f, pos)
            continue
        # length number of positions
        e0=re.search(foo_length, line)
        if e0:
            setLength(f,len(pos))
            continue
        # nothing is being replaced
        else:
            f.write(line)

    ftemplate.close()
    f.close()

def runDiderot(p_out, app, pos, output, runtimepath, isVis):
    if(isVis):
        outSize = len(pos)
        w_shape=" -s 1 "+str(outSize * outSize)
    #   os.system("./"+p_out+"| unu save -f nrrd -o tmp.nrrd")
    #os.system("unu reshape -i tmp.nrrd "+w_shape+" | unu save -f text -o "+p_out+".txt")
    else:
        os.system("./"+p_out)

def writeDiderot(p_out, app, pos, output, runtimepath, isVis):
    readDiderot(p_out,app,pos)
    # compile and run diderot program
    # print runtimepath
    os.system(" echo \"pout "+p_out+"\"")
    os.system(" echo \"create diderot program "+p_out+".diderot \"")
    os.system(runtimepath+" "+p_out+".diderot")
    runDiderot(p_out, app, pos, output, runtimepath,isVis)
    os.system("cp "+p_out+".txt "+output+".txt")
    os.system("cp "+p_out+".diderot "+output+".diderot")
    os.system("cp "+p_out+".c "+output+".c")
    os.system(" echo \"copied diderot program to "+output+".diderot \"")
    os.system(" echo \"copied o-out "+output+".txt \"")
    os.system("rm "+p_out+"*")

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0