(* derivative.sml * * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu) * * COPYRIGHT (c) 2015 The University of Chicago * All rights reserved. *) structure Derivative : sig val mkApply : Ein.ein_exp * Ein.ein_exp -> Ein.ein_exp option end = struct structure E = Ein fun err str=raise Fail (String.concat["Ill-formed EIN Operator: ", str]) fun mkAdd exps = E.Opn(E.Add, exps) fun mkSub (e1, e2) = E.Op2(E.Sub, e1, e2) fun mkProd exps = E.Opn(E.Prod, exps) fun mkDiv (e1, e2) = E.Op2(E.Div, e1, e2) fun mkNeg e = E.Op1(E.Neg, e) fun filterProd args = (case EinFilter.mkProd args of SOME e => e | NONE => mkProd args (* end case *)) fun rewriteProd [a] = a | rewriteProd exps = E.Opn(E.Prod, exps) (* chain rule *) fun prodAppPartial ([], _) = err "Empty App Partial" | prodAppPartial ([e1], p0) = E.Apply(p0, e1) | prodAppPartial (e::es, p0) = let val l = prodAppPartial (es, p0) val e2' = filterProd [e, l] val e1' = filterProd (es @ [E.Apply(p0, e)]) in mkAdd[e1', e2'] end fun applyOp1 (op1, e1, dx) = let val d0::dn = dx val px = E.Partial dx val inner = E.Apply(E.Partial[d0], e1) val square = mkProd [e1, e1] val one = E.Const 1 val e2 = mkDiv(one, E.Op1(E.Sqrt, mkSub(one, square))) fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2) in case op1 of E.Neg => mkNeg(E.Apply(px, e1)) | E.Exp => iterDn (mkProd [inner, E.Op1(E.Exp, e1)]) | E.Sqrt => let val half = mkDiv (E.Const 1, E.Const 2) val e3 = mkDiv (inner, E.Op1(op1, e1)) in case dn of [] => mkProd [half, e3] | _ => mkProd [half, E.Apply(E.Partial dn, e3)] (* end case *) end | E.Cosine => iterDn (mkProd [mkNeg (E.Op1(E.Sine, e1)), inner]) | E.ArcCosine => iterDn (mkProd [mkNeg e2, inner]) | E.Sine => iterDn (mkProd [E.Op1(E.Cosine, e1), inner]) | E.ArcSine => iterDn (mkProd [e2, inner]) | E.Tangent => iterDn (mkProd [mkDiv(one, mkProd[E.Op1(E.Cosine, e1), E.Op1(E.Cosine, e1)]), inner]) | E.ArcTangent => iterDn (mkProd [mkDiv(one, mkAdd[one, square]), inner]) | E.PowInt n => iterDn (mkProd [E.Const n, E.Op1(E.PowInt(n-1), e1), inner]) | E.PowReal r => raise Fail "applyOp1: PowReal" (* FIXME *) | E.PowEmb(sx1, n1) => iterDn (mkDiv (E.Sum(sx1, mkProd [e1, inner]), E.Op1(E.PowEmb(sx1, n1), e1))) (* end case *) end fun applyOp2 (op2, e1, e2, dx) = let val (d0::dn) = dx val p0 = E.Partial [d0] val inner1 = E.Apply(E.Partial[d0], e1) val inner2 = E.Apply(E.Partial[d0], e2) val zero = E.Const 0 fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2) in case op2 of E.Sub => mkSub (inner1, inner2) | E.Div => (case (e1, e2) of (_, E.Const e2) => mkDiv (inner1, E.Const e2) | (E.Const 1, _) => (case EinFilter.partitionField [e2] of (_, []) => zero | (pre, h) => let (* Quotient Rule *) val h' = E.Apply(p0, rewriteProd h) val num = mkProd [E.Const ~1, h'] in iterDn (mkDiv (num, mkProd (pre @ h @ h))) end (* end case *)) | (E.Const c, _) => (case EinFilter.partitionField [e2] of (_, []) => zero | (pre, h) => let (* Quotient Rule *) val h' = E.Apply(p0, rewriteProd h) val num = mkNeg (mkProd [E.Const c, h']) in iterDn (mkDiv (num, mkProd (pre@h@h))) end (* end case *)) | _ => (case EinFilter.partitionField [e2] of (_, []) => mkDiv (inner1, e2) (* Division by a real *) | (pre, h) => let (* Quotient Rule *) val g' = inner1 val h' = E.Apply(p0, rewriteProd h) val num = mkSub (mkProd (g' :: h), mkProd[e1, h']) in iterDn (mkDiv (num, mkProd (pre@h@h))) end (* end case *)) (* end case *)) (* end case *) end fun applyOpn (opn, es, dx) = let val (d0::dn) = dx val p0 = E.Partial [d0] fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2) in case opn of E.Add => mkAdd (List.map (fn a => E.Apply(E.Partial dx, a)) es) | E.Prod => let val (pre, post) = EinFilter.partitionField es in case post of [] => E.Const 0 (* no fields in expression *) | _ => iterDn (filterProd (pre @ [prodAppPartial (post, p0)])) (* end case *) end (* end case *) end (* rewrite Apply nodes*) fun mkApply (px as E.Partial dx, e) = let val (d0::dn) = dx val p0 = E.Partial[d0] fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2) val zero = E.Const 0 in case e of E.Const _ => SOME zero | E.ConstR _ => SOME zero | E.Tensor _ => err "Tensor without Lift" | E.Delta _ => SOME zero | E.Epsilon _ => SOME zero | E.Eps2 _ => SOME zero | E.Field _ => NONE | E.Lift _ => SOME zero | E.Conv(v, alpha, h, d2) => SOME(E.Conv(v, alpha, h, d2@dx)) | E.Partial _ => err("Apply of Partial") | E.Apply(E.Partial d2, e2) => SOME(E.Apply(E.Partial(dx@d2), e2)) | E.Apply _ => err "Apply of non-Partial expression" | E.Probe _ => err "Apply of Probe" | E.Value _ => err "Value used before expand" | E.Img _ => err "Probe used before expand" | E.Krn _ => err "Krn used before expand" | E.Sum(sx, e1) => SOME(E.Sum(sx, E.Apply(px, e1))) | E.Op1(op1, e1) => SOME(applyOp1(op1, e1, dx)) | E.Op2(op2, e1, e2) => SOME(applyOp2(op2, e1, e2, dx)) | E.Opn(opn, es) => SOME(applyOpn(opn, es, dx)) (* end case *) end end
Click to toggle
does not end with </html> tag
does not end with </body> tag
The output has ended thus: dx)) | E.Opn(opn, es) => SOME(applyOpn(opn, es, dx)) (* end case *) end end