Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/mid-to-low/evalKrn-set.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/mid-to-low/evalKrn-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3395 - (download) (annotate)
Tue Nov 10 18:23:07 2015 UTC (3 years, 8 months ago) by cchiw
File size: 12817 byte(s)
val-num in evalkrn
(*evalKrn
 *Evaluate an EIN kernel expression to low-IL ops
 *iterate over the range of the support, determine differentiation, and calls segements
 *)
structure EvalKrnSet = struct
    local
    structure DstIL = LowIL
    structure DstTy = LowILTypes
    structure Var = LowIL.Var
    structure E=Ein
    structure H=HelperSet

    in
 
    val testing=0
    fun lookup e =H.lookup e
    fun insert e=H.insert e
    fun find e=H.find e
    fun intToReal n=H.intToReal n
    fun assgnCons e=H.assgnCons e
    fun indexTensor e=H.indexTensor e
    fun mkAddVec e=H.mkAddVec e
    fun mkSubSca e= H.mkSubSca e
    fun mkProdVec e =H.mkProdVec e
    fun mkDotVec(setO,a,b,last)=H.mkDotVec(setO,last,[a,b])

    fun iTos n=Int.toString n
    fun err str=raise Fail(str)
    val realTy=DstTy.TensorTy []
    val intTy=DstTy.IntTy
    fun testp n =(case testing
         of 0 => 1
         | _  => (print(String.concat n);1)
         (*end case *))
 
     (* convert a rational to a FloatLit.float value.  We do this by long division
     * with a cutoff when we get to 12 digits.
     *)
     fun ratToFloat r = (case Rational.explode r
         of {sign=0, ...} => FloatLit.zero false
         | {sign, num, denom=1} => FloatLit.fromInt(IntInf.fromInt sign * num)
         | {sign, num, denom} => let
            (* normalize so that num <= denom *)
             val (denom, exp) = let
                 fun lp (n, denom) = if (denom < num)
                    then lp(n+1, denom*10)
                    else (denom, n)
                    in
                        lp (1, denom)
                    end
             (* normalize so that num <= denom < 10*num *)
             val (num, exp) = let
                fun lp (n, num) = if (10*num < denom)
                    then lp(n-1, 10*num)
                    else (num, n)
                    in
                        lp (exp, num)
                    end
             (* divide num/denom, computing the resulting digits *)
             fun divLp (n, a) = let
                val (q, r) = IntInf.divMod(a, denom)
                in
                    if (r = 0) then (q, [])
                    else if (n < 12) then let
                        val (d, dd) = divLp(n+1, 10*r)
                        in
                            if (d < 10)
                            then (q, (IntInf.toInt d)::dd)
                            else (q+1, 0::dd)
                        end
                    else if (IntInf.div(10*r, denom) < 5)
                        then (q, [])
                        else (q+1, []) (* round up *)
                end
             val digits = let
                val (d, dd) = divLp (0, num)
                in
                    (IntInf.toInt d)::dd
                end
             in
                FloatLit.fromDigits{isNeg=(sign < 0), digits=digits, exp=exp}
            end
        (* end case *))
 
         
    (* expand the EvalKernel operations into vector operations.  The parameters
    * are
    *	result	-- the lhs variable to store the result
    *	d	-- the vector width of the operation, which should be equal
    *		   to twice the support of the kernel
    *	h	-- the kernel
    *	k	-- the derivative of the kernel to evaluate
    *
    * The generated code is computing
    *
    *	result = a_0 + x*(a_1 + x*(a_2 + ... x*a_n) ... )
    *
    * as a d-wide vector operation, where n is the degree of the kth derivative
    * of h and the a_i are coefficient vectors that have an element for each
    * piece of h.  The computation is implemented as follows
    *
    *	m_n	= x * a_n
    *	s_{n-1}	= a_{n-1} + m_n
    *	m_{n-1}	= x * s_{n-1}
    *	s_{n-2}	= a_{n-2} + m_{n-1}
    *	m_{n-2}	= x * s_{n-2}
    *	...
    *	s_1	= a_1 + m_2
    *	m_1	= x * s_1
    *	result	= a_0 + m_1
    *
    * Note that the coeffient vectors are flipped (cf high-to-low/probe.sml).
    *)
    fun expandEvalKernel (setOrig,pre,d, h, k, x) = let
    val _="\n*********************\n"
         val {isCont, segs} = Kernel.curve (h, k)
         (* degree of polynomial *)
         val deg = List.length(hd segs) - 1
        (*segs is length 2*support, inner list is listof poynomial*)
         val segs = Vector.fromList (List.rev (List.map Vector.fromList segs))
         fun coefficient d i =
                 Literal.Float(ratToFloat (Vector.sub (Vector.sub(segs, i), d)))
         val ty = DstTy.vecTy d
         val coeffs = List.tabulate (deg+1,fn i => Var.new("P"^Int.toString i, ty))

        (* code to define the coefficient vectors *)
        (*
        val coeffVecs = let
            fun mk (x, (i, code)) = let
                val lits = List.tabulate(d, coefficient i)
                val vars = List.tabulate(d, fn _ => Var.new("_f", DstTy.realTy))
              (*  val code =
                    ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits) @(x, DstIL.CONS(Var.ty x, vars)) :: code
                    *)
                    
                    val code0 =ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits)
                    val code =code0@(x, DstIL.CONS(Var.ty x, vars)) :: code
                    val _ = List.map (fn e=>print(LowToString.toStringAll(LowILTypes.realTy, DstIL.ASSGN e))) [(x,DstIL.CONS(Var.ty x, vars))]
                in
                    (i-1, code)
                end
            in
                #2 (List.foldr mk (deg, []) coeffs)
            end
            
            
         *)
         
         fun filterLit([],vars,code,opset)=(vars,code,opset)
         | filterLit((lhs,rhs)::es,vars,code,opset)=let

           val (opset,var) = lowSet.filter(opset,(lhs,rhs))
           in (case var
                of NONE=> filterLit(es,vars@[lhs], code@[(lhs, rhs)],opset)
                | SOME v =>filterLit(es,vars@[v],code,opset)
            (*end case*))
          end
          
         val (coeffVecs,setOrig) = let
         fun mk (x, (i::xs, code,opset)) = let
                val lits = List.tabulate(d, coefficient i)
                val vars = List.tabulate(d, fn _ => Var.new("_f", DstTy.realTy))
       
                
                val code0 =ListPair.map (fn (x, lit) => (x, DstIL.LIT lit)) (vars, lits)
                val (vars,code0,opset)=filterLit(code0,[],[],opset)
                
                val code =code@code0@[(x, DstIL.CONS(Var.ty x, vars))]
                
            
                
                val _ = List.map (fn e=>(LowToString.toStringAll(LowILTypes.realTy, DstIL.ASSGN e))) [(x,DstIL.CONS(Var.ty x, vars))]
                in
                    (xs, code,opset)
                end
            val n=List.tabulate(deg+1, fn e=>e )
            val (a,b,c)=(List.foldr mk (n, [],setOrig) (List.rev coeffs))
            in
                (b,c)
            end
         
         
         
         
         val coeffVecsCode=List.map (fn(x,y)=>DstIL.ASSGN (x,y)) coeffVecs
         
         
         fun getSet([],done,opset,cnt)=(done,opset,cnt)
           | getSet( DstIL.ASSGN(lhs,rhs)::es,done,opset,cnt)=let
            val (opset,var) = lowSet.filter(opset,(lhs,rhs))
                in  (case var
                    of NONE => getSet(es,done@[DstIL.ASSGN(lhs,rhs)], opset,cnt)
                        | SOME v=> (("TASH:replacing"^DstIL.Var.toString(lhs));getSet(es,done@[DstIL.ASSGN(lhs,DstIL.VAR v)], opset,cnt+1))
                            (*end case*))
                end
           | getSet (e1::es, done, opset,cnt)=getSet(es,done@[e1],opset,cnt)
         (*
         val (coeffVecsCode,setOrig,cnt)=getSet(coeffVecsCode, [],setOrig,0)
        
         val _ =("\nTASHreplaced"^Int.toString(cnt))
         *)
     
     

     val _ = List.map (fn e=>(LowToString.toStringAll(LowILTypes.realTy,e))) coeffVecsCode
     
     
        (*get dot product and addition of list of coeffs*)
         fun mkdot(setO,[e2,e1],code)=let
            val (setA,vA,A)= mkProdVec(setO,d,[x,e2])
            val (setB,vB,B)=mkAddVec(setA,d,[e1,vA])
            in
                (setB,vB,code@A@B)
            end
          | mkdot(setO,e2::e1::es,code)= let
            val (setA,vA,A)= mkProdVec(setO,d,[x,e2])
            val (setB,vB,B)=mkAddVec(setA,d,[e1,vA])
            in
                mkdot(setB,vB::es,code@A@B)
            end
          | mkdot (setO,[e1],[])= mkProdVec(setO,d,[x,e1])
          | mkdot _ = raise Fail "0 or 1 item in Kernel coeffs"
        val (setC,vC,code)= mkdot(setOrig,List.rev coeffs, [])
        val _ =(String.concat["\n coeffVecs code :",Int.toString(length(coeffVecsCode))," other code code: ",Int.toString(length(code))])
        in
            (setC,vC,coeffVecsCode@code)
        end
                
                
    (*mkkrns:dict*string*E.params*Var List*sum_id list*(E.mu*E.mu) list*Kernel*int*int*int*int
    * kernels
    * comments on functions
    *
    * evalDels:dictionary*(E.mu*E.mu)list->int
    * evaluate each delta and therefore each differentiation level for each kernel
    *
    * mkSimpleOp:dict*string*E.params*Var list*E.body
    * -> Var * LowIL.assign list
    * turn position into low-IL op
    *
    *mkpos:(E.kernel_id*int*E.pos)list*  Var list* Var list *dict*int* LowIL.assign list* int
    * -> Var * LowIL.assign list
    * bind summation indices by creating mapp and evaluate position
    *
    * consfn: Var list list*Var list*LowIL.assign list 
    * ->Var * LowIL.assign list
    * con everything on the list, makes vectors
    *
    * evalK:(E.kernel_id*int*E.pos)list* var list*int*int*param_id*LowIL.assign list
    * ->Var * LowIL.assign list
    * evaluate kernel with segments 
    *)
    fun mkkrns(setOrig,mappOrig,lhs,params,args,  krns,h,sid,lb,range0,range1)=let

        fun evalDels(mapp,dels)=List.foldl(fn(x,y)=>x+y) 0   (List.map (fn(i,j)=>H.deltaToInt(mapp,i,j))  dels)

        fun mkSimpleOp(setO,mapp,lhs,params,args,  E.Sub(E.Tensor(t1,ix1),E.Value v1))=let
            val (setA,vA,A)=indexTensor(setO,mapp,(lhs,params,args,  t1,ix1,realTy))
            in (case (find(v1,mapp))
                of 0=>(setA,vA,A)
                | j=>let
                    val (setB,vB,B)= intToReal(setA, j)
                    val (setC,vC,C)=mkSubSca(setB, [vA,vB])
                    in (setC,vC,A@B@C) end
                (*end case*))
            end
      
        fun mkpos(setO,k,fin,l,dict, i,code,n)= (case (k,i)
            of ([],_)=>(setO,fin,code)
            | ((_,_,pos1)::ks,0)=>let
                val _=testp ["\n insert",iTos n, "->",iTos lb]
                val mapp=insert (n, lb) dict
                val (opset',l', code')=mkSimpleOp(setO,mapp,lhs,params,args,  pos1)
                val e=l@[l']
                val mapp'=insert (n, 0) dict
                in
                    mkpos(opset',ks,fin@[e],[],mapp',range0,code@code',n+1)
                end
            | ((_,_,pos1)::es,_) =>let
                val _=testp ["\n insert",iTos n, "->",iTos (lb+i)]
                val mapp= insert (n, lb+i) dict
                val (opset',l', code')=mkSimpleOp(setO,mapp,lhs,params,args,  pos1)
                in
                    mkpos(opset',k,fin,l@[l'],dict,i-1,code@code',n)
                end
            (*end case*))

        fun consfn(setO,[],_,rest, code)=(setO,rest,code)
          | consfn(setO,e1::es,n,rest,code)=let
                val (setA,  vA,A)= assgnCons(setO,  "h"^iTos n,[length(e1)],List.rev e1)
                in
                    consfn(setA,es, n+1,rest@[vA], code@A)
                end
                
        fun evalK(setO,[],[],newId,code)=(setO,newId,code)
          | evalK(setO,kn::kns,x::xs,newId,code)=let
                val (_,dk,pos) =kn
                val directionX= (case pos
                    of E.Sub(E.Tensor (_,[E.C directionX]),_)=>directionX
                    | _ => 0
                    (*end case*))
                val name=String.concat["h",iTos directionX,"_",iTos dk]
                val _ =testp["\n",Var.toString x," = ",name]
                val (opsetK,id,kcode)= expandEvalKernel (setO,name,range1,h,dk, x)
                in
                    evalK(opsetK,kns,xs,newId@[id],code@kcode)
                end
          | evalK _ =raise Fail "Non-equal variable list, error in mkKrns"
                
        val newkrns= List.map (fn (id,d1,pos)=>(id,evalDels(mappOrig,d1),pos)) krns
        val _ = testp["\n\n ***** Differentiation value of kernels ****\n "]
        val (set2,lftkrn,poscode)=mkpos(setOrig,newkrns,[],[],mappOrig,range0,[],sid)
        val (set3,lft,conscode)=consfn(set2,lftkrn,0,[],[])
        val (set4,ids, evalKcode)=evalK(set3,newkrns, lft,[],[])
        val _ = List.map (fn e=>testp["\n IDS",Var.toString e,","]) ids
        (*returns list in order h0, h1, h2*)
        in
            (set4, ids, poscode@conscode@evalKcode)
        end
                

end (* local *)

end 

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0