Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3668 - (download) (annotate)
Sun Feb 7 16:11:21 2016 UTC (3 years, 6 months ago) by cchiw
File size: 29801 byte(s)
DVF
(* Expands probe ein
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure ProbeEin = struct

    local
   
    structure E = Ein
    structure DstIL = MidIL
    structure DstOp = MidOps
    structure P = Printer
    structure T = TransformEin
    structure MidToS = MidToString
    structure DstV = DstIL.Var
    structure DstTy = MidILTypes
 
    in

    (* This file expands probed fields
    * Take a look at ProbeEin tex file for examples
    *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
    * Param_ids are used to note the placement of the argument in the midIL.var list
    * Index_ids  keep track of the shape of an Image or differentiation.
    * Mu  bind Index_id
    * Generally, we will refer to the following
    *dim:dimension of field V
    * s: support of kernel H
    * alpha: The alpha in <V_alpha * H^(deltas)>
    * deltas: The deltas in <V_alpha * H^(deltas)>
    * Vid:param_id for V
    * hid:param_id for H
    * nid: integer position param_id
    * fid :fractional position param_id
    * img-imginfo about V  
    *)
        
    val testing=0
    val valnumflag= false
    val tsplitvar = false
    val fieldliftflag=  false
    val constflag =  false
    val detflag = false
    val detsumflag= false
    fun incUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt + 1)
    fun decUse (DstIL.V{useCnt, ...}) = (useCnt := !useCnt - 1)
    
    val liftimgflag =false
    val pullKrn= false
            
    val cnt = ref 0
    fun transformToIndexSpace e=T.transformToIndexSpace e
    fun transformToImgSpace  e=T.transformToImgSpace  e
    fun transformToImgSpaceF  e=T.transformToImgSpaceF  e
    fun toStringBind e=(MidToString.toStringBind e)
    fun toStringBindp e=(MidToString.toStringBind e)
    fun mkEin e=Ein.mkEin e
    fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
    fun setConst e = E.setConst e
    fun setNeg e  =  E.setNeg e
    fun setExp e  =  E.setExp e
    fun setDiv e= E.setDiv e
    fun setSub e= E.setSub e
    fun setProd e= E.setProd e
    fun setAdd e= E.setAdd e
    fun mkCx es =List.map (fn c => E.C (c,true)) es
    fun mkCxSingle c = E.C (c,true)
    
    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))


    fun getRHSDst x  = (case DstIL.Var.binding x
        of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
        | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
        | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
    (* end case *))


    (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
        uses the Param_ids for the image, kernel,
        and position tensor to get the Mid-IL arguments
    returns the support of ther kernel, and image
    *)
    fun getArgsDst(hArg,imgArg,args) = (case (getRHSDst hArg, getRHSDst imgArg)
        of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ ))=> ((Kernel.support h) ,img,ImageInfo.dim img)
        |  ((k,_),(i,_)) => raise Fail (String.concat["Expected kernel:", (DstOp.toString k ),"Expected Image:", (DstOp.toString i)])
        (*end case*))


    (*handleArgs():int*int*int*Mid IL.Var list
        ->int*Mid.ILVars list* code*int* low-il-var
        * uses the Param_ids for the image, kernel, and tensor
        * and gets the mid-IL vars for each.
        *Transforms the position to index space
        *P is the mid-il var for the (transformation matrix)transpose
    *)
    fun handleArgs(Vid,hid,tid,args)=let
        val imgArg=List.nth(args,Vid)
        val hArg=List.nth(args,hid)
        val newposArg=List.nth(args,tid)
        val (s,img,dim) =getArgsDst(hArg,imgArg,args)
        val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
        in
            (dim,args@argsT,code, s,P)
        end
        
    fun handleArgsF(fieldset,Vid,hid,tid,args)=let
        val imgArg=List.nth(args,Vid)
        val hArg=List.nth(args,hid)
        val newposArg=List.nth(args,tid)
        val (s,img,dim) =getArgsDst(hArg,imgArg,args)
        val (fieldset,argsT,P,code)=transformToImgSpaceF(fieldset,dim,img,newposArg,imgArg)
        in
            (fieldset,dim,args@argsT,code, s,P)
        end
        



    (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
    * expands the body for the probed field
    *)
    fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let           
        (*1-d fields*)
        fun createKRND1 ()=let
            val sum=sx
            val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
            val pos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
            val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
            in 
               setProd[E.Img(Vid,alpha,pos),rest]
            end
            
        fun mkImg(imgpos)=E.Img(Vid,alpha,imgpos)
        
        (*createKRN Image field and kernels *)
        fun createKRN(0,imgpos,rest)=setProd ([mkImg(imgpos)] @rest)
        | createKRN(dim,imgpos,rest)=let
            val dim'=dim-1
            val sum=sx+dim'
            val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
            val pos=[setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(sum)]]
            val rest'= E.Krn(hid,dels,setSub(E.Tensor(nid,[mkCxSingle dim']),E.Value(sum)))
            in
                createKRN(dim',pos@imgpos,[rest']@rest)
            end
        val exp=(case dim
            of 1 => createKRND1()
            | _=> createKRN(dim, [],[])
            (*end case*))
        (*sumIndex creating summaiton Index for body*)
        val slb=1-s
        val _=List.tabulate(dim, (fn dim=> (String.concat[" sx:",Int.toString(sx)," dim:",Int.toString(dim),"esum",Int.toString(sx+dim) ]) ))
        val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
    in
        E.Sum(esum, exp)
    end
   


    (* build position *)
    fun buildPos (dir,dim,argsA,hid,nid,s) =let
        val vA  = DstV.new ("kernel_pos", DstTy.TensorTy([]))
        val p=[E.KRN,E.TEN(1,[dim])]
        val pos=setSub(E.Tensor(1,[mkCxSingle dir]),E.Value(0))
        val exp= E.BuildPos(s,pos)
        (*val exp = E.Sum([(E.V 0, slb, s)],E.Krn(0,[],pos))*)
        val a=[List.nth(argsA,hid),List.nth(argsA,nid)]
        val A=(vA,mkEinApp(mkEin(p,[],exp),a))
        in (vA,A) end

    (* apply differentiation *)
    fun getKrn1Del(dx,dim,args,slb,s)= let
        val n=Int.toString(dx)
        val vA  = DstV.new ("kernel_del"^n, DstTy.TensorTy([]))
        val p=[E.KRN,E.TEN(1,[dim])]
        val exp = E.EvalKrn dx
        val A = (vA,mkEinApp(mkEin(p,[],exp),args))
        in (vA,A) end

    (*create holder expression*)
    fun mkHolder(dim,args) =let
        val n=List.length(args)
        val vA  = DstV.new ("kernel_cons", DstTy.TensorTy([n]))
        val p=[E.KRN,E.TEN(1,[dim])]
        val A= (vA,mkEinApp(mkEin(p,[],E.Holder n),args))
        in (vA,A) end
        
    (*lifted Kernel expressions*)
    fun liftKrn(dx,dir,dim,argsA,hid,nid,slb,s)=let
        val  (vA,A)=buildPos(dir,dim,argsA,hid,nid,s)
        val args=[List.nth(argsA,hid),vA]
        fun iter(0,vBs,Bs)=let
            val (vA,A)=getKrn1Del(0,dim,args,slb,s)
            in (vA::vBs,A::Bs) end
          | iter (n,vBs,Bs)= let
            val (vA,A)=getKrn1Del(n,dim,args,slb,s)
            in iter(n-1,vA::vBs,A::Bs) end
        val (vBs,Bs)=iter(length(dx),[],[])
        val (vC,C)  =mkHolder(dim,vBs)
        in (vC,(A::Bs)@[C]) end
        


    fun createBody2(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=let
        (*1-d fields*)
        val slb=1-s


        (*making image*)
        val tid=(case liftimgflag
            of true => length(params)-1
            | _     => length(params)-1
        (*end case*))
        fun mkImg imgpos =(case liftimgflag
            of true=>(E.Tensor(Vid,alpha),SOME(E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim),slb,s))),E.Img(Vid,alpha,imgpos))))
            |  _ =>let
                val imgpos= List.tabulate(dim,fn e=> setAdd[E.Tensor(fid,[mkCxSingle e]),E.Value(e+sx)])
                in (E.Img(Vid,alpha,imgpos),NONE) end
        (*end case*))

        fun createKRND1 ()=let
            val sum=sx
            val dels=List.map (fn e=>(mkCxSingle 0,e)) deltas
            val imgpos=[setAdd[E.Tensor(fid,[]),E.Value(sum)]]
            val rest= E.Krn(hid,dels,setSub(E.Tensor(nid,[]),E.Value(sum)))
            val (talpha,iexp)= mkImg imgpos
            in (setProd[talpha,rest],iexp,NONE,NONE)end

        (*createKRN Image field and kernels *)
        fun createKRN(0,orig,imgpos,vAs,krnpos)= let
            val (talpha,iexp)= mkImg imgpos
            in  (setProd ([talpha]@orig),iexp,SOME vAs,SOME krnpos) end
        | createKRN(d,orig,imgpos,vAs,krnpos)=let
            val dim'=d-1
            val sum=sx+dim'
            val dels=List.map (fn e=>(mkCxSingle  dim',e)) deltas
            val ipos=setAdd[E.Tensor(fid,[mkCxSingle dim']),E.Value(dim')]
            val opos= E.Krn(hid,dels,E.Tensor(tid+d,[]))
            val (vA,A)= liftKrn(dels,dim',dim,argsA,hid,nid,slb,s)
            in
                createKRN(dim',[opos]@orig,[ipos]@imgpos,[vA]@vAs,A@krnpos)
            end

        val (oexp,iexp,vAs,keinapp)=(case dim
            of 1 => createKRND1()
            | _=> createKRN(dim, [],[],[],[])
        (*end case*))

        val oexp=E.Sum(List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s))), oexp)
        in (oexp,iexp,vAs,keinapp) end

    fun createBody3(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)=
            createBody2(dim, s,sx,[],deltas,Vid, hid, nid, fid,params,argsA)
    |   createBody3(dim, s,sx,alpha,deltas,Vid, hid, nid, fid,params,argsA)=
            (createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid),NONE,NONE,NONE)

    (*getsumshift:sum_indexid list* int list-> int
    *get fresh/unused index_id, returns int
    *)
    fun getsumshift(sx,n) =let
        val nsumshift= (case sx
            of []=> n
            | _=>let
                val (E.V v,_,_)=List.hd(List.rev sx)
                in v+1
                end
            (* end case *))
      
        val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
        val _ =(String.concat["\n", "SumIndex:" ,(String.concatWith"," aa),
        "\n\t Index length:",Int.toString n,
        "\n\t Freshindex: ", Int.toString nsumshift])
        in
            nsumshift
        end

    (*formBody:ein_exp->ein_exp
    *just does a quick rewrite
    *)
    fun formBody(E.Sum([],e))=formBody e
    | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
    | formBody(E.Opn(E.Prod, [e]))=e
    | formBody e=e

    (* silly change in order of the product to match vis branch WorldtoSpace functions*)
    fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
    (*
      | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
      *)
      | multiPs([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
      | multiPs(Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
    
        
    fun multiMergePs([P0,P1],[sx0,sx1],body)=E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
      | multiMergePs e=multiPs e 
      
    (* *******************************************  setImage *******************************************  *)
    fun replaceImgA(es,vid,newbie)=List.take(es,vid)@[newbie]@List.drop(es,vid+1)
    fun setImage(params',argsA,code,vexp2,index,alpha,paraminstant,Vid,s)=
        (case vexp2
            of NONE =>(params',argsA,code)
            | SOME vexp  =>    let
                val iArg  = DstV.new ("Img", DstTy.TensorTy([]))
                val alphax=List.map (fn (E.V i)=>List.nth(index,i)) alpha
                val ieinapp=(iArg,mkEinApp(mkEin(paraminstant,alphax,vexp),argsA))
                (*
                 val _ =print(String.concat["\n****\n Image (",Int.toString(length(argsA)),")"])
                val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
                val _ =print(String.concat["\n replace at ",Int.toString Vid ," with " , DstIL.Var.toString iArg ,"\n"])*)
                val argsA=replaceImgA(argsA,Vid,iArg)
                val params'=replaceImgA(params',Vid,E.TEN(2,[(s-(1-s)+1)*(s-(1-s)+1),(s-(1-s)+1)]))
                val code=code@[ieinapp]
                (*
                val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") argsA))
                val _ =print(String.concat["\n****\n Image(",Int.toString(length(argsA)),")"])*)
                in (params',argsA,code) end
        (*end case*))
      
      (*kernels*)
      fun setKernel(params',args',code,vAs2,keinapp2,dim)=
        (case (vAs2,keinapp2)
            of (NONE,NONE)=> (params',args',code)
            | (SOME vAs,SOME keinapp) => let
            (*
                val _ =print"\n****\n Kernels\n"
                val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])
                val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))*)
                val args'=   args'@vAs
                val params'= params'@(List.tabulate(dim,fn _=> E.TEN(2,[])))
                val code=code@keinapp
                (*
                val _ =print"\n"
                val _ =print(String.concat(List.map (fn e=> DstIL.Var.toString(e)^",") args'))
                val _ =print(String.concat["\n****\n Kernel(",Int.toString(length(args')),")"])*)
                in (params',args',code) end
        (*end case*))
        
      fun setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)=let
        val (params',args',code)=setImage(params',args',code,vexp2,index,alpha,paraminstant,Vid,s)
        in setKernel(params',args',code,vAs2,keinapp2,dim) end
    
    
    (* *******************************************  Replace probe *******************************************  *)
    (* replaceProbe
    * Transforms position to world space
    * transforms result back to index_space
    * rewrites body 
    * replace probe with expanded version
    *)
     fun replaceProbe0(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)
        =let
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e
        val _ = (String.concat["\n***************** \n Replace ************ \n"])
        val _=  toStringBindp (y, DstIL.EINAPP(e,args))
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val Pid=nid+1
        val nshift=length(dx)
        val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,length(index))
        val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
        val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
        val body' = multiPs(Ps,newsx1,body')
        
        val body'=(case originalb
            of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
            | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
            | _                                  => body'
            (*end case*))
        
        val args'=argsA@[PArg]
        val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
        val _= List.map toStringBindp(code@[einapp])
        in 
            (fieldset,code@[einapp])
        end
        


    fun replaceProbe3(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx) = let
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e
        val _ = testp["\n***************** \n Replace ************ \n"]
        val _=  toStringBind (y, DstIL.EINAPP(e,args))
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val Pid=nid+1
        val nshift=length(dx)
        val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,length(index))
        val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
     
        val paraminstant=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
        val params'=paraminstant@[E.TEN(1,[dim,dim])]
        

        
        val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid,paraminstant,argsA)
        val body' = multiPs(Ps,newsx1,body')
        
        val body'=(case originalb
            of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
            | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
            | _                                  => body'
            (*end case*))
            
        (*images and kernels*)
        val (params',argsA,code)=setImageKernel(params',argsA,code,vexp2,vAs2,keinapp2,dim,index,alpha,paraminstant,Vid,s)
        
        
        (*replace term*)
        val args'=argsA@[PArg]
        val einapp=(y,mkEinApp(mkEin(params',index,body'),args'))
        val _= List.map toStringBindp(code@[einapp])
        in
            (fieldset,code@[einapp])
        end
        
        
    (* ******************************************* Lift probe *******************************************  *)
    fun createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)= let
        val Pid=0
        val tid=1
      
        (*Assumes body is already clean*)
        val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
        
        (*need to rewrite dx*)
        val (_,sizes,e as E.Conv(_,alpha',_,dx))=(case sx@newsx
            of []=> ([],index,E.Conv(9,alpha,7,newdx))
            | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx),index,sx@newsx)
            (*end case*))
                
        val params=[E.TEN(1,[dim,dim]),E.TEN(1,sizes)]
        fun filterAlpha []=[]
          | filterAlpha(E.C _::es)= filterAlpha es
          | filterAlpha(e1::es)=[e1]@(filterAlpha es)
        
        val tshape=filterAlpha(alpha')@newdx
        val t=E.Tensor(tid,tshape)

        val (splitvar,body)=(case originalb
            of E.Sum(sx, E.Probe _)              => (true,multiPs(Ps,sx@newsx,t))
            | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => (false,E.Sum(sx,setProd[eps0,multiPs(Ps,newsx,t)]))
            | _                                  => (case tsplitvar
              of(* true =>   (true,multiMergePs(Ps,newsx,t))  (*pushes summations in place*)
                | false*) _ =>   (true,multiPs(Ps,newsx,t))
                (*end case*))
        (*end case*))
        
        val _ =(case splitvar
            of true=> (String.concat["splitvar is true", P.printbody body])
            | _ => (String.concat["splitvar is false",P.printbody body])
        (*end case*))
        
        
        val ein0=mkEin(params,index,body)
        in
            (splitvar,ein0,sizes,dx,alpha')
        end
    
    fun liftProbe0(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)=let
        val _=(String.concat["\n******* Lift Geneirc Probe ***\n"])
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e
        val _ =  (toStringBindp (y, DstIL.EINAPP(e,args)))
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val nshift=length(dx)
        val (fieldset,dim,args',code,s,PArg) = handleArgsF(fieldset,Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,length(index))
        
        (*transform T*P*P..Ps*)
        val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
        val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
    
            
            
        (*addedhere*)
        val ein9=mkEin(params,sizes,E.Conv(Vid,alpha',hid,dx))
        val einApp9=mkEinApp(ein9,args)
        val rtn9=(FArg,einApp9)
        val (fieldset,FArg,rtn1)= (case (einVarSet.rtnVarN(fieldset,rtn9))
            of  (fieldset,SOME v)  => let
                    val _ = (" \n did find"^toStringBind(rtn9))
                    in (fieldset, v,[]) end
            | (fieldset,NONE)  => let
                    (*lifted probe*)
                    val _ =(" \n did not find"^toStringBind(rtn9))
                    val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
                    val freshIndex'= length(sizes)
                    val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)
                    val ein1=mkEin(params',sizes,body')
                    val einApp1=mkEinApp(ein1,args')
                    val rtn1=(FArg,einApp1)
                    in (fieldset,FArg ,[rtn1]) end
            (*end case*))

        val einApp0=mkEinApp(ein0,[PArg,FArg])
        val rtn0=(case splitvar
            of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
            | _      => let
            val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
                in Split.splitEinApp bind3
        end
        (*end case*))

        val rtn=code@rtn1@rtn0
        val _= List.map toStringBindp (code@rtn1)
        val _ ="\n**** split code **\n"
        val _= List.map toStringBindp rtn0
        in
            (fieldset,rtn)
        end

    fun liftProbe3(fieldset,(y, DstIL.EINAPP(e,args)),p ,sx)=let
        val _=testp["\n******* Lift Geneirc Probe ***\n"]
        val originalb=Ein.body e
        val params=Ein.params e
        val index=Ein.index e
        val _ =  (toStringBind (y, DstIL.EINAPP(e,args)))
        
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=p
        val fid=length(params)
        val nid=fid+1
        val nshift=length(dx)
        val (dim,args',code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex=getsumshift(sx,length(index))
        
        (*transform T*P*P..Ps*)
        val (splitvar,ein0,sizes,dx,alpha')= createEinApp(originalb,alpha,index,freshIndex,dim,dx,sx)
        
        val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
        val einApp0=mkEinApp(ein0,[PArg,FArg])
        val rtn0=(case splitvar
            of false => [(y,mkEinApp(ein0,[PArg,FArg]))]
            | _      => let
                val bind3 = (y,DstIL.EINAPP(SummationEin.main ein0,[PArg,FArg]))
                in Split.splitEinApp bind3
                end
                (*end case*))
        
        (*lifted probe*)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
        val freshIndex'= length(sizes)
        
        (*val body' = createBody(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid)*)
      
        val (body',vexp2,vAs2,keinapp2) = createBody3(dim, s,freshIndex',alpha',dx,Vid, hid, nid, fid,params',args')
        
        (*set image and kernel*)
        val (params',args',code)=setImageKernel(params',args',code,vexp2,vAs2,keinapp2,dim,index,alpha,params',Vid,s)
        
        val ein1=mkEin(params',sizes,body')
        val einApp1=mkEinApp(ein1,args')
        val rtn1=(FArg,einApp1)
        val rtn=code@[rtn1]@rtn0
        val _= List.map toStringBind ([rtn1]@rtn0)
        val _=(String.concat["\n* end  Lift Geneirc Probe  ******** \n"])
          val _= List.map toStringBindp(rtn)
        in
            (fieldset,rtn)
        end

    fun replaceProbe e= (case pullKrn
        of true=>replaceProbe3 e
        | false => replaceProbe0 e
        (*end case*))
    fun liftProbe e=(case pullKrn
        of true=>liftProbe3 e
        | false => liftProbe0 e
        (*end case*))

        
    (* ******************************************* Reconstruction -> Lift|Replace probe *******************************************  *)
    (* scans dx for contant
     * arg:(1,code1, body1,[])
     *)
    fun reconstruction([],arg)= replaceProbe arg
     | reconstruction(dx,arg)=(case (constflag,fieldliftflag)
        of (true,true) => liftProbe arg
        | (_,false)    => replaceProbe arg
        | _ => let
            fun fConst [] = liftProbe arg
            | fConst (E.C _::_) = replaceProbe arg
            | fConst (_ ::es)= fConst es
            in fConst dx end
        (* end case*))
        
    (* **************************************************** Index Tensor **************************************************** *)
    (*Push constant indices to tensor replacement*)
    fun getF (e,fieldset,dim,newvx)= let
        val (y, DstIL.EINAPP(ein,args))=e
        val index0=Ein.index ein
        val index1 = index0@dim
        val b=Ein.body ein

        val (c1,dx,body1)=(case b
            of  E.Sum([(vsum,0,n)],E.Probe(E.Conv(V,[c1,v0],h,dx),pos))=>let
                val shiftdx=List.tabulate(length(dx),fn n=>E.V (n+2))
                val b=E.Probe(E.Conv(V,[E.V 0,E.V 1],h,shiftdx),pos)
                in (c1,dx,b) end
            | E.Probe(E.Conv(V,[c1,v0],h,dx),pos)=> let
                val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx, v0],h,dx),pos)
                (* clean to get body indices in order *)
                val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
                in (c1,dx,body1) end
            |  E.Probe(E.Conv(V,[c1],h,dx),pos)=> let
               val body1_unshifted= E.Probe(E.Conv(V,[E.V newvx],h,dx),pos)
               val ( _ , _, body1)= cleanIndex.cleanIndex(body1_unshifted,index1,[])
               in (c1,dx,body1) end
            (*end case*))
            
        val lhs1=DstV.new ("L", DstTy.TensorTy(index1))
        val ein1 = mkEin(Ein.params ein,index1,body1)
        val code1= (lhs1,mkEinApp(ein1,args))
        
        val (lhs0,(fieldset,codeAll))= (case valnumflag
            of false    => (lhs1, reconstruction(dx,(fieldset,code1,body1,[])))
            | true      => (case  (einVarSet.rtnVarN(fieldset,code1))
                of (fieldset,NONE)     => (lhs1, reconstruction(dx,(fieldset,code1,body1,[])))
                | (fieldset,SOME m)   =>  (m,(fieldset,[]))
                (*end case*))
            (*end case*))

        (*Probe that tensor at a constant position  c1*)
        val param0 = [E.TEN(1,index1)]
        val nx=List.tabulate(newvx,fn n=>E.V n)
        val body0 =  (case b
            of E.Sum([(vsum,0,n)],_)=>  E.Sum([(vsum,0,n)],E.Tensor(0,[vsum,vsum]@nx))
            | _ => E.Tensor(0,[c1]@nx)
            (*end case*))
        val ein0 = mkEin(param0,index0,body0)
        val einApp0 = mkEinApp(ein0,[lhs0])
        val code0 = (y,einApp0)
        val _= toStringBind code0
        in
            (fieldset,codeAll@[code0])
        end
    (* **************************************************** General Fn **************************************************** *)
    (* expandEinOp: code->  code list
    * A this point we only have simple ein ops
    * Looks to see if the expression has a probe. If so, replaces it.
    * Note how we keeps eps expressions so only generate pieces that are used
    *)
    fun expandEinOp(e0 as (y, DstIL.EINAPP(ein,args)),fieldset,varset)=let
        fun rewriteBody(fieldset,e,p as E.Probe(E.Conv(_,alpha,_,dx),_))= (case (detflag,alpha,dx)
            of (true,[E.C(_,true), E.V 0],[])            => getF(e,fieldset,[3],1)
            | (true,[E.C(_,true), E.V 0],[E.V 1])        => getF(e,fieldset,[3],2)
            | (true,[E.C(_,true), E.V 0],[E.V 1,E.V 2])  => getF(e,fieldset,[3],3)
            | (true,[E.C(_,true)],[])                    => getF(e,fieldset,[3],0)
            | (true,[E.C(_,true)],[E.V 0])               => getF(e,fieldset,[3],1)
            | (true,[E.C(_,true)],[E.V 0,E.V 1])         => getF(e,fieldset,[3],2)
            | (true,[E.C(_,true)],[E.V 0,E.V 1,E.V 2])   => getF(e,fieldset,[3],3)
            | _                                          => reconstruction(dx,(fieldset,e,p,[]))
            (*end case*))
        | rewriteBody(fieldset,e,E.Sum(sx,p as E.Probe(E.Conv(_,alpha,_,dx),_)))= (case (detsumflag,sx,alpha,dx)
            of (true,[(E.V 0,0,_)],[E.V 0 ,E.V 0],[])              => getF(e,fieldset,[3,3],0)
            | (true,[(E.V 1,0,_)],[E.V 1 ,E.V 1],[E.V 0])          => getF(e,fieldset,[3,3],1)
            | (true,[(E.V 2,0,_)],[E.V 2 ,E.V 2],[E.V 0,E.V 1])    => getF(e,fieldset,[3,3],2)
            | (_,_,_,[])                                => replaceProbe(fieldset,e,p, sx)  (*no dx*)
            | (_,_,[],_)                                => reconstruction(dx,(fieldset,e,p,sx))
            | _                                         => replaceProbe(fieldset,e,p, sx)
            (* end case *))
        | rewriteBody(fieldset,e,E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))     = replaceProbe(fieldset,e,E.Probe p,sx)
        | rewriteBody (fieldset,e,_)  = (fieldset,[e])
        
        val b=Ein.body ein
        fun pf()=("\n **************************** starting  **************************** \n"^(P.printerE(ein)))
        fun matchField()=(case b
            of E.Probe _ =>  (pf();1)
            | E.Sum (_, E.Probe _)=> (pf();1)
            | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _]))=> (pf();1)
            | _ =>0
            (*end case*))
        val m=matchField()
        val (fieldset,varset,code,flag) = (case valnumflag
            of true => (case (einVarSet.rtnVarN(fieldset,e0))
               of  (fieldset,NONE)     => let
                    val(fieldset,code)=rewriteBody(fieldset,e0,b)
                    in (fieldset,varset,code,0) end
              | (fieldset,SOME v)    => (fieldset,varset,[(y,DstIL.VAR v)],1)
                (*end case*))
                | _     => let
                val(fieldset,code)=rewriteBody(fieldset,e0, b)
                in (fieldset,varset,code,0) end
            (*end case*))

        in  (code,fieldset,varset,m,flag) end

  end; (* local *)

end (* local *)  

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0