Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3440 - (download) (annotate)
Mon Nov 16 19:16:54 2015 UTC (3 years, 10 months ago) by cchiw
File size: 19808 byte(s)
added trig
(* Expands probe ein
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure ProbeEin = struct

    local
   
    structure E = Ein
    structure DstIL = MidIL
    structure DstOp = MidOps
    structure P = Printer
    structure T = TransformEin
    structure MidToS = MidToString
    structure DstV = DstIL.Var
    structure DstTy = MidILTypes
 
    in

    (* This file expands probed fields
    * Take a look at ProbeEin tex file for examples
    *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
    * Param_ids are used to note the placement of the argument in the midIL.var list
    * Index_ids  keep track of the shape of an Image or differentiation.
    * Mu  bind Index_id
    * Generally, we will refer to the following
    *dim:dimension of field V
    * s: support of kernel H
    * alpha: The alpha in <V_alpha * H^(deltas)>
    * deltas: The deltas in <V_alpha * H^(deltas)>
    * Vid:param_id for V
    * hid:param_id for H
    * nid: integer position param_id
    * fid :fractional position param_id
    * img-imginfo about V  
    *)
        
    val testing=0
    val testlift=1
    val detflag =true 
    val fieldliftflag=true
    val valnumflag=true
    
    
    val cnt = ref 0
    fun transformToIndexSpace e=T.transformToIndexSpace e
    fun transformToImgSpace  e=T.transformToImgSpace  e
    fun toStringBind e=(MidToString.toStringBind e)
    fun mkEin e=Ein.mkEin e
    fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)

    fun testp n=(case testing
        of 0=> 1
        | _ =>((String.concat n);1)
        (*end case*))


    fun getRHSDst x  = (case DstIL.Var.binding x
        of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
        | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
        | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
    (* end case *))


    (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
        uses the Param_ids for the image, kernel,
        and position tensor to get the Mid-IL arguments
    returns the support of ther kernel, and image
    *)
    fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
        of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> let
            in
              ((Kernel.support h) ,img,ImageInfo.dim img)
            end
        |  _ => raise Fail "Expected Image and kernel arguments"
        (*end case*))


    (*handleArgs():int*int*int*Mid IL.Var list
        ->int*Mid.ILVars list* code*int* low-il-var
        * uses the Param_ids for the image, kernel, and tensor
        * and gets the mid-IL vars for each.
        *Transforms the position to index space
        *P is the mid-il var for the (transformation matrix)transpose
    *)
    fun handleArgs(Vid,hid,tid,args)=let
        val imgArg=List.nth(args,Vid)
        val hArg=List.nth(args,hid)
        val newposArg=List.nth(args,tid)
        val (s,img,dim) =getArgsDst(hArg,imgArg,args)
        val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
        in
            (dim,args@argsT,code, s,P)
        end

    (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
    * expands the body for the probed field
    *)
    fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let           
        (*1-d fields*)
        fun createKRND1 ()=let
            val sum=sx
            val dels=List.map (fn e=>(E.C 0,e)) deltas
            val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
            val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
            in 
                E.Prod [E.Img(Vid,alpha,pos),rest]
            end
        (*createKRN Image field and kernels *)
        fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
        | createKRN(dim,imgpos,rest)=let
            val dim'=dim-1
            val sum=sx+dim'
            val dels=List.map (fn e=>(E.C dim',e)) deltas 
            val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
            val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
            in
                createKRN(dim',pos@imgpos,[rest']@rest)
            end
        val exp=(case dim
            of 1 => createKRND1()
            | _=> createKRN(dim, [],[])
            (*end case*))
        (*sumIndex creating summaiton Index for body*)
        val slb=1-s
        val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
        val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
    in
        E.Sum(esum, exp)
    end

    (*getsumshift:sum_indexid list* int list-> int
    *get fresh/unused index_id, returns int 
    *)
    fun getsumshift(sx,n) =let
        val nsumshift= (case sx
            of []=> n
            | _=>let
                val (E.V v,_,_)=List.hd(List.rev sx)
                in v+1
                end
            (* end case *))
      
        val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
        val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
        "\n\t Index length:",Int.toString n,
        "\n\t Freshindex: ", Int.toString nsumshift])
        in
            nsumshift
        end 

    (*formBody:ein_exp->ein_exp
    *just does a quick rewrite
    *)
    fun formBody(E.Sum([],e))=formBody e
    | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
    | formBody(E.Prod [e])=e
    | formBody e=e

    (* silly change in order of the product to match vis branch WorldtoSpace functions*)
    fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
    (*
      | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, E.Prod([P0,body,P1])))
      *)
      | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,P3,body])))
      | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
    
        
    fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],E.Prod[P0,E.Sum([sx1],E.Prod[P1,body])])
      | multiMergePs e=multiPs e 
        
        
    (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
            -> ein_exp* *code
    * Transforms position to world space
    * transforms result back to index_space
    * rewrites body 
    * replace probe with expanded version
    *)
(*    fun replaceProbe(testN,y,originalb,b,params,args,index, sx)*)
        
     fun replaceProbe(testN,(y, DstIL.EINAPP(e,args)),p ,sx)
        =let
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e
        val _ = testp["\n***************** \n Replace ************ \n"]
        val _=  toStringBind (y, DstIL.EINAPP(e,args))
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val Pid=nid+1
        val nshift=length(dx)
        val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,length(index))
        val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
        val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
        val body' = multiPs(Ps,newsx1,body')
        
        val body'=(case originalb
            of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
            | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => E.Sum(sx,E.Prod[eps0,body'])
            | _                                  => body'
            (*end case*))
   
        
        val args'=argsA@[PArg]
        val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
        in 
            code@[einapp]
        end
        
    val tsplitvar=true
    fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
        val Pid=0
        val tid=1
      
        (*Assumes body is already clean*)
        val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
        
        (*need to rewrite dx*)
        val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
            of []=> ([],index,E.Conv(9,alpha,7,newdx))
            | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
            (*end case*))
                
        val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
        fun filterAlpha []=[]
          | filterAlpha(E.C _::es)= filterAlpha es
          | filterAlpha(e1::es)=[e1]@(filterAlpha es)
        
        val tshape=filterAlpha(alpha')@newdx
        val t=E.Tensor(tid,tshape)
        val (splitvar,body)=(case originalb
            of E.Sum(sx, E.Probe _)              => (*(false,E.Sum(sx,multiPs(Ps,newsx,t)))*) (true,multiPs(Ps,sx@newsx,t))
            | E.Sum(sx,E.Prod[eps0,E.Probe _ ])  => (false,E.Sum(sx,E.Prod[eps0,multiPs(Ps,newsx,t)]))
            | _                                  => (case tsplitvar
                of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
                | false*) _ =>   (true,multiPs(Ps,newsx,t))
                (*end case*))
            (*end case*))
        
        val _ =(case splitvar
        of true=> (String.concat["splitvar is true", P.printbody body])
        | _ => (String.concat["splitvar is false",P.printbody body])
        (*end case*))
        
        
        val ein0=mkEin(params,index,body)
        in
            (splitvar,ein0,sizes,dx,alpha')
        end
    
    fun liftProbe(printStrings,(y, DstIL.EINAPP(e,args)),p ,sx)=let
        val _=(String.concat["\n******* Lift Geneirc Probe ***\n"])
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e
        val _ =  (toStringBind (y, DstIL.EINAPP(e,args)))
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val nshift=length(dx)
        val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,length(index))
        
        (*transform T*P*P..Ps*)
        val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
        val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
        val einApp0=mkEinApp(ein0,[PArg,FArg])
        val rtn0=(case splitvar
            of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
            | _      => let
                 val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
                 in Split.splitEinApp bind3 
                 end
            (*end case*))
        
        (*lifted probe*)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
        val freshIndex'= length(sizes)
            
        val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
        val ein1=mkEin(params',sizes,body')
        val einApp1=mkEinApp(ein1,args')
        val rtn1=(FArg,einApp1)
        val rtn=code@[rtn1]@rtn0
        val _= List.map toStringBind ([rtn1]@rtn0)
         val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])
        in
            rtn
        end

    fun searchFullField (fieldset,code1,body1,dx)=let
        val (lhs,_)=code1
        fun continueReconstruction ()=let
            val _=print"Tash:don't replaced"
            in (case dx
                of []=> (lhs,replaceProbe(1,code1,body1,[]))
                | _ =>(lhs,liftProbe(1,code1,body1,[]))
                (*end case*))
             end
        in  (case valnumflag
            of false => (fieldset,continueReconstruction())
            | true => (case  (einSet.rtnVarN(fieldset,code1))
                of (fieldset,NONE)     => (fieldset,continueReconstruction())
                 | (fieldset,SOME m)   =>(print"TASH:replaced"; (fieldset,(m,[])))
                (*end case*))
            (*end case*))
        end

    fun liftFieldMat(newvx,e)=
        let
            val _=print "\n ***************************** start FieldMat\n"
            val (y, DstIL.EINAPP(ein,args))=e
            val E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=Ein.body ein
            val index0=Ein.index ein
            val index1 = index0@[3]
            val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
            (* clean to get body indices in order *)
            val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
            val _ = testp ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1]

            val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
            val ein1 = mkEin(Ein.params ein,index1,body1)
            val code1= (lhs1,mkEinApp(ein1,args))
            val codeAll= (case dx
            of []=> replaceProbe(1,code1,body1,[])
            | _ =>liftProbe(1,code1,body1,[])
            (*end case*))

            (*Probe that tensor at a constant position  c1*)
            val param0 = [E.TEN(1,index1)]
            val nx=List.tabulate(length(dx)+1,fn n=>E.V n)
            val body0 =  E.Tensor(0,[c1]@nx)
            val ein0 = mkEin(param0,index0,body0)
            val einApp0 = mkEinApp(ein0,[lhs1])
            val code0 = (y,einApp0)
            val _= toStringBind code0
                    val _=print "\n end FieldMat *****************************\n "
        in
            codeAll@[code0]
    end
    
    fun liftFieldVec(newvx,e,fieldset)=
    let
        val _=print "\n ***************************** start FieldVec\n"
        val (y, DstIL.EINAPP(ein,args))=e
        val E.Probe(E.Conv(V,[c1],h,dx),pos)=Ein.body ein
        val index0=Ein.index ein
        val index1 = index0@[3]
        val body1_unshifted = E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
        (* clean to get body indices in order *)
        val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
      
        
        val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
        val ein1 = mkEin(Ein.params ein,index1,body1)
        val code1= (lhs1,mkEinApp(ein1,args))
        val (fieldset,(lhs0,codeAll))=searchFullField (fieldset,code1,body1,dx)
        
        (*Probe that tensor at a constant position  c1*)
        val param0 = [E.TEN(1,index1)]
        val nx=List.tabulate(length(dx),fn n=>E.V n)
        val body0 =  E.Tensor(0,[c1]@nx)
        val ein0 = mkEin(param0,index0,body0)
        val einApp0 = mkEinApp(ein0,[lhs0])
        val code0 = (y,einApp0)
        
        val _ = (String.concat ["\n Shifted ",P.printbody body1_unshifted,"=>",P.printbody body1])
        val _=  (toStringBind code0)
        val _=print "\n end FieldVec *****************************\n "
        in
            codeAll@[code0]
    end
    
    
    
    fun liftFieldSum e =
    let
        val _=print "\n************************************* Start Lift Field Sum\n"
        val (y, DstIL.EINAPP(ein,args))=e
        val E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=Ein.body ein
        val index0=Ein.index ein
        val index1 = index0@[3]@[3]
        val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
        val body1 = E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)


        val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
        val ein1 = mkEin(Ein.params ein,index1,body1)
        val code1= (lhs1,mkEinApp(ein1,args))
        val codeAll= (case dx
            of []=> replaceProbe(1,code1,body1,[])
            | _ =>liftProbe(1,code1,body1,[])
            (*end case*))
        
        (*Probe that tensor at a constant position  c1*)
        val param0 = [E.TEN(1,index1)]
        val nx=List.tabulate(length(dx),fn n=>E.V n)
        val body0 =  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
        val ein0 = mkEin(param0,index0,body0)
        val einApp0 = mkEinApp(ein0,[lhs1])
        val code0 = (y,einApp0)
        val _ = toStringBind  e
        val _ = toStringBind code0
        val _  = (String.concat  ["\norig",P.printbody(Ein.body ein),"\n replace i  ",P.printbody body1,"\nfreshtensor",P.printbody body0])
        val _  =((List.map toStringBind (codeAll@[code0])))
        val _ = print "\n*** end Field Sum*************************************\n"
        in
        codeAll@[code0]
    end
    

    (* expandEinOp: code->  code list
    * A this point we only have simple ein ops
    * Looks to see if the expression has a probe. If so, replaces it.
    * Note how we keeps eps expressions so only generate pieces that are used
    *)
   fun expandEinOp( e as (y, DstIL.EINAPP(ein,args)),fieldset)=let

        fun (*checkConst ([],a) =
            (case fieldliftflag
                of true => liftProbe a
                | _ => replaceProbe a
            (*end case*))
        | checkConst ((E.C _::_),a) = replaceProbe a
        | checkConst ((_ ::es),a)= checkConst(es,a)
        *)
        checkConst (_,a) = liftProbe a
        fun rewriteBody b=(case (detflag,b)
            of (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[]),pos))
                => liftFieldMat (1,e)
            | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1]),pos))
                => liftFieldMat (2,e)
            | (true,E.Probe(E.Conv(_,[E.C _ ,E.V 0],_,[E.V 1,E.V 2] ),pos))
                => liftFieldMat (3,e)
            | (true, E.Sum([(E.V 0,0,_)],E.Probe(E.Conv(_,[E.V 0 ,E.V 0],_,[]),pos)))
                => liftFieldSum e
            | (true, E.Sum([(E.V 1,0,_)],E.Probe(E.Conv(_,[E.V 1 ,E.V 1],_,[E.V 0]),pos)))
                => liftFieldSum e
            | (true, E.Sum([(E.V 2,0,_)],E.Probe(E.Conv(_,[E.V 2 ,E.V 2],_,[E.V 0,E.V 1]),pos)))
                => liftFieldSum e
            | (true,E.Probe(E.Conv(_,[E.C _ ],_,[]),pos))
                => liftFieldVec (0,e,fieldset)
            | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0]),pos))
                => liftFieldVec (1,e,fieldset)
            | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1] ),pos))
                => liftFieldVec (2,e,fieldset)
            | (true,E.Probe(E.Conv(_,[E.C _],_,[E.V 0,E.V 1,E.V 2] ),pos))
                => liftFieldVec (3,e,fieldset)
            | (_,E.Probe(E.Conv(_,_,_,[]),_))
                => replaceProbe(0,e,b,[])
            | (_,E.Probe(E.Conv (_,alpha,_,dx),_))
                => checkConst(dx,(0,e,b,[])) (*scans dx for contant*)
            | (_,E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
                => replaceProbe(0,e,p, sx)  (*no dx*)
            | (_,E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
                => checkConst(dx,(0,e,p,sx)) (*scalar field*)
            | (_,E.Sum(sx,E.Probe p))
                => replaceProbe(0,e,E.Probe p, sx)
            | (_,E.Sum(sx,E.Prod[eps,E.Probe p]))
                => replaceProbe(0,e,E.Probe p,sx)
            | (_,_) => [e]
            (* end case *))
        
        val (fieldset,var) = (case valnumflag
            of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
            | _     => (fieldset,NONE)
        (*end case*))
        
        fun matchField b=(case b
            of E.Probe _ => 1
            | E.Sum (_, E.Probe _)=>1
            | E.Sum(_, E.Prod[ _ ,E.Probe _])=>1
            | _ =>0
            (*end case*))
        fun toStrField b=(case b
            of E.Probe _ => print("\n"^(P.printbody b))
            | E.Sum (_, E.Probe _)=>print("\n"^ (P.printbody b))
            | E.Sum(_, E.Prod[ _ ,E.Probe _])=>print("\n"^ (P.printbody b))
            | _ => print""
            (*end case*))
            val b=Ein.body ein
        
        val _=  toStrField b

        in  (case var
        of NONE=> ((rewriteBody(Ein.body ein),fieldset,matchField(Ein.body ein),0))
            | SOME v=> (("\n mapp_replacing"^(P.printerE ein)^":");( [(y,DstIL.VAR v)],fieldset, matchField(Ein.body ein),1))
            (*end case*))
        end

  end; (* local *)

end (* local *)  

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0