Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/ein/mk-operators.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/ein/mk-operators.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 5574 - (view) (download)

1 : jhr 3477 (* mk-operators.sml
2 : jhr 3476 *
3 : jhr 3477 * Functions to create the various Ein operators.
4 :     *
5 : jhr 3476 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
6 :     *
7 :     * COPYRIGHT (c) 2015 The University of Chicago
8 :     * All rights reserved.
9 :     *)
10 :    
11 :     structure MkOperators : sig
12 :    
13 : jhr 3477 type dim = int
14 :     type shape = dim list
15 :     type ids = Ein.index_id list
16 : jhr 3476
17 : jhr 3477 val addRR : Ein.ein
18 :     val addTT : shape -> Ein.ein
19 :     val addTF : dim * shape -> Ein.ein
20 :     val addFF : dim * shape -> Ein.ein
21 :    
22 :     val subRR : Ein.ein
23 :     val subTT : shape -> Ein.ein
24 :     val subTF : dim * shape -> Ein.ein
25 : jhr 3478 val subFT : dim * shape -> Ein.ein
26 : jhr 3477 val subFF : dim * shape -> Ein.ein
27 :    
28 :     val mulRT : shape -> Ein.ein
29 :     val mulRR : Ein.ein
30 :     val mulRF : dim * shape -> Ein.ein
31 :     val mulST : dim * shape -> Ein.ein
32 :     val mulSS : dim -> Ein.ein
33 :     val mulSF : dim * shape -> Ein.ein
34 :    
35 :     val divTR : shape -> Ein.ein
36 :     val divRR : Ein.ein
37 :     val divFR : dim * shape -> Ein.ein
38 :     val divSS : dim -> Ein.ein
39 :     val divFS : dim * shape -> Ein.ein
40 : cchiw 4000 val divTS : dim * shape -> Ein.ein
41 : jhr 4002
42 : cchiw 4409 val halfT : dim -> Ein.ein
43 :     val halfF : dim * dim -> Ein.ein
44 :     val scaleIdT : dim -> Ein.ein
45 :     val scaleIdF : dim * dim -> Ein.ein
46 : jhr 5007
47 :    
48 : jhr 3477 val negTT : shape -> Ein.ein
49 :     val negFF : dim * shape -> Ein.ein
50 :    
51 :     val cross2TT : Ein.ein
52 :     val cross3TT : Ein.ein
53 :     val cross2FF : Ein.ein
54 :     val cross3FF : Ein.ein
55 : cchiw 4000 val cross2TF : Ein.ein
56 :     val cross3TF : Ein.ein
57 :     val cross2FT : Ein.ein
58 :     val cross3FT : Ein.ein
59 : jhr 3477
60 : cchiw 3495 val outerTT : shape * shape -> Ein.ein
61 :     val outerFF : dim * shape * shape -> Ein.ein
62 : cchiw 3496 val outerTF : dim * shape * shape -> Ein.ein
63 :     val outerFT : dim * shape * shape -> Ein.ein
64 : jhr 3477
65 :     val innerTT : shape * ids -> Ein.ein
66 :     val innerFF : shape * dim * ids -> Ein.ein
67 :     val innerFT : shape * dim * ids -> Ein.ein
68 :     val innerTF : shape * dim * ids -> Ein.ein
69 :    
70 :     val colonTT : shape * ids -> Ein.ein
71 :     val colonFF : dim * shape * ids -> Ein.ein
72 :     val colonFT : dim * shape * ids -> Ein.ein
73 :     val colonTF : dim * shape * ids -> Ein.ein
74 :    
75 : cchiw 3494 val normT : shape -> Ein.ein
76 :     val normF : dim * shape -> Ein.ein
77 : jhr 3477
78 :     val normalizeTT : shape -> Ein.ein
79 :     val normalizeFF : dim * shape -> Ein.ein
80 :    
81 :     val traceT : dim -> Ein.ein
82 : cchiw 4273 val traceF : dim * dim * shape -> Ein.ein
83 : jhr 3477
84 :     val transposeT : shape -> Ein.ein
85 :     (* QUESTION: should these be index_kind? *)
86 :     val transposeF : dim * Ein.index_id * Ein.index_id -> Ein.ein
87 :    
88 : jhr 3478 val det2T : Ein.ein
89 :     val det3T : Ein.ein
90 : cchiw 4274 val det2F : dim -> Ein.ein
91 :     val det3F : dim -> Ein.ein
92 : jhr 5007
93 : cchiw 5407 val invR : Ein.ein
94 :     val invS : dim -> Ein.ein
95 : jhr 4414 val inv2T : Ein.ein
96 :     val inv2F : dim -> Ein.ein
97 : jhr 5570
98 : jhr 3477 val expF : dim -> Ein.ein
99 :     val expT : Ein.ein
100 :    
101 : jhr 3980 val powFI : dim * int -> Ein.ein
102 : jhr 5570 val powTI : int -> Ein.ein
103 : jhr 3498 val sqrtR : Ein.ein
104 : jhr 3477 val sqrtF : dim -> Ein.ein
105 : jhr 3498 val cosR : Ein.ein
106 : jhr 3477 val cosF : dim -> Ein.ein
107 : jhr 3498 val acosR : Ein.ein
108 : jhr 3477 val acosF : dim -> Ein.ein
109 : jhr 3498 val sinR : Ein.ein
110 : jhr 3477 val sinF : dim -> Ein.ein
111 : jhr 3498 val asinR : Ein.ein
112 : jhr 3477 val asinF : dim -> Ein.ein
113 : jhr 3498 val tanR : Ein.ein
114 : jhr 3477 val tanF : dim -> Ein.ein
115 : jhr 3498 val atanR : Ein.ein
116 : jhr 3477 val atanF : dim -> Ein.ein
117 :    
118 : cchiw 4281 val modulateTT : shape -> Ein.ein
119 :     val modulateTF : shape * dim -> Ein.ein
120 :     val modulateFT : shape * dim -> Ein.ein
121 :     val modulateFF : shape * dim -> Ein.ein
122 : jhr 3477
123 :     val identity : dim -> Ein.ein
124 : jhr 3478 val zeros : shape -> Ein.ein
125 : cchiw 3991 val sliceT : bool list * int list * Ein.index_bind list * int list -> Ein.ein
126 :     val sliceF : bool list * int list * Ein.index_bind list * int -> Ein.ein
127 : jhr 5574 val concatTensor : shape * int -> Ein.ein
128 :     val concatField : dim * shape * int -> Ein.ein
129 : jhr 5570
130 : cchiw 5238 val lerp3 : shape -> Ein.ein
131 :     val lerp5 : shape -> Ein.ein
132 : jhr 5574 val clampRRT : shape -> Ein.ein
133 :     val clampTTT : shape -> Ein.ein
134 : cchiw 5242 val clerp3 : shape -> Ein.ein
135 :     val clerp5 : shape -> Ein.ein
136 : jhr 5570
137 : jhr 3477 val conv : dim * shape -> Ein.ein
138 :     val probe : shape * dim -> Ein.ein
139 :    
140 :     val curl2d : Ein.ein
141 :     val curl3d : Ein.ein
142 :     val grad : shape -> Ein.ein
143 : cchiw 4159 val gradConstant : shape -> Ein.ein
144 : jhr 3477 val dotimes : dim * shape -> Ein.ein (* ?? *)
145 :     val divergence : dim * shape -> Ein.ein
146 :    
147 : jhr 5574 val cfexpMix : shape * shape list * shape list -> Ein.ein
148 : jhr 5570
149 : jhr 3476 end = struct
150 :    
151 :     structure E = Ein
152 :    
153 : jhr 3477 type dim = int
154 :     type shape = dim list
155 :     type ids = int list
156 : jhr 3701
157 :     (* controls whether tensor operations should be *)
158 :     val canSubst = true
159 :    
160 :     (* A constructor function for tensor variables that (by default) can be substituted for;
161 :     * this behavior is controlled by the canSubst flag.
162 :     *)
163 :     fun mkTEN alpha = E.TEN(canSubst, alpha)
164 :    
165 :     (* a constructor function for tensor parameters that should never be substituted for. *)
166 :     fun mkNoSubstTEN alpha = E.TEN(false, alpha)
167 :    
168 : jhr 3476 fun specialize (alpha, inc) = List.mapi (fn (i, _) => E.V(i + inc)) alpha
169 :    
170 :     fun sumIds (n, inc, alpha) = let
171 : jhr 4324 val vs = List.tabulate(n, fn v => (v+inc))
172 :     in
173 :     ListPair.map (fn(v, i) => (v, 0, i-1)) (vs, alpha)
174 :     end
175 : jhr 3476
176 : cchiw 3978 fun sumIds2 (n, i) = List.tabulate(n, fn v => (v, 0, i))
177 : jhr 5007
178 : jhr 3476 (******************************* Addition *****************************************)
179 :    
180 :     (* Adding tensors : < X{\alpha} + Y_{\alpha}>_{\alpha} *)
181 :     fun addTT alpha = let
182 : jhr 3498 val expindex = specialize(alpha, 0)
183 : jhr 3702 in
184 : jhr 3476 E.EIN{
185 : jhr 3701 params = [mkTEN alpha, mkTEN alpha],
186 : jhr 3476 index = alpha,
187 :     body = E.Opn(E.Add, [E.Tensor(0, expindex), E.Tensor(1, expindex)])
188 : jhr 3498 }
189 :     end
190 : jhr 3476
191 :     val addRR = addTT []
192 :    
193 : jhr 3477 (* Tensor and Fields *)
194 :     fun addTF (dim, shape) =let
195 : jhr 3498 val expindex = specialize(shape, 0)
196 :     in
197 :     E.EIN{
198 : jhr 3701 params = [mkTEN shape, E.FLD dim],
199 : jhr 3498 index = shape,
200 :     body = E.Opn(E.Add, [E.Lift(E.Tensor(0, expindex)), E.Field(1, expindex)])
201 :     }
202 :     end
203 : jhr 3476
204 : jhr 3477 (* Adding Fields : < F{\alpha} + G_{\alpha}>_{\alpha} *)
205 :     fun addFF (dim, shape) =let
206 : jhr 3498 val expindex = specialize(shape, 0)
207 :     in
208 :     E.EIN{
209 :     params = [E.FLD dim, E.FLD dim],
210 :     index = shape,
211 :     body = E.Opn(E.Add, [E.Field(0, expindex), E.Field(1, expindex)])
212 :     }
213 :     end
214 : jhr 3476
215 :     (********************************* Subtraction **************************************)
216 :    
217 :     fun subTT alpha = let
218 : jhr 3498 val expindex = specialize(alpha, 0)
219 :     in
220 : jhr 3476 E.EIN{
221 : jhr 3701 params = [mkTEN alpha, mkTEN alpha],
222 : jhr 3498 index = alpha,
223 :     body = E.Op2(E.Sub, E.Tensor(0, expindex), E.Tensor(1, expindex))
224 :     }
225 :     end
226 : jhr 3476
227 :     val subRR = subTT []
228 :    
229 :     fun subTF (dim, shape) = let
230 : jhr 3498 val expindex = specialize(shape, 0)
231 :     in
232 :     E.EIN{
233 : jhr 3701 params = [mkTEN shape, E.FLD dim],
234 : jhr 3498 index = shape,
235 :     body = E.Opn(E.Add,
236 :     [E.Lift(E.Tensor(0, expindex)), E.Op1(E.Neg, E.Field(1, expindex))])
237 :     }
238 :     end
239 : jhr 3476
240 :     fun subFT (dim, shape) = let
241 : jhr 3498 val expindex = specialize(shape, 0)
242 :     in
243 :     E.EIN{
244 : jhr 3701 params = [mkTEN shape, E.FLD dim],
245 : jhr 3498 index = shape,
246 :     body = E.Op2(E.Sub, E.Field(1, expindex), E.Lift(E.Tensor(0, expindex)))
247 :     }
248 :     end
249 : jhr 3476
250 :     fun subFF (dim, shape) = let
251 : jhr 3498 val expindex = specialize(shape, 0)
252 :     in
253 :     E.EIN{
254 :     params = [E.FLD dim, E.FLD dim],
255 :     index = shape,
256 :     body = E.Op2(E.Sub, E.Field(0, expindex), E.Field(1, expindex))
257 :     }
258 :     end
259 : jhr 3476
260 :     (********************************** Multiplication *************************************)
261 :    
262 :     (* scalar times tensor product: <s * T_{\alpha}>_{\alpha} *)
263 :     fun mulRT alpha = let
264 : jhr 3498 val expindex = specialize(alpha, 0)
265 :     in
266 : jhr 3476 E.EIN{
267 : jhr 3701 params = [mkTEN [], mkTEN alpha],
268 : jhr 3476 index = alpha,
269 : jhr 3477 body = E.Opn(E.Prod, [E.Tensor(0, []), E.Tensor(1, expindex)])
270 : jhr 3498 }
271 :     end
272 : jhr 3476
273 :     val mulRR = mulRT []
274 :    
275 : jhr 3477 fun mulRF (dim, shape) =let
276 : jhr 4324 val expindex = specialize(shape, 0)
277 :     in
278 : jhr 3476 E.EIN{
279 : jhr 4324 params = [mkTEN [], E.FLD dim],
280 :     index = shape,
281 :     body = E.Opn(E.Prod, [E.Lift(E.Tensor(0, [])), E.Field(1, expindex)])
282 :     }
283 :     end
284 : jhr 3476
285 :     fun mulST (dim, shape) =let
286 : jhr 4324 val expindex = specialize(shape, 0)
287 :     in
288 : jhr 3476 E.EIN{
289 : jhr 4324 params = [mkTEN shape, E.FLD dim],
290 :     index = shape,
291 :     body = E.Opn(E.Prod, [E.Lift(E.Tensor(0, expindex)), E.Field(1, [])])
292 :     }
293 :     end
294 : jhr 3476
295 :     fun mulSS dim = E.EIN{
296 : jhr 3477 params = [E.FLD dim, E.FLD dim],
297 : jhr 3476 index = [],
298 : jhr 3477 body = E.Opn(E.Prod, [E.Field(0, []), E.Field(1, [])])
299 : jhr 4324 }
300 : jhr 3476
301 : jhr 3477 fun mulSF(dim, shape) =let
302 : jhr 4324 val expindex = specialize(shape, 0)
303 :     in
304 :     E.EIN{
305 :     params = [E.FLD dim, E.FLD dim],
306 :     index = shape,
307 :     body = E.Opn(E.Prod, [E.Field(0, []), E.Field(1, expindex)])
308 :     }
309 :     end
310 : jhr 3476
311 :     (************************************ Division ************************************)
312 :    
313 : jhr 3477 fun divTR alpha = let
314 : jhr 3498 val expindex = specialize(alpha, 0)
315 :     in
316 : jhr 3476 E.EIN{
317 : jhr 3701 params = [mkTEN alpha, mkTEN []],
318 : jhr 3476 index = alpha,
319 : jhr 3477 body = E.Op2(E.Div, E.Tensor(0, expindex), E.Tensor(1, []))
320 : jhr 3498 }
321 :     end
322 : jhr 3476
323 : jhr 3477 val divRR = divTR []
324 : jhr 3476
325 : jhr 3477 fun divFR (dim, shape) = let
326 : jhr 3498 val expindex = specialize(shape, 0)
327 :     in
328 :     E.EIN{
329 : jhr 3701 params = [E.FLD dim, mkTEN []],
330 : jhr 3498 index = shape,
331 :     body = E.Op2(E.Div, E.Field(0, expindex), E.Lift(E.Tensor(1, [])))
332 :     }
333 :     end
334 : jhr 3477
335 : jhr 3476 fun divSS dim = E.EIN{
336 : jhr 3498 params = [E.FLD dim, E.FLD dim],
337 :     index = [],
338 :     body = E.Op2(E.Div, E.Field(0, []), E.Field(1, []))
339 :     }
340 : jhr 3476
341 : jhr 3477 fun divFS(dim, shape) = let
342 : jhr 3498 val expindex = specialize(shape, 0)
343 :     in
344 :     E.EIN{
345 :     params = [E.FLD dim, E.FLD dim],
346 :     index = shape,
347 : jhr 3512 body = E.Opn(E.Prod, [E.Field(0, expindex), E.Op2(E.Div, E.Const 1, E.Field(1, []))])
348 : jhr 3498 }
349 :     end
350 : jhr 4002
351 : cchiw 4000 fun divTS(dim, shape) = let
352 :     val expindex = specialize(shape,0)
353 : jhr 4002 in
354 : jhr 4317 E.EIN{
355 :     params = [mkTEN shape, E.FLD dim],
356 :     index = shape,
357 :     body = E.Op2(E.Div, E.Lift(E.Tensor(0, expindex)), E.Field(1, []))
358 :     }
359 : cchiw 4000 end
360 : jhr 3476
361 : cchiw 4409 (* Divide Scalars*)
362 :     fun halfT(n) = E.EIN{
363 :     params = [mkTEN [n, n]],
364 :     index = [n,n],
365 :     body = E.Op2(E.Div, E.Tensor(0, [E.V 0 , E.V 1]), E.Const 2)
366 :     }
367 : jhr 5007
368 : cchiw 4409 fun halfF(dim,n) = E.EIN{
369 :     params = [E.FLD(dim)],
370 :     index = [n,n],
371 :     body = E.Op2(E.Div, E.Field(0, [E.V 0 , E.V 1]), E.Const 2)
372 :     }
373 :    
374 :     (*scale by delta*)
375 :     fun scaleIdT(n) = E.EIN{
376 :     params = [mkTEN []],
377 :     index = [n,n],
378 :     body = E.Opn(E.Prod, [E.Tensor(0, []), E.Delta(E.V 0, E.V 1)])
379 :     }
380 :    
381 :     (*scale by delta*)
382 :     fun scaleIdF(dim,n) = E.EIN{
383 :     params = [E.FLD(dim)],
384 :     index = [n,n],
385 :     body = E.Opn(E.Prod, [E.Field(0, []), E.Delta(E.V 0, E.V 1)])
386 :     }
387 :    
388 :    
389 : jhr 3477 (************************************* Negation **********************************)
390 :    
391 : jhr 3476 fun negTT alpha = let
392 : jhr 4324 val expindex = specialize(alpha, 0)
393 : jhr 3477 in
394 :     E.EIN {
395 : jhr 3701 params = [mkTEN alpha], index = alpha,
396 : jhr 3498 body = E.Op1(E.Neg, E.Tensor(0, expindex))
397 : jhr 3702 }
398 : jhr 3498 end
399 : jhr 3476
400 : jhr 3477 fun negFF (dim, shape) = let
401 : jhr 3498 val expindex = specialize(shape, 0)
402 :     in
403 :     E.EIN{
404 :     params = [E.FLD dim], index = shape,
405 :     body = E.Op1(E.Neg, E.Field(0, expindex))
406 :     }
407 : jhr 3477 end
408 : jhr 3476
409 : jhr 3477 (****************************** cross product ***********************************)
410 : jhr 3476
411 : jhr 3477 (* 2-d cross product Eps_{ij}U_i V_j *)
412 : jhr 3476 val cross2TT = E.EIN{
413 : jhr 3701 params = [mkTEN [2], mkTEN [2]],
414 : jhr 3477 index = [],
415 : cchiw 3978 body = E.Sum([(0, 0, 1), (1, 0, 1)],
416 : cchiw 4289 E.Opn(E.Prod, [E.Eps2(E.V 0, E.V 1), E.Tensor(0, [E.V 0]), E.Tensor(1, [E.V 1])]))
417 : jhr 3498 }
418 : jhr 3476
419 : jhr 3477 (* crossProduct is on 3D vectors ..vec3 t8=t0 × t1; *)
420 :     val cross3TT = E.EIN{
421 : jhr 3701 params = [mkTEN [3], mkTEN [3]],
422 : jhr 3477 index = [3],
423 : cchiw 3978 body = E.Sum([(1, 0, 2), (2, 0, 2)],
424 : jhr 3498 E.Opn(E.Prod, [
425 : cchiw 4289 E.Epsilon(E.V 0, E.V 1, E.V 2),
426 : jhr 3498 E.Tensor(0, [E.V 1]),
427 :     E.Tensor(1, [E.V 2])
428 :     ]))
429 :     }
430 : jhr 3477
431 :     (* Field Cross Product *)
432 : jhr 3476 val cross2FF = E.EIN{
433 : jhr 3477 params = [E.FLD(2), E.FLD(2)], index = [],
434 : cchiw 3978 body = E.Sum([(0, 0, 1), (1, 0, 1)],
435 : cchiw 4289 E.Opn(E.Prod, [E.Eps2(E.V 0, E.V 1), E.Field(0, [E.V 0]), E.Field(1, [E.V 1])]))
436 : jhr 3498 }
437 : jhr 3476
438 : jhr 3477 (* Field Cross Product *)
439 : jhr 3476 val cross3FF = E.EIN{
440 : jhr 3701 params = [E.FLD(3), E.FLD(3)], index= [3],
441 : cchiw 3978 body = E.Sum([(1, 0, 2), (2, 0, 2)],
442 : jhr 4324 E.Opn(E.Prod, [
443 :     E.Epsilon(E.V 0, E.V 1, E.V 2),
444 :     E.Field(0, [E.V 1]),
445 :     E.Field(1, [E.V 2])
446 :     ]))
447 : jhr 3498 }
448 : jhr 3476
449 : cchiw 4000 (*Field and Tensor Cross product *)
450 :     val cross2FT = E.EIN{
451 :     params = [E.FLD(2), mkTEN [2]], index= [],
452 : jhr 4324 body = E.Sum([(0, 0, 1), (1, 0, 1)],
453 :     E.Opn(E.Prod, [
454 :     E.Eps2(E.V 0, E.V 1),
455 :     E.Field(0, [E.V 0]),
456 :     E.Lift(E.Tensor(1, [E.V 1]))
457 :     ]))
458 : jhr 4317 }
459 : cchiw 4000
460 :     val cross3FT = E.EIN{
461 :     params = [E.FLD(3), mkTEN[3]], index= [3],
462 : jhr 4002 body = E.Sum([(1, 0, 2), (2, 0, 2)],
463 : jhr 4324 E.Opn(E.Prod, [
464 :     E.Epsilon(E.V 0, E.V 1, E.V 2),
465 :     E.Field(0, [E.V 1]),
466 :     E.Lift(E.Tensor(1, [E.V 2]))
467 :     ]))
468 : jhr 4317 }
469 : cchiw 4000
470 :     val cross2TF = E.EIN{
471 : jhr 4317 params = [mkTEN[2], E.FLD(2)], index = [],
472 :     body = E.Sum([(0, 0, 1), (1, 0, 1)],
473 : jhr 4324 E.Opn(E.Prod, [
474 :     E.Eps2(E.V 0, E.V 1),
475 :     E.Lift(E.Tensor(0, [E.V 0])),
476 :     E.Field(1, [E.V 1])
477 :     ]))
478 : jhr 4317 }
479 : cchiw 4000
480 :     val cross3TF = E.EIN{
481 : jhr 4317 params = [mkTEN[3], E.FLD(3)], index= [3],
482 :     body = E.Sum([(1, 0, 2),(2,0,2)],
483 : jhr 4324 E.Opn(E.Prod, [
484 :     E.Epsilon(E.V 0, E.V 1, E.V 2),
485 :     E.Lift(E.Tensor(0, [E.V 1])),
486 :     E.Field(1, [E.V 2])
487 :     ]))
488 : jhr 4317 }
489 : cchiw 4000
490 : jhr 3477 (******************** outer product ********************************)
491 : jhr 3498
492 : jhr 4414 fun outerTT (alpha, beta) = let
493 :     val expIdxA = specialize (alpha, 0)
494 :     val expIdxB = specialize (beta, length alpha)
495 :     in
496 :     E.EIN{
497 :     params = [mkTEN alpha, mkTEN beta],
498 :     index = alpha@beta,
499 :     body = E.Opn(E.Prod, [E.Tensor(0, expIdxA), E.Tensor(1, expIdxB)])
500 :     }
501 :     end
502 : jhr 3476
503 : cchiw 3495 (*Assumes same dimension vector field *)
504 : jhr 3498 fun outerFF (dim, alpha, beta) =let
505 : jhr 4414 val expIdxA = specialize (alpha, 0)
506 :     val expIdxB = specialize (beta, length alpha)
507 :     in
508 :     E.EIN{
509 :     params = [E.FLD dim, E.FLD dim],
510 :     index = alpha@beta,
511 :     body = E.Opn(E.Prod, [E.Field(0, expIdxA), E.Field(1, expIdxB)])
512 :     }
513 :     end
514 : jhr 3477
515 : jhr 3498 fun outerTF (dim, alpha, beta) =let
516 : jhr 4414 val expIdxA = specialize (alpha, 0)
517 :     val expIdxB = specialize (beta, length alpha)
518 :     in
519 :     E.EIN{
520 :     params = [mkTEN alpha, E.FLD dim],
521 :     index = alpha@beta,
522 :     body = E.Opn(E.Prod, [E.Lift(E.Tensor(0, expIdxA)), E.Field(1, expIdxB)])
523 :     }
524 :     end
525 : jhr 3702
526 : jhr 3498 fun outerFT (dim, alpha, beta) =let
527 : jhr 4414 val expIdxA = specialize(alpha, 0)
528 :     val expIdxB = specialize(beta, length alpha)
529 :     in
530 :     E.EIN{
531 :     params = [E.FLD dim, mkTEN alpha],
532 :     index = alpha@beta,
533 :     body = E.Opn(E.Prod, [E.Field(0, expIdxA), E.Lift(E.Tensor(1, expIdxB))])
534 :     }
535 :     end
536 : jhr 3477
537 :     (*************************** inner product **********************************)
538 : jhr 3476 (* generic inner product: <T_{\alpha i} * T_{i \beta}>_{\alpha \beta} *)
539 : jhr 3477 fun innerTT (shape1, i::beta) = let
540 : jhr 3498 val alpha = List.take(shape1, length shape1 - 1)
541 :     val expindexA = specialize(alpha, 0)
542 :     val expindexB = specialize(beta, length alpha)
543 : cchiw 3978 val sid = length alpha + length beta
544 :     val sx = E.V sid
545 :     val s'' = [(sid, 0, i-1)]
546 : jhr 3498 in
547 : jhr 3476 E.EIN{
548 : jhr 3702 params = [mkTEN shape1, mkTEN(i :: beta)],
549 : jhr 3476 index = alpha@beta,
550 : jhr 3477 body = E.Sum(s'', E.Opn(E.Prod, [
551 : cchiw 3978 E.Tensor(0, expindexA@[sx]), (* T_{\alpha i} *)
552 :     E.Tensor(1, [sx]@expindexB ) (* T'_{i \beta} *)
553 : jhr 3498 ]))
554 :     }
555 :     end
556 : jhr 4414 | innerTT _ = raise Fail "Wrong shape for inner product"
557 : jhr 3476
558 : jhr 3477 (* generic inner product: <T_{\alpha i} * T_{i \beta}>_{\alpha \beta} *)
559 :     fun innerFF (shape1, dim, i::beta) = let
560 : jhr 3498 val alpha = List.take(shape1, length shape1 - 1)
561 :     val expindexA = specialize(alpha, 0)
562 :     val expindexB = specialize(beta, length alpha)
563 : cchiw 3978 val sid = length alpha + length beta
564 :     val sx = E.V sid
565 : jhr 3498 in
566 :     E.EIN{
567 :     params = [E.FLD dim, E.FLD dim],
568 : jhr 3702 index = alpha @ beta,
569 : jhr 3498 body = E.Sum([(sid, 0, i-1)],
570 :     E.Opn(E.Prod, [
571 : cchiw 3978 E.Field(0, expindexA @ [sx]), (* F_{\alpha i} *)
572 :     E.Field(1, [sx] @ expindexB) (* F'_{i \beta} *)
573 : jhr 3498 ]))
574 :     }
575 :     end
576 : jhr 3477 | innerFF _ = raise Fail "Wrong shape for innerProductField"
577 : jhr 3476
578 : jhr 3477 fun innerFT (shape1, dim, i::beta) = let
579 : jhr 3498 val alpha = List.take(shape1, length shape1-1)
580 :     val expindexA = specialize(alpha, 0)
581 :     val expindexB = specialize(beta, length alpha)
582 : cchiw 3978 val sid = length alpha + length beta
583 :     val sx = E.V sid
584 : jhr 3498 in
585 :     E.EIN{
586 : jhr 3702 params = [E.FLD dim, mkTEN(i::beta)],
587 :     index = alpha @ beta,
588 : jhr 3498 body = E.Sum([(sid, 0, i-1)],
589 :     E.Opn(E.Prod, [
590 : cchiw 3978 E.Field(0, expindexA @ [sx]), (* F_{\alpha i} *)
591 :     E.Lift(E.Tensor(1, [sx] @ expindexB )) (* F'_{i \beta} *)
592 : jhr 3498 ]))
593 :     }
594 :     end
595 : jhr 3477 | innerFT _ = raise Fail "Wrong shape for innerProductFieldTensor"
596 : jhr 3476
597 : jhr 3477 fun innerTF (shape1, dim, i::beta) = let
598 : jhr 3498 val alpha = List.take(shape1, length shape1 - 1)
599 :     val expindexA = specialize(alpha, 0)
600 :     val expindexB = specialize(beta, length alpha)
601 : jhr 4324 val sid = length alpha + length beta
602 : cchiw 3978 val sx = E.V sid
603 : jhr 3498 in
604 :     E.EIN{
605 : jhr 3702 params = [mkTEN shape1, E.FLD dim],
606 :     index = alpha @ beta,
607 : jhr 3498 body = E.Sum([(sid, 0, i-1)],
608 :     E.Opn(E.Prod, [
609 : cchiw 3978 E.Lift(E.Tensor(0, expindexA @ [sx])), (* F_{\alpha i} *)
610 :     E.Field(1, [sx] @ expindexB) (* F'_{i \beta} *)
611 : jhr 3498 ]))
612 :     }
613 :     end
614 : jhr 3477 | innerTF _ = raise Fail "Wrong shape for innerProductTensorField"
615 : jhr 3476
616 : jhr 3477 (*************************** colon product **********************************)
617 : jhr 3476
618 : jhr 3477 (* <T_{\alpha i j} * B{i j \beta }>_\alpha \beta *)
619 :     fun colonTT (shape1, i::j::beta) = let
620 : jhr 3498 val lenAlpha = length shape1 - 2
621 :     val alpha = List.take(shape1, lenAlpha)
622 :     val expindexA = specialize(alpha, 0)
623 :     val expindexB = specialize(beta, lenAlpha)
624 :     val sumi = lenAlpha + length beta
625 :     val s' = [E.V sumi, E.V(sumi+1)]
626 : cchiw 3978 val sx = [(sumi, 0, i-1), ((sumi+1), 0, j-1)]
627 : jhr 3498 in
628 : jhr 3476 E.EIN{
629 : jhr 3702 params = [mkTEN shape1, mkTEN(i::j::beta)],
630 : jhr 3476 index = alpha@beta,
631 : jhr 3477 body = E.Sum(sx,
632 : jhr 3498 E.Opn(E.Prod, [E.Tensor(0, expindexA@s'), E.Tensor(1, s'@expindexB)]))
633 :     }
634 :     end
635 : jhr 3476
636 : jhr 3477 (* <F_{\alpha i j} * G_{i j \beta }>_\alpha \beta *)
637 :     fun colonFF (dim, shape1, i::j::beta) = let
638 : jhr 3498 val lenAlpha = length shape1 - 2
639 :     val alpha = List.take(shape1, lenAlpha)
640 :     val expindexA = specialize(alpha, 0)
641 :     val expindexB = specialize(beta, lenAlpha)
642 :     val sumi = lenAlpha + length beta
643 :     val s' = [E.V sumi, E.V(sumi+1)]
644 : cchiw 3978 val sx = [(sumi, 0, i-1), ((sumi+1), 0, j-1)]
645 : jhr 3498 in
646 : jhr 3476 E.EIN{
647 : jhr 3477 params = [E.FLD dim, E.FLD dim],
648 : jhr 3476 index = alpha@beta,
649 : jhr 3477 body = E.Sum(sx,
650 : jhr 3498 E.Opn(E.Prod, [E.Field(0, expindexA@s'), E.Field(1, s'@expindexB)]))
651 :     }
652 :     end
653 : jhr 3476
654 :    
655 : jhr 3477 (* <F_{\alpha i j} * T_{i j \beta }>_\alpha \beta *)
656 :     fun colonFT (dim, shape1, i::j::beta) = let
657 : jhr 3498 val lenAlpha = length shape1 - 2
658 :     val alpha = List.take(shape1, lenAlpha)
659 :     val expindexA = specialize(alpha, 0)
660 :     val expindexB = specialize(beta, lenAlpha)
661 : cchiw 3978 val sid = lenAlpha + length beta
662 :     val s' = [E.V sid, E.V(sid+1)]
663 :     val sx = [(sid, 0, i-1), ((sid+1), 0, j-1)]
664 : jhr 3498 in
665 : jhr 3476 E.EIN{
666 : jhr 3701 params = [E.FLD dim, mkTEN shape1],
667 : jhr 3498 index = alpha@beta,
668 :     body = E.Sum(sx,
669 :     E.Opn(E.Prod, [E.Field(0, expindexA@s'), E.Lift(E.Tensor(1, s'@expindexB))]))
670 :     }
671 :     end
672 : jhr 3476
673 : jhr 3477 (* <T_{\alpha i j} * G{i j \beta }>_\alpha \beta *)
674 :     fun colonTF (dim, shape1, i::j::beta) = let
675 : jhr 3498 val lenAlpha = length shape1 - 2
676 :     val alpha = List.take(shape1, lenAlpha)
677 :     val expindexA = specialize(alpha, 0)
678 :     val expindexB = specialize(beta, lenAlpha)
679 : cchiw 3978 val sid = lenAlpha + length beta
680 :     val s' = [E.V sid, E.V(sid+1)]
681 :     val sx = [(sid, 0, i-1), ((sid+1), 0, j-1)]
682 : jhr 3498 in
683 : jhr 3476 E.EIN{
684 : jhr 3702 params = [mkTEN(i::j::beta), E.FLD dim],
685 : jhr 3498 index = alpha@beta,
686 :     body = E.Sum(sx,
687 :     E.Opn(E.Prod, [E.Lift(E.Tensor(0, expindexA@s')), E.Field(1, s'@expindexB)]))
688 :     }
689 :     end
690 : jhr 3476
691 : jhr 3477 (******************** Norm ********************************)
692 : jhr 3476
693 : jhr 5007 fun normT [] = E.EIN{
694 :     params = [mkTEN []], index = [], body = E.Op1(E.Abs, E.Tensor(0, []))
695 : jhr 3498 }
696 : jhr 5007 | normT alpha = let
697 :     val expIdx = specialize(alpha, 0)
698 :     val sx = sumIds(length alpha, 0, alpha)
699 :     in
700 :     E.EIN{
701 :     params = [mkTEN alpha],
702 :     index = [],
703 :     body = E.Op1(E.Sqrt,
704 :     E.Sum(sx, E.Opn(E.Prod, [E.Tensor(0, expIdx), E.Tensor(0, expIdx)])))
705 :     }
706 :     end
707 : jhr 3702
708 : cchiw 5006 fun normF (dim, []) =
709 : jhr 5007 E.EIN{params = [E.FLD dim],index = [], body = E.Op1(E.Abs, E.Field(0, []))}
710 : cchiw 4331 | normF (dim, alpha) = let
711 : jhr 3498 val expIdx = specialize(alpha, 0)
712 :     val sx = sumIds(length alpha, 0, alpha)
713 :     in
714 :     E.EIN{
715 :     params = [E.FLD dim],
716 :     index = [],
717 :     body = E.Op1(E.Sqrt,
718 :     E.Sum(sx, E.Opn(E.Prod, [E.Field(0, expIdx), E.Field(0, expIdx)])))
719 :     }
720 :     end
721 : jhr 3476
722 : cchiw 3494 fun normalizeTT alpha = let
723 : jhr 3498 val expindex = specialize(alpha, 0)
724 :     val len = length alpha
725 :     val expindexDot = specialize(alpha, len)
726 :     val sx = sumIds(len, len, alpha)
727 :     val f = E.Tensor(0, expindex)
728 :     val g = E.Tensor(1, expindexDot)
729 :     in
730 : jhr 3476 E.EIN{
731 : jhr 4132 params = [mkTEN alpha, mkTEN alpha],
732 : jhr 3498 index = alpha,
733 :     body = E.Opn(E.Prod, [
734 : jhr 3512 f, E.Op2(E.Div, E.Const 1, E.Op1(E.Sqrt, E.Sum(sx, E.Opn(E.Prod, [g, g]))))
735 : jhr 3498 ])
736 :     }
737 :     end
738 : jhr 3476
739 : cchiw 3494 fun normalizeFF (dim, alpha as i::_) = let
740 : jhr 3498 val expindex = specialize(alpha, 0)
741 :     val len = length alpha
742 :     val expindexDot = specialize(alpha, len)
743 :     val sx = sumIds(len, len, alpha)
744 :     val f = E.Field(0, expindex)
745 :     val g = E.Field(1, expindexDot)
746 :     in
747 : jhr 3476 E.EIN{
748 : jhr 3477 params = [E.FLD dim, E.FLD dim],
749 : cchiw 3545 index = alpha,
750 : jhr 3477 body = E.Opn(E.Prod, [
751 : jhr 3512 f, E.Op2(E.Div, E.Const 1, E.Op1(E.Sqrt, E.Sum(sx, E.Opn(E.Prod, [g, g]))))
752 : jhr 3498 ])
753 :     }
754 :     end
755 : jhr 3476
756 : jhr 3477 (************************* trace *************************)
757 :    
758 :     (* Trace: <M_{i, i}> This one Sx represents both i's*)
759 : jhr 3476 fun traceT dim = E.EIN{
760 : jhr 3701 params = [mkTEN [dim, dim]], index = [],
761 : cchiw 3978 body = E.Sum([(0, 0, dim-1)], E.Tensor(0, [E.V 0, E.V 0]))
762 : jhr 3498 }
763 : jhr 3476
764 : jhr 3477 (* Trace: <Sigma_i F_{\alpha i, i}> This one Sx represents both i's *)
765 : cchiw 4273 fun traceF (dim, d2, alpha) = let
766 : jhr 3498 val expindex = specialize(alpha, 0)
767 : cchiw 3978 val sid = length alpha
768 :     val sx = E.V sid
769 : jhr 3498 in
770 : jhr 3476 E.EIN{
771 : jhr 3477 params = [E.FLD dim],
772 : jhr 3476 index = alpha,
773 : cchiw 4273 body = E.Sum([(sid, 0, d2-1)], E.Field(0, expindex@[sx, sx]))
774 : jhr 3498 }
775 :     end
776 : jhr 3476
777 : jhr 3477 (************************* tranpose *************************)
778 : jhr 3476
779 : jhr 3477 fun transposeT alpha = E.EIN{
780 : jhr 3701 params = [mkTEN alpha],
781 : jhr 3498 index = List.rev alpha,
782 :     body = E.Tensor(0, [E.V 1, E.V 0])
783 :     }
784 : jhr 3476
785 : jhr 3477 (* Transpose Field F_{ji} *)
786 :     fun transposeF (dim, i, j) = E.EIN{
787 : jhr 3498 params = [E.FLD dim],
788 :     index = [i, j],
789 :     body = E.Field(0, [E.V 1, E.V 0])
790 :     }
791 : jhr 3476
792 : jhr 3477 (************************* determinant *************************)
793 : jhr 3476
794 : jhr 3478 val det2T = E.EIN{
795 : jhr 3701 params = [mkNoSubstTEN [2, 2]],
796 : jhr 3498 index = [],
797 :     body = E.Op2(E.Sub,
798 :     E.Opn(E.Prod, [E.Tensor(0, [E.C 0, E.C 0]), E.Tensor(0, [E.C 1, E.C 1])]),
799 :     E.Opn(E.Prod, [E.Tensor(0, [E.C 0, E.C 1]), E.Tensor(0, [E.C 1, E.C 0])]))
800 :     }
801 : jhr 3476
802 : jhr 3478 val det3T = let
803 : jhr 3498 val a = E.Tensor(0, [E.C 0, E.C 0])
804 :     val b = E.Tensor(0, [E.C 0, E.C 1])
805 :     val c = E.Tensor(0, [E.C 0, E.C 2])
806 :     val d = E.Tensor(0, [E.C 1, E.C 0])
807 :     val e = E.Tensor(0, [E.C 1, E.C 1])
808 :     val f = E.Tensor(0, [E.C 1, E.C 2])
809 :     val g = E.Tensor(0, [E.C 2, E.C 0])
810 :     val h = E.Tensor(0, [E.C 2, E.C 1])
811 :     val i = E.Tensor(0, [E.C 2, E.C 2])
812 :     in
813 :     E.EIN{
814 : jhr 3701 params = [mkNoSubstTEN [3, 3]],
815 : jhr 3498 index = [],
816 :     body = E.Op2(E.Sub,
817 :     E.Opn(E.Add, [
818 :     E.Opn(E.Prod, [a, e, i]), E.Opn(E.Prod, [b, f, g]), E.Opn(E.Prod, [c, d, h])
819 :     ]),
820 :     E.Opn(E.Add, [
821 :     E.Opn(E.Prod, [c, e, g]), E.Opn(E.Prod, [b, d, i]), E.Opn(E.Prod, [a, f, h])
822 :     ]))
823 :     }
824 :     end
825 : jhr 3477
826 : jhr 4324 fun det2F dim = E.EIN{
827 :     params = [E.FLD dim],
828 : jhr 3498 index = [],
829 :     body = E.Op2(E.Sub,
830 :     E.Opn(E.Prod, [E.Field(0, [E.C 0, E.C 0]), E.Field(0, [E.C 1, E.C 1])]),
831 :     E.Opn(E.Prod, [E.Field(0, [E.C 0, E.C 1]), E.Field(0, [E.C 1, E.C 0])]))
832 :     }
833 : jhr 3477
834 : jhr 4324 fun det3F dim = E.EIN{
835 :     params = [E.FLD dim],
836 : jhr 3498 index = [],
837 : cchiw 3978 body = E.Sum([(0, 0, 2)],
838 : jhr 3498 E.Opn(E.Prod, [
839 : jhr 3716 E.Field(0, [E.C 0, E.V 0]),
840 : cchiw 3978 E.Sum([(1, 0, 2)],
841 : jhr 3498 E.Opn(E.Prod, [
842 : jhr 3716 E.Field(0, [E.C 1, E.V 1]),
843 : cchiw 3978 E.Sum([(2, 0, 2)],
844 : cchiw 4289 E.Opn(E.Prod, [E.Epsilon(E.V 0, E.V 1, E.V 2), E.Field(0, [E.C 2, E.V 2])]))
845 : jhr 3498 ]))
846 :     ]))
847 :     }
848 : jhr 4414
849 : cchiw 4409 (************************* Inverse *************************)
850 : jhr 5570
851 : jhr 5421 fun mkInvS e = E.Op2(E.Div, E.Const 1, e(0, []))
852 :     fun invS dim = E.EIN{params = [E.FLD dim], index= [], body = mkInvS E.Field}
853 :     val invR = E.EIN{params = [mkTEN([])], index= [], body = mkInvS E.Tensor}
854 : jhr 5570
855 : jhr 4414 fun mkInv2x2 f = let
856 :     fun mkFCx (ix, jx) = f (0, [E.C ix, E.C jx])
857 :     fun mkFVx (ix, jx) = f (0, [E.V ix, E.V jx])
858 : cchiw 4420 val f00 = mkFCx (0, 0)
859 :     val f11 = mkFCx (1, 1)
860 :     val f01 = mkFCx (0, 1)
861 :     val f10 = mkFCx (1, 0)
862 : jhr 4414 val i = 0 and j = 1 and k = 2
863 :     val fij = mkFVx (i, j)
864 :     val fkk = mkFVx (k, k)
865 :     (* numerator*)
866 :     val en = E.Op2(E.Sub,
867 :     E.Opn(E.Prod, [E.Sum([(k, 0, 1)], fkk), E.Delta(E.V i, E.V j)]),
868 :     fij)
869 :     (* denominator *)
870 :     val d1 = E.Opn(E.Prod, [f00, f11])
871 :     val d2 = E.Opn(E.Prod, [f01, f10])
872 :     val dn = E.Op2(E.Sub, d1, d2)
873 :     in
874 :     E.Op2(E.Div, en, dn)
875 :     end
876 : jhr 3476
877 : jhr 4414 fun inv2F dim = E.EIN{params = [E.FLD dim], index= [2, 2], body = mkInv2x2 E.Field}
878 :    
879 :     val inv2T = let
880 :     val shape = [2, 2]
881 :     in
882 :     E.EIN{params = [mkTEN shape], index = shape, body = mkInv2x2 E.Tensor}
883 :     end
884 : jhr 5570
885 : jhr 3477 (************************* Exponential **************************)
886 :     fun expF dim = E.EIN{params = [E.FLD dim], index = [], body = E.Op1(E.Exp, E.Field(0, []))}
887 : jhr 3701 val expT = E.EIN{params = [mkNoSubstTEN []], index = [], body = E.Op1(E.Exp, E.Tensor(0, []))}
888 : jhr 3476
889 : jhr 3477 (************************* Lifted single-argument math functions *************************)
890 :     local
891 : jhr 3498 fun tensorFn rator = E.EIN{
892 : cchiw 4087 params = [mkTEN []],
893 : jhr 3498 index = [],
894 : cchiw 3969 body = E.Op1(rator, E.Tensor(0, []))
895 : jhr 3498 }
896 :     fun liftFn rator dim = E.EIN{
897 :     params = [E.FLD dim],
898 :     index = [],
899 :     body = E.Op1(rator, E.Field(0, []))
900 :     }
901 : jhr 3477 in
902 : cchiw 3978 fun powFI (dim, n) = E.EIN{
903 : jhr 3498 params = [E.FLD dim],
904 :     index = [], body = E.Op1(E.PowInt n, E.Field(0, []))
905 :     }
906 : jhr 5570 fun powTI n = E.EIN{
907 :     params = [mkTEN []],
908 :     index = [], body = E.Op1(E.PowInt n, E.Tensor(0, []))
909 :     }
910 : jhr 3498 val sqrtR = tensorFn E.Sqrt
911 : jhr 3477 val sqrtF = liftFn E.Sqrt
912 : jhr 3498 val cosR = tensorFn E.Cosine
913 :     val cosF = liftFn E.Cosine
914 :     val acosR = tensorFn E.ArcCosine
915 : jhr 3477 val acosF = liftFn E.ArcCosine
916 : jhr 3498 val sinR = tensorFn E.Sine
917 :     val sinF = liftFn E.Sine
918 :     val asinR = tensorFn E.ArcSine
919 : jhr 3477 val asinF = liftFn E.ArcSine
920 : jhr 3498 val tanR = tensorFn E.Tangent
921 :     val tanF = liftFn E.Tangent
922 :     val atanR = tensorFn E.ArcTangent
923 : jhr 3477 val atanF = liftFn E.ArcTangent
924 :     end (* local *)
925 : jhr 3476
926 : jhr 3477 (************************* other tensor ops *************************)
927 : jhr 4324
928 : cchiw 4322 fun modulateTT shape = let
929 : jhr 4324 val expindex = specialize(shape, 0)
930 :     in
931 : cchiw 4322 E.EIN{
932 : jhr 4324 params = [mkTEN shape, mkTEN shape],
933 : cchiw 4322 index = shape,
934 :     body = E.Opn(E.Prod, [E.Tensor(0, expindex), E.Tensor(1, expindex)])
935 : jhr 4324 }
936 :     end
937 : jhr 3476
938 : cchiw 4281 fun modulateFF(shape, dim) = let
939 : jhr 4324 val expindex = specialize(shape, 0)
940 :     in
941 : cchiw 4281 E.EIN{
942 : jhr 4324 params = [E.FLD dim, E.FLD dim],
943 : cchiw 4281 index = shape,
944 :     body = E.Opn(E.Prod,[E.Field(0, expindex), E.Field(1, expindex)])
945 : jhr 4324 }
946 :     end
947 :    
948 : cchiw 4281 fun modulateTF(shape, dim) = let
949 : jhr 4324 val expindex = specialize(shape, 0)
950 :     in
951 : cchiw 4281 E.EIN{
952 : jhr 4324 params = [mkTEN shape, E.FLD dim],
953 : cchiw 4281 index = shape,
954 :     body = E.Opn(E.Prod,[E.Lift(E.Tensor(0, expindex)), E.Field(1, expindex)])
955 : jhr 4324 }
956 :     end
957 :    
958 : cchiw 4281 fun modulateFT(shape, dim) = let
959 : jhr 4324 val expindex = specialize(shape, 0)
960 :     in
961 : cchiw 4281 E.EIN{
962 : jhr 4324 params = [E.FLD dim, mkTEN shape],
963 : cchiw 4281 index = shape,
964 : jhr 4324 body = E.Opn(E.Prod, [E.Field(0, expindex), E.Lift(E.Tensor(1, expindex))])
965 :     }
966 :     end
967 : cchiw 4281
968 : jhr 3477 fun identity dim = E.EIN{
969 : cchiw 4289 params = [], index = [dim, dim], body = E.Delta(E.V 0, E.V 1)
970 : jhr 3498 }
971 : jhr 3476
972 : jhr 3478 fun zeros shape = E.EIN{
973 : cchiw 4555 params = [], index = shape, body = E.Zero (specialize(shape, 0))
974 : jhr 3498 }
975 : jhr 3478
976 : jhr 3477 (* QUESTION: why do we need the const list? The indices are implicit in the position of
977 :     * of the mask element! Likewise, the result type can be determined from the argTy and
978 :     * mask.
979 :     *)
980 : cchiw 3991 fun sliceT (mask, const, rstTy, argTy) = let
981 : jhr 4317 fun iter ([], _, cnt) = []
982 :     | iter (true::es, c::cs, cnt) = (E.C c)::iter(es, cs, cnt)
983 :     | iter (false::es, cs, cnt) = (E.V cnt)::iter(es, cs, cnt+1)
984 :     val ix = iter(mask, const, 0)
985 :     in
986 :     E.EIN{params = [E.TEN(true, argTy)], index = rstTy, body = E.Tensor(0, ix)}
987 :     end
988 : jhr 3476
989 : cchiw 3991 fun sliceF (mask, const, rstTy, dim) = let
990 : jhr 4317 fun iter ([], _, cnt) = []
991 :     | iter (true::es, c::cs, cnt) = (E.C c)::iter(es, cs, cnt)
992 :     | iter (false::es, cs, cnt) = (E.V cnt)::iter(es, cs, cnt+1)
993 :     val ix = iter(mask, const, 0)
994 : jhr 4229 in
995 : jhr 4317 E.EIN{params = [E.FLD dim], index = rstTy, body = E.Field(0, ix)}
996 :     end
997 : jhr 5570
998 : jhr 5574 fun concatBody (expression, shape, nflds, idshift) = let
999 :     val expindex = specialize(shape, 1)
1000 :     val exps = List.tabulate (nflds, fn n => E.Opn(E.Prod, [expression(n+idshift, expindex), E.Delta(E.C n, E.V 0)]))
1001 :     in
1002 :     E.Opn(E.Add, exps)
1003 :     end
1004 :    
1005 :     fun concatTensor (shape, nflds) =
1006 :     E.EIN{
1007 :     params = List.tabulate (nflds, fn _=> mkTEN shape),
1008 :     index = nflds::shape,
1009 :     body = concatBody (E.Tensor, shape, nflds, 0)
1010 :     }
1011 :    
1012 :     fun concatField (dim, shape, nflds) =
1013 :     E.EIN{
1014 :     params = List.tabulate (nflds, fn _=> E.FLD dim),
1015 :     index = nflds::shape,
1016 :     body = concatBody (E.Field, shape, nflds, 0)
1017 :     }
1018 :    
1019 : jhr 5258 (* Lerp<ty>(a, b, t) -- computes a + t*(b-a), where a and b have type ty
1020 :     * and t has type real
1021 :     *)
1022 :     fun lerp3 alpha = let
1023 : cchiw 5241 val expindex = specialize(alpha, 0)
1024 :     val a = E.Tensor(0, expindex)
1025 :     val b = E.Tensor(1, expindex)
1026 :     val c = E.Tensor(2, [])
1027 :     val e3 = E.Op2(E.Sub, b, a)
1028 :     val e5 = E.Opn(E.Prod, [c, e3])
1029 :     in
1030 : cchiw 5238 E.EIN{
1031 :     params = [mkTEN alpha, mkTEN alpha, mkTEN []],
1032 :     index = alpha,
1033 :     body = E.Opn(E.Add, [a, e5])
1034 : jhr 5258 }
1035 : cchiw 5241 end
1036 : jhr 3476
1037 : jhr 5258 fun lerp5 alpha = let
1038 : cchiw 5241 val expindex = specialize(alpha, 0)
1039 :     val a = E.Tensor(0, expindex)
1040 :     val b = E.Tensor(1, expindex)
1041 :     val c = E.Tensor(2, [])
1042 :     val d = E.Tensor(3, [])
1043 :     val e = E.Tensor(4, [])
1044 :     val e1 = E.Op2(E.Sub, d, c)
1045 :     val e2 = E.Op2(E.Sub, e, c)
1046 :     val e3 = E.Op2(E.Sub, b, a)
1047 :     val e4 = E.Op2(E.Div, e1, e2)
1048 :     val e5 = E.Opn(E.Prod, [e4, e3])
1049 :     in
1050 : cchiw 5238 E.EIN{
1051 :     params = [mkTEN alpha, mkTEN alpha, mkTEN [], mkTEN [], mkTEN []],
1052 :     index = alpha,
1053 :     body = E.Opn(E.Add, [a, e5])
1054 : jhr 5258 }
1055 : cchiw 5241 end
1056 : jhr 5570
1057 : jhr 5258 (* clamps x to the range lo..hi, where lo and hi are scalars and x *)
1058 : cchiw 5241 fun clampRRT alpha = let
1059 :     val expindex = specialize(alpha, 0)
1060 :     val a = E.Tensor(0, [])
1061 :     val b = E.Tensor(1, [])
1062 :     val c = E.Tensor(2, expindex)
1063 :     in
1064 :     E.EIN{
1065 :     params = [mkTEN [], mkTEN [], mkTEN alpha],
1066 :     index = alpha,
1067 :     body = E.Op3(E.Clamp, a, b, c)
1068 : jhr 5258 }
1069 : cchiw 5241 end
1070 : jhr 5570
1071 : jhr 5258 (* clamps x[alpha] to the range lo[alpha]..hi[alpha] *)
1072 : cchiw 5241 fun clampTTT alpha = let
1073 :     val expindex = specialize(alpha, 0)
1074 :     val a = E.Tensor(0, expindex)
1075 :     val b = E.Tensor(1, expindex)
1076 :     val c = E.Tensor(2, expindex)
1077 :     in
1078 :     E.EIN{
1079 : cchiw 5242 params = [mkTEN alpha, mkTEN alpha, mkTEN alpha],
1080 : cchiw 5241 index = alpha,
1081 :     body = E.Op3(E.Clamp, a, b, c)
1082 : jhr 5258 }
1083 : cchiw 5241 end
1084 : jhr 5570
1085 : jhr 5258 fun clerp3 alpha = let
1086 : cchiw 5242 val expindex = specialize(alpha, 0)
1087 :     val a = E.Tensor(0, expindex)
1088 :     val b = E.Tensor(1, expindex)
1089 :     val c = E.Tensor(2, [])
1090 :     val e3 = E.Op2(E.Sub, b, a)
1091 :     val e5 = E.Opn(E.Prod, [c, e3])
1092 :     val elerp = E.Opn(E.Add, [a, e5])
1093 :     in
1094 : jhr 5258 E.EIN{
1095 :     params = [mkTEN alpha, mkTEN alpha, mkTEN []],
1096 :     index = alpha,
1097 :     body = E.Op3(E.Clamp, a, b, elerp)
1098 :     }
1099 : cchiw 5242 end
1100 : jhr 5570
1101 : jhr 5258 fun clerp5 alpha = let
1102 : cchiw 5242 val expindex = specialize(alpha, 0)
1103 :     val a = E.Tensor(0, expindex)
1104 :     val b = E.Tensor(1, expindex)
1105 :     val c = E.Tensor(2, [])
1106 :     val d = E.Tensor(3, [])
1107 :     val e = E.Tensor(4, [])
1108 :     val e1 = E.Op2(E.Sub, d, c)
1109 :     val e2 = E.Op2(E.Sub, e, c)
1110 :     val e3 = E.Op2(E.Sub, b, a)
1111 :     val e4 = E.Op2(E.Div, e1, e2)
1112 :     val e5 = E.Opn(E.Prod, [e4, e3])
1113 :     val elerp = E.Opn(E.Add, [a, e5])
1114 :     in
1115 : jhr 5258 E.EIN{
1116 :     params = [mkTEN alpha, mkTEN alpha, mkTEN [], mkTEN [], mkTEN []],
1117 :     index = alpha,
1118 :     body = E.Op3(E.Clamp, a, b, elerp)
1119 :     }
1120 : cchiw 5242 end
1121 : jhr 5570
1122 : cchiw 5238 (******************** other field ops ********************************)
1123 :    
1124 : jhr 3477 (* FLD here is bounded to image field, and dimension of h *)
1125 :     fun conv (dim, shape) =let
1126 : jhr 3498 val expindex = specialize(shape, 0)
1127 :     in
1128 :     E.EIN{
1129 : jhr 3645 params = [E.IMG(dim, shape), E.KRN],
1130 : jhr 3498 index = shape,
1131 : jhr 4229 body = E.Conv(0, expindex, 1, [])
1132 : jhr 3498 }
1133 :     end
1134 : jhr 3476
1135 : jhr 3477 (* Probe: <F(x)>_{\alpha} *)
1136 :     fun probe (alpha, dim) = let
1137 : jhr 3498 val expindex = specialize(alpha, 0)
1138 :     in
1139 :     E.EIN{
1140 : jhr 5570 params = [E.FLD dim, mkNoSubstTEN [dim]], index = alpha,
1141 : jhr 3498 body = E.Probe(E.Field(0, expindex), E.Tensor(1, []))
1142 :     }
1143 :     end
1144 : jhr 3476
1145 : jhr 3477 (***************************** derivative ****************************)
1146 : jhr 3476
1147 : jhr 3477 (* \EinExp{\sum_{ij}\mathcal{E}_{ij} \frac{F_j}{\partial x_i} *)
1148 : jhr 4184 val curl2d = E.EIN{
1149 : jhr 3498 params = [E.FLD 2],
1150 :     index = [],
1151 : cchiw 3978 body = E.Sum([(0, 0, 1), (1, 0, 1)],
1152 : jhr 3498 E.Opn(E.Prod, [
1153 : cchiw 4289 E.Eps2(E.V 0, E.V 1),
1154 : jhr 3498 E.Apply(E.Partial[E.V 0], E.Field(0, [E.V 1]))
1155 :     ]))
1156 : jhr 4184 }
1157 : jhr 3476
1158 : jhr 3477 val curl3d = E.EIN{
1159 : jhr 3701 params = [mkTEN [3]],
1160 : jhr 3498 index = [3],
1161 : cchiw 3978 body = E.Sum([(1, 0, 2), (2, 0, 2)],
1162 : jhr 3498 E.Opn(E.Prod, [
1163 : cchiw 4289 E.Epsilon(E.V 0, E.V 1, E.V 2),
1164 : jhr 3498 E.Apply(E.Partial[E.V 1], E.Field(0, [E.V 2]))
1165 :     ]))
1166 :     }
1167 : jhr 3476
1168 : jhr 4324 fun gradConstant (alpha as a::_) = E.EIN{
1169 : cchiw 4159 params = [E.FLD a],
1170 :     index = [],
1171 :     body = E.Apply(E.Partial [(E.C 0)], E.Field(0, []))
1172 : jhr 4324 }
1173 : cchiw 4159
1174 : jhr 3477 (*< d F / d_i>_i *)
1175 : cchiw 4159 fun grad (alpha as a::_) = let
1176 : jhr 3498 val expindex = specialize(alpha, 0)
1177 :     in
1178 :     E.EIN{
1179 :     params = [E.FLD a],
1180 :     index = alpha,
1181 :     body = E.Apply(E.Partial expindex, E.Field(0, []))
1182 :     }
1183 :     end
1184 : jhr 3477
1185 :     (*< Sigma d F_alpha / d x_i>ALpha i CHANGE HERE *)
1186 :     fun dotimes (dim, alpha) = let
1187 : jhr 3498 val n = length alpha
1188 :     val i' = List.tabulate (n, fn x => E.V x)
1189 :     in
1190 :     E.EIN{
1191 :     params = [E.FLD dim], index = alpha@[dim],
1192 :     body = E.Apply(E.Partial[E.V n], E.Field(0, i'))
1193 :     }
1194 :     end
1195 : jhr 3477
1196 :     (* <d F_i /d_i> *)
1197 :     fun divergence (dim, alpha) = let
1198 : jhr 3498 val expindex = specialize(alpha, 0)
1199 : cchiw 3978 val sid = length alpha
1200 :     val sumIndex = E.V sid
1201 : jhr 3498 val sumIndexL = [sumIndex]
1202 :     val S = expindex@sumIndexL
1203 :     in
1204 :     E.EIN{
1205 :     params = [E.FLD dim],
1206 :     index = alpha,
1207 : cchiw 3978 body = E.Sum([(sid, 0, dim-1)], E.Apply(E.Partial sumIndexL, E.Field(0, S)))
1208 : jhr 3498 }
1209 :     end
1210 : jhr 3477
1211 : jhr 5570 fun cfexpMix (alpha_f, alphas_tf, alphas_tt) = let
1212 :     val n_tf = length(alphas_tf)
1213 :     val tterm_tf = List.tabulate(n_tf, fn id => (id+1, E.F))
1214 :     val n_tt = length(alphas_tt)
1215 :     val shift_tf = n_tf+1
1216 :     val tterm_tt = List.tabulate(n_tt, fn id => (id+shift_tf,E.T))
1217 :     val fldtem = E.Tensor(0, specialize(alpha_f, 0))
1218 :     val bodyterm = E.OField(E.CFExp (tterm_tf@tterm_tt), fldtem , E.Partial [])
1219 :     val param_f = [mkTEN alpha_f]
1220 :     val param_tt = List.map (fn talpha => mkNoSubstTEN talpha) alphas_tt
1221 :     val param_tf = List.map (fn talpha => mkNoSubstTEN talpha) alphas_tf
1222 :     in
1223 :     E.EIN {
1224 :     params = param_f@param_tf@param_tt,
1225 :     index = alpha_f,
1226 :     body = bodyterm
1227 :     }
1228 :     end
1229 :    
1230 : jhr 3476 end (* mkOperators *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0