Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log

Revision 2867 - (download) (annotate)
Tue Feb 10 06:52:58 2015 UTC (4 years, 9 months ago) by cchiw
File size: 7894 byte(s)
moved split around, added norm to typechecker, added sqrt to ein
(* Expands probe ein
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.

structure ProbeEin = struct

    structure E = Ein
    structure DstIL = MidIL
    structure DstOp = MidOps
    structure P=Printer
    structure T=TransformEin
    structure MidToS=MidToString

(* This file expands probed fields
* Take a look at ProbeEin tex file for examples
*Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
* Param_ids are used to note the placement of the argument in the midIL.var list
* Index_ids  keep track of the shape of an Image or differentiation.
* Mu  bind Index_id
* Generally, we will refer to the following
*dim:dimension of field V
* s: support of kernel H
* alpha: The alpha in <V_alpha * H^(deltas)>
* deltas: The deltas in <V_alpha * H^(deltas)>
* Vid:param_id for V
* hid:param_id for H
* nid: integer position param_id
* fid :fractional position param_id
*img-imginfo about V 
    val testing=0
    val cnt = ref 0

    fun transformToIndexSpace e=T.transformToIndexSpace e
    fun transformToImgSpace  e=T.transformToImgSpace  e
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))
    fun getRHSDst x  = (case DstIL.Var.binding x
        of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
        | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
        | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
    (* end case *))

    (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
        uses the Param_ids for the image, kernel,
        and position tensor to get the Mid-IL arguments
    returns the support of ther kernel, and image
    fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
        of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
                ((Kernel.support h) ,img,ImageInfo.dim img)
        |  _ => raise Fail "Expected Image and kernel arguments"
        (*end case*))

    (*handleArgs():int*int*int*Mid IL.Var list
        ->int*Mid.ILVars list* code*int* low-il-var
        * uses the Param_ids for the image, kernel, and tensor
        * and gets the mid-IL vars for each.
        *Transforms the position to index space
        *P is the mid-il var for the (transformation matrix)transpose
    fun handleArgs(Vid,hid,tid,args)=let
        val imgArg=List.nth(args,Vid)
        val hArg=List.nth(args,hid)
        val newposArg=List.nth(args,tid)
        val (s,img,dim) =getArgsDst(hArg,imgArg,args)
        val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
            (dim,args@argsT,code, s,P)

    (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
    * expands the body for the probed field
    fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let           
        (*1-d fields*)
        fun createKRND1 ()=let
            val sum=sx
            val dels=List.map (fn e=>(E.C 0,e)) deltas
            val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
            val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
                E.Prod [E.Img(Vid,alpha,pos),rest]
        (*createKRN Image field and kernels *)
        fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
        | createKRN(dim,imgpos,rest)=let
            val dim'=dim-1
            val sum=sx+dim'
            val dels=List.map (fn e=>(E.C dim',e)) deltas 
            val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
            val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
        val exp=(case dim
            of 1 => createKRND1()
            | _=> createKRN(dim, [],[])
            (*end case*))
        (*sumIndex creating summaiton Index for body*)
        val slb=1-s
        val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
        E.Sum(esum, exp)

    (*getsumshift:sum_indexid list* int list-> int
    *get fresh/unused index_id, returns int 
    fun getsumshift(sx,index) =let
        val nsumshift= (case sx
            of []=> length(index)
            | _=>let
                val (E.V v,_,_)=List.hd(List.rev sx)
                in v+1
            (* end case *))
        val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
        val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),
            "\nThink nshift is ", Int.toString nsumshift]

    *just does a quick rewrite
    fun formBody(E.Sum([],e))=formBody e
    | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
    | formBody(E.Prod [e])=e
    | formBody e=e

    (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
            -> ein_exp* *code
    * Transforms position to world space
    * transforms result back to index_space
    * rewrites body 
    * replace probe with expanded version
    fun replaceProbe(b,params,args,index, sx)=let
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
        val fid=length(params)
        val nid=fid+1
        val Pid=nid+1
        val nshift=length(dx)
        val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,index)
        val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
        val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
        val body' =formBody(E.Sum(newsx1, E.Prod([body']@Ps)))
        val args'=argsA@[PArg]
            (body',params',args' ,code)

    (* expandEinOp: code->  code list
    *Looks to see if the expression has a probe. If so, replaces it.
    * Note how we keeps eps expressions so only generate pieces that are used
    fun expandEinOp( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
        fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",
                MidToS.printEINAPP e, "\n=>\n",
                (String.concatWith",\t"(List.map MidToS.printEINAPP code))]
        fun rewriteBody b=(case b
            of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator. Argument needs to be applied in High-IL"
            | E.Probe e =>let
                val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])
                val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
                val code=newbies@[einapp]
            | E.Sum(sx,E.Probe e)  =>let
                val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
                val  body'=E.Sum(sx,body')
                val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
                val code=newbies@[einapp]
            | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
                val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
                val  body'=E.Sum(sx,E.Prod[eps,body'])
                val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
                val code=newbies@[einapp]
            | _=> [e]
            (* end case *))
            rewriteBody body

  end; (* local *)

end (* local *)  

ViewVC Help
Powered by ViewVC 1.0.0