Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/ein/rewrite.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/ein/rewrite.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2383 - (view) (download)

1 : cchiw 2383 (* rewrite.sml
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure Rewrite : sig
8 :    
9 :     datatype arg
10 :     = Var of Var.var
11 :     | App of Ein.ein * Var.var list
12 :    
13 :     val evalEinApp : Ein.ein * arg list -> Ein.ein * Var.var list
14 :    
15 :     end = struct
16 :    
17 :     structure VarMap = Var.Map
18 :    
19 :     fun mkAdd es = (
20 :     case List.filter
21 :     (fn
22 :     (Ein.Const x) => Real.!=(x, 0.0)
23 :     | _ => true) es
24 :     of [] => Ein.Const 0.0
25 :     | [e] => e
26 :     | es => Ein.Add es
27 :     (* end case *))
28 :    
29 :     datatype arg
30 :     = Var of Var.var
31 :     | App of Ein.ein * Var.var list
32 :    
33 :     (* apply a substitution to an ein_exp ID*)
34 :     fun instantiateIdx body ids = let
35 :     val subst = Vector.fromList ids
36 :     fun substIdx id = Vector.sub(subst, id)
37 :     handle ex => (print(concat["substIdx ([|", String.concatWith "," (List.map Int.toString ids), "|],", Int.toString id, ")\n"]); raise ex)
38 :     fun apply e = (case e
39 :     of Ein.Const _ => e
40 :     | Ein.Tensor(id, mx) => Ein.Tensor(id, List.map substIdx mx)
41 :     | Ein.Field(id, mx) => Ein.Field(id, List.map substIdx mx)
42 :     | Ein.Add es => mkAdd(List.map apply es)
43 :    
44 :     | Ein.Sum(c,esum)=> Ein.Sum(c, esum)
45 :     | Ein.Prod es => Ein.Prod(List.map apply es)
46 :     | Ein.Sub(e1, e2) => Ein.Sub(apply e1, apply e2)
47 :     | Ein.Neg e => Ein.Neg(apply e)
48 :     | Ein.Delta(i, j) => Ein.Delta(substIdx i, substIdx j)
49 :     | Ein.Epsilon(i, j, k) => Ein.Epsilon(substIdx i, substIdx j, substIdx k)
50 :     | Ein.Conv(v, dx, h, i) => Ein.Conv(v,dx, h, substIdx i)
51 :     | Ein.Partial mx => Ein.Partial mx
52 :     | Ein.Probe(e, id) => Ein.Probe(apply e, id)
53 :     | Ein.Inside(e, id) => Ein.Inside(apply e, id)
54 :     | Ein.Apply(e1, e2)=> Ein.Apply(apply e1, apply e2)
55 :     (* end case *))
56 :     in
57 :     apply body
58 :     end
59 :    
60 :     datatype subst = S of {
61 :     tSub : (Ein.multiindex -> Ein.ein_exp) array, (* mapping from tensor ID to instantiation *)
62 :     fSub : (Ein.multiindex -> Ein.ein_exp) array (* mapping from field ID to instantiation *)
63 :     }
64 :    
65 :     fun newSubst (nTens, nFlds) = S{
66 :     tSub = Array.array(nTens, fn _ => raise Fail "undefined tensor"),
67 :     fSub = Array.array(nFlds, fn _ => raise Fail "undefined field")
68 :     }
69 :     fun bindTensor (S{tSub, ...}, id, f) = Array.update (tSub, id, f)
70 :     fun bindField (S{fSub, ...}, id, f) = Array.update (fSub, id, f)
71 :     fun substTensor (S{tSub, ...}, id, mx) = Array.sub (tSub, id) mx
72 :     handle ex => (print(concat["substTensor(_, ", Int.toString id, ", [|", String.concatWith "," (List.map Int.toString mx), "|])\n"]); raise ex)
73 :     fun substField (S{fSub, ...}, id, mx) = Array.sub (fSub, id) mx
74 :    
75 :     (* apply a substitution to an ein_exp term *)
76 :     fun applySubst (subst, e) = let
77 :     fun apply e = (case e
78 :     of Ein.Const _ => e
79 :     | Ein.Tensor(id, mx) => substTensor (subst, id, mx)
80 :     | Ein.Field(id, mx) => substField (subst, id, mx)
81 :     | Ein.Add es => Ein.Add(List.map apply es)
82 :     | Ein.Sum (c,esum)=>Ein.Sum(c, apply esum)
83 :     | Ein.Prod es => Ein.Prod(List.map apply es)
84 :     | Ein.Sub(e1, e2) => Ein.Sub(apply e1, apply e2)
85 :     | Ein.Neg e => Ein.Neg(apply e)
86 :     | Ein.Delta _ => e
87 :     | Ein.Epsilon _ => e
88 :     | Ein.Conv _ => e
89 :     | Ein.Partial _ => e
90 :     | Ein.Probe(e, id) => Ein.Probe(apply e, id)
91 :     | Ein.Inside(e, id) => Ein.Inside(apply e, id)
92 :     | Ein.Apply _=>e (*newbie? ???*)
93 :     (* end case *))
94 :     in
95 :     apply e
96 :     end
97 :    
98 :     (* NOTE: if all of the arguments are distinct variables, this function is the identity *)
99 :    
100 :     fun evalEinApp (Ein.EIN{params, index, body}, args : arg list) = let
101 :     fun renameVars ([], [], nT, nF, uniqueArgs) = (nT, nF, List.rev uniqueArgs)
102 :     | renameVars (param::params, arg::args, nT, nF, uniqueArgs) = let
103 :     fun isUnique (x, uArgs) = not(List.exists (fn (y, _, _) => Var.same(x, y)) uArgs)
104 :     fun continue (nT, nF, uniqueArgs) = renameVars (params, args, nT, nF, uniqueArgs)
105 :     fun doVar (p, x, nT, nF, uArgs) = (case p
106 :     of Ein.TEN =>
107 :     if isUnique (x, uArgs)
108 :     then (nT+1, nF, (x, p, nT)::uArgs)
109 :     else (nT, nF, uniqueArgs)
110 :     | Ein.FLD =>
111 :     if isUnique (x, uArgs)
112 :     then (nT, nF+1, (x, p, nF)::uArgs)
113 :     else (nT, nF, uniqueArgs)
114 :     (* end case *))
115 :     in
116 :     case arg
117 :     of (Var x) => continue(doVar(param, x, nT, nF, uniqueArgs))
118 :     | (App(Ein.EIN{params=ps, ...}, xs)) => let
119 :     fun lp ([], [], nT, nF, uniqueArgs) = continue(nT, nF, uniqueArgs)
120 :     | lp (p::ps, x::xs, nT, nF, uniqueArgs) = let
121 :     val (nT, nF, uniqueArgs) = doVar(p, x, nT, nF, uniqueArgs)
122 :     in
123 :     lp (ps, xs, nT, nF, uniqueArgs)
124 :     end
125 :     | lp _ = raise Fail "param/arg arity mismatch"
126 :     in
127 :     lp (ps, xs, nT, nF, uniqueArgs)
128 :     end
129 :     (* end case *)
130 :     end
131 :     val (nT, nF, uniqueArgs) = renameVars(params, args, 0, 0, [])
132 :     (* build a map from unique argument variables to (kind, id) pairs *)
133 :     val vMap = List.foldl
134 :     (fn ((x, k, id), vMap) => VarMap.insert(vMap, x, (k, id)))
135 :     VarMap.empty uniqueArgs
136 :     (* allocate the top-level substitution *)
137 :     val subst = newSubst (nT, nF)
138 :     (* add a mapping for a variable argument to a substitution *)
139 :     fun bindVar (subst, x, nT, nF) = (case VarMap.find(vMap, x)
140 :     of SOME(Ein.TEN, id) => (
141 :     bindTensor(subst, nT, fn mx => Ein.Tensor(id, mx));
142 :     (nT+1, nF))
143 :     | SOME(Ein.FLD, id) => (
144 :     bindField(subst, nF, fn mx => Ein.Field(id, mx));
145 :     (nT, nF+1))
146 :     | NONE => raise Fail(concat["undefined argument variable \"", Var.name x, "\""])
147 :     (* end case *))
148 :     (* rewrite arguments and intialize the top-level substitution *)
149 :     fun rewriteArgs ([], [], _, _) = ()
150 :     | rewriteArgs (_::params, (Var x)::args, nT, nF) = let
151 :     val (nT, nF) = bindVar (subst, x, nT, nF)
152 :     in
153 :     rewriteArgs (params, args, nT, nF)
154 :     end
155 :     | rewriteArgs (p::ps, App(Ein.EIN{params, body, ...}, xs)::args, nT, nF) = let
156 :     (* rewrite the argument body first *)
157 :     val body = let
158 :     (* allocate a new substitution for the argument body *)
159 :     val subst = let
160 :     fun f (Ein.TEN, (nT, nF)) = (nT+1, nF)
161 :     | f (Ein.FLD, (nT, nF)) = (nT, nF+1)
162 :     in
163 :     newSubst (List.foldl f (0, 0) params)
164 :     end
165 :     (* initialize the substitution *)
166 :     fun doVars ([], _, _) = ()
167 :     | doVars (x::xs, nT, nF) = let
168 :     val (nT, nF) = bindVar (subst, x, nT, nF)
169 :     in
170 :     doVars (xs, nT, nF)
171 :     end
172 :     in
173 :     doVars (xs, 0, 0);
174 :     applySubst (subst, body)
175 :     end
176 :     fun mkBody mx = instantiateIdx body mx
177 :     in
178 :     case p
179 :     of Ein.TEN => (
180 :     bindTensor (subst, nT, instantiateIdx body);
181 :     rewriteArgs (ps, args, nT+1, nF))
182 :     | Ein.FLD => (
183 :     bindField (subst, nT, instantiateIdx body);
184 :     rewriteArgs (ps, args, nT, nF+1))
185 :     (* end case *)
186 :     end
187 :     val _ = rewriteArgs (params, args, 0, 0)
188 :     in (
189 :     Ein.EIN{
190 :     params = List.map #2 uniqueArgs,
191 :     index = index,
192 :     body = applySubst (subst, body)
193 :     },
194 :     List.map #1 uniqueArgs
195 :     ) end
196 :    
197 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0