Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2525 - (download) (annotate)
Tue Jan 21 19:14:22 2014 UTC (5 years, 7 months ago) by cchiw
File size: 10974 byte(s)
eintypes->mid-iltypes
(* examples.sml
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)


(*
A couple of different approaches.
One approach is to find all the Probe(Conv). Gerenerate exp for it
Then use Subst function to sub in. That takes care for index matching and

*)

(*This approach creates probe expanded terms, and adds params to the end. *)


structure Expand = struct

    local
   
    structure E = Ein
    structure mk= mkOperators


structure SrcIL = HighIL
structure SrcTy = HighILTypes
structure SrcOp = HighOps
structure SrcSV = SrcIL.StateVar
structure VTbl = SrcIL.Var.Tbl
structure DstIL = MidIL
structure DstTy = MidILTypes
structure DstOp = MidOps
structure DstV = DstIL.Var
structure SrcV = SrcIL.Var
structure P=Printer


datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int
    in


fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
fun assignEin (x, rator, args) = (x, DstIL.EINAPP(rator, args))

fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args))
fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args))

fun insert (key, value) d =fn s =>
        if s = key then SOME value
        else d s
fun lookup k d = d k



fun getRHS x  = (case DstIL.Var.binding x
of DstIL.VB_RHS(DstIL.OP(rator, args)) => (O rator, args)
    | DstIL.VB_RHS(DstIL.VAR x') => getRHS x'
    | DstIL.VB_RHS(DstIL.EINAPP (e,args))=>(E e,args)
    | DstIL.VB_RHS(DstIL.CONS (ty,args))=>(C ty,args)
    | DstIL.VB_NONE=>(S 2,[])
    | vb => raise Fail(concat[
        "expected rhs operator for ", DstIL.Var.toString x,
        "but found ", DstIL.vbToString vb])
    (* end case *))


fun getRHS2 x  = (case SrcIL.Var.binding x
    of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)
    | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS2 x'
    | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)
    | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)
    | SrcIL.VB_NONE=>(S2 2,[])
    | vb => raise Fail(concat[
    "expected rhs operator for ", SrcIL.Var.toString x,
    "but found ", SrcIL.vbToString vb])
    (* end case *))




fun PrintBIND x  = (case DstIL.Var.binding x
    of vb=> print(String.concat[    "\n expected rhs operator for ", DstIL.Var.toString x,
    " but found ", DstIL.vbToString vb,"\n \n"])
(* end case *))

fun PrintBIND2 x  = (case SrcIL.Var.binding x
of vb=> print(String.concat[    "\n expected rhs operator for ", SrcIL.Var.toString x,
" but found ", SrcIL.vbToString vb,"\n \n"])
(* end case *))






(*Create fractional, and integer position vectors*)
fun createArgs(dim,v,posx,pos)=let 

    val translate=DstOp.Translate v
    val transform=DstOp.Transform v
    val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
    val T = DstV.new ("T", DstTy.vecTy dim)          (*translate*)
    val x = DstV.new ("x", DstTy.vecTy dim)
    val f = DstV.new ("f", DstTy.vecTy dim)         (*fractional*)
    val nd = DstV.new ("nd", DstTy.vecTy dim)       (*real position*)
    val n = DstV.new ("n", DstTy.iVecTy dim)        (*interger position*)


    val PosToImgSpace=mk.transform(dim,dim)
    val code=[
            assign(M, transform, []),
            assign(T, translate, []),
            pos,
            assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)
            assign(nd, DstOp.Floor dim, [x]),   (*nd *)
            assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)
            assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
            ]

    in ([f,n],code)
    end


(*Create fractional, and integer position vectors*)
fun createArgs2  (dim,v,posx)=let

    val translate=DstOp.Translate v
    val transform=DstOp.Transform v
    val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
    val T = DstV.new ("T", DstTy.tensorTy [dim,dim])          (*translate*)
    val x = DstV.new ("x", DstTy.vecTy dim)
    val f = DstV.new ("f", DstTy.vecTy dim)         (*fractional*)
    val nd = DstV.new ("nd", DstTy.vecTy dim)       (*real position*)
    val n = DstV.new ("n", DstTy.iVecTy dim)        (*interger position*)
 

    val PosToImgSpace=mk.transform(dim,dim)
    val code=[
        assign(M, transform, []),
        assign(T, translate, []),

        assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)
        assign(nd, DstOp.Floor dim, [x]),   (*nd *)
        assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)
        assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
    ]

    in ([f,n],code)
    end



fun Position(img,t,newposArg,dict,dim,pos1,ppos)=let
    val l=lookup t dict

    in (case l
        of NONE =>let
           

            val (args',code')=createArgs2(dim,img,newposArg)
            val pos2=pos1+1
            val dict'=insert(t,(pos1,pos2)) dict
       
        
                      
            in (pos1,pos2, dict',ppos,args',code')
            end


        | SOME (fid,nid)=>(fid,nid, dict,[],[],[]))
        (*end case*)
    end

(*
createDels=> creates the kronecker deltas for each Kernel
For each dimesnion a, and each index in derivative b create element (a,b)
*)
fun createDels([],_)= []
| createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)


fun replaceH(kvar, place,args)=let
    val l1=List.take(args, place)
    val l2=List.drop(args,place+1)
    in l1@[kvar]@l2 end

(*Get Img, and Kern Args*)
fun getArgs(hid,hArg,V,imgArg,args)=
    case (getRHS2 hArg,getRHS2 imgArg)
        of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let
            val hvar=DstV.new ("KNL", DstTy.KernelTy)
            val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
            val args1=replaceH(hvar, hid,args)
            val args2=replaceH(imgvar, V,args1)
            in
                (Kernel.support h ,img, [assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])],args2)
            end
        | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"
        |  _ => raise Fail "Not a kernel argument"
    


fun expandEinProbe((body,(params,index,args,d,code,change)),origargs,sx)=(case body
    of E.Probe(E.Conv(V,shape,h,deltas),E.Tensor(t,alpha)) =>let

            val m=print "IN EXPAND\n"
            val E.IMG(dim)=List.nth(params,V)

        
            val kArg=List.nth(origargs,h)
            val imgArg=List.nth(origargs,V)
            val (s,img,argcode,args2) =getArgs(h,kArg,V,imgArg,args)
            
            val newposArg=List.nth(args, t)

            val ppos=[E.TEN(1,[dim]),E.TEN(3,[dim])]
            val (fid,nid,d',params',args',code')=Position(img,t,newposArg,d,dim,(length params),ppos)
            val shift=(length index)
    
           val z=print(String.concat["\n SHIFt SET To",Int.toString(shift),"SX IS ", Int.toString(sx)])

            (*sumIndex creating summaiton Index for body*)
            fun sumIndex(0)=[]
                |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]

            (*createKRN Image field and kernels *)
            fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
                | createKRN(dim,imgpos,rest)=let
                    val dim'=dim-1
                    val sum=sx+dim'
                    val dels=createDels(deltas,dim')
                    val L=print "\n creatWith "
                    val LL=print(Int.toString(dim'))
                    val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
                    val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))

            in
                createKRN(dim',pos@imgpos,[rest']@rest)
            end


            val exp=createKRN(dim, [],[])
            val esum=sumIndex (dim)


            in (E.Sum(esum, exp),(params@params',index,args2@args',d',argcode@code@code',1)) end
            (*end case*))


fun TS x  = case SrcIL.Var.binding x
of vb => String.concat[SrcIL.Var.toString x,"\n Found ", SrcIL.vbToString vb,"\n"]
(* end case *)

fun TT x  = case DstIL.Var.binding x
of vb => String.concat[DstIL.Var.toString x,"\n Found ", DstIL.vbToString vb,"\n"]
(* end case *)



(*copied from high-to-mid.sml*)
fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let
(*
val g=print "expand is called"
val gg1=print(String.concatWith ","(List.map  TS origargs))
val gg2=print "\n ----------- newbie --------- \n"
val gg3=print(String.concatWith ","(List.map TT args))
*)
    val dummy=E.Const 0.0  (*tmp variables*)
        (*Maybe conv->0 sumewhere else *)
    val sumIndex=ref (length index)
    fun sumI(e)=let
        val (E.V v,_,_)=List.nth(e, length(e)-1)
        val ref x=sumIndex
        val v'=v+1
       in  if(x>v') then x else v' 
       end

    fun rewriteBody exp= let
        val (body, data)=exp
        in (case body
        of E.Const _=>exp
        | E.Tensor _=>exp
        | E.Krn _=>exp
        | E.Delta _=>exp
        | E.Value _ =>exp
        | E.Epsilon _=>exp
        | E.Partial _=>exp
        | E.Img _=>  exp
        | E.Conv _=>(print "\n No Probe, used, can not expand \n" ; (dummy,data))
        | E.Field _ =>(dummy, data)
        | E.Apply _ =>(dummy, data)
        | E.Neg e=> let
            val (body',exp')=rewriteBody (e, data)
            in
                (E.Neg(body'),exp')
            end
        | E.Sum (c,e)=> let
           
            val m=(sumI(c))
            val (body',exp')=(sumIndex:=m;rewriteBody (e,data))
            in
                ((E.Sum(c, body'), exp'))
            end
        | E.Probe(E.Conv _, _) =>let
            val ref x=sumIndex
            in expandEinProbe(exp,origargs, x) end
        | E.Sub(a,b)=>let
            val (bodya,dataa)= rewriteBody(a, data)
            val (bodyb, datab)= rewriteBody(b, dataa)
            in   (E.Sub( bodya, bodyb),datab)
            end
        | E.Div(a,b)=>let
            val (bodya,dataa)= rewriteBody(a, data)
            val (bodyb, datab)= rewriteBody(b, dataa)
            in  (E.Div(bodya, bodyb),datab) end 
        | E.Add es=> let
            fun filter([], done, data')= (E.Add done, data')
                | filter(e::es, done, data')= let
                    val (body', data'')= rewriteBody(e, data')
                      in filter(es, done@[body'], data'') end
            in filter(es, [],data) end 
        
        | E.Prod es=> let
            fun filter([], done, data)= (E.Prod done, data)
                | filter(e::es, done, data)= let
                    val (body', data')= rewriteBody(e, data)
                    in filter(es, done@[body'], data') end
                in filter(es, [],data) end 
        | E.Probe _=> exp
        (* end case *))
        end

     val empty =fn key =>NONE
    val (body',(params', index', args',_,code',change))=rewriteBody(body,(params,index,args,empty,[],0))
    val newbie=Ein.EIN{params=params', index=index', body=body'}
    in (change,newbie,args',code') end

  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0