Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-opt/apply.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-opt/apply.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3978 - (download) (annotate)
Wed Jun 15 19:07:40 2016 UTC (3 years, 2 months ago) by cchiw
File size: 5418 byte(s)
changed ein expressions, rewrote matchEps, added translation
(* apply.sml
 *
 * Apply EIN operator arguments to EIN operator.
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2015 The University of Chicago
 * All rights reserved.
 *)

structure Apply : sig

    val apply : Ein.ein * int * Ein.ein -> (bool * Ein.ein)

  end = struct

    structure E = Ein

    structure IMap = IntRedBlackMap

    fun mapId (i, dict, shift) = (case IMap.find(dict, i)
	  of NONE => i + shift
           | SOME j => j
	  (* end case *))

    fun mapIndex (ix, dict, shift) = (case IMap.find(dict, ix)
           of NONE => E.V(ix + shift)
            | SOME j => j
        (* end case *))

    fun mapId2 (i, dict, shift) = (case IMap.find(dict, i)
           of NONE => (
		print(concat["Error: ", Int.toString i, " is out of range\n"]);
		i+shift)
            | SOME j => j
         (* end case *))

    fun rewriteSubst (e, subId, mx, paramShift, sumShift) = let
	  fun insertIndex ([], _, dict, shift) = (dict, shift)
	    | insertIndex (e::es, n, dict, _) = let
		val shift = (case e of E.V ix => ix - n | E.C i => i - n)
		in
		  insertIndex(es, n+1, IMap.insert(dict, n, e), shift)
		end
	  val (subMu, shift) = insertIndex(mx, 0, IMap.empty, 0)
	  val shift' = Int.max(sumShift, shift)
          fun mapMu (E.V i) = mapIndex(i, subMu, shift')
            | mapMu c = c 
          fun mapAlpha mx = List.map mapMu mx
          fun mapSingle i = let
		val E.V v = mapIndex(i, subMu, shift')
                in
                  v
                end
          fun mapSum l = List.map (fn (a, b, c) => (mapSingle a, b, c)) l
	  fun mapParam id = mapId2(id, subId, 0)
          fun apply e = (case e
		 of E.Const _ => e
		  | E.ConstR _ => e
		  | E.Tensor(id, mx) => E.Tensor(mapParam id, mapAlpha mx)
		  | E.Delta(i, j) => E.Delta(mapSingle i, mapSingle j)
		  | E.Epsilon(i, j, k) => E.Epsilon(mapSingle i, mapSingle j, mapSingle k)
		  | E.Eps2(i, j) => E.Eps2(mapSingle i, mapSingle j)
		  | E.Field(id, mx) => E.Field(mapParam id, mapAlpha mx)
		  | E.Lift e1 => E.Lift(apply e1)
		  | E.Conv (v, mx, h, ux) => E.Conv(mapParam v, mapAlpha mx, mapParam h, mapAlpha ux)
		  | E.Partial mx => E.Partial (mapAlpha mx)
		  | E.Apply(e1, e2) => E.Apply(apply e1, apply e2)
		  | E.Probe(f, pos) => E.Probe(apply f, apply pos)
		  | E.Value _ => raise Fail "expression before expand"
		  | E.Img _ => raise Fail "expression before expand"
		  | E.Krn _ => raise Fail "expression before expand"
		  | E.Sum(c, esum) => E.Sum(mapSum c, apply esum)
		  | E.Op1(op1, e1) => E.Op1(op1, apply e1)
		  | E.Op2(op2, e1, e2) => E.Op2(op2, apply e1, apply e2)
		  | E.Opn(opn, e1) => E.Opn(opn, List.map apply e1)
		(* end case *))
          in
	    apply e
          end

 (* params subst *)
    fun rewriteParams (params, params2, place) = let
	  val beg = List.take(params, place)
	  val next = List.drop(params, place+1)
	  val params' = beg@params2@next
	  val n= length params
	  val n2 = length params2
	  val nbeg = length beg
	  val nnext = length next
	  fun createDict (0, shift1, shift2, dict) = dict
	    | createDict (n, shift1, shift2, dict) =
		createDict (n-1, shift1, shift2, IMap.insert (dict, n+shift1, n+shift2))
	  val origId = createDict (nnext, place, place+n2-1, IMap.empty)
	  val subId = createDict (n2, ~1, place-1, IMap.empty)
	  in
	    (params', origId, subId, nbeg)
	  end

  (* Looks for params id that match substitution *)
    fun apply (e1 as E.EIN{params, index, body}, place, e2) = let
	  val E.EIN{params=params2, index=index2, body=body2} = e2
	  val changed = ref false
	  val (params', origId, substId, paramShift) = rewriteParams(params, params2, place)
          val sumIndex = ref(length index)
          fun rewrite (id, mx, e) = let
                val x = !sumIndex
                in 
                  if (id = place)
		    then if (length mx = length index2)
		      then (
			changed := true;
			rewriteSubst (body2, substId, mx, paramShift, x))
		      else raise Fail "argument/parameter mismatch"
		    else (case e
		       of E.Tensor(id, mx) => E.Tensor(mapId(id, origId, 0), mx)
			| E.Field(id, mx) => E.Field(mapId(id, origId, 0), mx)
			|  _ => raise Fail "term to be replaced is not a Tensor or Fields"
		      (* end case *))
                end
	  fun sumI e = let val (v,_,_) = List.last e in v end
	  fun apply b = (case b
		 of E.Tensor(id, mx) => rewrite (id, mx, b)
		  | E.Field(id, mx) => rewrite (id, mx, b)
		  | E.Lift e1 => E.Lift(apply e1)
		  | E.Conv(v, mx, h, ux) => E.Conv(mapId(v, origId, 0), mx, mapId(h, origId, 0), ux)
		  | E.Apply(e1, e2) => E.Apply(apply e1, apply e2)
		  | E.Probe(f, pos) => E.Probe(apply f, apply pos)
		  | E.Value _ => raise Fail "expression before expand"
		  | E.Img _ => raise Fail "expression before expand"
		  | E.Krn _ => raise Fail "expression before expand"
		  | E.Sum(c, esum) => (
(* QUESTION: should we flag a change here? *)
		      sumIndex := sumI c;
		      E.Sum(c, apply esum))
		  | E.Op1(op1, e1) => E.Op1(op1, apply e1)
		  | E.Op2(op2, e1, e2) => E.Op2(op2, apply e1, apply e2)
		  | E.Opn(opn, es) => E.Opn(opn, List.map apply es)
		  | _ => b
		(* end case *))
	  val body'' = apply body
	  in
(* QUESTION: can we do the following?
	    if (! changed) then SOME(E.EIN{params=params', index=index, body=body''}) else NONE
*)
	    (!changed, E.EIN{params=params', index=index, body=body''})
	  end

    end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0