Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2680 - (download) (annotate)
Wed Aug 6 00:51:53 2014 UTC (4 years, 11 months ago) by cchiw
File size: 8467 byte(s)
update
(*hashs Ein Function after substitution*)
structure genKrn = struct
    local

    structure DstOp = LowOps
    structure DstTy = LowILTypes
    structure DstIL = LowIL
    structure E = Ein
    structure evalKrn =evalKrn
    structure S3=step3
 structure tS= toStringEin

    in

val testing=0
val Sca=DstTy.TensorTy []
val addR=DstOp.addSca

fun insert (key, value) d =fn s =>
    if s = key then SOME value
    else d s
fun lookup k d = d k
fun find(v, mapp)=(case (lookup v mapp)
    of NONE=> raise Fail "Outside Bound"
    |SOME s => s)
    

fun errS str=raise Fail(str)
fun mkInt n= S3.mkInt n 


fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca)
fun mkAddSca rest= S3.aaV(addR,rest,"addSca",Sca)
fun mkCons(shape, rest)=let
    val ty=DstTy.TensorTy shape
    val a=DstIL.Var.new("Cons" ,ty)
    val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
    val _=print("\n****"^tS.toStringAll(ty,code))
    in
        (a, [code])
    end


(*Add,Subtract Scalars*)
fun mkSimpleOp(mapp,e,args)=let
    fun subP e=(case e
        of E.Tensor(t1,ix1)=>S3.mkSca(mapp,(t1,ix1,args))
        | E.Value v1=> mkInt (find(v1,mapp))
        | _ => errS("ill-formed Kernel position")
    (*end case*))
    in (case e
        of E.Sub(e1,e2)=> let
            val (vA,A)=subP e1
            val (vB,B)=subP e2
            val (vD,D)=mkSubSca [vA,vB]
            in (vD,A@B@D) end
        | E.Add[e1,e2]=> let
            val (vA,A)=subP e1
            val (vB,B)=subP e2
            val (vD,D)=mkAddSca [vA,vB]
            in (vD,A@B@D) end
        | _ => raise Fail"Probed position is not subtraction or addition"
             (* end case*))
    end


(*FIX TYPE ON CONS TYPE HERE *)
(*con everything on the list, make a vectors*)

fun consfn([],rest, code,_)=(rest,code)
  | consfn(e::es,rest,code,dim)=let
        val (vA,A)= mkCons([length(e)],List.rev e)
    in
        consfn(es, [vA]@rest, A@code,dim)
    end


(*sort expression into kernels and images*)
fun sortK(a,b,[])=(a,b)
    | sortK(a,b,e::es)=(case e
        of E.Krn k1=>sortK(a,b@[k1],es)
        | E.Img img1=>sortK(a@[img1],b,es)
        | _ =>raise Fail"Non Image or Krn in summation expression"
        (*end case*))


(*
fun sumP(a,b,last)=let
    val (vD, D)=S3.aaV(DstOp.prodVec(last,1),[a, b],"prodV",DstTy.TensorTy([last]))
    val (vE, E)=S3.aaV(DstOp.sumVec(last,1),[vD],"sumVec",DstTy.intTy)
    in (vE,D@E) end
*)
(*added dot vec operator here *)
fun sumP(a,b,last)=let
    val (vE, E)=S3.aaV(DstOp.dotVec(last,1),[a,b],"dotVec",DstTy.intTy)
    in (vE,E) end



(*Images, ivar:original argument *)
fun mkImg(mappOrig,sx,starter,v,vNew,info,sid,lb,ub,top, R)=let
    val [(fid,ix,px)]=(case starter
        of [(fid,ix,px)]=>[(fid,ix,px)]
        | _=> raise Fail"Non summation range")

    val dim=length(px)
    val (vlb,BB)=  mkInt lb

    val (vBase,base)=S3.aaV(DstOp.baseAddr v,[vNew],"baseAddr",Sca)

    fun createImgVar mapp=let
        fun mkpos(e,rest,code)=(case e
            of [E.Add[E.Tensor(t1,ix1),_]]=> let 
                val (vA,A)=S3.mkSca(mapp,(t1,ix1,info))
                val (vC,C)=mkAddSca [vA,vlb]
                in
                    (rest@[vC],code@A@C)
                end
            | pos1::es => let
                val (vF,code1)=mkSimpleOp(mapp,pos1,info)
                in
                    mkpos(es,rest@[vF],code@code1)
                end
            | _=> raise Fail "Non-addition in Image Position"
        (*end case*))

        val (vF,F)= mkpos(px,[],[])     
        val imgType=DstTy.imgIndex(List.map (fn (e1)=> S3.mapIndex(e1,mapp)) ix)
        val (vA,A)=S3.aaV(DstOp.imgAddr(v,imgType,dim),[vBase]@vF,"Imageaddress",DstTy.intTy)
        val (vB,B)=S3.aaV(DstOp.imgLoad(v,dim,R),[vA],"imgLoad",DstTy.tensorTy([R]))
        in
            (vB,F@A@B)
        end
    fun sumI1(lft,dict,0,0,code,n')=let
        val mapp=insert (n', lb) dict
        val (lft', code')= createImgVar mapp
        in ([lft']@lft,code'@code) end
    |  sumI1(lft,dict,i,0,code,n')=let
        val mapp=insert (n', i-1) dict
        val (lft', code')=createImgVar mapp
        in sumI1([lft']@lft,dict,i-1,0,code'@code,n') end
    | sumI1(lft,dict,0,n,code,n')=let
        val mapp=insert (n', lb) dict
        in sumI1(lft,mapp,top,n-1,code,n'+1) end
    | sumI1(lft,dict,i,n, code,n')=let
         val mapp=insert (n', i+lb) dict
        val (lft',code')=sumI1(lft,mapp,top,n-1,code,n'+1)
        in sumI1(lft',dict,i-1,n,code',n') end
    val(lft,code)=sumI1([],mappOrig,top,dim-2,[],sid)
    in
        (lft,base@BB@code)
    end


(* kernels*)
fun mkkrns(mappOrig,sx,dels,h,args, sid,lb,ub,top,R)=let
    val newdels= List.map (fn (id,d1,pos)=>(id,S3.evalDels(mappOrig,d1),pos)) dels
    val _ =(case testing
        of 1=> let
            fun mm(e)=Int.toString e
            val _ =String.concat(["Differentiation value of kernels:"]@
                (List.map (fn(id,v, pos)=> mm(v)) newdels)@ ["\n ub:", mm ub, "lb:", mm lb, "Range", mm top])
            in 1 end
        | _ => 1)
    fun mkpos(k,fin,l,dict, i,code,n)= (case (k,i)
        of ([],_)=>(fin,code)
        | ((id1,d,pos1)::ks,0)=>let
            val mapp=insert (n, lb) dict
            val (l', code')=mkSimpleOp(mapp,pos1,args)
            val e=l@[l']
            val mapp'=insert (n, 0) dict
            in  mkpos(ks,fin@[e],[],mapp',top,code@code',n+1)
            end
        | (e1::es,_) =>let
            val (id1,d,pos1)=e1
            val mapp= insert (n, lb+i) dict
            val (l', code')=mkSimpleOp(mapp,pos1,args)
            in  mkpos(k,fin,l@[l'],dict,i-1,code@code',n)
            end
        (*end case*))
    fun evalK([],[],_,code,newId)=(newId,code)
      | evalK(kn::kns,x::xs,n,code,newId)=let
        val (_,dk,_) =kn
        val (id,kcode)= evalKrn.expandEvalKernel (R,h, dk, x,n)
        in
            evalK(kns,xs,n+1,code@kcode,newId@[id])
        end
     |evalK _ =raise Fail "Non-equal variable list, error in mkKrns"

    val(lftkrn,code)=mkpos(newdels,[],[],mappOrig,top,[],sid)
    val (lft,code')=consfn((lftkrn),[],[],top)
    val (ids, evalKcode)=evalK(newdels,lft,0,[],[])
    in
        (ids, code@code'@evalKcode)
    end


(*Product of Image and Kernel*)
fun prodImgKrn(imgArg,krnArg,R)=let

    fun ConsInt(shape, rest)=let
        val ty=DstTy.TensorTy [R]
        val a=DstIL.Var.new("Cons"  ,ty)
        val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
        in (a, [code])
        end


    fun dhz([],conslist,rest,code,_,_)=(conslist,code)
      | dhz(e::es,conslist,rest,code,hz,0)=let
        val (vA,A)=sumP(e,hz,R)
        val (vD,D)=ConsInt(R,rest@[vA])
        in dhz(es,conslist@[vD],[],code@A@D,hz,R-1)
        end
      | dhz(e::es,conslist,rest,code,hz,r)=let
        val (vA,A)=sumP(e,hz,R)
        in dhz(es,conslist,rest@[vA],code@A,hz,r-1)
    end


    fun dhy([],rest,code,hy)=   let
        val (vD,D)=ConsInt(length(rest),rest)
        in (vD,code@D) end
      | dhy(e::es,rest,code,hy)=let
         val (vA,A)=sumP(e,hy,R)
         in dhy(es,rest@[vA],code@A,hy)end

    (*Create Product by doing case analysis of the dimension*)
    in (case krnArg
        of [hx]=>let
            val [i]=imgArg
            in sumP(i,hx,R)
            end
        | [hy,hx]=>let
            val ty=DstTy.TensorTy[R]
            val (vD,code)=dhy(imgArg,[],[],hy)
            val (vE,E)=sumP(vD,hx,R)
            in
                (vE,code@E)
            end
    |   [hz,hy,hx]=>let
        val (vZ,codeZ)=dhz(imgArg,[],[],[],hz,R-1)
        val (vY,codeY)=dhy(vZ,[],[],hy)
        val (vE,E)=sumP(vY,hx,R)
        in
            (vE,codeZ@codeY@E)
        end
    | _ => raise Fail "Kernel dimensions not between 1-3"
        (*end case*))
    end 

fun  evalField(mapp,(E.Sum(sx,E.Prod e),(v,vNew),h,info))=let
    val _=(case testing
        of 0 => 1
        | _ => (print "\n\n ************** new direction **********\n\n";1)
        (*end test*))

    val (img1,k1)=sortK([],[],e)
    val (E.V sid,lb,ub)=hd(sx)
    val top=(ub-lb)
    val R=top+1

    val (imgArg,imgCode)= mkImg(mapp,sx,img1,v,vNew,info,sid,lb,ub,top,R)
    val (krnArg, krnCode)= mkkrns(mapp,sx,k1,h,info,sid,lb,ub,top,R)
    val (vA,A)=prodImgKrn(imgArg,krnArg,R)
    val _=(case testing
        of 0=> 1
        |_ =>(print ("Number of Assignments  in prodImgArg returned"^Int.toString(length(imgArg)));1)
        (*end case*))
    in
        (vA,imgCode@krnCode@A)
    end

|evalField _=raise Fail "Incorrect Field Expression"

end (* local *)

end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0