Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-opt/derivative.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-opt/derivative.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3978 - (download) (annotate)
Wed Jun 15 19:07:40 2016 UTC (3 years, 3 months ago) by cchiw
File size: 5669 byte(s)
changed ein expressions, rewrote matchEps, added translation
(* derivative.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure Derivative : sig

    val mkApply : Ein.ein_exp * Ein.ein_exp -> Ein.ein_exp option

  end  = struct

    structure E = Ein

    fun err str=raise Fail (String.concat["Ill-formed EIN Operator: ", str])

    fun mkAdd exps = E.Opn(E.Add, exps)
    fun mkSub (e1, e2) = E.Op2(E.Sub, e1, e2)
    fun mkProd exps = E.Opn(E.Prod, exps)
    fun mkDiv (e1, e2) = E.Op2(E.Div, e1, e2)
    fun mkNeg e = E.Op1(E.Neg, e)

    fun filterProd args = (case EinFilter.mkProd args
	   of SOME e => e
	    | NONE => mkProd args
	  (* end case *))

    fun rewriteProd [a] = a
      | rewriteProd exps = E.Opn(E.Prod, exps)

  (* chain rule *)
    fun prodAppPartial ([], _) = err "Empty App Partial"
      | prodAppPartial ([e1], p0) = E.Apply(p0, e1)
      | prodAppPartial (e::es, p0) = let
	  val l = prodAppPartial (es, p0)
	  val e2' = filterProd [e, l]
	  val e1' = filterProd (es @ [E.Apply(p0, e)])
	  in
	    mkAdd[e1', e2']
	  end

    fun applyOp1 (op1, e1, dx) = let
	  val d0::dn = dx
	  val px = E.Partial dx
	  val inner = E.Apply(E.Partial[d0], e1)
	  val square = mkProd [e1, e1]
	  val one = E.Const 1
	  val e2 = mkDiv(one, E.Op1(E.Sqrt, mkSub(one, square)))
	  fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
	  in 
	    case op1
	     of E.Neg => mkNeg(E.Apply(px, e1))
	      | E.Exp => iterDn (mkProd [inner, E.Op1(E.Exp, e1)])
	      | E.Sqrt => let
		  val half = mkDiv (E.Const 1, E.Const 2)
		  val e3 = mkDiv (inner, E.Op1(op1, e1))
		  in
		    case dn
		     of [] => mkProd [half, e3]
		      | _ => mkProd [half, E.Apply(E.Partial dn, e3)]
                    (* end case *)
		  end
	      | E.Cosine => iterDn (mkProd [mkNeg (E.Op1(E.Sine, e1)), inner])
	      | E.ArcCosine => iterDn (mkProd [mkNeg e2, inner])
	      | E.Sine => iterDn (mkProd [E.Op1(E.Cosine, e1), inner])
	      | E.ArcSine => iterDn (mkProd [e2, inner])
	      | E.Tangent =>
		  iterDn (mkProd [mkDiv(one, mkProd[E.Op1(E.Cosine, e1), E.Op1(E.Cosine, e1)]), inner])
	      | E.ArcTangent =>
		  iterDn (mkProd [mkDiv(one, mkAdd[one, square]), inner])
	      | E.PowInt n => iterDn (mkProd [E.Const n, E.Op1(E.PowInt(n-1), e1), inner])
	  (* end case *)
        end

    fun applyOp2 (op2, e1, e2, dx) = let
	  val (d0::dn) = dx
	  val p0 = E.Partial [d0]
	  val inner1 = E.Apply(E.Partial[d0], e1)
	  val inner2 = E.Apply(E.Partial[d0], e2)
	  val zero = E.Const 0
	  fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
          in
	    case op2
             of E.Sub => mkSub (inner1, inner2)
	      | E.Div => (case (e1, e2)
                   of (_, E.Const e2) => mkDiv (inner1, E.Const e2)
		    | (E.Const 1, _) => (case EinFilter.partitionField [e2]
			 of (_, []) => zero
			  | (pre, h) => let (* Quotient Rule *)
			      val h' = E.Apply(p0, rewriteProd h)
			      val num = mkProd [E.Const ~1, h']
			      in
				iterDn (mkDiv (num, mkProd (pre @ h @ h)))
			      end
			(* end case *))
		    | (E.Const c, _) => (case EinFilter.partitionField [e2]
			 of (_, []) => zero
			  | (pre, h) => let (* Quotient Rule *)
			      val h' = E.Apply(p0, rewriteProd h)
			      val num = mkNeg (mkProd [E.Const c, h'])
			      in
				iterDn (mkDiv (num, mkProd (pre@h@h)))
			      end
			(* end case *))
		    | _ => (case EinFilter.partitionField [e2]
			 of (_, []) => mkDiv (inner1, e2) (* Division by a real *)
			  | (pre, h) => let (* Quotient Rule *)
			      val g' = inner1
			      val h' = E.Apply(p0, rewriteProd h)
			      val num = mkSub (mkProd (g' :: h), mkProd[e1, h'])
			      in
				iterDn (mkDiv (num, mkProd (pre@h@h)))
			      end
			(* end case *))
		  (* end case *))
	    (* end case *)
          end

    fun applyOpn (opn, es, dx) = let
	  val (d0::dn) = dx
	  val p0 = E.Partial [d0]
	  fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
	  in
	    case opn
             of E.Add => mkAdd (List.map (fn a => E.Apply(E.Partial dx, a)) es)
	      | E.Prod => let
		  val (pre, post) = EinFilter.partitionField es
		  in
		    case post
                     of [] => E.Const 0 (* no fields in expression *)
                      | _ => iterDn (filterProd (pre @ [prodAppPartial (post, p0)]))
                    (* end case *)
		  end
            (* end case *)
	  end

    (* rewrite Apply nodes*)
    fun mkApply (px as E.Partial dx, e) = let
	  val (d0::dn) = dx
	  val p0 = E.Partial[d0]
	  fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
	  val zero = E.Const 0
          in
	    case e
	     of E.Const _ => SOME zero
	      | E.ConstR _ => SOME zero
	      | E.Tensor _ => err "Tensor without Lift"
	      | E.Delta _ => SOME zero
	      | E.Epsilon _ => SOME zero
	      | E.Eps2 _ => SOME zero
	      | E.Field _ => NONE
	      | E.Lift _ => SOME zero
	      | E.Conv(v, alpha, h, d2) => SOME(E.Conv(v, alpha, h, d2@dx))
	      | E.Partial _ => err("Apply of Partial")
	      | E.Apply(E.Partial d2, e2) => SOME(E.Apply(E.Partial(dx@d2), e2))
	      | E.Apply _ => err "Apply of non-Partial expression"
	      | E.Probe _ => err "Apply of Probe"
	      | E.Value _ => err "Value used before expand"
	      | E.Img _ => err "Probe used before expand"
	      | E.Krn _ => err "Krn used before expand"
	      | E.Sum(sx, e1) => SOME(E.Sum(sx, E.Apply(px, e1)))
	      | E.Op1(op1, e1) => SOME(applyOp1(op1, e1, dx))
	      | E.Op2(op2, e1, e2) => SOME(applyOp2(op2, e1, e2, dx))
	      | E.Opn(opn, es) => SOME(applyOpn(opn, es, dx))
            (* end case *)
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0