Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log

Revision 2615 - (download) (annotate)
Wed May 14 00:22:49 2014 UTC (5 years, 4 months ago) by cchiw
File size: 13655 byte(s)
added tree-il expressions and types
(* Currently under construction 
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.

A couple of different approaches.
One approach is to find all the Probe(Conv). Gerenerate exp for it
Then use Subst function to sub in. That takes care for index matching and


(*This approach creates probe expanded terms, and adds params to the end. *)

structure ProbeEin = struct

    structure E = Ein
    structure mk= mkOperators
    structure SrcIL = HighIL
    structure SrcTy = HighILTypes
    structure SrcOp = HighOps
    structure SrcSV = SrcIL.StateVar
    structure VTbl = SrcIL.Var.Tbl
    structure DstIL = MidIL
    structure DstTy = MidILTypes
    structure DstOp = MidOps
    structure DstV = DstIL.Var
    structure SrcV = SrcIL.Var
    structure P=Printer
    structure shift=ShiftEin
    structure split=SplitEin
    structure F=Filter
    structure T=TransformEin

val testing=0

datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int

fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))

fun getRHS x  = (case SrcIL.Var.binding x
    of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)
    | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'
    | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)
    | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)
    | SrcIL.VB_NONE=>(S2 2,[])
    | vb => raise Fail(concat[
    "expected rhs operator for ", SrcIL.Var.toString x,
    "but found ", SrcIL.vbToString vb])
    (* end case *))

(*Create fractional, and integer position vectors*)
fun transformToImgSpace  (dim,v,posx)=let

    val translate=DstOp.Translate v
    val transform=DstOp.Transform v
    val M  = DstV.new ("M", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)

    val T  = DstV.new ("T", DstTy.tensorTy [dim,dim])   (*translate*)
    val x  = DstV.new ("x", DstTy.vecTy dim)            (*Image-Space position*)
    val f  = DstV.new ("f", DstTy.vecTy dim)            (*fractional*)
    val nd = DstV.new ("nd", DstTy.vecTy dim)           (*real position*)
    val n  = DstV.new ("n", DstTy.iVecTy dim)           (*integer position*)
    val PosToImgSpace=mk.transform(dim,dim)
    val P  = DstV.new ("P", DstTy.tensorTy [dim,dim])   (*transform dim by dim?*)

    val code=[
        assign(M, transform, []),
        assign(T, translate, []),
        assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)
        assign(nd, DstOp.Floor dim, [x]),   (*nd *)
        assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)
        assign(n, DstOp.RealToInt dim, [nd]), (*real to Int*)
        assignEin(P, mk.transpose([dim,dim]), [M])
    in ([n,f],P,code)

fun replaceH(kvar, place,args)=let
    val l1=List.take(args, place)
    val l2=List.drop(args,place+1)
    in l1@[kvar]@l2 end

(*Get Img, and Kern Args*)
fun getArgs(hid,hArg,V,imgArg,args,lift)=case (getRHS hArg,getRHS imgArg)
    of ((O2(SrcOp.Kernel(h, i)),argK),(O2(SrcOp.LoadImage img),_))=> let
        val hvar=DstV.new ("KNL", DstTy.KernelTy)
        val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
        val argsVK= (case lift
            of 0=> let
                val argsN=replaceH(hvar, hid,args)
                in replaceH(imgvar, V,argsN) end
            | _ => [imgvar, hvar]
        (* end case *))
        val assigments=[assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])]

            (Kernel.support h ,img, assigments,argsVK)
    | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"
    |  _ => raise Fail "Not a kernel argument"

fun handleArgs(V,h,t,(params,args),origargs,lift)=let
    val E.IMG(dim)=List.nth(params,V)
    val kArg=List.nth(origargs,h)
    val imgArg=List.nth(origargs,V)
    val newposArg=List.nth(args, t)
    val (s,img,argcode,argsVH) =getArgs(h,kArg,V,imgArg,args,lift)
    val (argsT,P,code')=transformToImgSpace(dim,img,newposArg)
    in (dim,argsVH@argsT,argcode@code', s,P)

(*createDels=> creates the kronecker deltas for each Kernel*)
fun createDels([],_)= []
    | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)

(*Created new body for probe*)
fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let

    (*sumIndex creating summaiton Index for body*)
    fun sumIndex(0)=[]
    |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]

    (*createKRN Image field and kernels *)
    fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
    | createKRN(dim,imgpos,rest)=let
        val dim'=dim-1
        val sum=sx+dim'
        val dels=createDels(deltas,dim')
        val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
        val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))

    val exp=createKRN(dim, [],[])
    val esum=sumIndex (dim)
    in E.Sum(esum, exp)

fun ShapeConv([],n)=[]
    | ShapeConv(E.C c::es, n)=ShapeConv(es, n)
    | ShapeConv(E.V v::es, n)=
        if(n>v) then [E.V v] @ ShapeConv(es, n)
        else ShapeConv(es,n)

fun mapIndex([],_)=[]
    | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)
    | mapIndex(E.C c::es,index) = mapIndex(es,index)

(*Lift probe and Multiply by P*)
fun liftProbe(E.Probe(E.Conv(V,alpha,H,dx),E.Tensor(t,_)),(params,args),index, sumIndex,origargs)=let
    val _ =print "Lift Probe"

    val n=length(index)
    val ns=length sumIndex
    val nshift=length(dx)
    val np=length(params)
    val nsumshift =(case ns
        of 0=>   n
        |_=>let  val (E.V v,_,_)=List.nth(sumIndex, ns-1)
        (* end case *))

    (*Outer Index-id Of Probe*)
    val VShape=ShapeConv(alpha, n)
    val HShape=ShapeConv(dx, n)
    val shape=VShape@HShape

    (* Bindings for Shape*)
    val shapebind= mapIndex(shape,index)
    val Vshapebind= mapIndex(VShape,index)

    (*Look at Args and get dim, mid-il ops, support, and Arg for transformation matrix P*)
    val (dim,args',code,support,PArg) = handleArgs(V,H,t,(params,args), origargs,1)

    (*New transformations:params, sx, rest, will be empty if no transformation is made*)
    val (oldArg,newArg,dx, paramsT,sxT,restT,ixT,dataT) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,ns,4)

    val bodyExpanded = createBody(dim, support,nsumshift+nshift,alpha,dx,0, 1, 3, 2)

    val sx=sumIndex@sxT
    val body'=(case sx
        of [] =>E.Prod(restT@[bodyExpanded])
        | _ => E.Sum(sx, E.Prod(restT@[bodyExpanded]))
        (*end case*))

    (*create new EIN OPerator*)
    val _ =print("Found this many args ")
    val _ =print(Int.toString(length(args')))

    val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]@paramsT
    val (p',i',b',a')=shift.clean(params', index@ixT, body', args'@[PArg])
    val newbie'=Ein.EIN{params=p', index=i', body=b'}
    val data=assignEin (oldArg, newbie', a')

    val _ = (case testing
        of 0 => 1
        | _ => (print(String.concat["\n Lift Probe\n", split.printA(newArg, newbie', a'),"\n"]);1)
        (*end case *))
        (E.Tensor(np,shape), (params@[E.TEN(1,shapebind)],args@[newArg]),code@[data]@dataT)
 |liftProbe _ =raise Fail"Incorrect body for Probe"

(*Does not yet do transformation*)
 (* Expand probe in place *)
 fun replaceProbe(b,(params,args),index, sumIndex,origargs)=let

    val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
    val fid=length(params)
    val nid=fid+1
    val n=length(index)
    val ns=length sumIndex
    val nshift=length(dx)
    val nsumshift =(case ns
        of 0=> n
        | _=>let
            val (E.V v,_,_)=List.nth(sumIndex, ns-1)
            in v+1
    (* end case *))
    (*Outer Index-id Of Probe*)
    val VShape=ShapeConv(alpha, n)
    val HShape=ShapeConv(dx, n)
    val shape=VShape@HShape
    (* Bindings for Shape*)
    val shapebind= mapIndex(shape,index)
    val Vshapebind= mapIndex(VShape,index)

    val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,(params,args), origargs,0)
    val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)

    val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
    val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)
    val body' =(case nshift
        of 0=> body''
        | _ => E.Sum(sxT, E.Prod(restT@[body'']))
        (*end case*))
    val args'=argsA@[PArg]
    val _ =(case testing
        of 0=> 1
        | _ =>  let
            val subexp=Ein.EIN{params=params', index=index, body=body'}
            val _= print(String.concat["\n Don't replace probe  \n $$$ new sub-expression $$$ \n",P.printerE(subexp),"\n"])
            in 1 end
        (* end case *))

    in (body',(params',args') ,code)

(*Checks if (1) Summation variable occurs just once (2) it matches n.
Then we lift otherwise expand in place *)
fun checkSum(sx,b,info,index,origargs)=(case sx
    of [(E.V i,lb,ub)]=>  let
        val E.Probe(E.Conv(V,alpha,h,dx), E.Tensor(id,beta))=b
        val n=length(index)
        val _=(case testing
            of 1=> (print(String.concat["in check Sum\n " ,P.printbody(E.Sum([(E.V i,lb,ub)],b))]);1)
            |_ => 1)
            if (i=n) then (case F.countSx(sx,b)
                of (1,ixx) => liftProbe(b,info,index@[ub], [],origargs)
                | _ => replaceProbe(b, info,index,sx,origargs)
                (*end case*))
            else replaceProbe(b, info,index, sx,origargs)
    | _ =>replaceProbe(b, info,index, sx,origargs)
    (*end case*))

fun flatten []=[]
    | flatten(e1::es)=e1@(flatten es)

 (* sx-[] then move out, otherwise keep in *)
fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let

    val dummy=E.Const 0
    val sumIndex=ref []

    (*b-current body, info-original ein op, data-new assigments*)
    fun rewriteBody(b,info)= let

        fun callfn(c1,body)=let
            val ref x=sumIndex
            val c'=[c1]@x
            val (bodyK,infoK,dataK)= (sumIndex:=c';rewriteBody(body ,info))
            val ref s=sumIndex
            val z=hd(s)
            val e'=( case bodyK
                of E.Const _ =>bodyK
                | _ => E.Sum(z,bodyK)
                (*end case*))

        fun filter es=let
            fun filterApply([], doneB, infoB, dataB)= (doneB, infoB,dataB)
            | filterApply(B::es, doneA, infoA,dataA)= let
                val (bodyB, infoB,dataB)= rewriteBody(B,infoA)
                    filterApply(es, doneA@[bodyB], infoB,dataA@dataB)
            in filterApply(es, [],info,[])
        in (case b
            of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let
                val ref sx=sumIndex
                in (case sx
                    of   [] => liftProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
                      | [i]=> checkSum(i,b, info,index,origargs)
                      | _ => let
                        val (b,m,code)=replaceProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)
                        in (E.Sum(c,b),m,code)
                (* end case*))
        | E.Probe(E.Conv _, E.Tensor _) =>let
            val ref sx=sumIndex
            in (case sx
                of []=> liftProbe(b, info,index, [],origargs)
                | [i]=> checkSum(i,b, info,index,origargs)
                | _ => replaceProbe(b, info,index, flatten sx,origargs)
             (* end case*))
        | E.Probe _=> (dummy,info,[])
        | E.Conv _=>  (dummy,info,[])
        | E.Lift _=> (dummy,info,[])
        | E.Field _ => (dummy,info,[])
        | E.Apply _ => (dummy,info,[])
        | E.Neg e=> let
            val (body',info',data')=rewriteBody(e,info)
        | E.Sum (c,e)=> callfn(c,e)
        | E.Sub(a,b)=>let
            val (bodyA,infoA,dataA)= rewriteBody(a,info)
            val (bodyB, infoB, dataB)= rewriteBody(b,infoA)
            in   (E.Sub(bodyA, bodyB),infoB,dataA@dataB)
        | E.Div(a,b)=>let
            val (bodyA,infoA,dataA)= rewriteBody(a,info)
            val (bodyB, infoB,dataB)= rewriteBody(b,infoA)
            in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end
        | E.Add es=> let
            val (done, info',data')= filter es
            val (_, e)=F.mkAdd done
            in (e, info',data')
        | E.Prod es=> let
            val (done, info',data')= filter es
            val (_, e)=F.mkProd done
            in (e, info',data')
        | _=>  (b,info,[])
        (* end case *))

     val empty =fn key =>NONE
     val _ =(case testing
        of 0 => 1
        | _ => (print "\n ************************** \n Starting Expand";1)
        (*end case*))

    val (body',(params',args'),newbies)=rewriteBody(body,(params,args))
    val e'=Ein.EIN{params=params', index=index, body=body'}
    (*val _ =(case testing
        of 0 => 1
        | _ => (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "];1)
        (*end case*))*)

  end; (* local *)

end (* local *)

ViewVC Help
Powered by ViewVC 1.0.0