Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2522 - (download) (annotate)
Mon Jan 13 18:42:09 2014 UTC (5 years, 7 months ago) by cchiw
File size: 11441 byte(s)
added gentypes to mid-il and split to High-to-mid il
(* examples.sml
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)


(*
A couple of different approaches.
One approach is to find all the Probe(Conv). Gerenerate exp for it
Then use Subst function to sub in. That takes care for index matching and

*)

(*This approach creates probe expanded terms, and adds params to the end. *)


structure Expand = struct

    local
   
    structure E = Ein
    structure mk= mkOperators


structure SrcIL = HighIL
structure SrcTy = HighILTypes
structure SrcOp = HighOps
structure SrcSV = SrcIL.StateVar
structure VTbl = SrcIL.Var.Tbl
structure DstIL = MidIL
structure DstTy = MidILTypes
structure DstOp = MidOps
structure DstV = DstIL.Var
structure SrcV = SrcIL.Var
structure P=Printer


datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int
    in


fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
fun assignEin (x, rator, args) = (x, DstIL.EINAPP(rator, args))

fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args))
fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args))

fun insert (key, value) d =fn s =>
        if s = key then SOME value
        else d s
fun lookup k d = d k



fun getRHS x  = (case DstIL.Var.binding x
of DstIL.VB_RHS(DstIL.OP(rator, args)) => (O rator, args)
    | DstIL.VB_RHS(DstIL.VAR x') => getRHS x'
    | DstIL.VB_RHS(DstIL.EINAPP (e,args))=>(E e,args)
    | DstIL.VB_RHS(DstIL.CONS (ty,args))=>(C ty,args)
    | DstIL.VB_NONE=>(S 2,[])
    | vb => raise Fail(concat[
        "expected rhs operator for ", DstIL.Var.toString x,
        "but found ", DstIL.vbToString vb])
    (* end case *))


fun getRHS2 x  = (case SrcIL.Var.binding x
    of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)
    | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS2 x'
    | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)
    | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)
    | SrcIL.VB_NONE=>(S2 2,[])
    | vb => raise Fail(concat[
    "expected rhs operator for ", SrcIL.Var.toString x,
    "but found ", SrcIL.vbToString vb])
    (* end case *))




fun PrintBIND x  = (case DstIL.Var.binding x
    of vb=> print(String.concat[    "\n expected rhs operator for ", DstIL.Var.toString x,
    " but found ", DstIL.vbToString vb,"\n \n"])
(* end case *))

fun PrintBIND2 x  = (case SrcIL.Var.binding x
of vb=> print(String.concat[    "\n expected rhs operator for ", SrcIL.Var.toString x,
" but found ", SrcIL.vbToString vb,"\n \n"])
(* end case *))






(*Create fractional, and integer position vectors*)
fun createArgs(dim,v,posx,pos)=let 

    val translate=DstOp.Translate v
    val transform=DstOp.Transform v
    val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
    val T = DstV.new ("T", DstTy.vecTy dim)          (*translate*)
    val x = DstV.new ("x", DstTy.vecTy dim)
    val f = DstV.new ("f", DstTy.vecTy dim)         (*fractional*)
    val nd = DstV.new ("nd", DstTy.vecTy dim)       (*real position*)
    val n = DstV.new ("n", DstTy.iVecTy dim)        (*interger position*)


    val PosToImgSpace=mk.transform(dim,dim)
    val code=[
            assign(M, transform, []),
            assign(T, translate, []),
            pos,
            assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)
            assign(nd, DstOp.Floor dim, [x]),   (*nd *)
            assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)
            assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
            ]

    in ([f,n],code)
    end


(*Create fractional, and integer position vectors*)
fun createArgs2  (dim,v,posx)=let

    val translate=DstOp.Translate v
    val transform=DstOp.Transform v
    val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
    val T = DstV.new ("T", DstTy.vecTy dim)          (*translate*)
    val x = DstV.new ("x", DstTy.vecTy dim)
    val f = DstV.new ("f", DstTy.vecTy dim)         (*fractional*)
    val nd = DstV.new ("nd", DstTy.vecTy dim)       (*real position*)
    val n = DstV.new ("n", DstTy.iVecTy dim)        (*interger position*)

    val PosToImgSpace=mk.transform(dim,dim)
    val code=[
        assign(M, transform, []),
        assign(T, translate, []),

        assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)
        assign(nd, DstOp.Floor dim, [x]),   (*nd *)
        assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)
        assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
    ]

    in ([f,n],code)
    end


(*Currently can't get rhs of image*)

fun createImg (dict,dim,pos1,imgArg,posArg)=let
    val g=print "\n\n \t\t *** Did not find the proper bindings for img\n\n"
    val info= PrintBIND imgArg
    val info2= PrintBIND posArg
    val a= DstV.new ("pos1F", DstTy.vecTy dim)
    val b= DstV.new ("pos2F", DstTy.vecTy dim)
    val bug=(pos1,pos1+1,dict,[E.TEN 1,E.TEN 1],[a,b],[])
    in bug end 

fun Position(V,t,dict,dim,pos1,orig,args)=let
    val l=lookup t dict

    in (case l
        of NONE =>let
            val newimgArg=List.nth(args, V)
            val newposArg=List.nth(args, t)
            val imgArg=List.nth(orig,V)
            val posArg=List.nth(orig,t)
            in (case (getRHS2 imgArg,getRHS2 posArg)
                of ((O2(SrcOp.LoadImage img), a1),(C2 ty, a2))=>

                createImg(dict,dim,pos1,newimgArg,newposArg)
                (*How to reassign arguments?*)
(*let

                    val (args',code')=createArgs2(dim,img,newposArg)
                    val pos2=pos1+1
                    val dict'=insert(t,(pos1,pos2)) dict
                    val params'=[E.TEN 1,E.TEN 1]
                    val code= [assign(newimgArg,DstOp.LoadImage img,[])]

                    in (pos1,pos2, dict',params',args',code'@code)
                    end
*)


(*
                    val posx= DstV.new ("pos", DstTy.vecTy dim)
                    val pos= assignEin(posx, e,[])
                    val (args',code')=createArgs(dim,img,posx,pos)
                    val pos2=pos1+1
                    val dict'=insert(t,(pos1,pos2)) dict
                    val params'=[E.TEN 1,E.TEN 1]
                    in (pos1,pos2, dict',params',args',code')
                    end
             | ((O2(Src.LoadImage v,_),(C _,_))=>
                    let
                        val (args',code')=createArgs2(dim,v,posArg)
                        val pos2=pos1+1
                        val dict'=insert(t,(pos1,pos2)) dict
                        val params'=[E.TEN 1,E.TEN 1]
                    in (pos1,pos2, dict',params',args',code')
                    end
                
*)
                |_=>createImg(dict,dim,pos1,newimgArg,newposArg)
                (*end case*))
            end
             
        | SOME (fid,nid)=>(fid,nid, dict,[],[],[]))
        (*end case*)
    end

(*
createDels=> creates the kronecker deltas for each Kernel
For each dimesnion a, and each index in derivative b create element (a,b)
*)
fun createDels([],_)= []
| createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)




fun kernTransition(k,x)= case getRHS2 k
    of (O2(SrcOp.Kernel(h, i)),arg)=> (Kernel.support h ,[assign (x, DstOp.Kernel(h, i), [])])
    |_ => raise Fail "Not a kernel argument"
    



fun expandEinProbe((body,(params,index,args,d,code,change)),origargs,sx)=(case body
    of E.Probe(E.Conv(V,shape,h,deltas),E.Tensor(t,alpha)) =>let 
            val E.IMG(dim)=List.nth(params,V)
       


            (*Kernel arg*)
            val hnew=List.nth(args,h)
            val hnewx=DstIL.Var.new("hnew " ,DstTy.KernelTy)
            val (s,harg)=kernTransition(List.nth(origargs,h),hnewx)
            val ss=print(String.concat["\n Support",Int.toString(s)])


            val pnum=length params

            val (fid,nid,d',params',args',code')=Position(V,t,d,dim,pnum,origargs,args)
            val shift=length index
    
          

            (*sumIndex creating summaiton Index for body*)
            fun sumIndex(0)=[]
                |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]

            (*createKRN Image field and kernels *)
            fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
                | createKRN(dim,imgpos,rest)=let
                    val dim'=dim-1
                    val sum=dim'+shift
                    val dels=createDels(deltas,dim')
                    val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
                    val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
            in
                createKRN(dim',pos@imgpos,[rest']@rest)
            end


            val exp=createKRN(dim, [],[])
            val esum=sumIndex (dim)

            
            in (E.Sum(esum, exp),(params@params',index,args@args',d',code@harg@code',1)) end
            (*end case*))


(*copied from high-to-mid.sml*)
fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let

    val dummy=E.Const 0.0  (*tmp variables*)
        (*Maybe conv->0 sumewhere else *)
    val sumIndex=ref (length index)
    fun sumI(e)=let
        val (E.V v,_,_)=List.nth(e, length(e)-1)
    in v end

    fun rewriteBody exp= let
        val (body, data)=exp
        in (case body
        of E.Const _=>exp
        | E.Tensor _=>exp
        | E.Krn _=>exp
        | E.Delta _=>exp
        | E.Value _ =>exp
        | E.Epsilon _=>exp
        | E.Partial _=>exp
        | E.Img _=>  exp
        | E.Conv _=>(print "\n No Probe, used, can not expand \n" ; (dummy,data))
        | E.Field _ =>(dummy, data)
        | E.Apply _ =>(dummy, data)
        | E.Neg e=> let
            val (body',exp')=rewriteBody (e, data)
            in
                (E.Neg(body'),exp')
            end
        | E.Sum (c,e)=> let
           
            val m=(sumI(c))+1
            val (body',exp')=(sumIndex:=m;rewriteBody (e,data))
            in
                ((E.Sum(c, body'), exp'))
            end
        | E.Probe(E.Conv _, _) =>let
            val ref x=sumIndex
            in expandEinProbe(exp,origargs, x) end
        | E.Sub(a,b)=>let
            val (bodya,dataa)= rewriteBody(a, data)
            val (bodyb, datab)= rewriteBody(b, dataa)
            in   (E.Sub( bodya, bodyb),datab)
            end
        | E.Div(a,b)=>let
            val (bodya,dataa)= rewriteBody(a, data)
            val (bodyb, datab)= rewriteBody(b, dataa)
            in  (E.Div(bodya, bodyb),datab) end 
        | E.Add es=> let
            fun filter([], done, data')= (E.Add done, data')
                | filter(e::es, done, data')= let
                    val (body', data'')= rewriteBody(e, data')
                      in filter(es, done@[body'], data'') end
            in filter(es, [],data) end 
        
        | E.Prod es=> let
            fun filter([], done, data)= (E.Prod done, data)
                | filter(e::es, done, data)= let
                    val (body', data')= rewriteBody(e, data)
                    in filter(es, done@[body'], data') end
                in filter(es, [],data) end 
        | E.Probe _=> exp
        (* end case *))
        end

     val empty =fn key =>NONE
    val (body',(params', index', args',_,code',change))=rewriteBody(body,(params,index,args,empty,[],0))
    val newbie=Ein.EIN{params=params', index=index', body=body'}
    in (change,newbie,args',code') end

  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0