Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/ein16/src/compiler/high-il/derivative-ein.sml
ViewVC logotype

View of /branches/ein16/src/compiler/high-il/derivative-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 4236 - (download) (annotate)
Wed Jul 20 03:02:00 2016 UTC (3 years, 6 months ago) by cchiw
File size: 8044 byte(s)
added generic cases for trace|det and added test cases
structure derivativeEin = struct

    local

    structure E = Ein
    structure P=Printer
    structure F=Filter
    structure G=EpsHelpers
    structure Eq=EqualEin
    structure R=RationalEin

    in

    val testing=0
    fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])
    fun mkProd e= F.mkProd e
    fun filterField e=F.filterField e
    fun flatProd e=F.rewriteProd e

    fun setConst e = E.setConst e
    fun setNeg e  =  E.setNeg e
    fun setExp e  =  E.setExp e
    fun setDiv e= E.setDiv e
    fun setSub e= E.setSub e
    fun setProd e= E.setProd e
    fun setAdd e= E.setAdd e

    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))


    (*prodAppPartia:ein_exp list * mu list ->ein_exp
    * chain rule 
    *)
    fun prodAppPartial(es,p0)=(case es
        of []      => err "Empty App Partial"
        | [e1]     => E.Apply(p0,e1)
        | (e1::e2) => let

            val l= prodAppPartial(e2,p0)
            val (_,e2')= mkProd[e1,l]
            val (_,e1')=mkProd(e2@ [E.Apply(p0, e1)])
            in
                setAdd[e1',e2']
            end
        (* end case *))

    fun applyop1(op1,e1,dx)=
        let
            val (d0::dn)=dx
            val px = E.Partial dx
            val inner=E.Apply(E.Partial [d0],e1)
            val square = setProd [e1,e1]
            val one= E.B(E.Const 1)
            val e2 = setDiv(E.B(E.Const 1), E.Op1(E.Sqrt, E.Op2(E.Sub,one, square)))
            fun iterDn e2=(case dn
                of [] => (1,e2)
                | _   =>  (1,E.Apply(E.Partial dn, e2))
                (*end case*))
        in (case op1
            of E.Neg                 => (1,setNeg(E.Apply(px,e1)))
            | E.Exp                  => iterDn(setProd[inner,setExp e1])
            | E.Sqrt                 => let
                val half = setDiv(setConst 1 ,setConst 2)
                val e3=    setDiv(inner,E.Op1(op1, e1))
                in (case dn
                    of [] =>  (1,setProd [half,e3])
                    | _   =>  (1,setProd [half,E.Apply(E.Partial dn, e3)])
                    (*end case*))
                end
            | E.PowInt n             =>
                    iterDn(setProd[setConst n,E.Op1(E.PowInt(n-1),e1),inner])
            | E.PowEmb(sx1,n1)       => iterDn(setDiv(E.Sum(sx1,setProd[e1,inner]),E.Op1(E.PowEmb(sx1,n1),e1)))
            | E.Cosine               => iterDn(setProd [setNeg(E.Op1(E.Sine ,e1)),inner])
            | E.Sine                 => iterDn(setProd [E.Op1(E.Cosine ,e1),inner])
            | E.ArcCosine            => iterDn(setProd [setNeg e2,inner])
            | E.ArcSine              => iterDn(setProd [e2,inner])
            | E.Tangent              => iterDn(setProd [setDiv(one,setProd[E.Op1(E.Cosine, e1),E.Op1(E.Cosine, e1)]),inner])
            | E.ArcTangent           => iterDn(setProd [setDiv(one,setAdd[one,square]),inner])
        (*end case*))
        end

    fun applyop2(op2,e1,e2,dx)=
        let
            val _ = print(String.concat["\n\n applyop2 ",P.printbody(e1),"/",P.printbody(e2)])
            val (d0::dn)=dx
            val p0=E.Partial [d0]
            val inner1=E.Apply(E.Partial [d0],e1)
            val inner2=E.Apply(E.Partial [d0],e2)
            val zero= E.B(E.Const 0)
            fun iterDn e2=(case dn
                of [] =>  e2
                | _   =>  E.Apply(E.Partial dn, e2)
                (*end case*))
        val op2'= (case op2
            of E.Sub                        =>  setSub(inner1,inner2)
            | E.Div                         =>(case (e1,e2)
of (_,E.B(E.Const e2))      =>  (print "div-a";setDiv(inner1,E.B(E.Const e2)))
                | (_, E.Lift _ )      =>   (print "div-b";setDiv(inner1, e2))
                |  (E.B(E.Const 1),_)       =>
                    ( print "div-c";case filterField[e2]
                        of (_,[])           => zero
                        | (pre,h)           => let
                            (* Quotient Rule*)
                            val h'=E.Apply(p0,flatProd(h))
                            val num= setProd [setConst( ~1),h']
                            val e2=setDiv(num,setProd(pre@h@h))
                            in iterDn e2 end
                    (*end case*))
                | (E.B(E.Const c),_)        =>
                    ( print "div-d";case filterField[e2]
                        of (_,[])           =>  zero
                        | (pre,h)           => let
                            (* Quotient Rule*)
                            val h'=E.Apply(p0,flatProd(h))
                            val num= setNeg(setProd[setConst c ,h'])
                            val e2= setDiv(num,setProd(pre@h@h))
                            in iterDn(e2)
                        end
                    (*end case*))
                | _                         =>
                    ( print "div-e";case filterField[e2]
                        of (_,[])           => setDiv(inner1,e2) (*Division by a real*)
                        | (pre,h)               => let
                            (* Quotient Rule*)
                            val g'=inner1
                            val h'=E.Apply(p0,flatProd(h))
                            val num=setSub(setProd([g']@h),setProd[e1,h'])
                            val e2=setDiv(num,setProd(pre@h@h))
                            in iterDn e2 end
                    (*end case*))
                (*end case*))
            (*end case*))
            val _ = print(String.concat["\n\n ===> ",P.printbody(op2')])
        in (1,op2')
        end

    fun applyopn(opn,es,dx)=         let
        val (d0::dn)=dx
        val p0=E.Partial [d0]
        fun iterDn e2=(case dn
            of [] => (1,e2)
            | _   =>  (1,E.Apply(E.Partial dn, e2))
            (*end case*))
        in (case opn
            of E.Add                  =>
                (1,setAdd (List.map (fn(a)=>E.Apply(E.Partial dx,a)) es))
            | E.Prod                  => let
                val (pre, post)= filterField es
                in (case post
                    of []=> (1,setConst 0)(*no fields in expression*)
                    | _=>let
                        val (_,e2)= mkProd(pre@[prodAppPartial(post,p0)])
                        in iterDn e2 end
                    (*end case*))
                end
            (*end case*))
        end

    (* mkapply:mu list*ein_exp->int*ein_exp
    * rewrite Apply
    *)
    fun mkapply(px as E.Partial dx,e)=let
        val (d0::dn)=dx
        val p0=E.Partial [d0]

        fun iterDn e2=(case dn
            of [] => (1,e2)
            | _   =>  (1,E.Apply(E.Partial dn, e2))
        (*end case*))

        val zero =E.B(E.Const 0)
        val (c,g) =(case e
            of E.B _                => (1,zero)
            | E.Tensor _                => err("Tensor without Lift")
            | E.G  _                    => (1,zero)
            | E.Field _                 => (0,e)
            | E.Lift _                  => (1,zero)
            | E.Conv(v,alpha,h,d2)      => (1,E.Conv(v,alpha,h,d2@dx))
            | E.Partial _               => err("Apply of Partial")
            | E.Apply(E.Partial d2,e2)  => (1,E.Apply(E.Partial(dx@d2),e2))
            | E.Apply _                 => err" Apply of non-Partial expression"
            | E.Probe _                 => err("Apply of Probe")
            | E.Value _                 => err("Value used before expand")
            | E.Img _                   => err("Probe used before expand")
            | E.Krn _                   => err("Krn used before expand")
            | E.Sum(sx,e1)              => (1,E.Sum(sx,E.Apply(px,e1)))
            | E.Op1(op1,e1)             => applyop1(op1,e1,dx)
            | E.Op2(op2,e1,e2)          => applyop2(op2,e1,e2,dx)
            | E.Opn(opn,es)             => applyopn(opn,es,dx)
            (*end case*))
       (* val _= testp["\nInApply:",P.printbody e,"-->",P.printbody g]*)
    in
        (c,g)
    end
end

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0