Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

View of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3092 - (download) (annotate)
Tue Mar 17 20:02:38 2015 UTC (4 years, 5 months ago) by cchiw
File size: 11162 byte(s)
change det 3x3
(* Expands probe ein
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure ProbeEin = struct

    local
   
    structure E = Ein
    structure DstIL = MidIL
    structure DstOp = MidOps
    structure P = Printer
    structure T = TransformEin
    structure MidToS = MidToString
    structure DstV = DstIL.Var
    structure DstTy = MidILTypes
 
    in

    (* This file expands probed fields
    * Take a look at ProbeEin tex file for examples
    *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
    * Param_ids are used to note the placement of the argument in the midIL.var list
    * Index_ids  keep track of the shape of an Image or differentiation.
    * Mu  bind Index_id
    * Generally, we will refer to the following
    *dim:dimension of field V
    * s: support of kernel H
    * alpha: The alpha in <V_alpha * H^(deltas)>
    * deltas: The deltas in <V_alpha * H^(deltas)>
    * Vid:param_id for V
    * hid:param_id for H
    * nid: integer position param_id
    * fid :fractional position param_id
    *img-imginfo about V 
    *)
        
    val testing=0
    val testlift=1
    val cnt = ref 0

    fun printEINAPP e=MidToString.printEINAPP e
    fun transformToIndexSpace e=T.transformToIndexSpace e
    fun transformToImgSpace  e=T.transformToImgSpace  e
        
    fun transitionToString(testreplace,a,b)=(case testreplace
        of 0=> 1
        | _ => (print(String.concat["\n\n\n Replace probe:\n",P.printbody a,"\n=>",P.printbody b]);1)
        (*end case*))
    fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}
    fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
    fun getBody(_,DstIL.EINAPP(E.EIN{body,...},_))=body
    fun setBody(body',(y,DstIL.EINAPP(E.EIN{params,index,body},args)))=
            (y,DstIL.EINAPP(E.EIN{params=params,index=index,body=body'},args))

    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))
    fun  einapptostring (body,a,b)=(case testlift
        of 0=>1
        | _=> (print(String.concat["\n lift probe of ",P.printbody body,"=>\n\t", printEINAPP a,  "&\n\t", printEINAPP b]);1)
        (*end case*))
        
        
    fun getRHSDst x  = (case DstIL.Var.binding x
        of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
        | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
        | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
    (* end case *))


    (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
        uses the Param_ids for the image, kernel,
        and position tensor to get the Mid-IL arguments
    returns the support of ther kernel, and image
    *)
    fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
        of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
            in
              ((Kernel.support h) ,img,ImageInfo.dim img)
            end
        |  _ => raise Fail "Expected Image and kernel arguments"
        (*end case*))


    (*handleArgs():int*int*int*Mid IL.Var list
        ->int*Mid.ILVars list* code*int* low-il-var
        * uses the Param_ids for the image, kernel, and tensor
        * and gets the mid-IL vars for each.
        *Transforms the position to index space
        *P is the mid-il var for the (transformation matrix)transpose
    *)
    fun handleArgs(Vid,hid,tid,args)=let
        val imgArg=List.nth(args,Vid)
        val hArg=List.nth(args,hid)
        val newposArg=List.nth(args,tid)
        val (s,img,dim) =getArgsDst(hArg,imgArg,args)
        val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
        in
            (dim,args@argsT,code, s,P)
        end

    (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
    * expands the body for the probed field
    *)
    fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let           
        (*1-d fields*)
        fun createKRND1 ()=let
            val sum=sx
            val dels=List.map (fn e=>(E.C 0,e)) deltas
            val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
            val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
            in 
                E.Prod [E.Img(Vid,alpha,pos),rest]
            end
        (*createKRN Image field and kernels *)
        fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
        | createKRN(dim,imgpos,rest)=let
            val dim'=dim-1
            val sum=sx+dim'
            val dels=List.map (fn e=>(E.C dim',e)) deltas 
            val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
            val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
            in
                createKRN(dim',pos@imgpos,[rest']@rest)
            end
        val exp=(case dim
            of 1 => createKRND1()
            | _=> createKRN(dim, [],[])
            (*end case*))
        (*sumIndex creating summaiton Index for body*)
        val slb=1-s
        val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
    in
        E.Sum(esum, exp)
    end

    (*getsumshift:sum_indexid list* int list-> int
    *get fresh/unused index_id, returns int 
    *)
    fun getsumshift(sx,index) =let
        val nsumshift= (case sx
            of []=> length(index)
            | _=>let
                val (E.V v,_,_)=List.hd(List.rev sx)
                in v+1
                end
            (* end case *))
        val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
        val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),
            "\nThink nshift is ", Int.toString nsumshift]
        in
            nsumshift
        end 

    (*formBody:ein_exp->ein_exp
    *just does a quick rewrite
    *)
    fun formBody(E.Sum([],e))=formBody e
    | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
    | formBody(E.Prod [e])=e
    | formBody e=e

    (* silly change in order of the product to match vis branch WorldtoSpace functions*)
    fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
      | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
        
    (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
            -> ein_exp* *code
    * Transforms position to world space
    * transforms result back to index_space
    * rewrites body 
    * replace probe with expanded version
    *)
(*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
        
     fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
        =let
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e
        
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val Pid=nid+1
        val nshift=length(dx)
        val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,index)
        val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
        val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
        val body' = multiPs(Ps,newsx1,body')
        
        val body'=(case originalb
            of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
            | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
            | _                                  => body'
            (*end case*))
        val _=transitionToString(testN,originalb,body')
        
        val args'=argsA@[PArg]
        val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
        in 
            code@[einapp]
        end
        
        
    fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
        val Pid=0
        val tid=1
        
        val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
        
        (*need to rewrite dx*)
        val (_,sizes,E.Conv(_,_,_,dx))=(case sx@newsx
            of []=> ([],index,E.Conv(9,alpha,7,newdx))
            | _ =>cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
            (*end case*))
                
        val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
        val tshape=alpha@newdx
        val t=E.Tensor(tid,tshape)
        val exp = multiPs(Ps,newsx,t)
        val body=(case originalb
            of E.Sum(sx, E.Probe _)              => E.Sum(sx,exp)
            | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,exp])
            | _                                  => exp
            (*end case*))
        
        val ein0=mkEin(params,index,body)
        in
            (ein0,sizes,dx)
        end
    
    fun liftProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)=let
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e 
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val nshift=length(dx)
        val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,index)
        
        
        (*transform T*P*P..Ps*)
        val (ein0,sizes,dx)= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
        val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
        val einApp0=mkEinApp(ein0,[PArg,FArg])
        val rtn0=(y,einApp0)
        
        (*lifted probe*)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
        val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)        
        val ein1=mkEin(params',sizes,body')
        val einApp1=mkEinApp(ein1,args')
        val rtn1=(FArg,einApp1)
        val rtn=code@[rtn1,rtn0]
        val _= einapptostring (p,rtn1,rtn0)
        in
            rtn
        end
        

    (* expandEinOp: code->  code list
    *A this point we only have simple ein ops
    *Looks to see if the expression has a probe. If so, replaces it.
    * Note how we keeps eps expressions so only generate pieces that are used
    *)
    fun expandEinOp( e as (y, DstIL.EINAPP(ein,args))) = let
        fun checkConst ([],a) = liftProbe a
        | checkConst ((E.C _::_),a) =replaceProbe a
        | checkConst ((_ ::es),a)=checkConst(es,a)
        fun rewriteBody b=(case b
            of E.Probe(E.Conv(_,_,_,[]),_)
                => replaceProbe(1,e,b, [])
            | E.Probe(E.Conv (_,alpha,_,dx),_)
                => checkConst(alpha@dx,(0,e,b,[]))
            | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
                => replaceProbe(1,e,p, sx)
            | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))
                => checkConst(dx,(0,e,p,sx))
            | E.Sum(sx,E.Probe p)
                => replaceProbe(1,e,E.Probe p, sx)
            | E.Sum(sx,E.Prod[eps,E.Probe p])
                => replaceProbe(1,e,E.Probe p,sx)
            | _ => [e]
            (* end case *))
        in
            rewriteBody (Ein.body ein)
        end

  end; (* local *)

end (* local *)  

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0