Home My Page Projects Code Snippets Project Openings diderot

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/ein/deriv-ein.sml
 [diderot] / branches / vis15 / src / compiler / ein / deriv-ein.sml

Diff of /branches/vis15/src/compiler/ein/deriv-ein.sml

revision 5571, Wed May 30 22:11:07 2018 UTC revision 5572, Thu May 31 12:54:04 2018 UTC
# Line 11  Line 11
11
12  structure DerivativeEin : sig  structure DerivativeEin : sig
13
15      val differentiate: Ein.mu list * Ein.ein_exp -> Ein.ein_exp      val differentiate: Ein.mu list * Ein.ein_exp -> Ein.ein_exp
16
17    end  = struct    end  = struct
18
19      structure E = Ein      structure E = Ein
20
21      fun err str = raise Fail (String.concat["Ill-formed EIN Operator: ", str])      fun err msg = raise Fail ("Ill-formed EIN Operator: " ^ msg)
22
24      fun mkSub (e1, e2) = E.Op2(E.Sub, e1, e2)      fun mkSub (e1, e2) = E.Op2(E.Sub, e1, e2)
# Line 61  Line 62
62              iterA(es, [])              iterA(es, [])
63            end            end
64

65    (* chain rule *)    (* chain rule *)
66      fun prodAppPartial ([], _) = err "Empty App Partial"      fun prodAppPartial ([], _) = err "Empty App Partial"
67        | prodAppPartial ([e1], p0) = E.Apply(p0, e1)        | prodAppPartial ([e1], p0) = E.Apply(p0, e1)
# Line 82  Line 82
82            val half = mkDiv (E.Const 1, E.Const 2)            val half = mkDiv (E.Const 1, E.Const 2)
83            val square = mkProd [e1, e1]            val square = mkProd [e1, e1]
84            val e2 = mkDiv(one, E.Op1(E.Sqrt, mkSub(one, square)))            val e2 = mkDiv(one, E.Op1(E.Sqrt, mkSub(one, square)))
85            val ee = case op1            in
86                case op1
87               of E.Neg       => mkNeg del               of E.Neg       => mkNeg del
88                | E.Exp       => mkProd [del, E.Op1(E.Exp, e1)]                | E.Exp       => mkProd [del, E.Op1(E.Exp, e1)]
89                | E.Sqrt      => let                | E.Sqrt => mkProd [half, mkDiv (del, E.Op1(op1, e1))]
val e3 = mkDiv (del, E.Op1(op1, e1))
in mkProd [half, e3]
end
90                | E.Cosine    => mkProd [mkNeg (E.Op1(E.Sine, e1)), del]                | E.Cosine    => mkProd [mkNeg (E.Op1(E.Sine, e1)), del]
91                | E.ArcCosine =>  mkProd [mkNeg e2, del]                | E.ArcCosine =>  mkProd [mkNeg e2, del]
92                | E.Sine      =>  mkProd [E.Op1(E.Cosine, e1), del]                | E.Sine      =>  mkProd [E.Op1(E.Cosine, e1), del]
93                | E.ArcSine   =>  mkProd [e2, del]                | E.ArcSine   =>  mkProd [e2, del]
94                | E.Tangent   =>                | E.Tangent => mkProd [
95                    mkProd [mkDiv(one, mkProd[E.Op1(E.Cosine, e1), E.Op1(E.Cosine, e1)]), del]                      mkDiv(one, mkProd[E.Op1(E.Cosine, e1), E.Op1(E.Cosine, e1)]), del
96                | E.ArcTangent=>                    ]
97                    mkProd [mkDiv(one, mkAdd[one, square]), del]                | E.ArcTangent => mkProd [mkDiv(one, mkAdd[one, square]), del]
98                | E.PowInt n  => mkProd [E.Const n, E.Op1(E.PowInt(n-1), e1), del]                | E.PowInt n  => mkProd [E.Const n, E.Op1(E.PowInt(n-1), e1), del]
99                | E.Abs       => mkProd [del,  E.Op1(E.Sgn, e1)]                | E.Abs       => mkProd [del,  E.Op1(E.Sgn, e1)]
100                | E.Sgn       => raise Fail "unhandled case: ask charisee"                | E.Sgn       => raise Fail "unhandled case: ask charisee"
101            (* end case *)            (* end case *)
in ee
102          end          end
103
104      (*apply derivative with apply expression*)      (*apply derivative with apply expression*)
# Line 110  Line 107
107              val del = E.Apply(E.Partial[d0], e1)              val del = E.Apply(E.Partial[d0], e1)
108              fun iterDn e = if null dn then e else E.Apply(E.Partial dn, e)              fun iterDn e = if null dn then e else E.Apply(E.Partial dn, e)
109              val single = applyOp1Single (op1, e1, del)              val single = applyOp1Single (op1, e1, del)
110              (* end case *)              in
111              in iterDn single                iterDn single
112              end              end
113  (*---------------------------------------------------------------------------------------------------------*)  (*---------------------------------------------------------------------------------------------------------*)
114      fun applyOp2Single (op2, e1, dele1, e2, dele2) = (case op2      fun applyOp2Single (op2, e1, dele1, e2, dele2) = (case op2
# Line 129  Line 126
126          val dele2 = E.Apply(E.Partial[d0], e2)          val dele2 = E.Apply(E.Partial[d0], e2)
127          fun iterDn e = if null dn then e else E.Apply(E.Partial dn, e)          fun iterDn e = if null dn then e else E.Apply(E.Partial dn, e)
128          val single = applyOp2Single (op2, e1, dele1, e2, dele2)          val single = applyOp2Single (op2, e1, dele1, e2, dele2)
129          (* end case *)            in
130          in iterDn single              iterDn single
131      end      end
132  (*---------------------------------------------------------------------------------------------------------*)  (*---------------------------------------------------------------------------------------------------------*)
133      (* differentiate *)      (* differentiate *)
# Line 144  Line 141
141              | E.Eps2 _              => E.Const 0              | E.Eps2 _              => E.Const 0
142              | E.Field _             => body              | E.Field _             => body
143              | E.Tensor _            => E.Const 0              | E.Tensor _            => E.Const 0
144              | E.Poly (e1, n, dx)  => E.Poly(e1, n, dx@px)              | E.Poly(e, n, dx) => E.Poly(e, n, dx@px)
145              | E.Lift(e1)            => E.Lift(differentiate(px, e1))              | E.Lift e => E.Lift(differentiate(px, e))
146              | E.Sum(op1, e1)        => let              | E.Sum(op1, e) => (case differentiate(px, e)
val e2 = differentiate(px, e1)
in (case e2
147                      of E.Opn(E.Add, ps) => iterAA(List.map (fn e1=>E.Sum(op1, e1)) ps)                      of E.Opn(E.Add, ps) => iterAA(List.map (fn e1=>E.Sum(op1, e1)) ps)
148                      | _                 => E.Sum(op1, e2)                    | e' => E.Sum(op1, e')
149                  (*end case*))                  (*end case*))
150                  end              | E.Op1(op1, e1) => applyOp1Single(op1, e1, differentiate(px, e1))
151              | E.Op1(op1, e1) =>     (* applyOp1 (op1, e1, px)*)              | E.Op2(op2, e1, e2) =>
applyOp1Single(op1, e1, differentiate(px, e1))
| E.Op2(op2, e1, e2) =>  (*applyOp2 (op2, e1, e2, px)*)
152                  applyOp2Single (op2, e1, differentiate(px, e1), e2,differentiate(px, e2))                  applyOp2Single (op2, e1, differentiate(px, e1), e2,differentiate(px, e2))
153              | E.Opn(E.Prod, [e1])        => raise Fail(EinPP.expToString(e1))              | E.Opn(E.Prod, [e1]) => raise Fail(EinPP.expToString e1)
154              | E.Opn(E.Prod, e1::es)        =>  let              | E.Opn(E.Prod, e1::es)        =>  let
155                  val (d0::dn) = px                  val (d0::dn) = px
156                  val e1' = differentiate ([d0], e1)                  val e1' = differentiate ([d0], e1)
# Line 166  Line 159
159                  val B = iterPP(e1'::es)                  val B = iterPP(e1'::es)
160                  val e = iterAA([A,B])                  val e = iterAA([A,B])
161                  fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)                  fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
162                  in iterDn e  end                  in
163                      iterDn e
164                    end
165              | E.Opn(opn, es)        =>             let              | E.Opn(opn, es)        =>             let
166                  val xx = List.map (fn e1=> differentiate (px, e1)) es                  val xx = List.map (fn e1=> differentiate (px, e1)) es
167                  in iterAA(xx) end                  in
168              | _    => raise Fail(EinPP.expToString(body))                    iterAA(xx)
169                    end
170                | _ => raise Fail(EinPP.expToString body)
171          (* end case*))          (* end case*))
172
173     end     end

Legend:
 Removed from v.5571 changed lines Added in v.5572