Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/ein/specialize.sml
ViewVC logotype

View of /branches/charisee/src/compiler/ein/specialize.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2383 - (download) (annotate)
Thu Jun 13 01:57:34 2013 UTC (8 years, 1 month ago) by cchiw
File size: 3304 byte(s)
added ein
(* specialize.sml
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure Specialize : sig

    val transform : GenericEin.ein * (int*int list) list -> Ein.ein

  end = struct

    structure G = GenericEin
    structure E = Ein

  (* This function takes a shape-polymorphic Einstein expression and specializes it to a
   * particular shape.  For each multindex (MX) in the argument index list, the order of
   * the instantiated index is supplied.
   *)

    (*Creates an array of -1. lp function goes through each index in list.
        -Updates the *mapping , and returns an Index list.
        -the body is rewritten
        -The index in terms is rewritten using the array.
    *)
    fun transform (G.EIN{params, index, body}, inst) = let
        (* allocate the remapping array *)
          val mapping = Array.array(length index, [~1])
          fun mapIx ix = Array.sub(mapping, ix)
          fun mapSingleIx ix = (case mapIx ix
                 of [ix] => ix
                  | _ => raise Fail "unexpected multiindex"
                (* end case *))
        (* initialize the remapping and compute the new index *)
        (* i-position in array, j:DeB number *)
          fun lp (i, j, [], []) = []
            | lp (i, j, G.MX::r, (n, dn)::r') = (
                Array.update(mapping, i, List.tabulate(n, fn k => j+k));
                List.tabulate(n, fn k => E.IX(List.nth(dn,k))) @ lp(i+1, j+n, r, r'))
            | lp (i, j, G.IX::r, (_,[dn])::inst) = (
                Array.update(mapping, i, [j]);
                E.IX(dn):: lp(i+1, j+1, r, inst))
            | lp (i, j, G.SX::r, (_,[dn])::inst) = (
                Array.update(mapping, i, [j]);
                E.SX(dn):: lp(i+1, j+1, r, inst))
            | lp _ = raise Fail "multiindex/instantiation mismatch"
          val index' = lp (0, 0, index, inst)
        (* rewrite the body using the mapping *)
          fun rewrite e = (case e
                 of G.Const r => E.Const r
                  | G.Tensor(tid, ix) => E.Tensor(tid, List.concat(List.map mapIx ix))
                  | G.Field(fid, ix) => E.Field(fid, List.concat(List.map mapIx ix))	
                  | G.Add el => E.Add(List.map rewrite el)
                  | G.Sum (count,el)=> E.Sum (count, (rewrite el))
                  | G.Prod el => E.Prod(List.map rewrite el)
                  | G.Sub(e1, e2) => E.Sub(rewrite e1, rewrite e2)
                  | G.Neg e => E.Neg(rewrite e)
                  | G.Delta(ix, jx) => E.Delta(mapSingleIx ix, mapSingleIx jx)
                  | G.Epsilon(ix, jx, kx) => E.Epsilon(mapSingleIx ix, mapSingleIx jx, mapSingleIx kx)
                  | G.Conv(V, dx, h, ix)=> E.Conv(V, dx, h, mapSingleIx ix)
                  | G.Partial (ix) => E.Partial(ix)(*(List.concat(List.map mapIx ix))*)
                  | G.Probe (a,b)=> E.Probe(rewrite a,b) (*newbie*)
                  | G.Inside(e, x) => E.Inside(rewrite e, x)
                  | G.Apply(a, b)=> E.Apply(rewrite a, rewrite b) (*newbie*)
                (* end case *))
          in
            E.EIN{
                params = List.map (fn G.TEN => E.TEN | G.FLD => E.FLD) params,
                index = index',
                body = rewrite body
              }
          end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0