Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee/src/compiler/ein/rewrite.sml
ViewVC logotype

View of /branches/charisee/src/compiler/ein/rewrite.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2383 - (download) (annotate)
Thu Jun 13 01:57:34 2013 UTC (6 years, 3 months ago) by cchiw
File size: 8928 byte(s)
added ein
(* rewrite.sml
 *
 * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
 * All rights reserved.
 *)

structure Rewrite : sig

    datatype arg
      = Var of Var.var
      | App of Ein.ein * Var.var list

    val evalEinApp : Ein.ein * arg list -> Ein.ein * Var.var list

  end = struct

    structure VarMap = Var.Map

    fun mkAdd es = (
          case List.filter
            (fn
            (Ein.Const x) => Real.!=(x, 0.0)
            | _ => true) es
           of [] => Ein.Const 0.0
            | [e] => e
            | es => Ein.Add es
          (* end case *))

    datatype arg
      = Var of Var.var
      | App of Ein.ein * Var.var list

  (* apply a substitution to an ein_exp ID*)
    fun instantiateIdx body ids = let
          val subst = Vector.fromList ids
          fun substIdx id = Vector.sub(subst, id)
            handle ex => (print(concat["substIdx ([|", String.concatWith "," (List.map Int.toString ids), "|],", Int.toString id, ")\n"]); raise ex)
          fun apply e = (case e
                 of Ein.Const _ => e
                  | Ein.Tensor(id, mx) => Ein.Tensor(id, List.map substIdx mx)
                  | Ein.Field(id, mx) => Ein.Field(id, List.map substIdx mx)
                  | Ein.Add es => mkAdd(List.map apply es)
               
                  | Ein.Sum(c,esum)=> Ein.Sum(c, esum)
                  | Ein.Prod es => Ein.Prod(List.map apply es)
                  | Ein.Sub(e1, e2) => Ein.Sub(apply e1, apply e2)
                  | Ein.Neg e => Ein.Neg(apply e)
                  | Ein.Delta(i, j) => Ein.Delta(substIdx i, substIdx j)
                  | Ein.Epsilon(i, j, k) => Ein.Epsilon(substIdx i, substIdx j, substIdx k)
                  | Ein.Conv(v, dx, h, i) => Ein.Conv(v,dx, h, substIdx i)
                  | Ein.Partial mx => Ein.Partial mx
                  | Ein.Probe(e, id) => Ein.Probe(apply e, id)
                  | Ein.Inside(e, id) => Ein.Inside(apply e, id)
                  | Ein.Apply(e1, e2)=> Ein.Apply(apply e1, apply e2)
                (* end case *))
          in
            apply body
          end

    datatype subst = S of {
          tSub : (Ein.multiindex -> Ein.ein_exp) array, (* mapping from tensor ID to instantiation *)
          fSub : (Ein.multiindex -> Ein.ein_exp) array  (* mapping from field ID to instantiation *)
        }

    fun newSubst (nTens, nFlds) = S{
            tSub = Array.array(nTens, fn _ => raise Fail "undefined tensor"),
            fSub = Array.array(nFlds, fn _ => raise Fail "undefined field")
          }
    fun bindTensor (S{tSub, ...}, id, f) = Array.update (tSub, id, f)
    fun bindField (S{fSub, ...}, id, f) = Array.update (fSub, id, f)
    fun substTensor (S{tSub, ...}, id, mx) = Array.sub (tSub, id) mx
handle ex => (print(concat["substTensor(_, ", Int.toString id, ", [|", String.concatWith "," (List.map Int.toString mx), "|])\n"]); raise ex)
    fun substField (S{fSub, ...}, id, mx) = Array.sub (fSub, id) mx

  (* apply a substitution to an ein_exp term *)
    fun applySubst (subst, e) = let
          fun apply e = (case e
                 of Ein.Const _ => e
                  | Ein.Tensor(id, mx) => substTensor (subst, id, mx)
                  | Ein.Field(id, mx) => substField (subst, id, mx)
                  | Ein.Add es => Ein.Add(List.map apply es)
                  | Ein.Sum (c,esum)=>Ein.Sum(c, apply esum)
                  | Ein.Prod es => Ein.Prod(List.map apply es)
                  | Ein.Sub(e1, e2) => Ein.Sub(apply e1, apply e2)
                  | Ein.Neg e => Ein.Neg(apply e)
                  | Ein.Delta _ => e
                  | Ein.Epsilon _ => e
                  | Ein.Conv _ => e
                  | Ein.Partial _ => e 
                  | Ein.Probe(e, id) => Ein.Probe(apply e, id)
                  | Ein.Inside(e, id) => Ein.Inside(apply e, id)
                  | Ein.Apply _=>e (*newbie? ???*)
                (* end case *))
          in
            apply e
          end

(* NOTE: if all of the arguments are distinct variables, this function is the identity *)

    fun evalEinApp (Ein.EIN{params, index, body}, args : arg list) = let
          fun renameVars ([], [], nT, nF, uniqueArgs) = (nT, nF, List.rev uniqueArgs)
            | renameVars (param::params, arg::args, nT, nF, uniqueArgs) = let
                fun isUnique (x, uArgs) = not(List.exists (fn (y, _, _) => Var.same(x, y)) uArgs)
                fun continue (nT, nF, uniqueArgs) = renameVars (params, args, nT, nF, uniqueArgs)
                fun doVar (p, x, nT, nF, uArgs) = (case p
                       of Ein.TEN =>
                            if isUnique (x, uArgs)
                              then (nT+1, nF, (x, p, nT)::uArgs)
                              else (nT, nF, uniqueArgs)
                        | Ein.FLD =>
                            if isUnique (x, uArgs)
                              then (nT, nF+1, (x, p, nF)::uArgs)
                              else (nT, nF, uniqueArgs)
                      (* end case *))
                in
                  case arg
                   of (Var x) => continue(doVar(param, x, nT, nF, uniqueArgs))
                    | (App(Ein.EIN{params=ps, ...}, xs)) => let
                        fun lp ([], [], nT, nF, uniqueArgs) = continue(nT, nF, uniqueArgs)
                          | lp (p::ps, x::xs, nT, nF, uniqueArgs) = let
                              val (nT, nF, uniqueArgs) = doVar(p, x, nT, nF, uniqueArgs)
                              in
                                lp (ps, xs, nT, nF, uniqueArgs)
                              end
                          | lp _ = raise Fail "param/arg arity mismatch"
                        in
                          lp (ps, xs, nT, nF, uniqueArgs)
                        end
                  (* end case *)
                end
          val (nT, nF, uniqueArgs) = renameVars(params, args, 0, 0, [])
        (* build a map from unique argument variables to (kind, id) pairs *)
          val vMap = List.foldl
                (fn ((x, k, id), vMap) => VarMap.insert(vMap, x, (k, id)))
                  VarMap.empty uniqueArgs
        (* allocate the top-level substitution *)
          val subst = newSubst (nT, nF)
        (* add a mapping for a variable argument to a substitution *)
          fun bindVar (subst, x, nT, nF) = (case VarMap.find(vMap, x)
                 of SOME(Ein.TEN, id) => (
                      bindTensor(subst, nT, fn mx => Ein.Tensor(id, mx));
                      (nT+1, nF))
                  | SOME(Ein.FLD, id) => (
                      bindField(subst, nF, fn mx => Ein.Field(id, mx));
                      (nT, nF+1))
                  | NONE => raise Fail(concat["undefined argument variable \"", Var.name x, "\""])
                (* end case *))
        (* rewrite arguments and intialize the top-level substitution *)
          fun rewriteArgs ([], [], _, _) = ()
            | rewriteArgs (_::params, (Var x)::args, nT, nF) = let
                val (nT, nF) = bindVar (subst, x, nT, nF)
                in
                  rewriteArgs (params, args, nT, nF)
                end
            | rewriteArgs (p::ps, App(Ein.EIN{params, body, ...}, xs)::args, nT, nF) = let
              (* rewrite the argument body first *)
                val body = let
                    (* allocate a new substitution for the argument body *)
                      val subst = let
                            fun f (Ein.TEN, (nT, nF)) = (nT+1, nF)
                              | f (Ein.FLD, (nT, nF)) = (nT, nF+1)
                            in
                              newSubst (List.foldl f (0, 0) params)
                            end
                    (* initialize the substitution *)
                      fun doVars ([], _, _) = ()
                        | doVars (x::xs, nT, nF) = let
                            val (nT, nF) = bindVar (subst, x, nT, nF)
                            in
                              doVars (xs, nT, nF)
                            end
                      in
                        doVars (xs, 0, 0);
                        applySubst (subst, body)
                      end
                fun mkBody mx = instantiateIdx body mx
                in
                  case p
                   of Ein.TEN => (
                        bindTensor (subst, nT, instantiateIdx body);
                        rewriteArgs (ps, args, nT+1, nF))
                    | Ein.FLD => (
                        bindField (subst, nT, instantiateIdx body);
                        rewriteArgs (ps, args, nT, nF+1))
                  (* end case *)
                end
          val _ = rewriteArgs (params, args, 0, 0)
          in (
            Ein.EIN{
                params = List.map #2 uniqueArgs,
                index = index,
                body = applySubst (subst, body)
              },
            List.map #1 uniqueArgs
          ) end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0