Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/charisee_dev/src/compiler/high-to-mid/float-ein.sml
ViewVC logotype

View of /branches/charisee_dev/src/compiler/high-to-mid/float-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3673 - (download) (annotate)
Thu Feb 11 20:03:26 2016 UTC (3 years, 6 months ago) by cchiw
File size: 6011 byte(s)
ASF
(* float-ein.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *)

structure FloatEin : sig

   val transform : MidIL.var * (*Ein.ein * MidIL.var list*) MidIL.rhs -> MidIL.assignment list

  end = struct

    structure IR = MidIL
    structure V = IR.Var
    structure Ty = MidILTypes
    structure E = Ein
structure cleanP=cleanParams
structure cleanI=cleanIndex
fun mkEin e=Ein.mkEin e
fun mkEinApp(rator,args)=IR.EINAPP(rator,args)
fun cleanParams e = cleanP.cleanParams e
fun cleanIndex e = cleanI.cleanIndex e

    fun cut (name, origProbe, params, index, sx, argsOrig, avail, newvx) = let
        (*clean and rewrite current body*)
	  val (tshape, sizes, body) = cleanIndex.cleanIndex(origProbe, index, sx)
	  val id = length params
	  val Rparams = params@[E.TEN(1, sizes)]
	  val M = V.new (concat[name, "_l_", Int.toString id], Ty.TensorTy sizes)
	  val (y, IR.EINAPP(ein, args)) = cleanParams(M, body, Rparams, sizes, argsOrig@[M])
        (* shift indices in probe body from constant to variable *)
	  val Ein.EIN{
		  body=E.Probe(E.Conv(V, [c1], h, dx), pos),
		  index = index0,
		  params = params0
		} = ein
(* FIXME: this code is specialized to 3D *)
	  val index1 = index0@[3]
	  val unshiftedBody = E.Probe(E.Conv(V, [E.V newvx], h, dx), pos)
        (* clean to get body indices in order *)
	  val (_ , _, body1) = cleanIndex(unshiftedBody, index1, [])
	  val lhs1 = V.new ("L", Ty.TensorTy index1)
	  val ein1 = mkEin(params0, index1, body1)
	  val lhs2 = AvailRHS.addAssign avail (lhs1, mkEinApp(ein1, args))
	  val Rargs = argsOrig @ [lhs2]
        (*Probe that tensor at a constant position  c1*)

	  val Re = E.Tensor(id, c1 :: tshape)
	  val Rparams = params @ [E.TEN(1, index1)]
	  in
	    (Re, Rparams, Rargs)
	  end

    (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
    *lifts expression and returns replacement tensor
    * cleans the index and params of subexpression
    *creates new param and replacement tensor for the original ein_exp
    *)
    fun lift (name, e, params, index, sx, args, avail) = let
	  val (tshape, sizes, body) = cleanIndex(e, index, sx)
	  val id = length params
	  val Rparams = params @ [E.TEN(1, sizes)]
	  val Re = E.Tensor(id, tshape)
	  val M = V.new (concat[name, "_l_", Int.toString id], Ty.TensorTy sizes)
	  val (_, einapp) = cleanParams.cleanParams(M, body, Rparams, sizes, args @ [M])
	  val var = AvailRHS.addAssign avail (M, einapp)
	  val Rargs = args @ [var]
	  in
	    (Re, Rparams, Rargs)
	  end

    fun isOp e = (case e
	  of  E.Op1 _    => true
	   | E.Op2 _    => true
	   | E.Opn _    => true
	   | E.Sum _    => true
	   | E.Probe _  => true
	   |  _          => false
	 (* end case *))

    fun transform (y, IR.EINAPP(ein as Ein.EIN{body=E.Probe _, ...}, args)) =
	  [IR.ASSGN(y, IR.EINAPP(ein, args))]
      | transform (y, IR.EINAPP(ein as Ein.EIN{body=E.Sum(_, E.Probe _), ...}, args)) =
	  [IR.ASSGN(y, IR.EINAPP(ein, args))]
      | transform (y, IR.EINAPP(Ein.EIN{params, index, body}, args)) = let
	  val avail = AvailRHS.new()
       
	  fun filterOps (es, params, args, index, sx) = let
		fun filter ([], es', params, args) = (rev es', params, args)
		  | filter (e::es, es', params, args) = if isOp e
		      then let
			val (e', params', args') = lift("op1_e3", e, params, index, sx, args, avail)
			in
			  filter (es, e'::es', params', args')
			end
		      else filter (es, e::es', params, args)
		in
		  filter (es, [], params, args)
		end
	  fun rewrite (sx, exp, params, args) = (case exp
		 of E.Probe(E.Conv(_, [E.C _], _, []), _) =>
		      cut ("cut", exp, params, index, sx, args, avail, 0)
		  | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0]), _) =>
		      cut ("cut", exp, params, index, sx, args, avail, 1)
		  | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0, E.V 1]), _) =>
		      cut ("cut", exp, params, index, sx, args, avail, 2)
		  | E.Probe(E.Conv(_, [E.C _ ], _, [E.V 0, E.V 1, E.V 2]), _) =>
		      cut ("cut", exp, params, index, sx, args, avail, 3)
		  | E.Probe _ => lift ("probe", exp, params, index, sx, args, avail)
		  | E.Sum(_, E.Probe _) => lift ("probe", exp, params, index, sx, args, avail)
(*
            E.Probe _ => let
            val (params',body',args',code) = ProbeEin.replaceProbeF(params,index,sx,exp,args)
            val _ =List.map (fn e=> AvailRHS.addAssignNoSearch  avail e) code
            in
                (body',params',args')
            end
*)
		  | E.Op1(op1, e1) => let
		      val (e1', params', args') = rewrite (sx, e1, params, args)
		      val ([e1], params', args') = filterOps ([e1'], params', args', index, sx)
		      in
			(E.Op1(op1, e1), params', args')
		      end
		  | E.Op2(op2, e1, e2) => let
		      val (e1', params', args') = rewrite (sx, e1, params, args)
		      val (e2', params', args') = rewrite (sx, e2, params', args')
		      val ([e1', e2'], params', args') =
			    filterOps ([e1', e2'], params', args', index, sx)
		      in
			(E.Op2(op2, e1', e2'), params', args')
		      end
		  | E.Opn(opn, es) => let
		      fun iter ([], es, params, args) = (List.rev es, params, args)
			| iter (e::es, es', params, args) = let
			    val (e', params', args') = rewrite (sx, e, params, args)
			    in
			      iter (es, e'::es', params', args')
			    end
		      val (es, params, args) = iter (es, [], params, args)
		      val (es, params, args) = filterOps (es, params, args, index, sx)
		      in
			(E.Opn(opn, es), params, args)
		      end
		  | E.Sum(sx1, e) => let
		      val (e', params', args') = rewrite (sx1@sx, e, params, args)
		      in
			(E.Sum(sx1, e'), params', args')
		      end
		  | _ => (exp, params, args)
		(* end case *))
	  val (body', params', args') = rewrite ([], body, params, args)
          val einapp = cleanParams.cleanParams (y, body', params', index, args')
        val c=IR.ASSGN einapp
	  in
	    List.rev (c :: AvailRHS.getAssignments avail)
	  end

  end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0