Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/ProbeEin2.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/ProbeEin2.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2845 - (download) (annotate)
Fri Dec 12 06:46:23 2014 UTC (5 years, 9 months ago) by cchiw
File size: 13417 byte(s)
added norm
(* Currently under construction 
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure ProbeEin = struct

    local
   
    structure E = Ein
    structure mk= mkOperators
    structure SrcIL = HighIL
    structure SrcTy = HighILTypes
    structure SrcOp = HighOps
    structure SrcSV = SrcIL.StateVar
    structure VTbl = SrcIL.Var.Tbl
    structure DstIL = MidIL
    structure DstTy = MidILTypes
    structure DstOp = MidOps
    structure DstV = DstIL.Var
    structure SrcV = SrcIL.Var
    structure P=Printer
    structure F=Filter
    structure T=TransformEin
    structure split=Split
    structure cleanI=cleanIndex



    val testing=0


    in

 
(* This file expands probed fields
*Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
* Param_ids are used to note the placement of the argument in the midIL.var list
* Index_ids bind the shape of an Image or differentiation.
* Generally, we will refer to the following 
*dim:dimension of field V
* s: support of kernel H
* alpha: The alpha in <V_alpha * H^(deltas)>
* deltas: The deltas in <V_alpha * H^(deltas)>
* Vid:param_id for V
* hid:param_id for H
* nid: integer position param_id
* fid :fractional position param_id
*img-imginfo about V 
*)
             
             
val cnt = ref 0
fun genName prefix = let
val n = !cnt
in
cnt := n+1;
String.concat[prefix, "_", Int.toString n]
end


fun iterSx e=F.iterSx e
fun transformToIndexSpace e=T.transformToIndexSpace e
fun transformToImgSpace  e=T.transformToImgSpace  e
fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
fun testp n=(case testing
    of 0=> 1
    | _ =>(print(String.concat n);1)
    (*end case*))
fun getRHSDst x  = (case DstIL.Var.binding x
    of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
    | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
    | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
 (* end case *))

    (*iterSx:sum_index_id * index_id-> sum_index_id list * sum_index_id list
    Filters sum_index_id based on if it is in dx
    "Pre" are, "Post" are not
    used in ProbeEin
    *)
    fun iterSx(sx,dx)=let
    fun f([],pre,post)=(pre,post)
    | f((e1, lb,ub)::es,pre,post)= (case (List.find (fn x =>  x = e1) dx)
    of NONE => f(es,pre,post@[(e1, lb,ub)])
    | SOME _ => f(es,pre@[(e1, lb,ub)],post)
    (*end case*))
    in
    f(sx,[],[])
    end
    
    

(* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
    uses the Param_ids for the image, kernel, and position tensor to get the Mid-IL arguments
  returns the support of ther kernel, and image
*)
 fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
    of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
    in
        ((Kernel.support h) ,img,ImageInfo.dim img)
    end
 |  _ => raise Fail "Expected Image and kernel arguments"
 (*end case*))


(*handleArgs():int*int*int*Mid IL.Var list ->int*Mid.ILVars list* code*int* low-il-var
* uses the Param_ids for the image, kernel, and tensor and gets the mid-IL vars for each
*Transforms the position to index space
*P-mid-il var for the (transformation matrix)transpose
*)
fun handleArgs(Vid,hid,tid,args)=let
    val imgArg=List.nth(args,Vid)
    val hArg=List.nth(args,hid)
    val newposArg=List.nth(args,tid)
    val (s,img,dim) =getArgsDst(hArg,imgArg,args)
    val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
    in (dim,args@argsT,code, s,P)
    end


(*createBody:int*int*int, index_id list, param_id, param_id, param_id, param_id
* expands the body for the probed field
*)
fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
    
    (*1-d fields*)
    fun createKRND1 ()=let
        val sum=sx
        val dels=List.map (fn e=>(E.C 0,e)) deltas
        val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
        val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
        in 
            E.Prod [E.Img(Vid,alpha,pos),rest]

        end
    (*createKRN Image field and kernels *)
    fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
    | createKRN(dim,imgpos,rest)=let
        val dim'=dim-1
        val sum=sx+dim'
        val dels=List.map (fn e=>(E.C dim',e)) deltas 
        val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
        val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
        in
            createKRN(dim',pos@imgpos,[rest']@rest)
        end
    val exp=(case dim
        of 1 => createKRND1()
        | _=> createKRN(dim, [],[])
        (*end case*))

    (*sumIndex creating summaiton Index for body*)
    val slb=1-s
    val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
in
    E.Sum(esum, exp)
end

(*getsumshift:sum_index_id list* index_id list-> int 
*get fresh/unused index_id, returns int 
*)
fun getsumshift(sx,index) =let
    val nsumshift= (case sx
        of []=> length(index)
        | _=>let
            val (E.V v,_,_)=List.hd(List.rev sx)
            in v+1
            end
        (* end case *))
    val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
    val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]
    in
        nsumshift
    end 

(*formBody:ein_exp->ein_exp
*just does a quick rewrite
*)
fun formBody(E.Sum([],e))=formBody e
| formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
| formBody(E.Prod [e])=e
| formBody e=e


(* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code
* Transforms position to world space
* transforms result back to index_space
* rewrites body 
* replace probe with expanded version
*)
 fun replaceProbe(b,params,args,index, sx)=let

    val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
    val fid=length(params)
    val nid=fid+1
    val Pid=nid+1
    val nshift=length(dx)
    val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
    val freshIndex=getsumshift(sx,index)
    val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
    val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
    val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
    val body' =formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))
    val args'=argsA@[PArg]
    in
        (body',params',args' ,code)
    end



(* liftedProbe:e:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code
* Same as above except it does not transforms result back to index_space
* Also returns P arg. 
*)
fun liftedProbe(b,params,args,index, sumIndex)=let
    val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
    val fid=length(params)
    val nid=fid+1
    val nshift=length(dx)
    val freshIndex = getsumshift(sumIndex,index)
    val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,args)
    val params=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
    val body = createBody(dim, s, freshIndex+nshift,alpha,dx,V, h, nid, fid)
    in
        (body,params,argsA ,PArg,code)
    end

(* expandEinOp: code->  code list
*Looks to see if the expression has a probe. If so, replaces it.
* Note how we keeps eps type expressions so we have less time in mid-to-low-il stage
*)
fun expandEinOp99( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
    fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",
    (String.concatWith",\t"(List.map split.printEINAPP code))]

    fun rewriteBody b=(case b
        of  E.Probe e =>let
            val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
            in
                code
            end
        | E.Sum(sx,E.Probe e)  =>let
            val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
            val  body'=E.Sum(sx,body')
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
            in
                code
            end
        | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
            val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
            val  body'=E.Sum(sx,E.Prod[eps,body'])
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
            in
                code
            end
        | _=> [e]
        (* end case *))
    in
        rewriteBody body
    end


(* expandEinOp: code->  code list* Arg List 
*same as above but uses lifted probe()
*)

fun expandEinOp2( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
    fun rewriteBody b=(case b
        of  E.Probe e =>let
            val (body',params',args',PArg,newbies)=liftedProbe(E.Probe e,params,args, index, [])
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]

          
            in
                (code,[PArg])
            end
        | E.Sum(sx,E.Probe e)  =>let
             val (body',params',args',PArg,newbies)=liftedProbe(E.Probe e,params,args, index, sx)
            val  body'=E.Sum(sx,body')
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
            in
                (code,[PArg])
            end
        | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
            val (body',params',args',PArg,newbies)=liftedProbe(E.Probe e,params,args, index, sx)
            val  body'=E.Sum(sx,E.Prod[eps,body'])
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
                in
            (code,[PArg])
            end
        | _=> ([e],[])
        (* end case *))
    in
        rewriteBody body
    end

(* mkBoody:index_id list, ein_exp: sum_id list, ein_exp
*rewrite ein_exp. replaces dx with new delta
* Was sx bound to the probed field's alpha or delta?
* If it was delta then it needs to be moved to the original ein_app, not in the subexpression
*)
fun mkbody(beta ,body)=(case body
    of E.Probe(E.Conv (v,alpha,h,_), E.Tensor t)=>([], E.Probe(E.Conv (v,alpha,h,beta), E.Tensor t))
    | E.Sum(sx,E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t))=>let
       val(pre,post)=F.iterSx(sx,dx)
        val body=E.Sum(post,E.Probe(E.Conv (v,alpha,h,beta), E.Tensor t))
        in (pre,body)
        end
    | E.Sum(sx,E.Prod[eps,E.Probe(E.Conv(v,alpha,h,dx), E.Tensor t)])=>let
        val(pre,post)=F.iterSx(sx,dx)
        val body=E.Sum(post,E.Prod[eps,E.Probe(E.Conv(v,alpha,h,beta), E.Tensor t)])
        in (pre,body)
        end

    (*end case*))

(*getT:sum_index_id list* int list* int*index_id list*ein_exp*Param_ids*mid_il var list*mid_id var
*This goal of this function is to create simple EINAPPs
*When differentiation is involved the deltas in a probed field are rewritten
*and there is a multiplication by P for each index.
*This function lifts the probed field out and multiplies it's replacment tensor with the Ps
*The probed field is rewritten with the new indices then it is cleaned with "split.lift".
*The result is two einapps that hopefully produce less loops for mid-to-low.sml
*)
fun geT(sx,index,dim,dx,body,params,args,y)=let
    val Pid=length(args)+1
    val freshIndex=getsumshift(sx,index)
    val (newdx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
    val (newsx2,e')= mkbody(newdx,body)
    val newsx=newsx1@newsx2
    val (Re,Rparams,Rargs,[einappN])=split.lift(e',params,index,newsx,args)
     val (einappN,Parg)=expandEinOp2 einappN

    val body'=E.Sum(newsx,E.Prod(Ps@[Re]))
    val einappO=(y,DstIL.EINAPP(Ein.EIN{params=Rparams@[E.TEN(1,[dim,dim])], index=index, body=body'},Rargs@Parg))
    val code=einappN@[einappO]
    in code
    end


(*liftTransform: code->code
this function is called when we are testing lifting transformations
analyzes body of ein_exp and sets up the arguments to next function
Note, I hardcoded the dimension to be 2. 
*)
fun liftTransform( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
    fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",
            (String.concatWith",\t"(List.map split.printEINAPP code))]
    fun rewriteBody b=(case b
        of  E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t) =>let
            val code=geT([],index,2,dx,b,params,args,y)
            val _ = printResult  code
            in
                code
            end
        | E.Sum(sx,E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t))  =>let
            val code=geT(sx,index,2,dx,b,params,args,y)
            in
                code
            end
        | E.Sum(sx,E.Prod[eps,E.Probe(E.Conv (v,alpha,h,dx), E.Tensor t)]) =>let
            val code=geT(sx,index,2,dx,b,params,args,y)
            in
                code
            end
        | _=> [e]
        (* end case *))
    in
        rewriteBody body
    end

val testlift=0
fun expandEinOp e=(case testlift
    of 1=>liftTransform e
    | _ =>let
        val code= expandEinOp99 e
        in code
        end
(*end case*))



  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0