Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml
ViewVC logotype

View of /branches/charisee/src/compiler/mid-to-low/evalKrn-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3553 - (download) (annotate)
Thu Jan 7 17:39:31 2016 UTC (3 years, 7 months ago) by cchiw
File size: 11768 byte(s)
spacing
 (*evalKrn
 *Evaluate an EIN kernel expression to low-IL ops
 *iterate over the range of the support, determine differentiation, and calls segements
 *) 
structure EvalKrnSet = struct
    local
    structure DstIL = LowIL
    structure DstTy = LowILTypes
    structure Var = LowIL.Var
    structure E=Ein
    structure H=HelperSet

    in
 

    fun lookup e =H.lookup e
    fun insert e=H.insert e
    fun find e=H.find e
    fun intToReal n=H.intToReal n
    fun assgnCons e=H.assgnCons e
    fun indexTensor e=H.indexTensor e
    fun mkAddVec e=H.mkAddVec e
    fun mkSubSca e= H.mkSubSca e
    fun mkProdVec e =H.mkProdVec e
    fun mkDotVec (setO,a,b,last) =H.mkDotVec (setO,last,[a,b]) 
    fun iTos n=Int.toString n
    fun err str=raise Fail (str) 
    val realTy=DstTy.TensorTy []
    val intTy=DstTy.IntTy
    val testing=false
    fun testp n =if  (testing)  then  (print (String.concat n) ;1)  else 1
 
      (* convert a rational to a FloatLit.float value.  We do this by long division
     * with a cutoff when we get to 12 digits.
     *) 
     fun ratToFloat r =  (case Rational.explode r
         of {sign=0, ...} => FloatLit.zero false
         | {sign, num, denom=1} => FloatLit.fromInt (IntInf.fromInt sign * num) 
         | {sign, num, denom} => let
             (* normalize so that num <= denom *) 
             val (denom, exp)  = let
                 fun lp  (n, denom)  = if  (denom < num) 
                    then lp (n+1, denom*10) 
                    else  (denom, n) 
                    in
                        lp  (1, denom) 
                    end
              (* normalize so that num <= denom < 10*num *) 
             val (num, exp)  = let
                fun lp  (n, num)  = if  (10*num < denom) 
                    then lp (n-1, 10*num) 
                    else  (num, n) 
                    in
                        lp  (exp, num) 
                    end
              (* divide num/denom, computing the resulting digits *) 
             fun divLp  (n, a)  = let
                val (q, r)  = IntInf.divMod (a, denom) 
                in
                    if  (r = 0)  then  (q, []) 
                    else if  (n < 12)  then let
                        val (d, dd)  = divLp (n+1, 10*r) 
                        in
                            if  (d < 10) 
                            then  (q,  (IntInf.toInt d) ::dd) 
                            else  (q+1, 0::dd) 
                        end
                    else if  (IntInf.div (10*r, denom)  < 5) 
                        then  (q, []) 
                        else  (q+1, [])   (* round up *) 
                end
             val digits = let
                val (d, dd)  = divLp  (0, num) 
                in
                     (IntInf.toInt d) ::dd
                end
             in
                FloatLit.fromDigits{isNeg= (sign < 0) , digits=digits, exp=exp}
            end
         (* end case *) ) 
 
         
     (* expand the EvalKernel operations into vector operations.  The parameters
    * are
    *	result	-- the lhs variable to store the result
    *	d	-- the vector width of the operation, which should be equal
    *		   to twice the support of the kernel
    *	h	-- the kernel
    *	k	-- the derivative of the kernel to evaluate
    *
    * The generated code is computing
    *
    *	result = a_0 + x* (a_1 + x* (a_2 + ... x*a_n)  ... ) 
    *
    * as a d-wide vector operation, where n is the degree of the kth derivative
    * of h and the a_i are coefficient vectors that have an element for each
    * piece of h.  The computation is implemented as follows
    *
    *	m_n	= x * a_n
    *	s_{n-1}	= a_{n-1} + m_n
    *	m_{n-1}	= x * s_{n-1}
    *	s_{n-2}	= a_{n-2} + m_{n-1}
    *	m_{n-2}	= x * s_{n-2}
    *	...
    *	s_1	= a_1 + m_2
    *	m_1	= x * s_1
    *	result	= a_0 + m_1
    *
    * Note that the coeffient vectors are flipped  (cf high-to-low/probe.sml) .
    *) 
    fun expandEvalKernel  (setOrig,pre,d, h, k, x)  = let
         val {isCont, segs} = Kernel.curve  (h, k) 
          (* degree of polynomial *) 
         val deg = List.length (hd segs)  - 1
         (*segs is length 2*support, inner list is listof poynomial*) 
         val segs = Vector.fromList  (List.rev  (List.map Vector.fromList segs))
         fun coefficient d i =
                 Literal.Float (ratToFloat  (Vector.sub  (Vector.sub (segs, i) , d)))
         val ty = DstTy.vecTy d
         val coeffs = List.tabulate  (deg+1,fn i => Var.new ("P"^Int.toString i, ty))
         fun filterLit ([], vars, code, opset) = (vars, code, opset)
          | filterLit ((lhs, rhs) ::es, vars, code, opset) = let
               val (opset, var)  = lowSet.filter (opset, (lhs, rhs))
               in  (case var
                    of NONE=> filterLit (es, vars@[lhs], code@[(lhs, rhs)], opset)
                    | SOME v =>filterLit (es, vars@[v], code, opset)
                 (*end case*) ) 
              end
         val (coeffVecs,setOrig)  = let
             fun mk  (x,  (i::xs, code,opset) )  = let
                    val lits = List.tabulate (d, coefficient i) 
                    val vars = List.tabulate (d, fn _ => Var.new ("_f", DstTy.realTy))
                    val code0 = ListPair.map  (fn  (x, lit)  =>  (x, DstIL.LIT lit))   (vars, lits)
                    val (vars, code0, opset) =filterLit (code0,[],[],opset)
                    val code = code@code0@[ (x, DstIL.CONS (Var.ty x, vars))]
                    in
                         (xs, code, opset)
                    end
                val n= List.tabulate (deg+1, fn e=>e ) 
                val (a,b,c) = (List.foldr mk  (n, [],setOrig)   (List.rev coeffs))
                in
                     (b,c) 
                end
         val coeffVecsCode= List.map  (fn (x,y) =>DstIL.ASSGN  (x,y) )  coeffVecs
         fun getSet ([], done, opset, cnt) = (done, opset, cnt)
           | getSet (DstIL.ASSGN (lhs, rhs) ::es, done, opset, cnt) = let
            val (opset,var)  = lowSet.filter (opset, (lhs, rhs) ) 
                in (case var
                    of NONE => getSet (es,done@[DstIL.ASSGN (lhs, rhs) ], opset, cnt)
                    | SOME v=>   getSet (es,done@[DstIL.ASSGN (lhs,DstIL.VAR v) ], opset, cnt+1)
                    (*end case*))
                end
           | getSet  (e1::es, done, opset,cnt) =getSet (es,done@[e1],opset,cnt) 
         (*get dot product and addition of list of coeffs*) 
         fun mkdot (setO, [e2,e1], code) = let
            val (setA, vA, A) = mkProdVec (setO, d, [x,e2])
            val (setB, vB, B) = mkAddVec (setA, d, [e1,vA])
            in
                 (setB,vB,code@A@B) 
            end
          | mkdot (setO,e2::e1::es,code) = let
            val (setA, vA, A) = mkProdVec (setO, d, [x,e2])
            val (setB, vB, B) = mkAddVec (setA, d, [e1,vA])
            in
                mkdot (setB,vB::es,code@A@B) 
            end
          | mkdot  (setO, [e1], []) = mkProdVec (setO,d,[x,e1])
          | mkdot _ = raise Fail "0 or 1 item in Kernel coeffs"
        val (setC, vC, code) = mkdot (setOrig, List.rev coeffs, [])
        val _ = (String.concat["\n coeffVecs code :",Int.toString (length (coeffVecsCode) ) ," other code code: ",Int.toString (length (code))])
        in
             (setC,vC,coeffVecsCode@code) 
        end
                
                
     (*mkkrns:dict*string*E.params*Var List*sum_id list* (E.mu*E.mu)  list*Kernel*int*int*int*int
    * kernels
    * comments on functions
    *
    * evalDels:dictionary* (E.mu*E.mu) list->int
    * evaluate each delta and therefore each differentiation level for each kernel
    *
    * mkSimpleOp:dict*string*E.params*Var list*E.body
    * -> Var * LowIL.assign list
    * turn position into low-IL op
    *
    *mkpos: (E.kernel_id*int*E.pos) list*  Var list* Var list *dict*int* LowIL.assign list* int
    * -> Var * LowIL.assign list
    * bind summation indices by creating mapp and evaluate position
    *
    * consfn: Var list list*Var list*LowIL.assign list 
    * ->Var * LowIL.assign list
    * con everything on the list, makes vectors
    *
    * evalK: (E.kernel_id*int*E.pos) list* var list*int*int*param_id*LowIL.assign list
    * ->Var * LowIL.assign list
    * evaluate kernel with segments 
    *) 
    fun mkkrns (setOrig, mappOrig, lhs, params, args, krns, h, sid, lb, range0, range1) = let

        fun evalDels (mapp, dels) = List.foldl (fn (x, y) =>x+y)  0    (List.map  (fn (i, j) =>H.deltaToInt (mapp, i, j) )   dels)

        fun mkSimpleOp (setO, mapp, lhs, params, args,  E.Op2 (E.Sub,E.Tensor (t1,ix1) ,E.Value v1) ) = let
            val (setA, vA, A) = indexTensor (setO, mapp, (lhs, params, args, t1, ix1, realTy) )
            in  (case  (find (v1,mapp) ) 
                of 0=> (setA, vA, A) 
                | j=>let
                    val (setB, vB, B) = intToReal (setA, j) 
                    val (setC, vC, C) = mkSubSca (setB, [vA,vB])
                    in  (setC, vC, A@B@C)  end
                 (*end case*) ) 
            end
      
        fun mkpos (setO, k, fin, l, dict, i, code, n) =  (case  (k, i)
            of  ([], _) => (setO,fin,code)
            |  ( (_, _, pos1) ::ks,0) => let
                val _=testp ["\n insert",iTos n, "->",iTos lb]
                val mapp=insert  (n, lb)  dict
                val (opset',l', code') = mkSimpleOp (setO, mapp, lhs, params, args, pos1)
                val e=l@[l']
                val mapp'= insert  (n, 0)  dict
                in
                    mkpos (opset', ks, fin@[e], [], mapp', range0, code@code', n+1)
                end
            |  ( (_,_,pos1) ::es,_)  => let
                val _=testp ["\n insert",iTos n, "->",iTos  (lb+i) ]
                val mapp= insert  (n, lb+i)  dict
                val (opset', l', code') = mkSimpleOp (setO, mapp, lhs, params, args, pos1)
                in
                    mkpos (opset', k, fin, l@[l'], dict, i-1, code@code', n)
                end
             (*end case*) ) 

        fun consfn (setO, [], _, rest, code) = (setO, rest, code)
          | consfn (setO, e1::es, n, rest, code) = let
                val (setA,  vA, A) = assgnCons (setO,  "h"^iTos n, [length (e1)], List.rev e1)
                in
                    consfn (setA, es, n+1, rest@[vA], code@A)
                end
                
        fun evalK (setO, [], [], newId, code) = (setO, newId, code)
          | evalK (setO, kn::kns, x::xs, newId, code) = let
                val (_, dk, pos)  = kn
                val directionX=  (case pos
                    of E.Op2 (E.Sub,E.Tensor  (_,[E.C directionX]) ,_) => directionX
                    | _ => 0
                     (*end case*) ) 
                val name=String.concat["h",iTos directionX,"_",iTos dk]
                val _ =testp["\n",Var.toString x," = ",name]
                val (opsetK, id, kcode) = expandEvalKernel  (setO, name, range1, h, dk, x)
                in
                    evalK (opsetK, kns, xs, newId@[id], code@kcode)
                end
          | evalK _ =raise Fail "Non-equal variable list, error in mkKrns"
                
        val newkrns = List.map  (fn  (id, d1, pos) => (id, evalDels (mappOrig,d1) , pos) )  krns
        val _ = testp["\n\n ***** Differentiation value of kernels ****\n "]
        val (set2, lftkrn, poscode) = mkpos (setOrig, newkrns, [], [], mappOrig, range0, [], sid)
        val (set3, lft, conscode) = consfn (set2, lftkrn, 0, [], [])
        val (set4, ids, evalKcode) = evalK (set3, newkrns, lft, [], [])
        val _ = List.map  (fn e=>testp["\n IDS",Var.toString e,","])  ids
         (*returns list in order h0, h1, h2*) 
        in
             (set4, ids, poscode@conscode@evalKcode) 
        end
                

end  (* local *) 

end 

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0