Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2611 - (download) (annotate)
Mon May 5 21:21:12 2014 UTC (5 years, 2 months ago) by cchiw
File size: 8080 byte(s)
InnerProduct, DoubleDot:
(*hashs Ein Function after substitution*)
structure genKrn = struct
    local
    structure E = Ein
   structure DstOp = LowOps

    structure evalKrn =evalKrn
structure SrcIL = MidIL
structure SrcOp = MidOps
structure SrcSV = SrcIL.StateVar
structure SrcTy = MidILTypes
structure VTbl = SrcIL.Var.Tbl
structure DstIL = LowIL
structure DstTy = LowILTypes
 structure Var = LowIL.Var
    in


fun insert (key, value) d =fn s =>
if s = key then SOME value
else d s

fun lookup k d = d k

fun find(v, mapp)=let
val a=lookup v mapp
in (case a of NONE=> raise Fail "Outside Bound"
|SOME s => s)
end

val empty =fn key =>NONE

val testing=0
(*Add,Subtract Scalars*)
fun mkSimpleOp(mapp,e,args)=let

    fun subP e=(case e
        of E.Tensor(t1,ix1)=>gHelper.mkSca(mapp,(t1,ix1,args))
        (*| E.Value v1=> gHelper.aaV(DstOp.C (List.nth(mapp,v1)),[],"Const",DstTy.TensorTy([]))*)
        | E.Value v1=> gHelper.aaV(DstOp.C (find(v1,mapp)),[],"Const",DstTy.TensorTy([]))
        | E.Const c=> gHelper.aaV(DstOp.C 9,[],"Const",DstTy.TensorTy([]))
    (*end case*))

    in (case e
        of E.Sub(e1,e2)=> let
            val (vA,A)=subP e1
            val (vB,B)=subP e2
            val (vD,D)=gHelper.aaV(DstOp.subSca,[vA,vB],"Subsca",DstTy.TensorTy([]))
            in (vD,A@B@D) end
        | E.Add[e1,e2]=> let
            val (vA,A)=subP e1
            val (vB,B)=subP e2
            val (vD,D)=gHelper.aaV(DstOp.addSca, [vA,vB],"addsca",DstTy.TensorTy([]))
            in (vD,A@B@D) end
        (* ebd case*))
    end


(*FIX TYPE ON CONS TYPE HERE *)
(*con everything on the list, make a vectors*)
fun consfn([],rest, code,dim,n)=(rest,code)
    | consfn(e::es,rest,code,dim,n)=let
        val gg=length(e)
        val (vA,A)=gHelper.aaV(DstOp.cons(DstTy.TensorTy [gg],0),List.rev e,"Cons "^Int.toString(n)^":--",DstTy.TensorTy([gg]))
        in consfn(es, [vA]@rest, A@code,dim,n+1)
        end


(*sort expression into kernels and images*)
fun sortK(a,b,[])=(a,b)
| sortK(a,b,e::es)=(case e
    of E.Krn k1=>sortK(a,b@[k1],es)
    | E.Img img1=>sortK(a@[img1],b,es)
    (*end case*))



val bV= ref 0




fun sumP(a,b,last)=let
    val (vD, D)=gHelper.aaV(DstOp.prodVec(last),[a, b],"prodV",DstTy.TensorTy([last]))
    val (vE, E)=gHelper.aaV(DstOp.sumVec(last),[vD],"sumVec",DstTy.intTy)
    in (vE,D@E) end

fun ccons(rest,shape)= let
    val(vE,E)=gHelper.aaV(DstOp.cons(DstTy.TensorTy shape,0),rest,"Cons",DstTy.TensorTy(shape))
    in (vE,E) end



(*Images*)
fun mkImg(mappOrig,sx,[(fid,ix,px)],v,args)=let

    val (E.V vid,lb,ub)=hd(sx)
    val top=ub-lb
    val R=top+1
    val dim=length(px)
    val sx'=List.tabulate(dim, fn _ =>top)
    val sx''=List.map (fn n=>n+1) sx'
    val argType=DstTy.tensorTy (List.tabulate(dim, fn _ =>R))

    val (vlb,BB)= gHelper.mkC lb

    fun  createImgVar mapp=let
        fun mkpos([E.Add[E.Tensor(t1,ix1),_]],rest,code)= let
            val (vA,A)=gHelper.mkSca(mapp,(t1,ix1,args))
        
            val (vC,C)=gHelper.aaV(DstOp.addSca, [vA,vlb],"addsca",DstTy.TensorTy([]))
        
            in
                (rest@[vC],code@A@C)
            end

        | mkpos(pos1::es,rest,code)= let
            val (vF,code1)=mkSimpleOp(mapp,pos1,args)
            in mkpos(es,rest@[vF],code@code1)
            end
        val ix1=List.map (fn (e1)=> gHelper.mapIndex(e1,mapp)) ix
        val (vF,F)= mkpos(px,[],[])     
        val imgType=DstTy.imgIndex ix1
        val (vA,A)=gHelper.aaV(DstOp.imgAddr(v,imgType,dim),vF,"Imageaddress",DstTy.intTy)
        val (vB,B)=gHelper.aaV(DstOp.imgLoad(v,dim,R),[vA],"imgLoad---",DstTy.tensorTy([R]))
        in
            (vB,F@A@B)
        end


    fun sumI1(lft,ix,0,0,code,n')=let
            val mapp=insert (n', lb) ix
            val (lft', code')= createImgVar mapp
            in ([lft']@lft,code'@code)
            end
    |  sumI1(lft,ix,i,0,code,n')=let
        val mapp=insert (n', i-1) ix
        val (lft', code')=createImgVar mapp
        in sumI1([lft']@lft,ix,i-1,0,code'@code,n')
        end
    | sumI1(lft,ix,0,n,code,n')=let
        val mapp=insert (n', lb) ix
        in
            sumI1(lft,mapp,top,n-1,code,n'+1)
         end
    | sumI1(lft,ix,i,n, code,n')=let
         val mapp=insert (n', i+lb) ix
        val (lft',code')=sumI1(lft,mapp,top,n-1,code,n'+1)
        in sumI1(lft',ix,i-1,n,code',n') end


    val(lft,code)=sumI1([],mappOrig,top,dim-2,[],vid)

    in
        (lft,BB@code)

    end


(* kernels*)

fun  mkkrns2(mappOrig,sx,k1,h,args)=let


    val k= List.map (fn (id,d1,pos)=>(id,gHelper.evalDelta(d1,mappOrig),pos)) k1


    val (E.V sid,lb,ub)=hd(sx)
    val R=(ub-lb)
    val R'=R+1

    fun mm(e)=Int.toString e

    val _ =(case testing
        of 1=> let
            val _ =print "Differentiation value of kernels:"
            val _= List.map (fn(id,v, pos)=> print(Int.toString(v)))  k
            val _ =print(String.concat["\n ub:", mm ub, "lb:", mm lb, "Range", mm R ])
            in 1 end
        | _ => 1)
    

    fun q([],fin,l,ix, i,code,n')=(fin,code)
        | q((id1,d,pos1)::ks,fin,l,ix,0,code,n')=let
            val mapp=insert (n', lb) ix
            val (l', code')=mkSimpleOp(mapp,pos1,args)
            val e=l@[l']
            val mapp'=insert (n', 0) ix
            in  q(ks,fin@[e],[],mapp',R,code@code',n'+1)
            end
        | q(k::ks,fin,l,ix, i,code,n')=let
            val (id1,d,pos1)=k
            val mapp= insert (n', lb+i) ix
            val (l', code')=mkSimpleOp(mapp,pos1,args)
            in  q(k::ks,fin,l@[l'],ix,i-1,code@code',n')
            end

 
    val(lftkrn,code)=q(k,[],[],mappOrig,R,[],sid)
    val (lft,code')=consfn((lftkrn),[],[],R,0)



    fun evalK([],[],n,code,newId)=(newId,code)
    | evalK(kn::kns,x::xs,n,code,newId)=let
        val (_,dk,_) =kn
        val (id,kcode)= evalKrn.expandEvalKernel (R', h, dk, x,n)
        in      evalK(kns,xs,n+1,code@kcode,newId@[id])
        end

    val (ids, evalKcode)=evalK(k,lft,0,[],[])

    in
       (* (lft,code@code')*)
        (ids, code@code'@evalKcode)
    end

(*Written for 2-d and 3-d*)
fun prodImgKrn(imgArg,krnArg,R)=let


    val tyM=DstTy.TensorTy[R,R]
    val tyV=DstTy.TensorTy[R]
    val _=(case testing of 0=> 1
        | _ =>(print ("Number of Assignments  in prodImgArg returned"^Int.toString(length(imgArg)));1))

    fun dhz([],conslist,rest,code,_,_)=(conslist,code)
    | dhz(e::es,conslist,rest,code,hz,0)=let
        val (vA,A)=sumP(e,hz,R)
        val (vD,D)=gHelper.aaV(DstOp.cons(DstTy.intTy,R),rest@[vA],"Cons",tyV)
        in dhz(es,conslist@[vD],[],code@A@D,hz,R-1)
        end
    | dhz(e::es,conslist,rest,code,hz,r)=let
        val (vA,A)=sumP(e,hz,R)
        in dhz(es,conslist,rest@[vA],code@A,hz,r-1)
    end


    fun dhy([],rest,code,hy)=   let
             val n=length(rest)
            val (vD,D)=gHelper.aaV(DstOp.cons(DstTy.intTy,n),rest,"Cons",tyV)
            in
            (vD,code@D) end
        | dhy(e::es,rest,code,hy)=let
            val (vA,A)=sumP(e,hy,R)
            in dhy(es,rest@[vA],code@A,hy)
            end

    in (case krnArg
        of [hx]=>let
            val [i]=imgArg
            in sumP(i,hx,R)
            end

    | [hy,hx]=>let
            val ty=DstTy.TensorTy[R]
            val (vD,code)=dhy(imgArg,[],[],hy)
            val (vE,E)=sumP(vD,hx,R)
            in
                (vE,code@E)
            end
    |   [hz,hy,hx]=>let
 

        val (vZ,codeZ)=dhz(imgArg,[],[],[],hz,R-1)
        val (vY,codeY)=dhy(vZ,[],[],hy)
        val (vE,E)=sumP(vY,hx,R)
        in
            (vE,codeZ@codeY@E)
        end

    (*end case*))
    end 

fun  evalField(mapp,(E.Sum(sx,E.Prod e),v,h,args))=let
    val _=(case testing
        of 0 => 1
        | _ => (print "\n\n ************** new direction **********\n\n Outer Bound:";1)
        (*end test*))


    val (img1,k1)=sortK([],[],e)
    val (_,lb,ub)=hd(sx)
    val R=(ub-lb)+1

    val (imgArg,imgCode)= mkImg(mapp,sx,img1,v,args)
    val (krnArg, krnCode)= mkkrns2(mapp,sx,k1,h,args)
    val (vA,A)=prodImgKrn(imgArg,krnArg,R)
    in  (vA,imgCode@krnCode@A)


    end



end (* local *)

end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0