Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml
ViewVC logotype

View of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3641 - (download) (annotate)
Mon Feb 1 03:53:00 2016 UTC (3 years, 9 months ago) by jhr
File size: 10389 byte(s)
principleEvec
(* ein-to-low.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *)

(*
* genfn-Does preliminary scan of the body of EIN.EIN for vectorization potential
* If there is a field then passes to FieldToLow
* If there is a tensor then passes to handle*() functions to check if indices match
* i.e. <A_ij+B_ij>_ij vs.<A_ji+B_ij>_ij
*
*     (1) If indices match then passes to Iter->VecToLow functions.
*            Creates LowIR vector operators.
*     (2) Iter->ScaToLow
*           Creates Low-IL scalar operators
* Note. The Iter function creates LowIR.CONS and therefore binds the indices in the EIN.body
*)

structure EinToLow : sig

    val expand : LowIR.var * Ein.ein * LowIR.var list -> LowIR.assignment list

  end = struct

    structure Var = LowIR.Var
    structure E = Ein
    structure H = Helper
    structure Op = LowOps

    fun iter e = Iter.prodIter e
    fun intToReal n = H.intToReal n

  (* `dropIndex alpha` returns the (len, i, alpha') where
   * len = length(alpha') and alpha = alpha'@[i].
   *)
    fun dropIndex alpha = let
	  fun drop ([], _, _) = raise Fail "dropIndex[]"
	    | drop ([idx], n, idxs') = (n, idx, List.rev idxs')
	    | drop (idx::idxs, n, idx') = drop (idxs, n+1, idx::idx')
	  in
	    drop (alpha, 0, [])
	  end

    (*matchLast:E.alpha*int -> (E.alpha) Option
    * Is the last index of alpha E.V n.
    * If so, return the rest of the list
    *)
    fun matchLast (alpha, n) = (case List.rev alpha
	   of (E.V v)::es => if (n = v) then SOME(List.rev es) else NONE
            | _ => NONE
	  (* end case *))

    (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
    * Is the last index of alpha = n.
    * is n anywhere else?
    *)
    fun matchFindLast (alpha, n) = let
	  fun find es = List.find (fn (E.V idx') => (n = idx') | _ => false) es
	  in
	    case List.rev alpha
	     of (E.V v)::es => if (n = v)
		  then (SOME(List.rev es), find es)
		  else (NONE, find es)
	      | _::es => (NONE, find es)
	      | [] => (NONE, NONE)
	    (* end case *)
	  end

    (*runGeneralCase:Var*E.EIN*Var-> Var*LowIR.ASSN list
    * does not do vector projections
    * instead approach like a general EIN
    *)
    fun runGeneralCase (index, args:LowIR.var list) =
	  iter (AvailRHS.new(), index, index, ScaToLow.generalfn, e, args)

    fun createP (params, args, vecIndex, id, ix) =
          VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Proj vecIndex)

    fun createI (params, args, id, ix) =
          VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Indx)

  (* generate low-IL code for scaling a non-scalar tensor *)
    fun handleScale (id1, id2, alpha2, params, index) = let
          val (n, vecIX, index') = dropIndex index
          in
	    case matchLast(alpha2, n)
             of SOME ix2 => let
		  val avail = AvailRHS.new()
		  val vecA = createI (params, args, id1, [])
		  val vecB = createP (params, args, vecIX, id2, ix2)
		  val nextfnargs = (vecIX, Op.VScale vecIX, vecA, vecB)
		  in
		    iter (avail, index, index', VecToLow.op2, nextfnargs)
		  end
	      | _ => runGeneralCase info
	    (* end case *)
          end

    (*handleSumProd:E.body*int list*info ->Var*LowIR.ASSN list
    * info:(string*E.EIN*Var list)
    * low-IL code for dot product
    *)
    fun handleSumProd1 (E.Sum([(E.V v, _, ub)], E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])), index, info) =
        case(matchFindLast(alpha, v), matchFindLast(beta, v))
            of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
                (*v is the last index of alpha, beta and nowhere else, possible sumProd*)
                val avail = AvailRHS.new()
                val vecIX= ub+1
                val vecA = createP(params, args, vecIX, id1, ix1)
                val vecB = createP(params, args, vecIX, id2, ix2)
                val nextfnargs = (vecIX, vecA, vecB)
                in
                    iter(avail, index, index, VecToLow.dotV, nextfnargs)
                end
        | _ => runGeneralCase info


    (*handleSumProd:E.body*int list*info ->Var*LowIR.ASSN list
    * info:(string*E.EIN*Var list)
    * low-IL code for double dot product
    * Sigma_{i, j} A_ij B_ij
    *)
    fun handleSumProd2 (params, body, index, args) = let
          val E.Sum(
		[(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],
		E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])
	      ) = body
          fun check (v, ub, sx) = (case (matchFindLast(alpha, v), matchFindLast(beta, v))
            of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
                (*v is the last index of alpha, beta and nowhere else, possible sumProd*)
                val avail = AvailRHS.new()
                (*val nextfnargs = (Ein.params e, args, sx, ub+1, id1, ix1, id2, ix2)*)
                val vecIX= ub+1
                val vecA = createP(params, args, vecIX, id1, ix1)
                val vecB = createP(params, args, vecIX, id2, ix2)
                val nextfnargs = (sx, vecIX, vecA, vecB)
                in
                    SOME(iter(avail, index, index, VecToLow.sumDotV, nextfnargs))
                end
            | _ => NONE
            (* end case *))
        in (case check(v1, ub1, (E.V v2, lb2, ub2))
            of SOME e =>e
            | _ => (case check(v2, ub2, (E.V v1, lb1, ub1))
                of SOME e => e
                |_ => runGeneralCase info
                (* end case *))
            (* end case *))
        end

  (* expand an Ein expression that has a non-scalar result *)
    fun nonScalar (params, body, index, args) = (case body
	   of E.Op2(E.Sub, E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)) => let
		val (n, vecIX, index') = dropIndex index
		in
		  case (matchLast(alpha, n), matchLast(beta, n))
		   of (SOME ix1, SOME ix2) => let
			val vecA = createP (params, args, vecIX, id1, ix1)
			val vecB = createP (params, args, vecIX, id2, ix2)
			val nextfnargs = (vecIX, Op.VSub vecIX, vecA, vecB)
			in
			  iter (AvailRHS.new(), index, index', VecToLow.op2, nextfnargs)
			end
		    | _  => runGeneralCase (index, args)
		  (* end case *)
		end
	    | E.Opn(E.Add, es as E.Tensor(_, _::_)::_) => let
		val (n, vecIX, index') = dropIndex index
(* QUESTION: what does the following comment mean?  What do we do if each tensor has matching indices? *)
	      (* check that each tensor in addition list has matching indices *)
		fun sample ([], rest) = let
		      val nextfnargs = (vecIX, List.rev rest)
		      in
			iter (AvailRHS.new(), index, index', VecToLow.addV, nextfnargs)
		      end
		  | sample (E.Tensor(id1, alpha)::ts, rest) = (case matchLast(alpha, n)
		       of SOME ix1 => sample(ts, createP(params, args, vecIX, id1, ix1)::rest)
			| _ => runGeneralCase (index, args)
		      (* end case *))
		  | sample _ = runGeneralCase (index, args)
		in
		  sample es
		end
	    | E.Op1(E.Neg, E.Tensor(id, alpha as (_::_))) => let
		val (n, vecIX, index') = dropIndex index
		in
		  case matchLast (alpha, n)
		   of SOME ix1 => let
			val avail = AvailRHS.new()
			val (avail, vA) = intToReal(avail, ~1)
			val vecB = createP(params, args, vecIX, id, ix1)
			val nextfnargs = (vA, vecIX, vecB)
			in
			  iter (avail, index, index', VecToLow.negV, nextfnargs)
			end
		    | _ => runGeneralCase (index, args)
		  (* end case *)
		end
	    | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, shp as _::_)]) =>
		handleScale (s, v, shp, index, params, index)
	    | E.Opn(E.Prod, [E.Tensor(v, shp as _::_), E.Tensor(s , [])]) =>
		handleScale (s, v, j::jx, index, params, index)
	    | E.Opn(E.Prod, [E.Tensor(id1 , alpha as _::_), E.Tensor(id2, beta as _::_)]) => let
		val (n, vecIX, index') = dropIndex index
		val avail = AvailRHS.new()
		in
		  case (matchFindLast(alpha, n), matchFindLast(beta, n))
		   of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
		      (* n is the last index of alpha, beta and nowhere else, possible modulate *)
			val vecA = createP(params, args, vecIX, id1, ix1)
			val vecB = createP(params, args, vecIX, id2, ix2)
			val nextfnargs = (vecIX, Op.VMul vecIX, vecA, vecB)
			in
			  iter (avail, index, index', VecToLow.op2, nextfnargs)
			end
		    | ((NONE, NONE), (SOME ix2, NONE)) => let
		      (* n is the last index of beta and nowhere else, possible scaleVector *)
			val vecA = createI(params, args, id1, alpha)
			val vecB = createP(params, args, vecIX, id2, ix2)
			val nextfnargs = (vecIX, Op.VScale vecIX, vecA, vecB)
			in
			  iter (avail, index, index', VecToLow.op2 , nextfnargs)
			end
		    | ((SOME ix1, NONE), (NONE, NONE)) => let
		      (* n is the last index of alpha and nowhere else, ossile scaleVector *)
			val vecA = createI(params, args, id2, beta)
			val vecB = createP(params, args, vecIX, id1, ix1)
			val nextfnargs = (vecIX, Op.VScale vecIX, vecA, vecB)
			in
			  iter (avail, index, index', VecToLow.op2, nextfnargs)
			end
		    | _ => runGeneralCase (index, args)
		  (* end case *)
		end
(* QUESTION: since we are guaranteed a non-scalar result, the only case in gen that can
 * apply is the default case; right?
 *)
(*
	    |  _ => gen()
*)
	    | _ => runGeneralCase (index, args)
	  (* end case *))

  (* scan:var*E.Ein*Var list * Var list-> Var*LowIR.Assgn list
   * scans body  for vectorization potential
   *)
    fun expand (y, Ein.EIN{params, index, body}, args : LowIR.var list) = let
	  val info = (e, args)
	  val all = (b, index, info)
        (* any result type *)
	  fun gen () = (case body
	         of E.Sum([_], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
		      handleSumProd1 all
	          | E.Sum([_, _], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
		      handleSumProd2 all
	          |  _ => runGeneralCase info
		(* end case *))
(* QUESTION: what is special about "3" here? *)
          val (avail, _) = (case index
		 of [3, 3] => runGeneralCase (index, args)
		  | [3, 3, 3] => runGeneralCase (index, args)
		  | _::_ => nonScalar (params, body, index, args)
(* TODO: inline gen here, once we've checked that it isn't required in nonScalar. *)
		  | _ => gen ()
		(* end case *))
	  val (x, asgn) :: rest = AvailRHS.getAssignments avail
	  in
(* QUESTION: should we have A twice here? *)
	    List.revMap LowIR.ASSGN ((y, A)::(x, A)::rest)
	  end

    end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0