Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/ein/specialize.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/ein/specialize.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2383 - (view) (download)

1 : cchiw 2383 (* specialize.sml
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure Specialize : sig
8 :    
9 :     val transform : GenericEin.ein * (int*int list) list -> Ein.ein
10 :    
11 :     end = struct
12 :    
13 :     structure G = GenericEin
14 :     structure E = Ein
15 :    
16 :     (* This function takes a shape-polymorphic Einstein expression and specializes it to a
17 :     * particular shape. For each multindex (MX) in the argument index list, the order of
18 :     * the instantiated index is supplied.
19 :     *)
20 :    
21 :     (*Creates an array of -1. lp function goes through each index in list.
22 :     -Updates the *mapping , and returns an Index list.
23 :     -the body is rewritten
24 :     -The index in terms is rewritten using the array.
25 :     *)
26 :     fun transform (G.EIN{params, index, body}, inst) = let
27 :     (* allocate the remapping array *)
28 :     val mapping = Array.array(length index, [~1])
29 :     fun mapIx ix = Array.sub(mapping, ix)
30 :     fun mapSingleIx ix = (case mapIx ix
31 :     of [ix] => ix
32 :     | _ => raise Fail "unexpected multiindex"
33 :     (* end case *))
34 :     (* initialize the remapping and compute the new index *)
35 :     (* i-position in array, j:DeB number *)
36 :     fun lp (i, j, [], []) = []
37 :     | lp (i, j, G.MX::r, (n, dn)::r') = (
38 :     Array.update(mapping, i, List.tabulate(n, fn k => j+k));
39 :     List.tabulate(n, fn k => E.IX(List.nth(dn,k))) @ lp(i+1, j+n, r, r'))
40 :     | lp (i, j, G.IX::r, (_,[dn])::inst) = (
41 :     Array.update(mapping, i, [j]);
42 :     E.IX(dn):: lp(i+1, j+1, r, inst))
43 :     | lp (i, j, G.SX::r, (_,[dn])::inst) = (
44 :     Array.update(mapping, i, [j]);
45 :     E.SX(dn):: lp(i+1, j+1, r, inst))
46 :     | lp _ = raise Fail "multiindex/instantiation mismatch"
47 :     val index' = lp (0, 0, index, inst)
48 :     (* rewrite the body using the mapping *)
49 :     fun rewrite e = (case e
50 :     of G.Const r => E.Const r
51 :     | G.Tensor(tid, ix) => E.Tensor(tid, List.concat(List.map mapIx ix))
52 :     | G.Field(fid, ix) => E.Field(fid, List.concat(List.map mapIx ix))
53 :     | G.Add el => E.Add(List.map rewrite el)
54 :     | G.Sum (count,el)=> E.Sum (count, (rewrite el))
55 :     | G.Prod el => E.Prod(List.map rewrite el)
56 :     | G.Sub(e1, e2) => E.Sub(rewrite e1, rewrite e2)
57 :     | G.Neg e => E.Neg(rewrite e)
58 :     | G.Delta(ix, jx) => E.Delta(mapSingleIx ix, mapSingleIx jx)
59 :     | G.Epsilon(ix, jx, kx) => E.Epsilon(mapSingleIx ix, mapSingleIx jx, mapSingleIx kx)
60 :     | G.Conv(V, dx, h, ix)=> E.Conv(V, dx, h, mapSingleIx ix)
61 :     | G.Partial (ix) => E.Partial(ix)(*(List.concat(List.map mapIx ix))*)
62 :     | G.Probe (a,b)=> E.Probe(rewrite a,b) (*newbie*)
63 :     | G.Inside(e, x) => E.Inside(rewrite e, x)
64 :     | G.Apply(a, b)=> E.Apply(rewrite a, rewrite b) (*newbie*)
65 :     (* end case *))
66 :     in
67 :     E.EIN{
68 :     params = List.map (fn G.TEN => E.TEN | G.FLD => E.FLD) params,
69 :     index = index',
70 :     body = rewrite body
71 :     }
72 :     end
73 :    
74 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0