Home My Page Projects Code Snippets Project Openings diderot

SCM Repository

[diderot] View of /branches/charisee/src/compiler/ein/specialize.sml
 [diderot] / branches / charisee / src / compiler / ein / specialize.sml

View of /branches/charisee/src/compiler/ein/specialize.sml

Thu Jun 13 01:57:34 2013 UTC (8 years ago) by cchiw
File size: 3304 byte(s)
`added ein`
```(* specialize.sml
*
* COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
*)

structure Specialize : sig

val transform : GenericEin.ein * (int*int list) list -> Ein.ein

end = struct

structure G = GenericEin
structure E = Ein

(* This function takes a shape-polymorphic Einstein expression and specializes it to a
* particular shape.  For each multindex (MX) in the argument index list, the order of
* the instantiated index is supplied.
*)

(*Creates an array of -1. lp function goes through each index in list.
-Updates the *mapping , and returns an Index list.
-the body is rewritten
-The index in terms is rewritten using the array.
*)
fun transform (G.EIN{params, index, body}, inst) = let
(* allocate the remapping array *)
val mapping = Array.array(length index, [~1])
fun mapIx ix = Array.sub(mapping, ix)
fun mapSingleIx ix = (case mapIx ix
of [ix] => ix
| _ => raise Fail "unexpected multiindex"
(* end case *))
(* initialize the remapping and compute the new index *)
(* i-position in array, j:DeB number *)
fun lp (i, j, [], []) = []
| lp (i, j, G.MX::r, (n, dn)::r') = (
Array.update(mapping, i, List.tabulate(n, fn k => j+k));
List.tabulate(n, fn k => E.IX(List.nth(dn,k))) @ lp(i+1, j+n, r, r'))
| lp (i, j, G.IX::r, (_,[dn])::inst) = (
Array.update(mapping, i, [j]);
E.IX(dn):: lp(i+1, j+1, r, inst))
| lp (i, j, G.SX::r, (_,[dn])::inst) = (
Array.update(mapping, i, [j]);
E.SX(dn):: lp(i+1, j+1, r, inst))
| lp _ = raise Fail "multiindex/instantiation mismatch"
val index' = lp (0, 0, index, inst)
(* rewrite the body using the mapping *)
fun rewrite e = (case e
of G.Const r => E.Const r
| G.Tensor(tid, ix) => E.Tensor(tid, List.concat(List.map mapIx ix))
| G.Field(fid, ix) => E.Field(fid, List.concat(List.map mapIx ix))
| G.Sum (count,el)=> E.Sum (count, (rewrite el))
| G.Prod el => E.Prod(List.map rewrite el)
| G.Sub(e1, e2) => E.Sub(rewrite e1, rewrite e2)
| G.Neg e => E.Neg(rewrite e)
| G.Delta(ix, jx) => E.Delta(mapSingleIx ix, mapSingleIx jx)
| G.Epsilon(ix, jx, kx) => E.Epsilon(mapSingleIx ix, mapSingleIx jx, mapSingleIx kx)
| G.Conv(V, dx, h, ix)=> E.Conv(V, dx, h, mapSingleIx ix)
| G.Partial (ix) => E.Partial(ix)(*(List.concat(List.map mapIx ix))*)
| G.Probe (a,b)=> E.Probe(rewrite a,b) (*newbie*)
| G.Inside(e, x) => E.Inside(rewrite e, x)
| G.Apply(a, b)=> E.Apply(rewrite a, rewrite b) (*newbie*)
(* end case *))
in
E.EIN{
params = List.map (fn G.TEN => E.TEN | G.FLD => E.FLD) params,
index = index',
body = rewrite body
}
end

end
```