Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2584 - (download) (annotate)
Tue Apr 15 03:22:58 2014 UTC (6 years, 10 months ago) by cchiw
File size: 13132 byte(s)
Multiply Fields
(* examples.sml
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)


(*
A couple of different approaches.
One approach is to find all the Probe(Conv). Gerenerate exp for it
Then use Subst function to sub in. That takes care for index matching and

*)

(*This approach creates probe expanded terms, and adds params to the end. *)


structure Expand = struct

    local
   
    structure E = Ein
    structure mk= mkOperators


structure SrcIL = HighIL
structure SrcTy = HighILTypes
structure SrcOp = HighOps
structure SrcSV = SrcIL.StateVar
structure VTbl = SrcIL.Var.Tbl
structure DstIL = MidIL
structure DstTy = MidILTypes
structure DstOp = MidOps
structure DstV = DstIL.Var
structure SrcV = SrcIL.Var
structure P=Printer
structure shiftHtM=shiftHtM
structure split=splitHtM


datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
datatype peanut2=    O2 of  SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int
    in


fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))

fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args))
fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args))

fun insert (key, value) d =fn s =>
        if s = key then SOME value
        else d s
fun lookup k d = d k



fun getRHS x  = (case DstIL.Var.binding x
of DstIL.VB_RHS(DstIL.OP(rator, args)) => (O rator, args)
    | DstIL.VB_RHS(DstIL.VAR x') => getRHS x'
    | DstIL.VB_RHS(DstIL.EINAPP (e,args))=>(E e,args)
    | DstIL.VB_RHS(DstIL.CONS (ty,args))=>(C ty,args)
    | DstIL.VB_NONE=>(S 2,[])
    | vb => raise Fail(concat[
        "expected rhs operator for ", DstIL.Var.toString x,
        "but found ", DstIL.vbToString vb])
    (* end case *))


fun getRHS2 x  = (case SrcIL.Var.binding x
    of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)
    | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS2 x'
    | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)
    | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)
    | SrcIL.VB_NONE=>(S2 2,[])
    | vb => raise Fail(concat[
    "expected rhs operator for ", SrcIL.Var.toString x,
    "but found ", SrcIL.vbToString vb])
    (* end case *))




fun PrintBIND x  = (case DstIL.Var.binding x
    of vb=> print(String.concat[    "\n expected rhs operator for ", DstIL.Var.toString x,
    " but found ", DstIL.vbToString vb,"\n \n"])
(* end case *))

fun PrintBIND2 x  = (case SrcIL.Var.binding x
of vb=> print(String.concat[    "\n expected rhs operator for ", SrcIL.Var.toString x,
" but found ", SrcIL.vbToString vb,"\n \n"])
(* end case *))



(*Create fractional, and integer position vectors*)
fun createArgs  (dim,v,posx)=let

    val translate=DstOp.Translate v
    val transform=DstOp.Transform v
    val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
    val T = DstV.new ("T", DstTy.tensorTy [dim,dim])          (*translate*)
    val x = DstV.new ("x", DstTy.vecTy dim)
    val f = DstV.new ("f", DstTy.vecTy dim)         (*fractional*)
    val nd = DstV.new ("nd", DstTy.vecTy dim)       (*real position*)
    val n = DstV.new ("n", DstTy.iVecTy dim)        (*interger position*)
 

    val PosToImgSpace=mk.transform(dim,dim)
    val code=[
        assign(M, transform, []),
        assign(T, translate, []),

        assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)
        assign(nd, DstOp.Floor dim, [x]),   (*nd *)
        assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)
        assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
    ]

    in ([n,f],code)
    end





(*
createDels=> creates the kronecker deltas for each Kernel
For each dimesnion a, and each index in derivative b create element (a,b)
*)
fun createDels([],_)= []
| createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)




(*Get Img, and Kern Args*)
fun getArgs(hid,hArg,V,imgArg,args)=
    case (getRHS2 hArg,getRHS2 imgArg)
        of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let
            val hvar=DstV.new ("KNL", DstTy.KernelTy)
            val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
            fun replaceH(kvar, place,args)=let
                val l1=List.take(args, place)
                val l2=List.drop(args,place+1)
                in l1@[kvar]@l2 end

            val args1=replaceH(hvar, hid,args)
            val args2=replaceH(imgvar, V,args1)
            in
                (Kernel.support h ,img, [assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])],args2)
            end
        | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"
        |  _ => raise Fail "Not a kernel argument"
    

(*Get Img, and Krn Args*)
fun getArgs3(hid,hArg,V,imgArg,args)=
    case (getRHS2 hArg,getRHS2 imgArg)
        of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let
            val hvar=DstV.new ("KNL", DstTy.KernelTy)
            val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
            in
            (Kernel.support h ,img, [assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])],[imgvar, hvar])
            end
        | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"
        |  _ => raise Fail "Not a kernel argument"




fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let

    (*sumIndex creating summaiton Index for body*)
    fun sumIndex(0)=[]
    |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]

    (*createKRN Image field and kernels *)
    fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
    | createKRN(dim,imgpos,rest)=let

        val dim'=dim-1
        val sum=sx+dim'
        val dels=createDels(deltas,dim')
        val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
        val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
        in
            createKRN(dim',pos@imgpos,[rest']@rest)
        end

    val exp=createKRN(dim, [],[])
    val esum=sumIndex (dim)
    in E.Sum(esum, exp)
    end

 (* Expand probe in place *)
fun expandEinProbe4(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),index, sx,origargs)=let


    val E.IMG(dim)=List.nth(params,V)
    val kArg=List.nth(origargs,h)
    val imgArg=List.nth(origargs,V)
    val newposArg=List.nth(args, t)

    val (s,img,argcode,args2) =getArgs(h,kArg,V,imgArg,args)
    val params'=[E.TEN(3,[dim]),E.TEN(1,[dim])]
    val (args',code')=createArgs(dim,img,newposArg)
    val fid=length params
    val nid=fid+1

    val body=createBody(dim, s,sx,shape,deltas,V, h, nid, fid)
    in (params@params',body,args2@args',argcode@code') end

(*Lift Probe*)
fun expandEinProbe3(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),index, sx,origargs)=let


    val E.IMG(dim)=List.nth(params,V)
    val kArg=List.nth(origargs,h)
    val imgArg=List.nth(origargs,V)
    val newposArg=List.nth(args, t)

    val (s,img,argcode,argsVH) =getArgs3(h,kArg,V,imgArg,args)
    val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]
    val (argsT,code')=createArgs(dim,img,newposArg)

    val h=1
    val fid=2
    val nid=3

    val body=createBody(dim, s,sx,shape,deltas,0, h, nid, fid)
    in (params',body,argsVH@argsT,argcode@code') end



fun TS x  = case SrcIL.Var.binding x
    of vb => String.concat[SrcIL.Var.toString x,"\n Found ", SrcIL.vbToString vb,"\n"]
    (* end case *)

fun TT x  = case DstIL.Var.binding x
    of vb => String.concat[DstIL.Var.toString x,"\n Found ", DstIL.vbToString vb,"\n"]
    (* end case *)







fun ShapeConv([],n)=[]
    | ShapeConv(E.C c::es, n)=ShapeConv(es, n)
    | ShapeConv(E.V v::es, n)=
        if(n>v) then [E.V v] @ ShapeConv(es, n)
        else ShapeConv(es,n)


(*Lift probe*)
fun foundProbe3(b,(params,args),index, sumIndex,origargs)=let
    val _=print "\n Lift Probe\n"
    val E.Probe(E.Conv(_,alpha,_,dx),pos)=b

    val newId=length(params)
    val n=length(index)

    (*Create new tensor replacement*)
    val shape=ShapeConv(alpha@dx, n)
    val newB=E.Tensor(newId,shape)

    (* Create new Param*)
    val shape'=List.map (fn E.V v=>(List.nth(index, v))) shape
    val newP= E.TEN(1,shape')

    (*Create new Arg*)
    val newArg = DstV.new ("PC", DstTy.tensorTy shape')


    (*Expand Probe*)
    val ns=length sumIndex
    val (params',body',args',code) =(case ns
        of 0=> expandEinProbe3(b,(params,args),index, n,origargs)
        |_=>let  
            val (E.V v,_,_)=List.nth(sumIndex, ns-1)
            val (p,body',a,c)= expandEinProbe3(b,(params,args),index, v+1,origargs)
            in  (p,E.Sum(sumIndex ,body'),a,c)
            end
    (* end case *))

   
    val (p',i',b',a')=shiftHtM.clean(params', index,body', args')
    val newbie'=Ein.EIN{params=p', index=i', body=b'}
     val _ = print(String.concat["\n ", split.printA(newArg, newbie', a'),"\n"])
    val data=assignEin (newArg, newbie', a')
    in (newB, (params@[newP],args@[newArg]) ,code@[data])
    end
 
 
 (* Expand probe in place *)
 fun foundProbe4(b,(params,args),index, sumIndex,origargs)=let
    val _=print "\n Don't replace probe \n"
    val E.Probe(E.Conv(_,alpha,_,dx),pos)=b
 
    val newId=length(params)
    val n=length(index)
 
    (*Expand Probe*)
    val ns=length sumIndex
    val (params',body',args',code) =(case ns
        of 0=> expandEinProbe4(b,(params,args),index, n,origargs)
        |_=>let
            val (E.V v,_,_)=List.nth(sumIndex, ns-1)
            (* val _ =print(String.concat["\n Found sum sx, last element is ",Int.toString(v), "\n"])*)
            in expandEinProbe4(b,(params,args),index, v+1,origargs)
            end
    (* end case *))

        val UUU=Ein.EIN{params=params', index=index, body=body'}
        val _ =print(String.concat["\n $$$ new sub-expression $$$ \n",P.printerE(UUU),"\n"])

    in (body',(params',args') ,code)
 end
 

fun flatten []=[]
    | flatten(e1::es)=e1@(flatten es)

 
 (* sx-[] then move out, otherwise keep in *)
fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let

    val dummy=E.Const 0
    val sumIndex=ref []

    (*b-current body, info-original ein op, data-new assigments*)
    fun rewriteBody(b,info)= let
        val t=print(String.concat["\n\n ***:",Printer.printbody(b),"$ \n"])
        in (case b
            of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let
                val ref sx=sumIndex
                in (case sx 
                    of [] => foundProbe3(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
                    | _ =>  foundProbe4(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)
                (* end case*))
            end
        | E.Probe(E.Conv _, E.Tensor _) =>let
            val ref sx=sumIndex
            in (case sx
                of []=> foundProbe3(b, info,index, [],origargs)
                | _=> foundProbe4(b, info,index, flatten sx,origargs)
             (* end case*))
            end
        | E.Probe _=> (dummy,info,[])
        | E.Conv _=>  (dummy,info,[])
        | E.Field _ => (dummy,info,[])
        | E.Apply _ => (dummy,info,[])
        | E.Neg e=> let
            val (body',info',data')=rewriteBody(e,info)
            in
                (E.Neg(body'),info',data')
            end
        | E.Sum (c,e)=> let
            val ref x=sumIndex
            val c'=[c]@x
            val (body',info',data')=(sumIndex:=c';rewriteBody (e,info))
            val ref s=sumIndex
            val z=hd(s)
            in 
                (sumIndex:=tl(s);(E.Sum(z, body'),info',data'))
            end
        | E.Sub(a,b)=>let
            val (bodyA,infoA,dataA)= rewriteBody(a,info)
            val (bodyB, infoB, dataB)= rewriteBody(b,infoA)
            in   (E.Sub(bodyA, bodyB),infoB,dataA@dataB)
            end
        | E.Div(a,b)=>let
            val (bodyA,infoA,dataA)= rewriteBody(a,info)
            val (bodyB, infoB,dataB)= rewriteBody(b,infoA)
            in  (E.Div(bodyA, bodyB),infoB,dataA@dataB) end
        | E.Add es=> let
            fun filter([], done, info', data)= (E.Add done, info',data)
                | filter(e::es, done, info',data)= let
                    val (body', info'',data')= rewriteBody(e,info')
                    in filter(es, done@[body'], info'',data@data') end
            in filter(es, [],info,[]) end
        
        | E.Prod es=> let
            fun filter([], done, info',data)= (E.Prod done,info', data)
                | filter(e::es, done, info',data)= let
                    val (body', info'',data')= rewriteBody(e, info')
                    in filter(es, done@[body'], info'',data@data') end
                in filter(es, [],info,[]) end
        | _=>  (b,info,[])
        (* end case *))
        end

     val empty =fn key =>NONE
    val mm=print "\n ************************** \n Starting Exapnd"
    val (body',(params',args'),newbies)=rewriteBody(body,(params,args))
    val e'=Ein.EIN{params=params', index=index, body=body'}
    
    val rr=print (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "])
    in ((e',args'),newbies) end

  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0