Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2843 - (download) (annotate)
Mon Dec 8 01:27:25 2014 UTC (4 years, 11 months ago) by cchiw
File size: 7616 byte(s)
added 2-d cross product, new rep. of 2-d curl
(* Currently under construction 
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure ProbeEin = struct

    local
   
    structure E = Ein
    structure mk= mkOperators
    structure SrcIL = HighIL
    structure SrcTy = HighILTypes
    structure SrcOp = HighOps
    structure SrcSV = SrcIL.StateVar
    structure VTbl = SrcIL.Var.Tbl
    structure DstIL = MidIL
    structure DstTy = MidILTypes
    structure DstOp = MidOps
    structure DstV = DstIL.Var
    structure SrcV = SrcIL.Var
    structure P=Printer
    structure F=Filter
    structure T=TransformEin
    structure split=Split
    structure cleanI=cleanIndex


    val testing=1


    in

 
(* This file expands probed fields
*Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
* Param_ids are used to note the placement of the argument in the midIL.var list
* Index_ids bind the shape of an Image or differentiation.
* Generally, we will refer to the following 
*dim:dimension of field V
* s: support of kernel H
* alpha: The alpha in <V_alpha * H^(deltas)>
* deltas: The deltas in <V_alpha * H^(deltas)>
* Vid:param_id for V
* hid:param_id for H
* nid: integer position param_id
* fid :fractional position param_id
*img-imginfo about V 
*)
             
             
val cnt = ref 0
fun genName prefix = let
val n = !cnt
in
cnt := n+1;
String.concat[prefix, "_", Int.toString n]
end


fun iterSx e=F.iterSx e
fun transformToIndexSpace e=T.transformToIndexSpace e
fun transformToImgSpace  e=T.transformToImgSpace  e
fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
fun testp n=(case testing
    of 0=> 1
    | _ =>(print(String.concat n);1)
    (*end case*))
fun getRHSDst x  = (case DstIL.Var.binding x
    of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
    | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
    | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
 (* end case *))


(* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
    uses the Param_ids for the image, kernel, and position tensor to get the Mid-IL arguments
  returns the support of ther kernel, and image
*)
 fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
    of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
    in
        ((Kernel.support h) ,img,ImageInfo.dim img)
    end
 |  _ => raise Fail "Expected Image and kernel arguments"
 (*end case*))


(*handleArgs():int*int*int*Mid IL.Var list ->int*Mid.ILVars list* code*int* low-il-var
* uses the Param_ids for the image, kernel, and tensor and gets the mid-IL vars for each
*Transforms the position to index space
*P-mid-il var for the (transformation matrix)transpose
*)
fun handleArgs(Vid,hid,tid,args)=let
    val imgArg=List.nth(args,Vid)
    val hArg=List.nth(args,hid)
    val newposArg=List.nth(args,tid)
    val (s,img,dim) =getArgsDst(hArg,imgArg,args)
    val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
    in (dim,args@argsT,code, s,P)
    end


(*createBody:int*int*int, index_id list, param_id, param_id, param_id, param_id
* expands the body for the probed field
*)
fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
    
    (*1-d fields*)
    fun createKRND1 ()=let
        val sum=sx
        val dels=List.map (fn e=>(E.C 0,e)) deltas
        val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
        val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
        in 
            E.Prod [E.Img(Vid,alpha,pos),rest]

        end
    (*createKRN Image field and kernels *)
    fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
    | createKRN(dim,imgpos,rest)=let
        val dim'=dim-1
        val sum=sx+dim'
        val dels=List.map (fn e=>(E.C dim',e)) deltas 
        val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
        val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
        in
            createKRN(dim',pos@imgpos,[rest']@rest)
        end
    val exp=(case dim
        of 1 => createKRND1()
        | _=> createKRN(dim, [],[])
        (*end case*))

    (*sumIndex creating summaiton Index for body*)
    val slb=1-s
    val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
in
    E.Sum(esum, exp)
end

(*getsumshift:sum_index_id list* index_id list-> int 
*get fresh/unused index_id, returns int 
*)
fun getsumshift(sx,index) =let
    val nsumshift= (case sx
        of []=> length(index)
        | _=>let
            val (E.V v,_,_)=List.hd(List.rev sx)
            in v+1
            end
        (* end case *))
    val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
    val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]
    in
        nsumshift
    end 

(*formBody:ein_exp->ein_exp
*just does a quick rewrite
*)
fun formBody(E.Sum([],e))=formBody e
| formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
| formBody(E.Prod [e])=e
| formBody e=e


(* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list :ein_exp* *code
* Transforms position to world space
* transforms result back to index_space
* rewrites body 
* replace probe with expanded version
*)
 fun replaceProbe(b,params,args,index, sx)=let

    val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
    val fid=length(params)
    val nid=fid+1
    val Pid=nid+1
    val nshift=length(dx)
    val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
    val freshIndex=getsumshift(sx,index)
    val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
    val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
    val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
    val body' =formBody(E.Sum(newsx1, E.Prod(Ps@[body'])))
    val args'=argsA@[PArg]
    in
        (body',params',args' ,code)
    end


(* expandEinOp: code->  code list
*Looks to see if the expression has a probe. If so, replaces it.
* Note how we keeps eps type expressions so we have less time in mid-to-low-il stage
*)
fun expandEinOp( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
    fun printResult code=testp["\nINSIDE PROBEEIN","\nbody",split.printEINAPP e, "\n=>\n",
    (String.concatWith",\t"(List.map split.printEINAPP code))]

    fun rewriteBody b=(case b
        of  E.Probe e =>let
            val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, [])
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
            in
                code
            end
        | E.Sum(sx,E.Probe e)  =>let
            val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
            val  body'=E.Sum(sx,body')
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
            in
                code
            end
        | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
            val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
            val  body'=E.Sum(sx,E.Prod[eps,body'])
            val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
            val code=newbies@[einapp]
            in
                code
            end
        | _=> [e]
        (* end case *))
    in
        rewriteBody body
    end



  end; (* local *)

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0