Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-opt/derivative.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-opt/derivative.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5285 - (download) (annotate)
Thu Aug 10 16:50:56 2017 UTC (3 years, 2 months ago) by cchiw
File size: 7245 byte(s)
added sgn to ein
(* derivative.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *
 * FIXME: this file needs documentation
 *)

structure Derivative : sig

    val mkApply : Ein.ein_exp * Ein.ein_exp -> Ein.ein_exp option

  end  = struct

    structure E = Ein

    fun err str=raise Fail (String.concat["Ill-formed EIN Operator: ", str])

    fun mkAdd exps = E.Opn(E.Add, exps)
    fun mkSub (e1, e2) = E.Op2(E.Sub, e1, e2)
    fun mkProd exps = E.Opn(E.Prod, exps)
    fun mkDiv (e1, e2) = E.Op2(E.Div, e1, e2)
    fun mkNeg e = E.Op1(E.Neg, e)

    fun filterProd args = (case EinFilter.mkProd args
           of SOME e => e
            | NONE => mkProd args
          (* end case *))

    fun rewriteProd [a] = a
      | rewriteProd exps = E.Opn(E.Prod, exps)

  (* chain rule *)
    fun prodAppPartial ([], _) = err "Empty App Partial"
      | prodAppPartial ([e1], p0) = E.Apply(p0, e1)
      | prodAppPartial (e::es, p0) = let
          val l = prodAppPartial (es, p0)
          val e2' = filterProd [e, l]
          val e1' = filterProd (es @ [E.Apply(p0, e)])
          in
            mkAdd[e1', e2']
          end

    fun applyOp1 (op1, e1, dx) = let
          val d0::dn = dx
          val px = E.Partial dx
          val inner = E.Apply(E.Partial[d0], e1)
          val square = mkProd [e1, e1]
          val one = E.Const 1
          val e2 = mkDiv(one, E.Op1(E.Sqrt, mkSub(one, square)))
          fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
          in
            case op1
             of E.Neg => mkNeg(E.Apply(px, e1))
              | E.Exp => iterDn (mkProd [inner, E.Op1(E.Exp, e1)])
              | E.Sqrt => let
                  val half = mkDiv (E.Const 1, E.Const 2)
                  val e3 = mkDiv (inner, E.Op1(op1, e1))
                  in
                    case dn
                     of [] => mkProd [half, e3]
                      | _ => mkProd [half, E.Apply(E.Partial dn, e3)]
                    (* end case *)
                  end
              | E.Cosine => iterDn (mkProd [mkNeg (E.Op1(E.Sine, e1)), inner])
              | E.ArcCosine => iterDn (mkProd [mkNeg e2, inner])
              | E.Sine => iterDn (mkProd [E.Op1(E.Cosine, e1), inner])
              | E.ArcSine => iterDn (mkProd [e2, inner])
              | E.Tangent =>
                  iterDn (mkProd [mkDiv(one, mkProd[E.Op1(E.Cosine, e1), E.Op1(E.Cosine, e1)]), inner])
              | E.ArcTangent =>
                  iterDn (mkProd [mkDiv(one, mkAdd[one, square]), inner])
              | E.PowInt n => iterDn (mkProd [E.Const n, E.Op1(E.PowInt(n-1), e1), inner])
              | E.Abs => iterDn  (mkProd [inner, E.Op1(E.Sgn,e1)]) (*fix here*)
              | E.Sgn => E.Const 0
            (* end case *)
          end

    fun applyOp2 (op2, e1, e2, dx) = let
          val (d0::dn) = dx
          val p0 = E.Partial [d0]
          val inner1 = E.Apply(E.Partial[d0], e1)
          val inner2 = E.Apply(E.Partial[d0], e2)
          val zero = E.Const 0
          fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
          in
            case op2
             of E.Sub => mkSub (inner1, inner2)
              | E.Div => (case (e1, e2)
                   of (_, E.Const e2) => mkDiv (inner1, E.Const e2)
                    | (E.Const 1, _) => (case EinFilter.partitionField [e2]
                         of (_, []) => zero
                          | (pre, h) => let (* Quotient Rule *)
                              val h' = E.Apply(p0, rewriteProd h)
                              val num = mkProd [E.Const ~1, h']
                              in
                                iterDn (mkDiv (num, mkProd (pre @ h @ h)))
                              end
                        (* end case *))
                    | (E.Const c, _) => (case EinFilter.partitionField [e2]
                         of (_, []) => zero
                          | (pre, h) => let (* Quotient Rule *)
                              val h' = E.Apply(p0, rewriteProd h)
                              val num = mkNeg (mkProd [E.Const c, h'])
                              in
                                iterDn (mkDiv (num, mkProd (pre@h@h)))
                              end
                        (* end case *))
                    | _ => (case EinFilter.partitionField [e2]
                         of (_, []) => mkDiv (inner1, e2) (* Division by a real *)
                          | (pre, h) => let (* Quotient Rule *)
                              val g' = inner1
                              val h' = E.Apply(p0, rewriteProd h)
                              val num = mkSub (mkProd (g' :: h), mkProd[e1, h'])
                              in
                                iterDn (mkDiv (num, mkProd (pre@h@h)))
                              end
                        (* end case *))
                  (* end case *))
            (* end case *)
          end

    fun applyOpn (opn, es, dx) = let
          val (d0::dn) = dx
          val p0 = E.Partial [d0]
          fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
          in
            case opn
             of E.Add => mkAdd (List.map (fn a => E.Apply(E.Partial dx, a)) es)
              | E.Prod => let
                  val (pre, post) = EinFilter.partitionField es
                  in
                    case post
                     of [] => E.Const 0 (* no fields in expression *)
                      | _ => iterDn (filterProd (pre @ [prodAppPartial (post, p0)]))
                    (* end case *)
                  end
            (* end case *)
          end

  (* rewrite Apply nodes *)
    fun mkApply (px as E.Partial dx, e) = let
          val (d0::dn) = dx
          val p0 = E.Partial[d0]
          fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
          val zero = E.Const 0
          in
            case e
             of E.Const _ => SOME zero
              | E.ConstR _ => SOME zero
              | E.Tensor _ => err "Tensor without Lift"
              | E.Zero _ => err "Zero without Lift"
              | E.Delta _ => SOME zero
              | E.Epsilon _ => SOME zero
              | E.Eps2 _ => SOME zero
              | E.Field _ => NONE
              | E.Lift _ => SOME zero
              | E.Conv(v, alpha, h, d2) => SOME(E.Conv(v, alpha, h, d2@dx))
              | E.Partial _ => err("Apply of Partial")
              | E.Apply(E.Partial d2, e2) => SOME(E.Apply(E.Partial(dx@d2), e2))
              | E.Apply _ => err "Apply of non-Partial expression"
              | E.Probe _ => err "Apply of Probe"
              | E.Value _ => err "Value used before expand"
              | E.Img _ => err "Probe used before expand"
              | E.Krn _ => err "Krn used before expand"
              | E.Sum(sx, e1) => SOME(E.Sum(sx, E.Apply(px, e1)))
              | E.Op1(op1, e1) => SOME(applyOp1(op1, e1, dx))
              | E.Op2(op2, e1, e2) => SOME(applyOp2(op2, e1, e2, dx))
              | E.Op3(op3, e1, e2, e3) => SOME zero (*assume clamp is not lifted*)
              | E.Opn(opn, es) => SOME(applyOpn(opn, es, dx))
            (* end case *)
          end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0