Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3271 - (download) (annotate)
Fri Oct 9 18:12:58 2015 UTC (3 years, 11 months ago) by cchiw
File size: 13735 byte(s)
lifted substitution
(* Expands probe ein
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure ProbeEin = struct

    local
   
    structure E = Ein
    structure DstIL = MidIL
    structure DstOp = MidOps
    structure P = Printer
    structure T = TransformEin
    structure MidToS = MidToString
    structure DstV = DstIL.Var
    structure DstTy = MidILTypes
 
    in

    (* This file expands probed fields
    * Take a look at ProbeEin tex file for examples
    *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
    * Param_ids are used to note the placement of the argument in the midIL.var list
    * Index_ids  keep track of the shape of an Image or differentiation.
    * Mu  bind Index_id
    * Generally, we will refer to the following
    *dim:dimension of field V
    * s: support of kernel H
    * alpha: The alpha in <V_alpha * H^(deltas)>
    * deltas: The deltas in <V_alpha * H^(deltas)>
    * Vid:param_id for V
    * hid:param_id for H
    * nid: integer position param_id
    * fid :fractional position param_id
    * img-imginfo about V  
    *)
        
    val testing=0
    val testlift=0
        
    val cnt = ref 0
    fun transformToIndexSpace e=T.transformToIndexSpace e
    fun transformToImgSpace  e=T.transformToImgSpace  e
    fun toStringBind e=(MidToString.toStringBind e)
    fun mkEin e=Ein.mkEin e
    fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)

    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))


    fun getRHSDst x  = (case DstIL.Var.binding x
        of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
        | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
        | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
    (* end case *))


    (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
        uses the Param_ids for the image, kernel,
        and position tensor to get the Mid-IL arguments
    returns the support of ther kernel, and image
    *)
    fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
        of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
            in
              ((Kernel.support h) ,img,ImageInfo.dim img)
            end
        |  _ => raise Fail "Expected Image and kernel arguments"
        (*end case*))


    (*handleArgs():int*int*int*Mid IL.Var list
        ->int*Mid.ILVars list* code*int* low-il-var
        * uses the Param_ids for the image, kernel, and tensor
        * and gets the mid-IL vars for each.
        *Transforms the position to index space
        *P is the mid-il var for the (transformation matrix)transpose
    *)
    fun handleArgs(Vid,hid,tid,args)=let
        val imgArg=List.nth(args,Vid)
        val hArg=List.nth(args,hid)
        val newposArg=List.nth(args,tid)
        val (s,img,dim) =getArgsDst(hArg,imgArg,args)
        val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
        in
            (dim,args@argsT,code, s,P)
        end

    (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
    * expands the body for the probed field
    *)
    fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let           
        (*1-d fields*)
        fun createKRND1 ()=let
            val sum=sx
            val dels=List.map (fn e=>(E.C 0,e)) deltas
            val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
            val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
            in 
                E.Prod [E.Img(Vid,alpha,pos),rest]
            end
        (*createKRN Image field and kernels *)
        fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
        | createKRN(dim,imgpos,rest)=let
            val dim'=dim-1
            val sum=sx+dim'
            val dels=List.map (fn e=>(E.C dim',e)) deltas 
            val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
            val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
            in
                createKRN(dim',pos@imgpos,[rest']@rest)
            end
        val exp=(case dim
            of 1 => createKRND1()
            | _=> createKRN(dim, [],[])
            (*end case*))
        (*sumIndex creating summaiton Index for body*)
        val slb=1-s
        val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
    in
        E.Sum(esum, exp)
    end

    (*getsumshift:sum_indexid list* int list-> int
    *get fresh/unused index_id, returns int 
    *)
    fun getsumshift(sx,index) =let
        val nsumshift= (case sx
            of []=> length(index)
            | _=>let
                val (E.V v,_,_)=List.hd(List.rev sx)
                in v+1
                end
            (* end case *))
        val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
        val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),
            "\nThink nshift is ", Int.toString nsumshift]
        in
            nsumshift
        end 

    (*formBody:ein_exp->ein_exp
    *just does a quick rewrite
    *)
    fun formBody(E.Sum([],e))=formBody e
    | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
    | formBody(E.Prod [e])=e
    | formBody e=e

    (* silly change in order of the product to match vis branch WorldtoSpace functions*)
    fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
      | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
    
        
    fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
      | multiMergePs e=multiPs e 
        
        
    (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
            -> ein_exp* *code
    * Transforms position to world space
    * transforms result back to index_space
    * rewrites body 
    * replace probe with expanded version
    *)
(*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
        
     fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
        =let
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e
        val _ = testp["\n***************** \n Replace ************ \n"]
        val _=  toStringBind (y, DstIL.EINAPP(e,args))
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val Pid=nid+1
        val nshift=length(dx)
        val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,index)
        val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
        val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
        val body' = multiPs(Ps,newsx1,body')
        
        val body'=(case originalb
            of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
            | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
            | _                                  => body'
            (*end case*))
   
        
        val args'=argsA@[PArg]
        val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
        in 
            code@[einapp]
        end
        
    val tsplitvar=true
    fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
        val Pid=0
        val tid=1
      
        (*Assumes body is already clean*)
        val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
        
        (*need to rewrite dx*)
        val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
            of []=> ([],index,E.Conv(9,alpha,7,newdx))
            | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
            (*end case*))
                
        val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
        fun filterAlpha []=[]
          | filterAlpha(E.C _::es)= filterAlpha es
          | filterAlpha(e1::es)=[e1]@(filterAlpha es)
        
        val tshape=filterAlpha(alpha')@newdx
        val t=E.Tensor(tid,tshape)
        val (splitvar,body)=(case originalb
            of E.Sum(sx, E.Probe _)              => (false,E.Sum(sx,multiPs(Ps,newsx,t)))
            | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
            | _                                  => (case tsplitvar
                of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
                | false*) _ =>   (true,multiPs(Ps,newsx,t))
                (*end case*))
            (*end case*))
        
        val ein0=mkEin(params,index,body)
        in
            (splitvar,ein0,sizes,dx,alpha')
        end
    
    fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
        val _=testp["\n******* Lift ******** \n"]
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e
        val _=  toStringBind (y, DstIL.EINAPP(e,args))
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val nshift=length(dx)
        val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,index)
        
        (*transform T*P*P..Ps*)
        val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
        val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
        val einApp0=mkEinApp(ein0,[PArg,FArg])
        val rtn0=(case splitvar
            of false => [(y,einApp0)]
            | _      => Split.splitEinApp (y,einApp0)
            (*end case*))
        
        (*lifted probe*)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
        val body' = createBody(dim, s,freshIndex+nshift,alpha',dx,Vid, hid, nid, fid)
        val ein1=mkEin(params',sizes,body')
        val einApp1=mkEinApp(ein1,args')
        val rtn1=(FArg,einApp1)
        val rtn=code@[rtn1]@rtn0
        val _= List.map toStringBind ([rtn1]@rtn0)
        
        in
            rtn
        end

    (* expandEinOp: code->  code list
    * A this point we only have simple ein ops
    * Looks to see if the expression has a probe. If so, replaces it.
    * Note how we keeps eps expressions so only generate pieces that are used
    *)
   fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let

        fun checkConst ([],a) = liftProbe a
        | checkConst ((E.C _::_),a) = replaceProbe a
        | checkConst ((_ ::es),a)= checkConst(es,a)

        fun liftFieldMat(newvx,E.Probe(E.Conv(V,[E.C c1,E.V 0],h,dx),pos))=
            let

                val _= toStringBind e
                val index0=Ein.index ein
                val index1 = index0@[3]
                val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, E.V 0],h,dx),pos)
                 (* clean to get body indices in order *)
                val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
                val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]

                val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
                val ein1 = mkEin(Ein.params ein,index1,body1)
                val code1= (lhs1,mkEinApp(ein1,args))
                val codeAll= (case dx
                    of []=> replaceProbe(1,code1,body1,[])
                    | _ =>liftProbe(1,code1,body1,[])
                (*end case*))
                
                (*Probe that tensor at a constant position E.C c1*)
                val param0 = [E.TEN(1,index1)]
                val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
                val body0 =  E.Tensor(0,[E.C c1]@nx)
                val ein0 = mkEin(param0,index0,body0)
                val einApp0 = mkEinApp(ein0,[lhs1])
                val code0 = (y,einApp0)
                val _= toStringBind code0
            in
                codeAll@[code0]
            end
        
        fun rewriteBody b=(case b
            of (*E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos)
                => liftFieldMat (1,b)
            | E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos)
                => liftFieldMat (2,b)
            | E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos)
                => liftFieldMat (3,b)
            |*) E.Probe(E.Conv(_,_,_,[]),_)
                => replaceProbe(0,e,b,[])
            | E.Probe(E.Conv (_,alpha,_,dx),_)
                => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
            | E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_))
                => replaceProbe(0,e,p, sx)  (*no dx*)
            | E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_))
                => checkConst(dx,(0,e,p,sx)) (*scalar field*)
            | E.Sum(sx,E.Probe p)
                => replaceProbe(0,e,E.Probe p, sx)
            | E.Sum(sx,E.Prod[eps,E.Probe p])
                => replaceProbe(0,e,E.Probe p,sx)
            | _ => [e]
            (* end case *))
        
        val (fieldset,var) = einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
        
        fun matchField b=(case b
            of E.Probe _ => 1
            | E.Sum (_, E.Probe _)=>1
            | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
            | _ =>0
            (*end case*))
            
        in  (case var
            of NONE=> (("\n \n mapp_not_replacing:"^(P.printerE ein)^":");(rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
            | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
            (*end case*))
        end

  end; (* local *)

end (* local *)  

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0