Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/high-il/derivative-ein.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/high-il/derivative-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3353 - (download) (annotate)
Wed Oct 28 23:08:21 2015 UTC (3 years, 8 months ago) by cchiw
File size: 5829 byte(s)
sync files
structure derivativeEin = struct

    local

    structure E = Ein
    structure P=Printer
    structure F=Filter
    structure G=EpsHelpers
    structure Eq=EqualEin
    structure R=RationalEin

    in

    val testing=0
    fun err str=raise Fail (String.concat["Ill-formed EIN Operator",str])
    fun mkProd e= F.mkProd e
    fun filterField e=F.filterField e
    fun flatProd e=F.rewriteProd e


    fun testp n=(case testing
        of 0=> 1
        | _ =>(print(String.concat n);1)
        (*end case*))

    (*prodAppPartia:ein_exp list * mu list ->ein_exp
    * chain rule 
    *)
    fun prodAppPartial(es,p0)=(case es
        of []      => err "Empty App Partial"
        | [e1]     => E.Apply(p0,e1)
        | (e1::e2) => let

            val l= prodAppPartial(e2,p0)
            val (_,e2')= mkProd[e1,l]
            val (_,e1')=mkProd(e2@ [E.Apply(p0, e1)])
            in
                E.Add[e1',e2']
            end
        (* end case *))

    (* mkapply:mu list*ein_exp->int*ein_exp
    * rewrite Apply
    *)
    fun mkapply(px as E.Partial dx,e)=let
        val (d0::dn)=dx
        val p0=E.Partial [d0]

        fun iterDn e2=(case dn
            of [] => (1,e2)
            | _   =>  (1,E.Apply(E.Partial dn, e2))
            (*end case*))

        val (c,g) =(case e
            of E.Tensor _               => err("Tensor without Lift")
            | E.Partial _               => err("Apply of Partial")
            | E.Krn _                   => err("Krn used before expand")
            | E.Value _                 => err("Value used before expand")
            | E.Img _                   => err("Probe used before expand")
            | E.Prod []                 => err("Apply of empty product")
            | E.Add []                  => err("Apply of empty Addition")
            | E.Const _                 => (1,E.Const 0)
            | E.ConstR _                => (1,E.Const 0)
            | E.Lift _                  => (1,E.Const 0)
            | E.Delta _                 => (1,E.Const 0)
            | E.Epsilon _               => (1,E.Const 0)
            | E.Eps2 _                  => (1,E.Const 0)
            | E.Field _                 => (0,e)
            | E.Probe _                 => (0,e)
            | E.Conv(v,alpha,h,d2)      => (1,E.Conv(v,alpha,h,d2@dx))
            | E.Apply(E.Partial d2,e2)  => (1,E.Apply(E.Partial(dx@d2),e2))
            | E.Apply _                 => err" Apply of non-Partial expression"
            | E.Sum(sx,e1)              => (1,E.Sum(sx,E.Apply(px,e1)))
            | E.Neg e1                  => (1,E.Neg(E.Apply(px,e1)))
            | E.Add es                  =>
                (1,E.Add (List.map (fn(a)=>E.Apply(px,a)) es))
            | E.Sub (e1,e2)             => (1,E.Sub(E.Apply(px,e1),E.Apply(px,e2)))
            | E.Cosine e1               =>
                iterDn(E.Prod[E.Neg(E.Sine e1),E.Apply(p0,e1)])
            | E.Sine e1                 =>
                iterDn(E.Prod[E.Cosine e1,E.Apply(p0,e1)])
            | E.ArcCosine e1            => 
                iterDn(E.Prod[E.Neg(E.Div(E.Const 1, E.Sqrt(E.Sub(E.Const 1, E.Prod[e1,e1])))),E.Apply(p0,e1)])
            | E.ArcSine e1              =>
                iterDn(E.Prod[(E.Div(E.Const 1, E.Sqrt(E.Sub(E.Const 1, E.Prod[e1,e1])))),E.Apply(p0,e1)])
            | E.Sqrt s                  => let
                val half=E.Div(E.Const 1 ,E.Const 2)
                val e2=E.Div(E.Apply(p0,s),e)

                in (case dn
                    of []=>  (1,E.Prod[half, e2])
                    | _  =>  (1,E.Prod[half,E.Apply(E.Partial dn, e2)])
                    (*end case*))
                end
            | E.Prod p                  => let
                (*Product Rule*)
                val (pre, post)= filterField p
                in (case post
                    of []=> (1,E.Const 0)(*no fields in expression*)
                    | _=>let

                        val (_,e2)= mkProd(pre@[prodAppPartial(post,p0)])
                        in iterDn(e2)
                        end
                    (*end case*))
                end
            | E.Div (e1,E.Const e2)     => (1,E.Div(E.Apply(px,e1),E.Const e2))
            | E.Div (E.Const 1,b)       =>
                (case filterField[b]
                    of (_,[]) => (1,E.Const 0)
                    | (pre,h) => let

                            (* Quotient Rule*)
                            val h'=E.Apply(p0,flatProd(h))
                            val num= E.Prod[E.Const ~1,h']
                            val e2=E.Div(num,E.Prod(pre@h@h))
                            in iterDn(e2)
                            end
                    (*end case*))
            | E.Div (E.Const c,b) =>
                (case filterField[b]
                    of (_,[]) => (1,E.Const 0)
                    | (pre,h) => let
                        (* Quotient Rule*)

                        val h'=E.Apply(p0,flatProd(h))
                        val num= E.Neg(E.Prod[E.Const c ,h'])
                        val e2=E.Div(num,E.Prod(pre@h@h))
                        in iterDn(e2)
                        end
                    (*end case*))
            | E.Div (g,b) =>
                (case filterField[b]
                    of (_,[]) => (1,E.Div(E.Apply(px,g),b)) (*Division by a real*)
                    | (pre,h) => let
                        (* Quotient Rule*)
            
                        val g'=E.Apply(p0,g)
                        val h'=E.Apply(p0,flatProd(h))
                        val num=E.Sub(E.Prod([g']@h),E.Prod[g,h'])
                        val e2=E.Div(num,E.Prod(pre@h@h))
                        in iterDn(e2)
                            end
                    (*end case*))
            (*end case*))
        in
            (c,g)
        end
end

end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0