Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml
ViewVC logotype

View of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3666 - (download) (annotate)
Sat Feb 6 16:45:48 2016 UTC (3 years, 6 months ago) by jhr
File size: 9707 byte(s)
working on merge
(* ein-to-low.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *)

(*
* genfn-Does preliminary scan of the body of EIN.EIN for vectorization potential
* If there is a field then passes to FieldToLow
* If there is a tensor then passes to handle*() functions to check if indices match
* i.e. <A_ij+B_ij>_ij vs.<A_ji+B_ij>_ij
*
*     (1) If indices match then passes to Iter->ToVec functions.
*            Creates LowIR vector operators.
*     (2) Iter->ScaToLow
*           Creates Low-IL scalar operators
* Note. The Iter function creates LowIR.CONS and therefore binds the indices in the EIN.body
*)

structure EinToLow : sig

    val expand : LowIR.var * Ein.ein * LowIR.var list -> LowIR.assignment list

  end = struct

    structure Var = LowIR.Var
    structure E = Ein
    structure Op = LowOps
    structure Mk = MkLowIR
    structure ToVec = EinToVector

  (* `dropIndex alpha` returns the (len, i, alpha') where
   * len = length(alpha') and alpha = alpha'@[i].
   *)
    fun dropIndex alpha = let
	  fun drop ([], _, _) = raise Fail "dropIndex[]"
	    | drop ([idx], n, idxs') = (n, idx, List.rev idxs')
	    | drop (idx::idxs, n, idx') = drop (idxs, n+1, idx::idx')
	  in
	    drop (alpha, 0, [])
	  end

    (*matchLast:E.alpha*int -> (E.alpha) Option
    * Is the last index of alpha E.V n.
    * If so, return the rest of the list
    *)
    fun matchLast (alpha, n) = (case List.rev alpha
	   of (E.V v)::es => if (n = v) then SOME(List.rev es) else NONE
            | _ => NONE
	  (* end case *))

    (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
    * Is the last index of alpha = n.
    * is n anywhere else?
    *)
    fun matchFindLast (alpha, n) = let
	  fun find es = List.find (fn (E.V idx') => (n = idx') | _ => false) es
	  in
	    case List.rev alpha
	     of (E.V v)::es => if (n = v)
		  then (SOME(List.rev es), find es)
		  else (NONE, find es)
	      | _::es => (NONE, find es)
	      | [] => (NONE, NONE)
	    (* end case *)
	  end

  (* unroll the body of an Ein expression.  The arguments are
   *   avail  -- available RHS forms (for redundant comnputation elimination)
   *   shape  -- the shape of the tensor computed by the expression
   *   index  -- the shape of the iteration structure
   *   bodyFn -- the function for generating the body
   *   args   -- additional arguments to pass to bodyFn
   *)
    fun unroll (avail, shape, index, bodyFn, args) = let
	  in
raise Fail "unimplemented"
	  end

  (* in the general case, we expand the body to scalar code *)
    fun scalarExpand (params, body, index, args) =
	  unroll (AvailRHS.new(), index, index, EinToScalar.expand, (params, body, args))

    fun createP (args, vecIndex, id, ix) =
	  ToVec.Param{id = id, arg = List.nth(args, id), ix = ix, kind = ToVec.Proj vecIndex}
    fun createI (args, id, ix) =
	  ToVec.Param{id = id, arg = List.nth(args, id), ix = ix, kind = ToVec.Indx}

  (* generate low-IL code for scaling a non-scalar tensor; `sId` is the scalar
   * parameter's ID and `vId` is the tensor parameter's ID.
   *)
    fun expandScale (sId, vId, shape, params, body, index, args) = let
          val (n, vecIX, index') = dropIndex index
          in
	    case matchLast(shape, n)
             of SOME ix => let
		  val vecA = createI (args, sId, [])
		  val vecB = createP (args, vecIX, vId, ix)
		  in
		    unroll (
		      AvailRHS.new(), index, index',
		      ToVec.op2, (vecIX, Op.VScale vecIX, vecA, vecB))
		  end
	      | _ => scalarExpand (params, body, index, args)
	    (* end case *)
          end

  (* handle potential sum-of-products (i.e., inner products); otherwise fall back to the
   * general scalar case.
   *)
    fun expandInner (params, body, index, args) = (case body
	   of E.Sum(
		[(E.V v, _, ub)],
		E.Opn(E.Prod, [E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)])
	      ) => (case (matchFindLast(alpha, v), matchFindLast(beta, v))
		 of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
		    (* v is the last index of alpha, beta and nowhere else *)
		      val avail = AvailRHS.new()
		      val vecIX= ub+1
		      val vecA = createP (args, vecIX, id1, ix1)
		      val vecB = createP (args, vecIX, id2, ix2)
		      in
			unroll (avail, index, index, ToVec.dotV, (vecIX, vecA, vecB))
		      end
		  | _ => scalarExpand (params, body, index, args)
		(* end case *))
	    | E.Sum(
		[(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],
		E.Opn(E.Prod, [E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)])
	      ) => let
		fun check (v, ub, sx) = (case (matchFindLast(alpha, v), matchFindLast(beta, v))
		       of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
			  (* v is the last index of alpha, beta and nowhere else *)
			    val vecIX = ub+1
			    val vecA = createP (args, vecIX, id1, ix1)
			    val vecB = createP (args, vecIX, id2, ix2)
			    in
			      SOME(unroll (
				  AvailRHS.new(), index, index,
				  ToVec.sumDotV, (sx, vecIX, vecA, vecB)))
			    end
			| _ => NONE
		      (* end case *))
		in
		  case check(v1, ub1, (E.V v2, lb2, ub2))
		   of SOME e =>e
		    | _ => (case check(v2, ub2, (E.V v1, lb1, ub1))
			 of SOME e => e
			  | _ => scalarExpand (params, body, index, args)
		      (* end case *))
		  (* end case *)
		end
	    |  _ => scalarExpand (params, body, index, args)
	  (* end case *))

  (* expand an Ein expression that has a non-scalar result *)
    fun nonScalar (params, body, index, args) = (case body
	   of E.Op2(E.Sub, E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)) => let
		val (n, vecIX, index') = dropIndex index
		in
		  case (matchLast(alpha, n), matchLast(beta, n))
		   of (SOME ix1, SOME ix2) => let
			val vecA = createP (args, vecIX, id1, ix1)
			val vecB = createP (rgs, vecIX, id2, ix2)
			in
			  unroll (
			    AvailRHS.new(), index, index',
			    ToVec.op2, (vecIX, Op.VSub vecIX, vecA, vecB))
			end
		    | _  => scalarExpand (params, body, index, args)
		  (* end case *)
		end
	    | E.Opn(E.Add, es as E.Tensor(_, _::_)::_) => let
		val (n, vecIX, index') = dropIndex index
(* QUESTION: what does the following comment mean?  What do we do if each tensor has matching indices? *)
	      (* check that each tensor in addition list has matching indices *)
		fun sample ([], rest) =
		      unroll (AvailRHS.new(), index, index', ToVec.addV, (vecIX, List.rev rest))
		  | sample (E.Tensor(id1, alpha)::ts, rest) = (case matchLast(alpha, n)
		       of SOME ix1 => sample(ts, createP (args, vecIX, id1, ix1)::rest)
			| _ => scalarExpand (params, body, index, args)
		      (* end case *))
		  | sample _ = scalarExpand (params, body, index, args)
		in
		  sample (es, [])
		end
	    | E.Op1(E.Neg, E.Tensor(id, alpha as (_::_))) => let
		val (n, vecIX, index') = dropIndex index
		in
		  case matchLast (alpha, n)
		   of SOME ix1 => let
			val avail = AvailRHS.new()
			val vA = Mk.intToRealLit (avail, ~1)
			val vecB = createP (args, vecIX, id, ix1)
			in
			  unroll (avail, index, index', ToVec.negV, (vA, vecIX, vecB))
			end
		    | _ => scalarExpand (params, body, index, args)
		  (* end case *)
		end
	    | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, shp as _::_)]) =>
		expandScale (s, v, shp, params, body, index, args)
	    | E.Opn(E.Prod, [E.Tensor(v, shp as _::_), E.Tensor(s , [])]) =>
		expandScale (s, v, shp, params, body, index, args)
	    | E.Opn(E.Prod, [E.Tensor(id1 , alpha as _::_), E.Tensor(id2, beta as _::_)]) => let
		val (n, vecIX, index') = dropIndex index
		val avail = AvailRHS.new()
		in
		  case (matchFindLast(alpha, n), matchFindLast(beta, n))
		   of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
		      (* n is the last index of alpha, beta and nowhere else, possible modulate *)
			val vecA = createP (args, vecIX, id1, ix1)
			val vecB = createP (args, vecIX, id2, ix2)
			in
			  unroll (
			    avail, index, index',
			    ToVec.op2, (vecIX, Op.VMul vecIX, vecA, vecB))
			end
		    | ((NONE, NONE), (SOME ix2, NONE)) => let
		      (* n is the last index of beta and nowhere else, possible scaleVector *)
			val vecA = createI (args, id1, alpha)
			val vecB = createP (args, vecIX, id2, ix2)
			in
			  unroll (
			    avail, index, index',
			    ToVec.op2, (vecIX, Op.VScale vecIX, vecA, vecB))
			end
		    | ((SOME ix1, NONE), (NONE, NONE)) => let
		      (* n is the last index of alpha and nowhere else, ossile scaleVector *)
			val vecA = createI (args, id2, beta)
			val vecB = createP (args, vecIX, id1, ix1)
			val nextfnargs = (vecIX, Op.VScale vecIX, vecA, vecB)
			in
			  unroll (
			    avail, index, index',
			    ToVec.op2, (vecIX, Op.VScale vecIX, vecA, vecB))
			end
		    | _ => scalarExpand (params, body, index, args)
		  (* end case *)
		end
	    |  _ => expandInner (params, body, index, args)
	  (* end case *))

  (* scan:var*E.Ein*Var list * Var list-> Var*LowIR.Assgn list
   * scans body  for vectorization potential
   *)
    fun expand (y, Ein.EIN{params, index, body}, args : LowIR.var list) = let
(* FIXME: We really only care if the last dimension of a non-scalar shape is "3", since it
 * causes poor performance in the generated code.  Really, this should be handled by the LowIR
 * to TreeIR transform, where we take into account machine vector widths.
 *)
          val (avail, _) = (case index
		 of [3, 3] => scalarExpand (params, body, index, args)
		  | [3, 3, 3] => scalarExpand (params, body, index, args)
		  | _::_ => nonScalar (params, body, index, args)
		  | _ => expandInner (params, body, index, args)
		(* end case *))
	  val (_, asgn) :: rest = AvailRHS.getAssignments avail
	  in
	    List.revMap LowIR.ASSGN ((y, asgn)::rest)
	  end

    end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0