Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3657 - (download) (annotate)
Thu Feb 4 19:57:40 2016 UTC (3 years, 7 months ago) by cchiw
File size: 9400 byte(s)
fix arguments to lift()
(* probe-ein.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *)

structure ProbeEin : sig

    val expand : MidIR.assign -> MidIR.assign list

  end = struct

    structure IR = MidIR
    structure Op = MidOps
    structure V = IR.Var
    structure Ty = MidTypes
    structure E = Ein
    structure T = CoordSpaceTransform

   (* This file expands probed fields
    * Take a look at ProbeEin tex file for examples
    * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )
    * Param_ids are used to note the placement of the argument in the midIR.var list
    * Index_ids  keep track of the shape of an Image or differentiation.
    * Mu  bind Index_id
    * Generally, we will refer to the following
    * dim:dimension of field V
    * s: support of kernel H
    * alpha: The alpha in <V_alpha * H^(deltas)>
    * dx: The dx in <V_alpha * nabla_dx H>
    * deltas: The deltas in <V_alpha * h^(deltas) h^(deltas)>
    * Vid:param_id for V
    * hid:param_id for H
    * nid: integer position param_id
    * fid :fractional position param_id
    * img-imginfo about V  
    *)

    fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}

    fun getRHSDst x = (case IR.Var.getDef x
           of IR.OP(rator, args) => (rator, args)
	    | rhs => raise Fail(concat[
		  "expected rhs operator for ", IR.Var.toString x,
		  " but found ", IR.RHS.toString rhs
		])
          (* end case *))

    fun getImageDst imgArg = (case IR.Var.getDef imgArg
	   of IR.OP(Op.LoadImage(Ty.ImageTy info, _), _) => info
(* FIXME: also border control! *)
	    | rhs => raise Fail (String.concat[
		  "expected image for ", IR.Var.toString imgArg,
		  " but found ", IR.RHS.toString rhs
		])
          (* end case *))

    fun getKernelDst hArg = (case IR.Var.getDef hArg
	   of IR.OP(Op.Kernel(h, _), _) => Kernel.support h
	    | rhs => raise Fail (String.concat[
		  "expected kernel for ", IR.Var.toString hArg,
		  " but found ", IR.RHS.toString rhs
		])
	  (* end case *))

    (*handleArgs():int*int*int*Mid IR.Var list
        ->int*Mid.IRVars list* code*int* low-il-var
        * uses the Param_ids for the image, kernel, and tensor
        * and gets the mid-IR vars for each.
        *Transforms the position to index space
        *P is the mid-il var for the (transformation matrix)transpose
    *)
    fun handleArgs (Vid, hid, tid, args) = let
    val _ =print (String.concat["\nVid:",Int.toString(Vid),"\nhid:",Int.toString(hid),
                "\nTid:",Int.toString(tid),
                "\n",String.concatWith","(List.map IR.Var.toString args),"\n"])
	  val imgArg = List.nth (args, Vid)
	  val info = getImageDst imgArg
	  val s = getKernelDst (List.nth(args, hid))
	  val (argsT, P, code) = T.worldToIndex{info = info, img = imgArg, pos = List.nth(args, tid)}
	  in
	    (ImageInfo.dim info, args@argsT, code, s, P)
	  end

    (*fieldReconstruction:int*int*int,mu list, param_id, param_id, param_id, param_id
    * expands the body for the probed field
    *)
    fun fieldReconstruction (dimO, s, sx, alpha, dx, Vid, hid, nid, fid) = let
        (*1-d fields*)
	  fun createKRND1 () = let
		val imgpos = [E.Opn(E.Add,[E.Tensor(fid,[]), E.Value(sx)])]
		val deltas = List.map (fn e =>(E.C 0,e)) dx
		val rest = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value(sx)))
		in 
		  E.Opn(E.Prod, [E.Img(Vid,alpha,imgpos),rest])
		end
        (*createKRN Image field and kernels *)
          fun createKRN (0, imgpos, rest) = E.Opn(E.Prod, E.Img(Vid,alpha,imgpos)::rest)
            | createKRN (d, imgpos, rest) = let
		val d' = d-1
		val cx = E.C(d')
		val Vsum = E.Value(sx+d')
		val pos0 = E.Opn(E.Add, [E.Tensor(fid,[cx]), Vsum])
		val deltas = List.map (fn e =>(cx, e)) dx
		val rest0 = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[cx]),  Vsum))
		in
		  createKRN (d', pos0::imgpos, rest0::rest)
		end
        (*sumIndex creating summation Index for body*)
	  val esum = List.tabulate (dimO, fn d => (E.V d, 1-s, s))
	  val exp = if (dimO = 1) then createKRND1() else createKRN(dimO, [], [])
	  in
	    E.Sum(esum, exp)
	  end

   (*getsumshift:sum_indexid list* int list-> int
    *get fresh/unused index_id, returns int 
    *)
    fun getsumshift ([], n) = n
      | getsumshift (sx, n) = let
	  val (E.V v,_,_) = List.hd( List.rev sx)
	  in
	    v+1
	  end

  (*formBody:ein_exp->ein_exp*)
    fun formBody (E.Sum([],e)) = formBody e
      | formBody (E.Sum(sx,e)) = E.Sum(sx,formBody e)
      | formBody (E.Opn(E.Prod, [e])) = e
      | formBody e = e

  (* silly change in order of the product to match vis branch WorldtoSpace functions*)
    fun multiPs (Ps, sx, body) = let
	  val exp = (case Ps
		 of [P0, P1, P2] => [P0, P1, P2, body]
		  | [P0, P1, P2, P3] => [P0, P1, P2, P3, body]
		  | _ => body::Ps
		(* end case *))
	  in
	    formBody(E.Sum(sx, E.Opn(E.Prod, exp)))
	  end
    
    fun arrangeBody (body, Ps, newsx, exp) = (case body
	   of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
            | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) =>
		(false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))
            | E.Probe _ => (true, multiPs(Ps, newsx, exp))
            | _ => raise Fail "impossible"
	  (* end case *))
    
    (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
            -> ein_exp* *code
    * Transforms position to world space
    * transforms result back to index_space
    * rewrites body 
    * replace probe with expanded version
    *)
     fun replaceProbe ((y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
     val _ =print(String.concat["\n replacing probe:",EinPP.expToString(probe)])
	  val fid = length params
	  val nid = fid+1
	  val Pid = nid+1
	  val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
	  val (dim, argsA, code, s, PArg) = handleArgs (Vid, hid, tid, args)
	  val freshIndex = getsumshift (sx, length index)
	  val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, Pid)
	  val params' = params @ [E.TEN(true, [dim]), E.TEN(true, [dim]), E.TEN(true, [dim, dim])]
	  val probe' = fieldReconstruction (dim, s, freshIndex+length dx', alpha, dx', Vid, hid, nid, fid)
	  val (_, body') = arrangeBody (body, Ps, sx', probe')
	  val einapp = (y, IR.EINAPP(mkEin(params', index, body'), argsA@[PArg]))
	  in 
            code@[einapp]
	  end

    (*transform T*P*P..Ps*)
    fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let
	  val Pid = 0
	  val tid = 1
	  val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
        (*need to rewrite dx*)
	  val sxx = sx@newsx
	  val (_, sizes, E.Conv(_, alpha', _, dx')) = (case sxx
(* QUESTION: what is the significance of "9" and "7" in this code? *)
		 of [] => ([], index, E.Conv(9, alpha, 7, dx'))
		  | _ => CleanIndex.clean(E.Conv(9, alpha, 7, dx'), index, sxx)
		(* end case *))
	  fun filterAlpha [] = dx'
	    | filterAlpha (E.C _::es) = filterAlpha es
	    | filterAlpha (e1::es) = e1::(filterAlpha es)
	  val exp = E.Tensor(tid, filterAlpha alpha')
	  val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)
	  val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]
	  val ein0 = mkEin(params, index, body')
	  in
	    (splitvar, ein0, sizes, dx', alpha')
	  end
        
  (* floats the reconstructed field term *)
    fun liftProbe ((y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
        val _ =print"\n lifting probe"
	  val fid = length(params)
	  val nid = fid+1
	  val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
	  val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)
	  val freshIndex = getsumshift(sx, length(index))        
        (* transform T*P*P..Ps *)
	  val (splitvar, ein0, sizes, dx, alpha') =
		createEinApp (body, alpha, index, freshIndex, dim, dx, sx)
	  val FArg = V.new ("T", Ty.TensorTy(sizes))
	  val einApp0 = IR.EINAPP(ein0, [PArg, FArg])
	  val rtn0 = if splitvar
		then FloatEin.transform(y, EinSums.transform ein0, [PArg, FArg])
		else [(y, IR.EINAPP(ein0, [PArg, FArg]))]       
        (* reconstruct the lifted probe *)
	  val params' = params@[E.TEN(true, [dim]), E.TEN(true, [dim])]
	  val freshIndex' = length sizes
	  val body' = fieldReconstruction (dim, s, freshIndex', alpha', dx, Vid, hid, nid, fid)
	  val einApp1 = IR.EINAPP(mkEin(params', sizes, body'), args')
	  in
            code @ (FArg, einApp1) :: rtn0
	  end

  (* expandEinOp: code->  code list
   * A this point we only have simple ein ops
   * Looks to see if the expression has a probe. If so, replaces it.
   *)
    fun expand (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body
	   of (E.Probe(E.Conv(_, _, _, []) ,_)) => replaceProbe(e, body, [])
            | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) => liftProbe (e, body, []) (*scans dx for contant*)
            | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) => replaceProbe(e, p, sx)  (*no dx*)
            | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) => liftProbe (e, p, sx) (*scalar field*)
            | (E.Sum(sx, E.Probe p)) => replaceProbe(e, E.Probe p, sx)
            | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) => replaceProbe(e ,E.Probe p, sx)
            | _ => [e]
	  (* end case *))

  end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0