47 |
(* end case *) |
(* end case *) |
48 |
end |
end |
49 |
|
|
50 |
|
(* filter function shifts constant/greeks to outside product*) |
51 |
|
fun filter([],pre,post)=(pre,post) |
52 |
|
| filter(E.Const c::es, pre, post)=filter(es, pre@[E.Const c],post) |
53 |
|
| filter(E.Delta d::es,pre,post)=filter(es,pre@[E.Delta d],post) |
54 |
|
| filter(E.Value v::es, pre, post)=filter(es, pre@[E.Value v],post) |
55 |
|
| filter(E.Epsilon e::es, pre, post)=filter(es, pre@[E.Epsilon e],post) |
56 |
|
| filter(E.Tensor(id,[])::es, pre, post)=filter(es, pre@[E.Tensor(id,[])],post) |
57 |
|
| filter(E.Prod p::es, pre, post)=filter(p@es,pre,post) |
58 |
|
| filter(e::es, pre, post)= filter(es, pre, post@[e]) |
59 |
|
|
60 |
|
|
61 |
|
|
62 |
fun rmEpsIndex(_,_,[])=[] |
fun prodPartial ([e1],p1)= E.Prod[E.Partial p1,e1] |
63 |
| rmEpsIndex([],[],cs)=cs |
| prodPartial((e1::e2),p1)=let |
64 |
| rmEpsIndex([],m ,e1::cs)=[e1]@rmEpsIndex(m,[],cs) |
val l= prodPartial(e2,p1) |
65 |
| rmEpsIndex(i::ix,rest ,(E.V c)::cs)= |
val (_,e2')= mkProd[e1,l] |
66 |
if(i=c) then rmEpsIndex(rest@ix,[],cs) |
val (_,e1')=mkProd(e2@ [E.Partial p1, e1]) |
67 |
else rmEpsIndex(ix,rest@[i],(E.V c)::cs) |
in |
68 |
|
E.Add[e1',e2'] |
69 |
|
end |
70 |
|
|
71 |
|
fun prodAppPartial ([e1],p1)= E.Apply(E.Partial p1,e1) |
72 |
|
| prodAppPartial((e1::e2),p1)=let |
73 |
|
val l= prodAppPartial(e2,p1) |
74 |
|
val (_,e2')= mkProd[e1,l] |
75 |
|
val (_,e1')=mkProd(e2@ [E.Apply(E.Partial p1, e1)]) |
76 |
|
in |
77 |
|
E.Add[e1',e2'] |
78 |
|
end |
79 |
|
|
80 |
|
|
81 |
|
|
82 |
|
|
83 |
|
(*remove eps Index*) |
84 |
|
fun rmEpsIndex(i,[],rest)=rest |
85 |
|
| rmEpsIndex(i, (E.V c ,lb, ub)::es,rest)= |
86 |
|
if (i=c) then rest@es |
87 |
|
else rmEpsIndex(i, es, rest@[(E.V c, lb, ub)]) |
88 |
|
|
89 |
|
|
90 |
|
(*remove index variable from list*) |
91 |
|
fun rmIndex(_,_,[])=[] |
92 |
|
| rmIndex([],[],cs)=cs |
93 |
|
| rmIndex([],m ,e1::cs)=[e1]@rmIndex(m,[],cs) |
94 |
|
| rmIndex(i::ix,rest ,(c,lb,ub)::cs)= |
95 |
|
if(i=c) then rmIndex(rest@ix,[],cs) |
96 |
|
else rmIndex(ix,rest@[i],(c,lb,ub)::cs) |
97 |
|
|
98 |
|
|
99 |
|
|
100 |
|
|
101 |
(* Transform eps to deltas*) |
(* Transform eps to deltas*) |
108 |
|
|
109 |
(*remove index from original index list*) |
(*remove index from original index list*) |
110 |
|
|
111 |
val s'= rmEpsIndex([i,s,t,u,v],[],count) |
val s'= rmEpsIndex(i,count,[]) |
112 |
val s''=[E.V s, E.V t ,E.V u, E.V v] |
|
113 |
val deltas= E.Sub( |
val deltas= E.Sub( |
114 |
E.Sum(s'',E.Prod([E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)] @e3)), |
E.Prod([E.Delta(E.V s,E.V u), E.Delta(E.V t,E.V v)] @e3), |
115 |
E.Sum(s'',E.Prod([E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]@e3))) |
E.Prod([E.Delta(E.V s,E.V v), E.Delta(E.V t,E.V u)]@e3)) |
116 |
|
|
117 |
in (case (eps,es,s') |
in (case (eps,es,s') |
118 |
of ([],[],[]) =>(1,deltas) |
of ([],[],[]) =>(1,deltas) |
149 |
end |
end |
150 |
|
|
151 |
val (es,rest)=findeps([],e) |
val (es,rest)=findeps([],e) |
|
|
|
152 |
in |
in |
153 |
dist(es,[],rest) |
dist(es,[],rest) |
154 |
end |
end |
155 |
|
|
156 |
|
|
|
fun rmIndex(_,_,[])=[] |
|
|
| rmIndex([],[],cs)=cs |
|
|
| rmIndex([],m ,e1::cs)=[e1]@rmIndex(m,[],cs) |
|
|
| rmIndex(i::ix,rest ,c::cs)= |
|
|
if(i=c) then rmIndex(rest@ix,[],cs) |
|
|
else rmIndex(ix,rest@[i],c::cs) |
|
157 |
|
|
158 |
(* Apply deltas to tensors/fields*) |
(* Apply deltas to tensors/fields*) |
159 |
fun reduceDelta(E.Sum(c,E.Prod p))=let |
fun reduceDelta(E.Sum(c,E.Prod p))=let |
162 |
| findDeltas(dels,rest,E.Epsilon eps::es)=findDeltas(dels,rest@[E.Epsilon eps],es) |
| findDeltas(dels,rest,E.Epsilon eps::es)=findDeltas(dels,rest@[E.Epsilon eps],es) |
163 |
| findDeltas(dels,rest,es)= (dels,rest,es) |
| findDeltas(dels,rest,es)= (dels,rest,es) |
164 |
|
|
|
|
|
165 |
fun distribute(change,d,dels,[],done)=(change,dels@d,done) |
fun distribute(change,d,dels,[],done)=(change,dels@d,done) |
166 |
| distribute(change,[],[],e,done)=(change,[],done@e) |
| distribute(change,[],[],e,done)=(change,[],done@e) |
167 |
| distribute(change,E.Delta(i,j)::ds,dels,E.Tensor(id,[tx])::es,done)= |
| distribute(change,E.Delta(i,j)::ds,dels,E.Tensor(id,[tx])::es,done)= |
177 |
val index=rmIndex(change,[],c) |
val index=rmIndex(change,[],c) |
178 |
|
|
179 |
in |
in |
180 |
(change, E.Sum(index,E.Prod (eps@dels'@done))) |
(length change, E.Sum(index,E.Prod (eps@dels'@done))) |
181 |
end |
end |
182 |
|
|
183 |
|
|
184 |
fun mkApplySum(E.Apply(E.Partial d,E.Sum(c,e)))=(print "apply sum";case e |
(*Apply Sum*) |
185 |
|
fun mkApplySum(E.Apply(E.Partial d,E.Sum(c,e)))=(case e |
186 |
of E.Tensor(a,[])=>(1,E.Const 0.0) |
of E.Tensor(a,[])=>(1,E.Const 0.0) |
187 |
| E.Const _ =>(1,E.Const 0.0) |
| E.Const _ =>(1,E.Const 0.0) |
188 |
|
| E.Delta _ =>(1,E.Const 0.0) |
189 |
|
| E.Value _ =>(1,E.Const 0.0) |
190 |
|
| E.Epsilon _ =>(1,E.Const 0.0) |
191 |
|
|
192 |
| E.Add l => (1,E.Add(List.map (fn e => E.Apply(E.Partial d, E.Sum(c,e))) l)) |
| E.Add l => (1,E.Add(List.map (fn e => E.Apply(E.Partial d, E.Sum(c,e))) l)) |
193 |
| E.Sub(e2, e3) =>(1, E.Sub(E.Apply(E.Partial d, E.Sum(c,e2)), E.Apply(E.Partial d, E.Sum(c,e3)))) |
| E.Sub(e2, e3) =>(1, E.Sub(E.Apply(E.Partial d, E.Sum(c,e2)), E.Apply(E.Partial d, E.Sum(c,e3)))) |
194 |
|
| E.Conv (fid,alpha,tid, delta)=> let |
195 |
|
val e'=E.Conv(fid,alpha, tid, delta@d) |
196 |
|
in (1,E.Sum(c,e')) end |
197 |
| E.Prod [e1]=>(1,E.Apply(E.Partial d,E.Sum(c,e1))) |
| E.Prod [e1]=>(1,E.Apply(E.Partial d,E.Sum(c,e1))) |
198 |
| E.Prod(E.Tensor(a,[])::e1::[])=>(1,E.Prod[E.Tensor(a,[]),E.Apply(E.Partial d,E.Sum(c,e1))]) |
| E.Prod es=> let |
199 |
|
val (pre, post)= filter(es,[],[]) |
200 |
| E.Prod(E.Tensor(a,[])::e2)=>(1,E.Prod[E.Tensor(a,[]),E.Apply(E.Partial d,E.Sum(c,E.Prod e2))]) |
val x1= prodAppPartial(post,d) |
201 |
|
in (case x1 |
202 |
| E.Prod es=>(let |
of E.Add a=> (1,E.Add(List.map (fn e => E.Sum(c,E.Prod(pre@[e]))) a)) |
203 |
fun prod [e1] =E.Apply(E.Partial d,e1) |
| _ => (1,E.Sum(c, E.Prod(pre@[x1]))) |
204 |
| prod (E.Epsilon eps1::es) = (E.Apply(E.Partial d, E.Prod (E.Epsilon eps1::es))) |
(*end case*)) |
205 |
| prod (E.Delta e1::es) = (E.Apply(E.Partial d, E.Prod (E.Delta e1::es))) |
end |
|
| prod (E.Prod e1::es)=prod(e1@es) |
|
|
| prod(e1::e2)=(let |
|
|
val l= prod(e2) |
|
|
val (_, a)= mkProd[e1,l] |
|
|
val lr=e2 @[E.Apply(E.Partial d,e1)] |
|
|
val(_,b) =mkProd lr |
|
|
in E.Add[b,a] |
|
|
end) |
|
|
val chainrule=prod es |
|
|
in (1,E.Sum(c, chainrule)) end) |
|
206 |
|_=>(0,E.Apply(E.Partial d,E.Sum(c,e))) |
|_=>(0,E.Apply(E.Partial d,E.Sum(c,e))) |
207 |
(* end case*)) |
(* end case*)) |
208 |
|
|
209 |
fun mkApply2(E.Apply(E.Partial d,e))=(print "aa";case e |
(*Apply*) |
210 |
|
fun mkApply(E.Apply(E.Partial d,e))=(case e |
211 |
of E.Tensor(a,[])=>(1,E.Const 0.0) |
of E.Tensor(a,[])=>(1,E.Const 0.0) |
212 |
| E.Const _ =>(1,E.Const 0.0) |
| E.Const _ =>(1,E.Const 0.0) |
213 |
|
| E.Delta _ =>(1,E.Const 0.0) |
214 |
|
| E.Value _ =>(1,E.Const 0.0) |
215 |
|
| E.Epsilon _ =>(1,E.Const 0.0) |
216 |
|
| E.Conv (fid,alpha,tid, delta)=> let |
217 |
|
val e'=E.Conv(fid,alpha, tid, delta@d) |
218 |
|
in (1,e') end |
219 |
| E.Add l => (1,E.Add(List.map (fn e => E.Apply(E.Partial d, e)) l)) |
| E.Add l => (1,E.Add(List.map (fn e => E.Apply(E.Partial d, e)) l)) |
220 |
| E.Sub(e2, e3) =>(1, E.Sub(E.Apply(E.Partial d, e2), E.Apply(E.Partial d, e3))) |
| E.Sub(e2, e3) =>(1, E.Sub(E.Apply(E.Partial d, e2), E.Apply(E.Partial d, e3))) |
221 |
| E.Apply(E.Partial e1,e2)=>(1,E.Apply(E.Partial(d@e1), e2)) |
| E.Div(e2, e3) =>(1, E.Div(E.Apply(E.Partial d, e2), e3)) |
222 |
|
| E.Apply(E.Partial d2,e2)=>(1,E.Apply(E.Partial(d@d2), e2)) |
223 |
| E.Prod [e1]=>(1,E.Apply(E.Partial d,e1)) |
| E.Prod [e1]=>(1,E.Apply(E.Partial d,e1)) |
224 |
| E.Prod(E.Tensor(a,[])::e1::[])=>(1,E.Prod[E.Tensor(a,[]),E.Apply(E.Partial d,e1)]) |
| E.Prod es=> let |
225 |
| E.Prod(E.Tensor(a,[])::e2)=>(1,E.Prod[E.Tensor(a,[]),E.Apply(E.Partial d,E.Prod e2)]) |
val (pre, post)= filter(es,[],[]) |
226 |
| E.Prod es=> (let |
val (_,x)=mkProd(pre@[prodAppPartial(post,d)]) |
227 |
fun prod [e1] =(0,E.Apply(E.Partial d,e1)) |
in (1,x) end |
228 |
| prod (E.Epsilon eps1::es) = (0,E.Apply(E.Partial d, E.Prod (E.Epsilon eps1::es))) |
|_=>(0,E.Apply(E.Partial d,e)) |
229 |
| prod (E.Delta e1::es) = (0,E.Apply(E.Partial d, E.Prod (E.Delta e1::es))) |
(* end case*)) |
|
| prod (E.Prod e1::es)=prod(e1@es) |
|
|
| prod(E.Tensor t::e2)=(let |
|
|
val (_,l)= prod(e2) val m= E.Prod[E.Tensor t,l] |
|
|
val lr=e2 @[E.Apply(E.Partial d,E.Tensor t)] val(b,a) =mkProd lr |
|
|
in (1,E.Add[a,m]) |
|
|
end) |
|
|
| prod(E.Field f::e2)=(let |
|
|
val (_,l)= prod(e2) val m= E.Prod[E.Field f,l] |
|
|
val lr=e2 @[E.Apply(E.Partial d,E.Field f)] val(b,a) =mkProd lr |
|
|
in (1,E.Add[a,m]) |
|
|
end) |
|
|
| prod e = (0,E.Apply(E.Partial d, E.Prod e)) |
|
230 |
|
|
231 |
|
(*Sum Apply*) |
232 |
|
|
|
val (a,b)= prod es |
|
233 |
|
|
234 |
in (a, b) end) |
fun mkSumApply(E.Sum(c,E.Apply(E.Partial d,e)))=(case e |
235 |
|_=>(0,E.Apply(E.Partial d,e)) |
|
|
(* end case*)) |
|
236 |
|
|
|
fun mkSumApply2(E.Sum(c,E.Apply(E.Partial d,e)))=(print "in here ";case e |
|
237 |
of E.Const _=>(1,E.Const 0.0) |
of E.Const _=>(1,E.Const 0.0) |
238 |
| E.Tensor(_,[])=> (1,E.Const 0.0) |
| E.Tensor(_,[])=> (1,E.Const 0.0) |
239 |
| E.Field _=>(0,E.Sum(c,E.Apply(E.Partial d,e))) |
| E.Delta _ =>(1,E.Const 0.0) |
240 |
| E.Apply(E.Partial e1,e2)=>(1,E.Sum(c,E.Apply(E.Partial(d@e1),e2))) |
| E.Value _ =>(1,E.Const 0.0) |
241 |
|
| E.Epsilon _ =>(1,E.Const 0.0) |
242 |
|
| E.Conv (fid,alpha,tid, delta)=> let |
243 |
|
val e'=E.Conv(fid,alpha, tid, delta@d) |
244 |
|
in (1,E.Sum(c,e')) end |
245 |
|
| E.Apply(E.Partial d1,e2)=>(1,E.Sum(c,E.Apply(E.Partial(d@d1),e2))) |
246 |
| E.Add l => (1,E.Add(List.map (fn e => E.Sum(c,E.Apply(E.Partial d, e))) l)) |
| E.Add l => (1,E.Add(List.map (fn e => E.Sum(c,E.Apply(E.Partial d, e))) l)) |
247 |
| E.Sub(e2, e3) => |
| E.Sub(e1, e2) => (1, E.Sub(E.Sum(c,E.Apply(E.Partial d, e1)), E.Sum(c,E.Apply(E.Partial d, e2)))) |
|
(*(0,E.Sub(e2,e3)) |
|
|
*) |
|
|
(print "sub";(1, E.Sub(E.Sum(c,E.Apply(E.Partial d, e2)), E.Sum(c,E.Apply(E.Partial d, e3))))) |
|
|
|
|
|
| E.Prod [e1]=>(print "one";(1,E.Sum(c,E.Apply(E.Partial d,e1)))) |
|
248 |
|
|
249 |
|
| E.Prod [e1]=>(1,E.Sum(c,E.Apply(E.Partial d,e1))) |
250 |
|
|
|
| E.Prod(E.Tensor(a,[])::e2::[])=>("in scalar";(1, E.Prod[E.Tensor(a,[]),E.Sum(c,E.Apply(E.Partial d,e2))])) |
|
251 |
|
|
252 |
| E.Prod(E.Tensor(a,[])::e2)=>("in scalar";(1, E.Prod[E.Tensor(a,[]),E.Sum(c,E.Apply(E.Partial d,E.Prod e2))])) |
| E.Prod(E.Tensor(a,[])::e2)=>(1, E.Prod[E.Tensor(a,[]),E.Sum(c,E.Apply(E.Partial d,E.Prod e2))]) |
253 |
|
|
254 |
| E.Prod es =>(print "in prod";let |
| E.Prod es =>(let |
255 |
fun prod (change,rest, sum,partial,[]) = (change,E.Sum(sum,E.Apply(E.Partial partial,E.Prod rest))) |
fun prod (change,rest, sum,partial,[]) = (change,E.Sum(sum,E.Apply(E.Partial partial,E.Prod rest))) |
256 |
| prod (change,rest, sum,partial,E.Epsilon(i,j,k)::ps)= let |
| prod (change,rest, sum,partial,E.Epsilon(i,j,k)::ps)= let |
257 |
fun matchprod(2,_,_,_)= 1 (*matched 2*) |
fun matchprod(2,_,_,_)= 1 (*matched 2*) |
287 |
|
|
288 |
| prod (change,rest,sum, partial,e::es)= prod(change,rest@[e],sum,partial,es) |
| prod (change,rest,sum, partial,e::es)= prod(change,rest@[e],sum,partial,es) |
289 |
|
|
290 |
val (change,exp) = prod(0,[],c, d, es) |
in prod(0,[],c, d, es) |
291 |
|
|
292 |
|
|
|
in |
|
|
(change,exp) |
|
293 |
end) |
end) |
294 |
| _=>(print "nope";(0,E.Sum(c,E.Apply(E.Partial d,e)))) |
| _=>(0,E.Sum(c,E.Apply(E.Partial d,e))) |
295 |
(* end case*)) |
(* end case*)) |
296 |
|
|
|
(* |
|
|
E.Sum(c,Apply(d,e)) |
|
|
try E.Sum(c,e)=> E.Sum(c',e') |
|
|
==> E.Sum(c',E.Apply(d,e')) |
|
|
E.Apply(d,e')=> E.Apply(d',e'') |
|
|
==>E.Sum(c',E.Apply(d',e'') |
|
|
*) |
|
297 |
|
|
298 |
(*Apply normalize to each term in product list |
(*Apply normalize to each term in product list |
299 |
or Apply normalize to tail of each list*) |
or Apply normalize to tail of each list*) |
305 |
of E.Const _=> body |
of E.Const _=> body |
306 |
| E.Tensor _ =>body |
| E.Tensor _ =>body |
307 |
| E.Field _=> body |
| E.Field _=> body |
|
| E.Kernel _ =>body |
|
308 |
| E.Delta _ => body |
| E.Delta _ => body |
309 |
| E.Value _ =>body |
| E.Value _ =>body |
310 |
| E.Epsilon _=>body |
| E.Epsilon _=>body |
311 |
|
| E.Conv _=>body |
312 |
|
|
313 |
| E.Neg e => E.Neg(rewriteBody e) |
| E.Neg e => E.Neg(rewriteBody e) |
314 |
| E.Add es => let val (change,body')= mkAdd(List.map rewriteBody es) |
| E.Add es => let val (change,body')= mkAdd(List.map rewriteBody es) |
316 |
| E.Sub (a,b)=> E.Sub(rewriteBody a, rewriteBody b) |
| E.Sub (a,b)=> E.Sub(rewriteBody a, rewriteBody b) |
317 |
| E.Div (a, b) => E.Div(rewriteBody a, rewriteBody b) |
| E.Div (a, b) => E.Div(rewriteBody a, rewriteBody b) |
318 |
| E.Partial _=>body |
| E.Partial _=>body |
319 |
| E.Conv (V, alpha)=> E.Conv(rewriteBody V, alpha) |
| E.Krn(tid,deltas,pos)=> E.Krn(tid,deltas, (rewriteBody pos)) |
320 |
| E.Probe(u,v)=> E.Probe(rewriteBody u, rewriteBody v) |
| E.Img(fid,alpha,pos)=> E.Img(fid,alpha, (List.map rewriteBody pos)) |
321 |
| E.Image es => E.Image(List.map rewriteBody es) |
|
322 |
|
|
323 |
(*Product*) |
(*************Product**************) |
324 |
| E.Prod [e1] => rewriteBody e1 |
| E.Prod [e1] => rewriteBody e1 |
325 |
| E.Prod(e1::(E.Add(e2))::e3)=> |
| E.Prod((E.Add(e2))::e3)=> |
326 |
|
(changed := true; E.Add(List.map (fn e=> E.Prod([e]@e3)) e2)) |
327 |
|
| E.Prod((E.Sub(e2,e3))::e4)=> |
328 |
|
(changed :=true; E.Sub(E.Prod([e2]@e4), E.Prod([e3]@e4 ))) |
329 |
|
|
330 |
|
| E.Prod((E.Div(e2,e3))::e4)=> |
331 |
|
(changed :=true; E.Div(E.Prod([e2]@e4), e3 )) |
332 |
|
|
333 |
|
| E.Prod(e1::E.Add(e2)::e3)=> |
334 |
(changed := true; E.Add(List.map (fn e=> E.Prod([e1, e]@e3)) e2)) |
(changed := true; E.Add(List.map (fn e=> E.Prod([e1, e]@e3)) e2)) |
335 |
| E.Prod(e1::(E.Sub(e2,e3))::e4)=> |
| E.Prod(e1::E.Sub(e2,e3)::e4)=> |
336 |
(changed :=true; E.Sub(E.Prod([e1, e2]@e4), E.Prod([e1,e3]@e4 ))) |
(changed :=true; E.Sub(E.Prod([e1, e2]@e4), E.Prod([e1,e3]@e4 ))) |
337 |
| E.Prod [E.Partial r1,E.Conv(f,deltas)]=> |
|
338 |
(changed:=true;E.Conv(f,deltas@r1)) |
|
339 |
| E.Prod (E.Partial r1::E.Conv(f,deltas)::ps)=> |
|
340 |
(changed:=true; |
|
|
let val (change,e)=mkProd([E.Conv(f,deltas@r1)]@ps) |
|
|
in e end) |
|
341 |
| E.Prod[(E.Epsilon(e1,e2,e3)), E.Tensor(_,[E.V i1,E.V i2])]=> |
| E.Prod[(E.Epsilon(e1,e2,e3)), E.Tensor(_,[E.V i1,E.V i2])]=> |
342 |
if(e2=i1 andalso e3=i2) then (changed :=true;E.Const(0.0)) |
if(e2=i1 andalso e3=i2) then (changed :=true;E.Const(0.0)) |
343 |
else body |
else body |
344 |
| E.Prod [E.Partial r1, E.Tensor(_,[])]=> (changed:=true;E.Const(0.0)) |
| E.Prod [E.Partial r1, E.Tensor(_,[])]=> (changed:=true;E.Const(0.0)) |
345 |
|
| E.Prod [E.Partial r1,E.Partial r2]=> |
346 |
|
(changed:=true;E.Partial(r1@r2)) |
347 |
|
|
348 |
| E.Prod(E.Partial r1::E.Partial r2::p)=> |
| E.Prod(E.Partial r1::E.Partial r2::p)=> |
349 |
(changed:=true;E.Prod([E.Partial(r1@r2)]@p)) |
(changed:=true;E.Prod([E.Partial(r1@r2)]@p)) |
|
| E.Prod [E.Partial _, _] =>body |
|
|
|
|
|
| E.Prod (E.Partial p1::es)=> (let |
|
|
fun prod [e1] =E.Apply(E.Partial p1,e1) |
|
|
| prod(e1::e2)=(let |
|
|
val l= prod(e2) val m= E.Prod[e1,l] |
|
|
val lr=e2 @[E.Apply(E.Partial p1,e1)] val(b,a) =mkProd lr |
|
|
in E.Add[a,m] |
|
|
end) |
|
|
in (changed:=true;prod es) end) |
|
|
|
|
350 |
| E.Prod(E.Sum(c1,E.Prod(E.Epsilon e1::es1))::E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es)=>let |
| E.Prod(E.Sum(c1,E.Prod(E.Epsilon e1::es1))::E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es)=>let |
351 |
val (change,e,rest)=epsToDels(E.Sum(c1@c2, E.Prod([E.Epsilon e1, E.Epsilon e2]@es1@es2@es))) |
val (change,e,rest)=epsToDels(E.Sum(c1@c2, E.Prod([E.Epsilon e1, E.Epsilon e2]@es1@es2@es))) |
352 |
in(case (change,e, rest) |
in(case (change,e, rest) |
353 |
of (1,[e1],_)=> e1 |
of (1,[e1],_)=> (changed:=true;e1) |
354 |
| _=>body) |
| _=>let |
355 |
|
val e1=rewriteBody(E.Sum(c1,E.Prod(E.Epsilon e1::es1))) |
356 |
|
val es'=rewriteBody(E.Prod(E.Sum(c2,E.Prod(E.Epsilon e2::es2))::es)) |
357 |
|
val (_,e)=(case es' of E.Prod p=>mkProd([e1]@p) |
358 |
|
|_=> mkProd([e1]@e) |
359 |
|
(*end case*)) |
360 |
|
in e |
361 |
end |
end |
362 |
|
(*end case*)) |
363 |
|
end |
364 |
|
| E.Prod[e1,e2]=> body |
365 |
| E.Prod(e::es)=>let |
| E.Prod(e::es)=>let |
366 |
val e'=rewriteBody e |
val e'=rewriteBody e |
367 |
val e2=rewriteBody(E.Prod es) |
val e2=rewriteBody(E.Prod es) |
368 |
val(a,b)=(case e2 of E.Prod p'=> mkProd([e']@p') |
val(_,b)=(case e2 |
369 |
|
of E.Prod p'=> mkProd([e']@p') |
370 |
|_=>mkProd [e',e2]) |
|_=>mkProd [e',e2]) |
371 |
in b |
in b |
372 |
end |
end |
373 |
|
|
374 |
(*Apply*) |
(**************Apply**************) |
375 |
|
|
376 |
| E.Apply(E.Partial d,E.Sum(c,e))=>let |
(* Apply, Sum*) |
377 |
val(c,e')=mkApplySum(E.Apply(E.Partial d,E.Sum(c, rewriteBody e))) |
| E.Apply(E.Partial d,E.Sum e)=>let |
378 |
val e''=(case e' |
val s'=rewriteBody(E.Sum e) |
379 |
of E.Apply(d,E.Sum s)=>E.Apply(d,rewriteBody(E.Sum s)) |
val (c, e')=(case s' |
380 |
|_=> e') |
of E.Sum e1=> mkApplySum(E.Apply(E.Partial d,s')) |
381 |
in (print "bb";case c of 1=>(changed:=true;e'') |
| _=>(0, E.Apply(E.Partial d, s')) |
382 |
|_=> e'')end |
(*end case*)) |
383 |
| E.Apply(E.Partial [],e)=> e |
in (case c |
384 |
|
of 1=>(changed:=true;e') |
385 |
|
|_=> e' |
386 |
|
(*end case*)) |
387 |
|
end |
388 |
|
|
389 |
|
| E.Apply(E.Partial [],e)=> e |
390 |
|
| E.Apply(E.Partial p,E.Probe(E.Conv(fid,alpha,tid,d),x))=> |
391 |
|
(changed:=true;E.Probe(E.Conv(fid,alpha,tid,d@p),x)) |
392 |
|
| E.Apply(E.Partial p,E.Conv(fid,alpha,tid,d))=> |
393 |
|
(changed:=true;E.Conv(fid,alpha,tid,d@p)) |
394 |
| E.Apply(E.Partial p, e)=>let |
| E.Apply(E.Partial p, e)=>let |
395 |
|
|
396 |
val body'=E.Apply(E.Partial p, rewriteBody e) |
val body'=E.Apply(E.Partial p, rewriteBody e) |
397 |
val (c, e')=mkApply2(body') |
val (c, e')=mkApply(body') |
398 |
in (case c of 1=>(changed:=true;e') |
in (case c |
399 |
|
of 1=>(changed:=true;e') |
400 |
| _ =>e') end |
| _ =>e') end |
|
| E.Apply(e1,e2)=>E.Apply(rewriteBody e1, rewriteBody e2) |
|
401 |
|
|
402 |
|
| E.Apply(e1,e2)=>((E.Apply(rewriteBody e1, rewriteBody e2)) |
403 |
|
) |
404 |
|
|
405 |
|
|
406 |
(* Sum *) |
(************** Sum *****************) |
407 |
| E.Sum([],e)=> (changed:=true;rewriteBody e) |
| E.Sum([],e)=> (changed:=true;rewriteBody e) |
408 |
| E.Sum(_,E.Const c)=>(changed:=true;E.Const c) |
| E.Sum(_,E.Const c)=>(changed:=true;E.Const c) |
409 |
| E.Sum(c,(E.Add l))=> (changed:=true;E.Add(List.map (fn e => E.Sum(c,e)) l)) |
| E.Sum(c,(E.Add l))=> (changed:=true;E.Add(List.map (fn e => E.Sum(c,e)) l)) |
410 |
| E.Sum(c,E.Sub(e1,e2))=>(changed:=true; E.Sub(E.Sum(c,e1),E.Sum(c,e2))) |
| E.Sum(c,E.Sub(e1,e2))=>(changed:=true; E.Sub(E.Sum(c,e1),E.Sum(c,e2))) |
411 |
|
| E.Sum(c,E.Div(e1,e2))=>(changed:=true; E.Div(E.Sum(c,e1),E.Sum(c,e2))) |
412 |
|
| E.Sum(c, E.Prod(E.Const e::es))=>(changed:=true;E.Prod[E.Const e,E.Sum(c, E.Prod es)]) |
413 |
|
|
414 |
|
| E.Sum(c, E.Prod(E.Value v::es))=>(changed:=true; E.Prod [E.Value v, E.Sum(c, E.Prod es)]) |
415 |
|
| E.Sum(c, E.Prod(E.Tensor(id,[])::es))=> (changed:=true;E.Prod [E.Tensor(id,[]), E.Sum(c, E.Prod es)]) |
416 |
| E.Sum(c,E.Prod(E.Epsilon eps1::E.Epsilon eps2::ps))=> |
| E.Sum(c,E.Prod(E.Epsilon eps1::E.Epsilon eps2::ps))=> |
417 |
let val (i,e,rest)=epsToDels(body) |
let val (i,e,rest)=epsToDels(body) |
418 |
in (print "eps to dels \n ";case (i, e,rest) |
in (case (i, e,rest) |
419 |
of (1,[e1],r) =>(print "changed\n";changed:=true;e1) |
of (1,[e1],r) =>(changed:=true;e1) |
420 |
|(0,eps,[])=>(print "non";body) |
|(0,eps,[])=>body |
421 |
|(0,eps,rest)=>(let |
|(0,eps,rest)=> let |
422 |
val p'=rewriteBody(E.Prod rest) |
val p'=rewriteBody(E.Prod rest) |
423 |
val p''= (case p' of E.Prod p=>p |e=>[e]) |
val p''= (case p' of E.Prod p=>p |e=>[e]) |
424 |
val(a,b)= mkProd (eps@p'') |
val(_,b)= mkProd (eps@p'') |
425 |
in E.Sum(c,b) end |
in E.Sum(c,b) end |
426 |
) |
|_=>body |
427 |
|_=>body) |
(*end case*)) |
428 |
end |
end |
429 |
|
|
430 |
| E.Sum(c1,E.Prod(E.Epsilon eps1::E.Sum(c2,E.Prod(E.Epsilon eps2::s2))::ps))=>let |
| E.Sum(c1,E.Prod(E.Epsilon eps1::E.Sum(c2,E.Prod(E.Epsilon eps2::s2))::ps))=>let |
437 |
|
|
438 |
| E.Sum(c, E.Prod(E.Delta d::es))=>let |
| E.Sum(c, E.Prod(E.Delta d::es))=>let |
439 |
val (change,a)=reduceDelta(body) |
val (change,a)=reduceDelta(body) |
440 |
val (change',body')=(case a |
in (case (change,a) |
441 |
of E.Prod p=> mkProd p |
of (0, _)=> E.Sum(c,rewriteBody(E.Prod([E.Delta d]@es))) |
442 |
|_=> (0,a)) |
| (_, E.Prod p)=>let |
443 |
in (case change of []=>body'|_=>(changed:=true;body')) end |
val (_, p') = mkProd p |
444 |
|
in (changed:=true;p') end |
445 |
| E.Sum(c,E.Apply(E.Partial _,e))=>let |
| _ => (changed:=true;a ) |
446 |
val (change,exp)=mkSumApply2(body) |
(*end case*)) |
447 |
val exp'=(case exp |
end |
|
of E.Const c => E.Const c |
|
|
| E.Sum(c',E.Apply(d',e')) => (let |
|
|
val s'=rewriteBody(E.Sum(c',e')) |
|
|
in (case s' |
|
|
of E.Sum([],e'')=> rewriteBody (E.Apply(d',e'')) |
|
|
| E.Sum(s'',e'') => E.Sum(s'',rewriteBody(E.Apply(d',e''))) |
|
|
| _ => E.Apply(d',s')) |
|
448 |
|
|
|
end) |
|
449 |
|
|
450 |
|
| E.Sum(c,E.Apply(E.Partial p,e))=>let |
451 |
|
val (change,exp)=mkSumApply(body) |
452 |
|
val exp'=(case change |
453 |
|
of 1=> (changed:=true;exp) |
454 |
|
| _ => E.Sum(c,rewriteBody(E.Apply(E.Partial p,e)))) |
455 |
|
in exp' end |
456 |
|
|
|
| _ =>exp |
|
|
(* end case *)) |
|
457 |
|
|
458 |
in (case change of 1=>(changed:=true;exp') |_=>exp') |
| E.Sum(c,e)=>E.Sum(c,rewriteBody e) |
|
end |
|
459 |
|
|
460 |
|
(*Probe*) |
461 |
|
| E.Probe(E.Sum(c,s),x)=>(changed:=true;E.Sum(c,E.Probe(s,x))) |
462 |
|
| E.Probe(E.Neg e1,x)=>(changed:=true;E.Neg(E.Probe(e1,x))) |
463 |
|
| E.Probe(E.Add es,x) => |
464 |
|
(changed:=true;E.Add(List.map (fn(e1)=>E.Probe(e1,x)) es)) |
465 |
|
| E.Probe(E.Sub (a,b),x)=> |
466 |
|
(changed:=true;E.Sub(rewriteBody(E.Probe(a,x)), rewriteBody(E.Probe(b,x)))) |
467 |
|
| E.Probe(E.Div (a,b),x) => |
468 |
|
(changed:=true;E.Div(rewriteBody(E.Probe(a, x)),b)) |
469 |
|
|
|
| E.Sum(c,e)=>E.Sum(c,rewriteBody e) |
|
470 |
|
|
471 |
|
|
472 |
|
(* |
473 |
|
| E.Probe(E.Prod([E.Sum s] @es),x) |
474 |
|
| E.Probe(E.Prod([E.Neg e] @es),x) |
475 |
|
| E.Probe(E.Prod([E.Apply e] @es),x) needs to be rewritten |
476 |
|
*) |
477 |
|
|
478 |
|
|
479 |
|
(*Should be taken care of in next rule. |
480 |
|
| E.Probe(E.Prod([E.Add e] @es),x) |
481 |
|
| E.Probe(E.Prod([E.Sub (e1,e2)] @es),x)=> |
482 |
|
| E.Probe(E.Prod([E.Div e] @es),x)=> |
483 |
|
*) |
484 |
|
|
485 |
|
|
486 |
|
|
487 |
|
| E.Probe(E.Prod p, x)=>let |
488 |
|
val (p',x')= (rewriteBody (E.Prod p), rewriteBody x) |
489 |
|
fun probeprod([],rest) = |
490 |
|
(print "err-Did not find field/Conv"; body) |
491 |
|
| probeprod(E.Const c::es,rest)= |
492 |
|
(changed:=true;probeprod(es,rest@[E.Const c])) |
493 |
|
| probeprod(E.Tensor t::es,rest)= |
494 |
|
(changed:=true;probeprod(es,rest@[E.Tensor t])) |
495 |
|
| probeprod(E.Krn e::es, rest)= |
496 |
|
(changed:=true;probeprod(es, rest@[E.Krn e])) |
497 |
|
| probeprod(E.Delta e::es, rest)= |
498 |
|
(changed:=true;probeprod(es, rest@[E.Delta e])) |
499 |
|
| probeprod(E.Value e::es, rest)= |
500 |
|
(changed:=true;probeprod(es, rest@[E.Value e])) |
501 |
|
| probeprod(E.Epsilon e::es, rest)= |
502 |
|
(changed:=true;probeprod(es, rest@[E.Epsilon e])) |
503 |
|
| probeprod(E.Partial e::es, rest)= |
504 |
|
(changed:=true;probeprod(es, rest@[E.Partial e])) |
505 |
|
| probeprod(E.Field f::es,rest)= |
506 |
|
(changed:=true;E.Prod(rest@[E.Probe(E.Field f, x')] @es)) |
507 |
|
| probeprod(E.Conv f::es,rest)= |
508 |
|
(changed:=true;E.Prod(rest@[E.Probe(E.Conv f, x')] @es)) |
509 |
|
| probeprod(E.Prod p::es , rest)= |
510 |
|
(changed:=true;probeprod(p@es,rest)) |
511 |
|
| probeprod(_,[])=body |
512 |
|
| probeprod(e1::es, rest)=let |
513 |
|
val e'= rewriteBody(E.Prod(e1::es)) |
514 |
|
val e''= rewriteBody(E.Probe(e',x')) |
515 |
|
in (changed:=true;E.Prod(rest@[e''])) |
516 |
|
end |
517 |
|
in (case p' |
518 |
|
of E.Prod pro=>probeprod(p,[]) |
519 |
|
|_=> E.Probe(p',x') |
520 |
|
(*end case*)) |
521 |
|
end |
522 |
|
| E.Probe(u,v)=> (E.Probe(rewriteBody u, rewriteBody v)) |
523 |
(*end case*)) |
(*end case*)) |
524 |
|
|
525 |
fun loop body = let |
fun loop body = let |
527 |
|
|
528 |
in |
in |
529 |
if !changed |
if !changed |
530 |
then (changed := false;(*print(P.printbody body');*) print "\n => \n" ;loop body') |
then (changed := false ;loop body') |
531 |
else body' |
else body' |
532 |
end |
end |
533 |
|
val z=print "hi" |
534 |
|
val u= print(Int.toString( length(params))); |
535 |
val b = loop body |
val b = loop body |
536 |
in |
in |
537 |
((Ein.EIN{params=params, index=index, body=b})) |
Ein.EIN{params=params, index=index, body=b} |
538 |
end |
end |
539 |
end |
end |
540 |
|
|