Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3785 - (download) (annotate)
Wed Apr 27 20:24:39 2016 UTC (3 years, 4 months ago) by cchiw
File size: 13393 byte(s)
add IR.SEQ wrapper to probe-ein
(* probe-ein.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *)

structure ProbeEin : sig

    val expand : AvailRHS.t -> MidIR.var * MidIR.rhs -> unit

  end = struct

    structure IR = MidIR
    structure Op = MidOps
    structure V = IR.Var
    structure Ty = MidTypes
    structure E = Ein
    structure T = CoordSpaceTransform

   (* This file expands probed fields
    * Take a look at ProbeEin tex file for examples
    * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )
    * Param_ids are used to note the placement of the argument in the midIR.var list
    * Index_ids  keep track of the shape of an Image or differentiation.
    * Mu  bind Index_id
    * Generally, we will refer to the following
    * dim:dimension of field V
    * s: support of kernel H
    * alpha: The alpha in <V_alpha * H^(deltas)>
    * dx: The dx in <V_alpha * nabla_dx H>
    * deltas: The deltas in <V_alpha * h^(deltas) h^(deltas)>
    * Vid:param_id for V
    * hid:param_id for H
    * nid: integer position param_id
    * fid :fractional position param_id
    * img-imginfo about V  
    *)

    fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
    fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
    fun getRHSDst x = (case IR.Var.getDef x
           of IR.OP(rator, args) => (rator, args)
	    | rhs => raise Fail(concat[
		  "expected rhs operator for ", IR.Var.toString x,
		  " but found ", IR.RHS.toString rhs
		])
          (* end case *))

    fun checkImg imgArg = (case IR.Var.getDef imgArg
	   of IR.OP(Op.LoadImage _, _) => imgArg
	    | rhs => raise Fail (String.concat[
		  "expected image for ", IR.Var.toString imgArg,
		  " but found ", IR.RHS.toString rhs
		])
        )
    
    fun getImagInfo e = (case IR.Var.getDef e
        of IR.OP(Op.LoadImage(Ty.ImageTy info, _), []) => (e, info, NONE)
        | IR.OP(Op.BorderCtlDefault info, [imgArg])    => (imgArg, info, raise Fail "Default boarder control")
        | IR.OP(Op.BorderCtlClamp info, [imgArg])      => (imgArg, info, SOME IndexCtl.Clamp)
        | IR.OP(Op.BorderCtlMirror info, [imgArg])     => (imgArg, info, SOME IndexCtl.Mirror)
        | IR.OP(Op.BorderCtlWrap info, [imgArg])       => (imgArg, info, SOME IndexCtl.Wrap)
        | rhs => raise Fail (String.concat[
        "expected image for ", IR.Var.toString e,
        " but found ", IR.RHS.toString rhs
        ])
        (* end case *))

    fun getKernelDst hArg = (case IR.Var.getDef hArg
	   of IR.OP(Op.Kernel(h, _), _) => h
	    | rhs => raise Fail (String.concat[
		  "expected kernel for ", IR.Var.toString hArg,
		  " but found ", IR.RHS.toString rhs
		])
	  (* end case *))

    (* handleArgs- returns image arguments, info, and border
        * uses the Param_ids for the image, kernel, and tensor
        * and gets the mid-IR vars for each.
        * Transforms the position to index space
        * P is the mid-il var for the (transformation matrix)transpose
    *)
    fun handleArgs (avail, Vid, hid, tid, args) = let
	  val vI = List.nth (args, Vid)
	  val (vI, info, border) = getImagInfo vI
      val vH = List.nth(args, hid)
	  val (vN, vF, vP) = T.worldToIndex{
		  avail = avail, info = info, img = vI, pos = List.nth(args, tid)
		}
      val dim = ImageInfo.dim info
	  in
	    (vI, vH, vN, vF, vP, info, border, dim)
	  end

  (*lifted Kernel expressions
  args are axis, ein index_ids that represent differentiation,  image dimension, kernel, fractional position, support
  *)
    fun liftKrn (avail, dir, dx, dim, h, vF, s) = let
        val range = 2*s
        
        (* build position vector for EvalKernel *)
        val vX =
            if (dim=1) then vF   (* position is a real type*)
            else AvailRHS.addAssign (avail, "vxindexed_dir"^Int.toString(dir)^"_", Ty.realTy, IR.OP(Op.Index(Ty.TensorTy[dim], dir), [vF]))
            
        val vPos =  AvailRHS.addAssign (avail, "kernelpos_dir"^Int.toString(dir)^"_", Ty.TensorTy[range], IR.OP(Op.BuildPos s, [vX]))

        val nKernEvals = List.length dx + 1
        fun mkEval k = AvailRHS.addAssign (avail, "mkeval_dir"^Int.toString(dir)^"_del"^Int.toString k,
                Ty.TensorTy[range], IR.OP(Op.EvalKernel(range, h, k), [vPos]))
        val vKs = List.tabulate(nKernEvals, (fn k => mkEval k))
        in
          case vKs
           of [v] => v (* scalar result *)
            | _ => let
            val consTy = Ty.TensorTy[nKernEvals, range]
            in
                AvailRHS.addAssign (avail, "kernelCons_dir_"^Int.toString(dir),  consTy, IR.CONS(vKs, consTy))
            end
          (* end case *)
        end


    fun mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border) = let
       (* creates lb int *)
       val vLb = AvailRHS.addAssign (avail, "lit", Ty.intTy,  IR.LIT(Literal.Int (1-(IntInf.fromInt s))))
       
       (*created n_0 +lb, n_1+lb*)
       fun f i =
        let
            val vA = AvailRHS.addAssign (avail, "lit", Ty.intTy,  IR.LIT(Literal.Int (IntInf.fromInt  i)))
            val vB = AvailRHS.addAssign (avail, "subscript", Ty.intTy, IR.OP(Op.Subscript(Ty.TensorTy[dim]), [vN, vA]))
            in
                AvailRHS.addAssign (avail, "add", Ty.intTy, IR.OP(Op.IAdd, [vB, vLb]))
            end
            
        (* image positions *)
        val s'= 2*s
        val supportshape =  List.tabulate(dim, fn _ => s')
        val ldty = Ty.TensorTy (shape@supportshape)
        val vNs = List.tabulate( dim, fn n => f n)
        val vSq = AvailRHS.addAssign (avail, "seq", Ty.TensorTy[9], IR.SEQ(vNs, MidTypes.SeqTy(MidTypes.IntTy, SOME dim)))
              
        val op1 = (case border
            of NONE => Op.LoadVoxels (info, s)
            | SOME b =>  Op.LoadVoxelsWithCtl (info, s, b)
            (* end case *))
        in
            AvailRHS.addAssign (avail, "ldvox", ldty, IR.OP(op1, [vI, vSq]))
        end
        
    
    
    (*fieldReconstruction expands the body for the probed field*)
      fun fieldReconstruction (avail, sx, alpha, shape, dx,  Vid, Vidnew, kid, hid, tid, args) = let
        val  (vI, vH, vN, vF, vP, info, border, dim) = handleArgs (avail, Vid, hid, tid, args)
        val h = getKernelDst vH
        val s = Kernel.support h
        
        (* creating summation Index *)
        val vs = List.tabulate (dim, fn i => (i +sx))
        val esum = List.map (fn i => (E.V i, 1-s, s)) vs
   
        (*represent image in ein expression with tensor*)
        val imgexp= E.Img(Vidnew, alpha, List.map (fn i=> E.Value i)  vs, s, E.None)
        (*val imgexp = E.Tensor (Vidnew, alpha@(List.map (fn i => E.V i) vs))*)
        
        (* create load voxel operator for image *)
        val vLd = mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border)

        (* create kernel body *)
        fun createKrn (0,  krnexp, vAs) = (krnexp, vAs)
          | createKrn (dir, krnexp, vAs) = let
            val dir' = dir-1
            (* ein expression *)
            val deltas = List.map (fn e =>(E.C(dir'), e)) dx
            val kexp0 = E.Krn(kid+dir, deltas, dir)
            (* evalkernel operators *)
            val vA = liftKrn (avail, dir, dx, dim, h, vF, s)
        in
            createKrn (dir', kexp0::krnexp, vA::vAs)
        end

      (* final ein expression body to represent field reconstruction *)
      val (krnexp, vKs) = createKrn (dim, [], [])
      val exp =  E.Sum(esum, E.Opn(E.Prod, imgexp::krnexp))
	  in
	     (vLd::vKs, vP,  exp)
	  end

   (*getsumshift:sum_indexid list* int list-> int
    *get fresh/unused index_id, returns int 
    *)
    fun getsumshift ([], n) = n
      | getsumshift (sx, n) = let
	  val (E.V v, _, _) = List.hd(List.rev sx)
	  in
	    v+1
	  end

  (*formBody:ein_exp->ein_exp*)
    fun formBody (E.Sum([],e)) = formBody e
      | formBody (E.Sum(sx,e)) = E.Sum(sx,formBody e)
      | formBody (E.Opn(E.Prod, [e])) = e
      | formBody e = e

  (* silly change in order of the product to match vis branch WorldtoSpace functions*)
    fun multiPs (Ps, sx, body) = let
	  val exp = (case Ps
		 of [P0, P1, P2] => [P0, P1, P2, body]
		  | [P0, P1, P2, P3] => [P0, P1, P2, P3, body]
		  | _ => body::Ps
		(* end case *))
	  in
	    formBody(E.Sum(sx, E.Opn(E.Prod, exp)))
	  end
    
    fun arrangeBody (body, Ps, newsx, exp) = (case body
	   of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
            | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) =>
		(false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))
            | E.Probe _ => (true, multiPs(Ps, newsx, exp))
            | _ => raise Fail "impossible"
	  (* end case *))
      
      
    
    (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
            -> ein_exp* *code
    * Transforms position to world space
    * transforms result back to index_space
    * rewrites body 
    * replace probe with expanded version
    *)
     fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let

     (* tensor ids for position, transform matrix P, and kernel terms*)
	  val pid = length params
      val Vidnew = pid+1
      val kid = Vidnew
      
	  val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
      val E.IMG(dim, shape) = List.nth(params, Vid)
	  val freshIndex = getsumshift (sx, length index)
	  val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, pid)
      val sxn = freshIndex+length dx' (*next available index id *)
      val (args', vP, probe') = fieldReconstruction (avail, sxn, alpha, shape, dx',  Vid, Vidnew, kid, hid, tid, args)


      (* add new params transformation matrix (Pid), image param, and kernel ids *)
      val pP = E.TEN(true, [dim, dim])
      val pV = List.nth(params, Vid)
      val pK = List.tabulate(dim,fn _=> E.KRN)
      val params' = params @ (pP::pV::pK)
	  val (_, body') = arrangeBody (body, Ps, sx', probe')
      val einapp = (y, IR.EINAPP(mkEin(params', index, body'), args @ (vP::args')))
	  in
	    AvailRHS.addAssignToList (avail, einapp)
	  end


    (*transform T*P*P..Ps*)
    fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let

	  val Pid = 0
	  val tid = 1
	  val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
        (*need to rewrite dx*)
	  val sxx = sx@newsx

	  val (_, sizes, E.Conv(_, alpha', _, dx)) = (case sxx
(* QUESTION: what is the significance of "9" and "7" in this code? *)
		 of [] => ([], index, E.Conv(9, alpha, 7, dx'))
		  | _ => CleanIndex.clean(E.Conv(9, alpha, 7, dx'), index, sxx)
		(* end case *))
    
	  fun filterAlpha [] = dx'
	    | filterAlpha (E.C _::es) = filterAlpha es
	    | filterAlpha (e1::es) = e1::(filterAlpha es)
	  
      val exp = E.Tensor(tid, filterAlpha alpha')
      
	  val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)
	  val params = [E.TEN(true, [dim,dim]), E.TEN(true, sizes)]
	  val ein0 = mkEin(params, index, body')

	  in
	    (splitvar, ein0, sizes, dx, alpha')
	  end
 
 
  (* floats the reconstructed field term *)
    fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
	  val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
	  val freshIndex = getsumshift(sx, length(index))
      val E.IMG(dim, shape) = List.nth(params, Vid)
      
                    (* transform T*P*P..Ps *)
	  val (splitvar, ein0, sizes, dx, alpha') =
		createEinApp (body, alpha, index, freshIndex, dim, dx, sx)
      val vT = V.new ("TPP", Ty.TensorTy(sizes))

                    (* reconstruct the lifted probe *)
      (* making params args: image, position, and kernel ids *)
      val kid = 0 (* params used *)
      val params' = List.nth(params,Vid)::(List.tabulate(dim,fn _=> E.KRN))
      (* create body for ein expression *)
      val sxn = length sizes (*next available index id *)
      val (args', vP, probe') = fieldReconstruction (avail, sxn, alpha', shape, dx,  Vid, Vid, kid, hid, tid, args)
	  val einApp1 = IR.EINAPP(mkEin(params', sizes, probe'), args')
      
                  (* transform T*P*P..Ps *)
      val rtn0 = if splitvar
      then FloatEin.transform(y, EinSums.transform ein0, [vP, vT])
      else [(y, IR.EINAPP(ein0, [vP, vT]))]
      
	  in
      List.app (fn e => AvailRHS.addAssignToList(avail, e)) (((vT, einApp1)::(rtn0)))
	  end


  (* expandEinOp: code->  code list
   * A this point we only have simple ein ops
   * Looks to see if the expression has a probe. If so, replaces it.
   *)
    fun expand avail (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body
	   of (E.Probe(E.Conv(_, _, _, []) ,_)) =>
		replaceProbe (avail, e, body, [])
            | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>
		liftProbe (avail, e, body, []) (*scans dx for contant*)
        
            | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>
		replaceProbe (avail, e, p, sx)  (*no dx*)
    
            | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>
		liftProbe (avail, e, p, sx) (*scalar field*)
     
            | (E.Sum(sx, E.Probe p)) =>
		replaceProbe (avail, e, E.Probe p, sx)
            | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>
		replaceProbe (avail, e, E.Probe p, sx)

            | _ => AvailRHS.addAssignToList (avail, e)
	  (* end case *))

  end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0