Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/low-opt/low-contract.sml
ViewVC logotype

View of /branches/vis15/src/compiler/low-opt/low-contract.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3799 - (download) (annotate)
Mon May 2 22:04:19 2016 UTC (4 years, 5 months ago) by jhr
File size: 4717 byte(s)
  added contraction for ProjectLast and TensorIndex to low-contract.sml
(* low-contract.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *
 * Contraction phase for LowIR.
 *)

structure LowContract : sig

    val transform : LowIR.program -> LowIR.program

  end = struct

    structure IR = LowIR
    structure Op = LowOps
    structure Ty = LowTypes
    structure V = IR.Var
    structure ST = Stats

  (********** Counters for statistics **********)
    val cntAddNeg               = ST.newCounter "low-contract:add-neg"
    val cntSubNeg               = ST.newCounter "low-contract:sub-neg"
    val cntSubSame              = ST.newCounter "low-contract:sub-same"
    val cntNegNeg               = ST.newCounter "low-contract:neg-neg"
    val cntIntToReal            = ST.newCounter "low-contract:int-to-real"
    val cntTensorIndex		= ST.newCounter "low-contract:tensor-index"
    val cntProjectLast		= ST.newCounter "low-contract:project-last"
    val cntUnused               = ST.newCounter "low-contract:unused"
    val firstCounter            = cntAddNeg
    val lastCounter             = cntUnused

    structure UnusedElim = UnusedElimFn (
        structure IR = IR
        val cntUnused = cntUnused)

    fun useCount (IR.V{useCnt, ...}) = !useCnt

  (* adjust a variable's use count *)
    fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
    fun decUse (IR.V{useCnt, ...}) = (useCnt := !useCnt - 1)
    fun use x = (incUse x; x)

    fun getRHSOpt x = (case V.getDef x
           of IR.OP arg => SOME arg
            | _ => NONE
          (* end case *))

(* TODO: tensor selection operations *)
    fun doAssign (lhs, IR.OP rhs) = (case rhs
           of (Op.IAdd, [a, b]) => (case getRHSOpt b
                 of SOME(Op.INeg, [c]) => (
                    (* rewrite to "a-c" *)
                      ST.tick cntAddNeg;
                      decUse b;
                      SOME[(lhs, IR.OP(Op.ISub, [a, use c]))])
                  | _ => NONE
                 (* end case *))
            | (Op.ISub, [a, b]) => if IR.Var.same(a, b)
                then ( (* rewrite to 0 *)
                  ST.tick cntSubSame;
                  decUse a; decUse b;
                  SOME[(lhs, IR.LIT(Literal.Int 0))])
                else (case getRHSOpt b
                   of SOME(Op.INeg, [c]) => (
                      (* rewrite to "a+c" *)
                        ST.tick cntSubNeg;
                        decUse b;
                        SOME[(lhs, IR.OP(Op.IAdd, [a, use c]))])
                    | _ => NONE
                  (* end case *))
            | (Op.INeg, [a]) => (case getRHSOpt a
                 of SOME(Op.INeg, [b]) => (
                    (* rewrite to "b" *)
                      ST.tick cntNegNeg;
                      decUse a;
                      SOME[(lhs, IR.VAR(use b))])
                  | _ => NONE
                (* end case *))
            | (Op.IntToReal, [a]) => (case V.getDef a
                 of IR.LIT(Literal.Int n) => (
                      (* rerite to a real literal *)
                        ST.tick cntIntToReal;
                        decUse a;
                        SOME[(lhs, IR.LIT(Literal.Real(RealLit.fromInt n)))])
                  | _ => NONE
                (* end case *))
	    | (Op.TensorIndex(Ty.TensorTy dims, idxs), [t]) => let
		fun get ([], [], x) = (
		      SOME[(lhs, IR.VAR(use x))])
		  | get (ix::ixs, d::ds, x) = (case V.getDef x
		       of IR.CONS(ys, _) => get(ixs, ds, List.nth(ys, ix))
			| _ => SOME[(lhs, IR.OP(Op.TensorIndex(Ty.tensorTy ds, ix::ixs), [use x]))]
		      (* end case *))
		  | get _ = raise Fail "malformed TensorIndex"
		in
		  case V.getDef t
		   of IR.CONS _ => (ST.tick cntTensorIndex; decUse t; get(idxs, dims, t))
		    | _ => NONE
		  (* end case *)
		end
	    | (Op.ProjectLast(Ty.TensorTy dims, idxs), [t]) => let
		fun get ([], [_], x) = (
		      SOME[(lhs, IR.VAR(use x))])
		  | get (ix::ixs, d::ds, x) = (case V.getDef x
		       of IR.CONS(ys, _) => get(ixs, ds, List.nth(ys, ix))
			| _ => SOME[(lhs, IR.OP(Op.ProjectLast(Ty.tensorTy ds, ix::ixs), [use x]))]
		      (* end case *))
		  | get _ = raise Fail "malformed ProjectLast"
		in
		  case V.getDef t
		   of IR.CONS _ => (ST.tick cntProjectLast; decUse t; get(idxs, dims, t))
		    | _ => NONE
		  (* end case *)
		end
            | _ => NONE
          (* end case *))
      | doAssign _ = NONE

    fun doMAssign _ = NONE

    structure Rewrite = RewriteFn (
      struct
        structure IR = IR
        val doAssign = doAssign
        val doMAssign = doMAssign
        val elimUnusedVars = UnusedElim.reduce
      end)

    val transform = Rewrite.transform

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0