Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log

Revision 2838 - (download) (annotate)
Tue Nov 25 03:40:24 2014 UTC (4 years, 8 months ago) by cchiw
File size: 8427 byte(s)
edit split-ein
(* Currently under construction 
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.

A couple of different approaches.
One approach is to find all the Probe(Conv). Gerenerate exp for it
Then use Subst function to sub in. That takes care for index matching and


(*This approach creates probe expanded terms, and adds params to the end. *)

structure ProbeEin = struct

    structure E = Ein
    structure mk= mkOperators
    structure SrcIL = HighIL
    structure SrcTy = HighILTypes
    structure SrcOp = HighOps
    structure SrcSV = SrcIL.StateVar
    structure VTbl = SrcIL.Var.Tbl
    structure DstIL = MidIL
    structure DstTy = MidILTypes
    structure DstOp = MidOps
    structure DstV = DstIL.Var
    structure SrcV = SrcIL.Var
    structure P=Printer
    structure F=Filter
    structure T=TransformEin
    structure split=Split

    val testing=0


val cnt = ref 0
fun genName prefix = let
val n = !cnt
cnt := n+1;
String.concat[prefix, "_", Int.toString n]

fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
fun testp n=(case testing
    of 0=> 1
    | _ =>(print(String.concat n);1)
    (*end case*))

(*transform image-space position x to world space position*)

fun getTys 1= (DstTy.intTy,[],[])
 | getTys dim = (DstTy.iVecTy dim,[dim],[dim,dim])

fun WorldToImagespace(dim,v,posx,imgArgDst)=let
        val translate=DstOp.Translate v
        val transform=DstOp.Transform v
        val (_ ,fty,pty)=getTys dim
        val mty=DstTy.TensorTy  pty
        val rty=DstTy.TensorTy fty

        val M  = DstV.new (genName "M", mty)   (*transform dim by dim?*)
        val T  = DstV.new (genName "T", rty)
        val x  = DstV.new (genName "x", rty)            (*Image-Space position*)
        val x0  = DstV.new (genName "x0", rty)
        val (PosToImgSpaceA,PosToImgSpaceB)=(case dim
            of 1=>(mk.prodScalar,mk.addScalar)
            | _ => (mk.transformA(dim,dim) ,mk.transformB(dim))
            (*end case*))
        val code=[
            assign(M, transform, [imgArgDst]),
            assign(T, translate, [imgArgDst]),
            assignEin(x0, PosToImgSpaceA,[M,posx]) , (*xo=MX*)
            assignEin(x, PosToImgSpaceB,[x0,T])  (*x=x0+T*)
    in (M,x,code)

(*Create fractional, and integer position vectors*)
fun transformToImgSpace  (dim,v,posx,imgArgDst)=let
    val (ity,fty,pty)=getTys dim
    val mty=DstTy.TensorTy  pty
    val rty=DstTy.TensorTy fty

    val f  = DstV.new ("f", rty)            (*fractional*)
    val nd = DstV.new ("nd",  rty)           (*real position*)
    val n  = DstV.new ("n", ity)           (*integer position*)
    val P  = DstV.new ("P",mty)   (*transform dim by dim?*)

    val (M,x,code1)=WorldToImagespace(dim,v,posx,imgArgDst)
    val (P,PCode)=(case dim
        of 1=>(M,[])
        | _ =>(P,[assignEin(P, mk.transpose(pty), [M])])
        (*end case*))
    val code=[
        assign(nd, DstOp.Floor dim, [x]),   (*nd *)
        assignEin(f, mk.subTen(fty),[x,nd]),           (*fractional*)
        assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
    in ([n,f],P,code1@PCode@code)

 fun getRHSDst x  = (case DstIL.Var.binding x
    of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
    | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
    | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
 (* end case *))

 (*Get Img, and Kern Args*)
 fun getArgsDst(hid,hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
    of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
        ((Kernel.support h) ,img)
 |  _ => raise Fail "Expected Image and kernel argument"
 (*end case*))


fun handleArgs(V,hid,t,args)=let
    val hArg=List.nth(args,hid)
    val imgArg=List.nth(args,V)
    val newposArg=List.nth(args,t)
    val (s,img) =getArgsDst(hid,hArg,imgArg,args)
    val dim=ImageInfo.dim img
    val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
    in (dim,args@argsT,code, s,P)

(*Created new body for probe*)
fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let

    (*sumIndex creating summaiton Index for body*)
    fun sumIndex 0=[]
    |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]

    fun createKRND1 ()=let
        val sum=sx
        val dels=List.map (fn e=>(E.C 0,e)) deltas
        val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
        val rest= E.Krn(h,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
            E.Prod [E.Img(V,shape,pos),rest]


    (*createKRN Image field and kernels *)
    fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
    | createKRN(dim,imgpos,rest)=let
        val dim'=dim-1
        val sum=sx+dim'
        val dels=List.map (fn e=>(E.C dim',e)) deltas 
        val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
        val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))

    val exp=(case dim
        of 1 => createKRND1()
        | _=> createKRN(dim, [],[])
        (*end case*))

    val esum=sumIndex (dim)
    in E.Sum(esum, exp)

fun ShapeConv([],n)=[]
    | ShapeConv(E.C c::es, n)=ShapeConv(es, n)
    | ShapeConv(E.V v::es, n)=
        if(n>v) then [E.V v] @ ShapeConv(es, n)
        else ShapeConv(es,n)

fun mapIndex([],_)=[]
    | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)
    | mapIndex(E.C c::es,index) = mapIndex(es,index)

 (* Expand probe in place eplaceProbe(b,params,args, index, sx,args)*)
 fun replaceProbe(b,params,args,index, sumIndex)=let

    val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
    val fid=length(params)
    val nid=fid+1
    val n=length(index)

    val nshift=length(dx)
    val nsumshift =(case sumIndex
        of []=> n
        | _=>let
            val (E.V v,_,_)=List.hd(List.rev sumIndex)
            in v+1
    (* end case *))

    val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sumIndex
    val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]
    (*Outer Index-id Of Probe*)
    val VShape=ShapeConv(alpha, n)
    val HShape=ShapeConv(dx, n)
    val shape=VShape@HShape
    (* Bindings for Shape*)
    val shapebind= mapIndex(shape,index)
    val Vshapebind= mapIndex(VShape,index)

    val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,args)
    val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)

    val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
    val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)
    val body' =(case nshift
        of 0=> body''
        | _ => E.Sum(sxT, E.Prod(restT@[body'']))
        (*end case*))
    val args'=argsA@[PArg]
        (body',params',args' ,code)

 (* sx-[] then move out, otherwise keep in *)
fun expandEinOp( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let

    (*b-current body, info-original ein op, data-new assigments*)
    fun rewriteBody b= let
        in (case b
            of E.Probe(E.Conv _, E.Tensor _) =>let
            val (body',params',args',newbies)=replaceProbe(b, params,args,index, [])
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
        | E.Sum(sx,E.Probe e) =>let
            val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
            val  body'=E.Sum(sx,body')
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
        | _=> (0,[e])
        (* end case *))

     val empty =fn key =>NONE

    val (c,code)=rewriteBody body
    val b=String.concatWith",\t"(List.map split.printEINAPP code)
    val _ =(case c
        of 1 =>print(String.concat["\nbody",split.printEINAPP e, "\n=>\n",b ])
        | _ =>print(String.concat[""])
        (*end case*))

  end; (* local *)

end (* local *)

ViewVC Help
Powered by ViewVC 1.0.0