Revision 3067 - (download) (annotate)
Sat Mar 14 17:24:52 2015 UTC (4 years, 4 months ago) by cchiw
File size: 8120 byte(s)
matrix fields
structure EvalImg = struct
    structure DstOp = LowOps
    structure DstTy = LowILTypes
    structure DstIL = LowIL
    structure LowToS= LowToString
    structure Var = LowIL.Var
    structure E = Ein
    structure P=Printer
    structure H=Helper


    val testing=0
    fun lookup e =H.lookup e
    fun insert e=H.insert e
    fun find e=H.find e
    fun mapIndex e=H.mapIndex e
    fun mkInt n =H.mkInt n
    fun assgn e=H.assgn e
    fun indexTensor e = H.indexTensor e
    fun mkAddInt e= H.mkAddInt("",e)
    fun mkAddPtr(e,ty)= H.mkAddPtr("",e,ty)
    fun mkProdInt e= H.mkProdInt("",e)

    fun iTos n=Int.toString n
    fun err str=raise Fail(str)
    val intTy=DstTy.IntTy
    fun testp n =(case testing
        of 0 => 1
        | _  => (print(String.concat n);1)
        (*end case *))
    fun psize n=foldl (fn (a,b) => b*a) 1 n
    fun asize n=foldl (fn (a,b) => b+a) 0 n

    (* mkImg:dict*string*E.params*var list*sum_id list*()*image*var*int*int*int
    * ->var*lowIL.assgn
    * The image "imgarg" is probed at positions
    * Σ_{sx} V_alpha[pos0::px]
    * sumPos() iterates over the summation indices and creates a mapp for the indicies
    * once mapp(j->2 k->0) is created sumPos() calls createImgVar()  to get the addr of Σ_k V_{i}[T_j,T_k]
    * createImgVar() uses  mkpos(), getPosAddr() and getImgAddr() to get imgvar 
    fun mkImg(mappOrig,lhs,params,args,sx,(_,v_alpha,pos0::px),v,imgarg,lb,range0,range1)=let

        val dim=ImageInfo.dim v
        val ptyTy=DstTy.AddrTy v
        val sizes=ImageInfo.sizes v
        val (vBase,base)=assgn(DstOp.baseAddr v,[imgarg],"baseAddr",ptyTy)   (*base address*)
        val (vShapeShift,ShapeShiftcode)= mkInt(psize (ImageInfo.voxelShape v)) (*shift of the image field.*)

        (*Since the image is loaded as a vector
        * we evaluate the first position just once
        * Σ_{ij..} V[T+i,T+j...]-> Σ_{j..} V[T+j...]
        * and we drop the first summation index
        * Additionally, summation indices are reversed
        * that inner loop is second(y) axis and outer loop is third(z)axis
        val (vPos0,Pos0code,sxx)=let
            val E.Add[E.Tensor(t1,ix1),_ ]=pos0
            val (vA,A)=indexTensor(mappOrig,("",params,args,t1,ix1,intTy))
            val (vB,B)= mkInt lb
            val (vC,C)= mkAddInt [vA,vB]
            val _ =testp["\nsxx\n original:"]
            val sxx= List.map(fn (E.V sid,_,_)=> sid) sx
            val _ =List.map (fn e=> testp["-",iTos e]) sxx
            val sxx=List.rev(List.tl(sxx))
            val _ =testp["\n used:"]
            val _ =List.map (fn e=> testp["-",iTos e]) sxx


        (* mkpos:ein_exp list*var list*DstIL.assgn list
        * transform ein_exp to low-il
        * returns var for the position
        fun mkpos(e,mapp,rest,code)=(case e
            of [] => (rest,code)
            | ((E.Add[ E.Tensor(t1,ix1),E.Value v1])::es)=> let
                val (vA,A)=indexTensor(mapp,("",params,args,t1,ix1,intTy))
                val j=find(v1,mapp)
                val (rest',code')=(case j
                    of 0 => (vA,A)
                    | _ => let
                        val (vB,B)= mkInt j
                        val (vC,C)=mkAddInt[vA,vB]
                    in (vC,A@B@C) end
                    (*end case*))
                in mkpos(es,mapp,rest@[rest'],code@code') end
            | e1::_ => raise Fail("Incorrect pos for Image: "^P.printbody e1)
            (*end case*))

        (* getPosAddr:var list->var*DstIL.assgn list
        * create position addr based on image info's shapeshift,image info's sizes,and args
        * args are the variables for this specific positions. V_([x,y])
        * returns vPosAddr,PosAddrcode
        fun getPosAddr args=(case (sizes,args)
            of([ _ ],[i]) =>mkProdInt[vShapeShift,i] (*1-d*)
            | ([x, _ ],[i,j]) =>let                  (*2-d*)
                val (vA,A)= mkInt x
                val (vB,B)= mkProdInt [vA,j]
                val (vC,C)= mkAddInt [i,vB]
                val (vD,D)= mkProdInt[vShapeShift,vC]
                in (vD,A@B@C@D) end
            | ([x,y,_],[i,j,k])  =>let                (*3-d*)
                val (vA,A)= mkInt y
                val (vB,B)= mkProdInt [vA,k]
                val (vC,C)= mkAddInt [j,vB]
                val (vD,D)= mkInt x
                val (vE,E)= mkProdInt [vD,vC]
                val (vF,F)= mkAddInt [i,vE]
                val( vG,G)= mkProdInt[vShapeShift,vF]
                in (vG,A@B@C@D@E@F@G) end
            (*end case*))

        (* getImgAddr:int list *var->var*DstIL.assgn list
        * creates image address with ^position address,imgType, and base address
        * imgType are  image specific indices V[0,1](_)
        * ->returns (vImgAddr,ImgAddrcode)
        fun getImgAddr (imgType,vPosAddr) =(case imgType
            of [] => mkAddPtr([vBase,vPosAddr],ptyTy)
            | [0] => mkAddPtr([vBase,vPosAddr],ptyTy)
            | [_] => let
                val (vA,A)= mkAddPtr([vBase,vPosAddr],ptyTy)
                val (vB,B)= mkInt (asize imgType)
                val (vC,C)= mkAddPtr([vB, vA],ptyTy)
                in (vC,A@B@C)end
            | [i,j] => let
                val [a,b]=ImageInfo.voxelShape v
                (*val _=print(String.concat[Int.toString i,"-",Int.toString j,"\nvoxel",Int.toString a,"-",Int.toString b])*)
                val (vA,A)= mkAddPtr([vBase,vPosAddr],ptyTy)
                val (vB,B)= mkInt ((b*j)+i)
                val (vC,C)= mkAddPtr([vB, vA],ptyTy)
                in (vC,A@B@C)end
                (*end case*))

        (* createImgVar:dict->var*DstIL.assgn list
        *  gets low-il var for loading an image address
        fun createImgVar mapp=let
            val (vA,A)= mkpos(px,mapp,[],[])     (*transforms the probed position to low-il*)
            val posArgs=[vPos0]@vA                              (*adds intial position to ^*)
            val (vPosAddr,PosAddrcode)=getPosAddr posArgs                (*position address*)
            val imgType=List.map (fn (e1)=> mapIndex(e1,mapp)) v_alpha (*img specific index*)
            val (vImgAddr,ImgAddrcode)=getImgAddr (imgType,vPosAddr)          (*img address*)
            val (vD,D)=assgn(DstOp.imgLoad(v,dim,range1),[vImgAddr],"imgLoad",DstTy.tensorTy([range1]))

        (* sumPos:index_id * var list*lowil.assgn list*dict*int
        * ->var*lowil.assgn list
        * sumPos iterates over the summation indices and creates mapp
        fun sumPos([],lft,code,dict,_)=let
            val (lft', code')= createImgVar dict
            in ([lft']@lft,code'@code) end
        | sumPos([sid],lft,code,dict,0)=let
            val n'=lb
            val _=testp["\n insert",iTos sid, "->",iTos n']
            val mapp=insert (sid, n') dict
            val (lft', code')= createImgVar mapp
            in ([lft']@lft,code'@code) end
        | sumPos([sid],lft,code,dict,r)=let
            val n'=lb+r
            val _=testp["\n insert",iTos sid, "->",iTos n']
            val mapp=insert (sid, n') dict
            val (lft', code')= createImgVar mapp
            in sumPos([sid],[lft']@lft,code'@code,dict,r-1) end
        | sumPos(sid::sxx,lft,code,dict,0)=let
            val n'=lb
            val _=testp["\n insert",iTos sid, "->",iTos n']
            val mapp=insert (sid, n') dict
            val (lft',code')=sumPos(sxx,lft,[],mapp,range0)
            in (lft',code'@code) end

        | sumPos(sid::sxx,lft,code,dict,r)=let
            val n'=lb+r
            val _=testp["\n insert",iTos sid, "->",iTos n']
            val mapp=insert (sid, n') dict
            val (lft',code')=sumPos(sxx,lft,[],mapp,range0)


end (* local *)


