Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-opt/derivative.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-opt/derivative.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3978 - (view) (download)

1 : jhr 3515 (* derivative.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2015 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure Derivative : sig
10 :    
11 :     val mkApply : Ein.ein_exp * Ein.ein_exp -> Ein.ein_exp option
12 :    
13 :     end = struct
14 :    
15 :     structure E = Ein
16 :    
17 :     fun err str=raise Fail (String.concat["Ill-formed EIN Operator: ", str])
18 :    
19 :     fun mkAdd exps = E.Opn(E.Add, exps)
20 :     fun mkSub (e1, e2) = E.Op2(E.Sub, e1, e2)
21 :     fun mkProd exps = E.Opn(E.Prod, exps)
22 :     fun mkDiv (e1, e2) = E.Op2(E.Div, e1, e2)
23 :     fun mkNeg e = E.Op1(E.Neg, e)
24 :    
25 : jhr 3520 fun filterProd args = (case EinFilter.mkProd args
26 : jhr 3515 of SOME e => e
27 :     | NONE => mkProd args
28 :     (* end case *))
29 :    
30 :     fun rewriteProd [a] = a
31 :     | rewriteProd exps = E.Opn(E.Prod, exps)
32 :    
33 :     (* chain rule *)
34 :     fun prodAppPartial ([], _) = err "Empty App Partial"
35 :     | prodAppPartial ([e1], p0) = E.Apply(p0, e1)
36 :     | prodAppPartial (e::es, p0) = let
37 :     val l = prodAppPartial (es, p0)
38 :     val e2' = filterProd [e, l]
39 :     val e1' = filterProd (es @ [E.Apply(p0, e)])
40 :     in
41 :     mkAdd[e1', e2']
42 :     end
43 :    
44 :     fun applyOp1 (op1, e1, dx) = let
45 :     val d0::dn = dx
46 :     val px = E.Partial dx
47 :     val inner = E.Apply(E.Partial[d0], e1)
48 :     val square = mkProd [e1, e1]
49 :     val one = E.Const 1
50 :     val e2 = mkDiv(one, E.Op1(E.Sqrt, mkSub(one, square)))
51 :     fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
52 :     in
53 :     case op1
54 :     of E.Neg => mkNeg(E.Apply(px, e1))
55 :     | E.Exp => iterDn (mkProd [inner, E.Op1(E.Exp, e1)])
56 :     | E.Sqrt => let
57 :     val half = mkDiv (E.Const 1, E.Const 2)
58 :     val e3 = mkDiv (inner, E.Op1(op1, e1))
59 :     in
60 :     case dn
61 :     of [] => mkProd [half, e3]
62 :     | _ => mkProd [half, E.Apply(E.Partial dn, e3)]
63 :     (* end case *)
64 :     end
65 :     | E.Cosine => iterDn (mkProd [mkNeg (E.Op1(E.Sine, e1)), inner])
66 :     | E.ArcCosine => iterDn (mkProd [mkNeg e2, inner])
67 :     | E.Sine => iterDn (mkProd [E.Op1(E.Cosine, e1), inner])
68 :     | E.ArcSine => iterDn (mkProd [e2, inner])
69 :     | E.Tangent =>
70 :     iterDn (mkProd [mkDiv(one, mkProd[E.Op1(E.Cosine, e1), E.Op1(E.Cosine, e1)]), inner])
71 :     | E.ArcTangent =>
72 :     iterDn (mkProd [mkDiv(one, mkAdd[one, square]), inner])
73 :     | E.PowInt n => iterDn (mkProd [E.Const n, E.Op1(E.PowInt(n-1), e1), inner])
74 :     (* end case *)
75 :     end
76 :    
77 :     fun applyOp2 (op2, e1, e2, dx) = let
78 :     val (d0::dn) = dx
79 :     val p0 = E.Partial [d0]
80 :     val inner1 = E.Apply(E.Partial[d0], e1)
81 :     val inner2 = E.Apply(E.Partial[d0], e2)
82 :     val zero = E.Const 0
83 :     fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
84 :     in
85 :     case op2
86 :     of E.Sub => mkSub (inner1, inner2)
87 :     | E.Div => (case (e1, e2)
88 :     of (_, E.Const e2) => mkDiv (inner1, E.Const e2)
89 : jhr 3520 | (E.Const 1, _) => (case EinFilter.partitionField [e2]
90 : jhr 3515 of (_, []) => zero
91 :     | (pre, h) => let (* Quotient Rule *)
92 :     val h' = E.Apply(p0, rewriteProd h)
93 :     val num = mkProd [E.Const ~1, h']
94 :     in
95 :     iterDn (mkDiv (num, mkProd (pre @ h @ h)))
96 :     end
97 :     (* end case *))
98 : jhr 3520 | (E.Const c, _) => (case EinFilter.partitionField [e2]
99 : jhr 3515 of (_, []) => zero
100 :     | (pre, h) => let (* Quotient Rule *)
101 :     val h' = E.Apply(p0, rewriteProd h)
102 :     val num = mkNeg (mkProd [E.Const c, h'])
103 :     in
104 :     iterDn (mkDiv (num, mkProd (pre@h@h)))
105 :     end
106 :     (* end case *))
107 : jhr 3520 | _ => (case EinFilter.partitionField [e2]
108 : jhr 3515 of (_, []) => mkDiv (inner1, e2) (* Division by a real *)
109 :     | (pre, h) => let (* Quotient Rule *)
110 :     val g' = inner1
111 :     val h' = E.Apply(p0, rewriteProd h)
112 :     val num = mkSub (mkProd (g' :: h), mkProd[e1, h'])
113 :     in
114 :     iterDn (mkDiv (num, mkProd (pre@h@h)))
115 :     end
116 :     (* end case *))
117 :     (* end case *))
118 :     (* end case *)
119 :     end
120 :    
121 :     fun applyOpn (opn, es, dx) = let
122 :     val (d0::dn) = dx
123 :     val p0 = E.Partial [d0]
124 :     fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
125 :     in
126 :     case opn
127 :     of E.Add => mkAdd (List.map (fn a => E.Apply(E.Partial dx, a)) es)
128 :     | E.Prod => let
129 : jhr 3520 val (pre, post) = EinFilter.partitionField es
130 : jhr 3515 in
131 :     case post
132 :     of [] => E.Const 0 (* no fields in expression *)
133 :     | _ => iterDn (filterProd (pre @ [prodAppPartial (post, p0)]))
134 :     (* end case *)
135 :     end
136 :     (* end case *)
137 :     end
138 :    
139 :     (* rewrite Apply nodes*)
140 :     fun mkApply (px as E.Partial dx, e) = let
141 :     val (d0::dn) = dx
142 :     val p0 = E.Partial[d0]
143 :     fun iterDn e2 = if null dn then e2 else E.Apply(E.Partial dn, e2)
144 :     val zero = E.Const 0
145 :     in
146 :     case e
147 :     of E.Const _ => SOME zero
148 :     | E.ConstR _ => SOME zero
149 :     | E.Tensor _ => err "Tensor without Lift"
150 :     | E.Delta _ => SOME zero
151 :     | E.Epsilon _ => SOME zero
152 :     | E.Eps2 _ => SOME zero
153 :     | E.Field _ => NONE
154 :     | E.Lift _ => SOME zero
155 :     | E.Conv(v, alpha, h, d2) => SOME(E.Conv(v, alpha, h, d2@dx))
156 :     | E.Partial _ => err("Apply of Partial")
157 :     | E.Apply(E.Partial d2, e2) => SOME(E.Apply(E.Partial(dx@d2), e2))
158 :     | E.Apply _ => err "Apply of non-Partial expression"
159 :     | E.Probe _ => err "Apply of Probe"
160 :     | E.Value _ => err "Value used before expand"
161 :     | E.Img _ => err "Probe used before expand"
162 :     | E.Krn _ => err "Krn used before expand"
163 :     | E.Sum(sx, e1) => SOME(E.Sum(sx, E.Apply(px, e1)))
164 :     | E.Op1(op1, e1) => SOME(applyOp1(op1, e1, dx))
165 :     | E.Op2(op2, e1, e2) => SOME(applyOp2(op2, e1, e2, dx))
166 :     | E.Opn(opn, es) => SOME(applyOpn(opn, es, dx))
167 :     (* end case *)
168 :     end
169 :    
170 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0