Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] View of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

View of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3569 - (download) (annotate)
Mon Jan 11 05:47:54 2016 UTC (4 years, 5 months ago) by cchiw
File size: 11644 byte(s)
small cleanup to probe-ein
(* probe-ein.sml
 *
 * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
 *
 * COPYRIGHT (c) 2016 The University of Chicago
 * All rights reserved.
 *)

structure ProbeEin : sig

  end = struct
   
    structure E = Ein
    structure DstIL = MidIL
    structure DstOp = MidOps
    structure T = TransformEin
    structure MidToS = MidToString
    structure DstV = DstIL.Var
    structure DstTy = MidILTypes

    (* This file expands probed fields
    * Take a look at ProbeEin tex file for examples
    * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
    * Param_ids are used to note the placement of the argument in the midIL.var list
    * Index_ids  keep track of the shape of an Image or differentiation.
    * Mu  bind Index_id
    * Generally, we will refer to the following
    * dim:dimension of field V
    * s: support of kernel H
    * alpha: The alpha in <V_alpha * H^(deltas)>
    * deltas: The deltas in <V_alpha * H^(deltas)>
    * Vid:param_id for V
    * hid:param_id for H
    * nid: integer position param_id
    * fid :fractional position param_id
    * img-imginfo about V  
    *)

(* FIXME: what are these for? should they be settable from the command-line? *)
    val valnumflag = true
    val tsplitvar = true
    val fieldliftflag = true
    val detflag = true

    fun transformToIndexSpace e = T.transformToIndexSpace e
    fun transformToImgSpace e = T.transformToImgSpace  e

    fun mkEin e = Ein.mkEin e
    fun mkEinApp (rator, args) = DstIL.EINAPP(rator, args)
    fun setConst e = E.setConst e
    fun setNeg e = E.setNeg e
    fun setExp e = E.setExp e
    fun setDiv e= E.setDiv e
    fun setSub e= E.setSub e
    fun setProd e= E.setProd e
    fun setAdd e= E.setAdd e

    fun getRHSDst x  = (case DstIL.Var.binding x
           of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
	    | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
	    | vb => raise Fail(concat[
		  "expected rhs operator for ", DstIL.Var.toString x,
		  " but found ", DstIL.vbToString vb
		])
          (* end case *))

    (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
        uses the Param_ids for the image, kernel,
        and position tensor to get the Mid-IL arguments
    returns the support of ther kernel, and image
    *)
    fun getArgsDst (hArg, imgArg, args) = (case (getRHSDst hArg, getRHSDst imgArg)
           of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ )) =>
		(Kernel.support h, img, ImageInfo.dim img)
	     | ((k,_), (i,_)) => raise Fail (String.concat[
		    "Expected kernel: ", DstOp.toString k, ", Expected Image: ", DstOp.toString i
		  ])
          (* end case *))

    (*handleArgs():int*int*int*Mid IL.Var list
        ->int*Mid.ILVars list* code*int* low-il-var
        * uses the Param_ids for the image, kernel, and tensor
        * and gets the mid-IL vars for each.
        *Transforms the position to index space
        *P is the mid-il var for the (transformation matrix)transpose
    *)
    fun handleArgs (Vid, hid, tid, args) = let
	  val imgArg = List.nth(args, Vid)
	  val hArg = List.nth(args, hid)
	  val newposArg = List.nth(args, tid)
	  val (s,img, dim) = getArgsDst(hArg, imgArg, args)
	  val (argsT, P, code) = transformToImgSpace(dim, img, newposArg, imgArg)
	  in
            (dim, args@argsT, code, s, P)
	  end

    (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
    * expands the body for the probed field
    *)
    fun createBody (dim, s, sx, alpha, deltas, Vid, hid, nid, fid)=let
        (*1-d fields*)
        fun createKRND1 () = let
            val sum = sx
            val dels = List.map (fn e=>(E.C 0,e)) deltas
            val pos = [setAdd[E.Tensor(fid,[]), E.Value(sum)]]
            val rest = E.Krn(hid, dels, setSub(E.Tensor(nid,[]), E.Value(sum)))
            in 
               setProd [E.Img(Vid,alpha,pos),rest]
            end
        (*createKRN Image field and kernels *)
      fun createKRN (0, imgpos, rest) = setProd ([E.Img(Vid,alpha,imgpos)] @rest)
        | createKRN (dim, imgpos, rest) = let
            val dim' = dim-1
            val sum = sx+dim'
            val dels = List.map (fn e=>(E.C dim',e)) deltas
            val pos = [setAdd[E.Tensor(fid,[E.C dim']), E.Value(sum)]]
            val rest' =  E.Krn(hid, dels, setSub(E.Tensor(nid,[E.C dim']), E.Value(sum)))
            in
                createKRN(dim',pos@imgpos,[rest']@rest)
            end
        val exp = (case dim
            of 1 => createKRND1()
            | _ => createKRN(dim, [], [])
            (* end case *))
        (*sumIndex creating summaiton Index for body*)
        val slb = 1-s
        val esum = List.tabulate(dim, (fn dim=>(E.V (dim+sx), slb, s)))
    in
        E.Sum(esum, exp)
    end

    (*getsumshift:sum_indexid list* int list-> int
    *get fresh/unused index_id, returns int 
    *)
    fun getsumshift ([], n) = n
    fun getsumshift (sx, n) = let
        val (E.V v,_,_) = List.hd( List.rev sx)
        in
            v+1
        end

    (*formBody:ein_exp->ein_exp
    *)
    fun formBody (E.Sum([],e))=formBody e
      | formBody (E.Sum(sx,e))= E.Sum(sx,formBody e)
      | formBody (E.Opn(E.Prod, [e]))=e
      | formBody e=e

    (* silly change in order of the product to match vis branch WorldtoSpace functions*)
    fun multiPs ([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
    (*
      | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))
      *)
      | multiPs ([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
      | multiPs (Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
    
        
    fun multiMergePs ([P0, P1], [sx0, sx1], body) = E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])
      | multiMergePs e = multiPs e 
        
        
    (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
            -> ein_exp* *code
    * Transforms position to world space
    * transforms result back to index_space
    * rewrites body 
    * replace probe with expanded version
    *)
     fun replaceProbe (testN, (y, DstIL.EINAPP(e,args)), p, sx) = let
        val params = Ein.params e
        val fid = length(params)
        val nid = fid+1
        val Pid = nid+1
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = p
        val nshift = length(dx)
        val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
        val freshIndex = getsumshift(sx,length(Ein.index e))
        val (dx,newsx1,Ps) = transformToIndexSpace(freshIndex,dim,dx,Pid)
        val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
        val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
        val body' = multiPs(Ps,newsx1,body')
        val body'= (case (Ein.body e)
            of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')
            | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])
            | _                                  => body'
            (* end case *))
        val einapp=(y,mkEinApp(mkEin(params',index,body'),argsA@[PArg]))
        in 
            code@[einapp]
        end

    fun createEinApp (originalb, alpha, index, freshIndex, dim, dx, sx) = let
        val Pid = 0
        val tid = 1
      
        (*Assumes body is already clean*)
        val (newdx, newsx, Ps)=transformToIndexSpace(freshIndex, dim, dx, Pid)
        
        (*need to rewrite dx*)
        val (_, sizes, e as E.Conv(_,alpha',_,dx)) = (case sx@newsx
            of [] => ([], index, E.Conv(9,alpha,7,newdx))
            | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx), index, sx@newsx)
            (* end case *))
        val params = [E.TEN(1,[dim,dim]), E.TEN(1,sizes)]
        fun filterAlpha []=[]
          | filterAlpha (E.C _::es) = filterAlpha es
          | filterAlpha (e1::es) = [e1]@(filterAlpha es)
        val tshape = filterAlpha(alpha')@newdx
        val t = E.Tensor(tid, tshape)
        val (splitvar, body) = (case originalb
            of E.Sum(sx, E.Probe _)              => (true, multiPs(Ps, sx@newsx,t))
            | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ]))  => (false, E.Sum(sx, setProd[eps0, multiPs(Ps, newsx, t)]))
            | _  => (true, multiPs(Ps, newsx, t))
            (* end case *))
        
        val ein0 = mkEin(params, index, body)
        in
            (splitvar, ein0, sizes, dx, alpha')
        end
    
    fun liftProbe (printStrings, (y, DstIL.EINAPP(e, args)), p, sx) = let

        val params = Ein.params e
        val index = Ein.index e
        val fid = length(params)
        val nid = fid+1
        val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = p
        val nshift = length(dx)
        val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)
        val freshIndex = getsumshift(sx, length(index))
        
        (*transform T*P*P..Ps*)
        val (splitvar, ein0, sizes, dx, alpha') = createEinApp(Ein.body e, alpha, index, freshIndex, dim, dx, sx)
        val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
        val einApp0 = mkEinApp(ein0, [PArg,FArg])
        val rtn0 = (case splitvar
            of false => [(y, mkEinApp(ein0, [PArg,FArg]))]
            | _      => let
                val bind3 = (y, DstIL.EINAPP(SummationEin.main ein0, [PArg, FArg]))
                in
                    Split.splitEinApp bind3
                end
            (* end case *))
        
        (*lifted probe*)
        val params' = params@[E.TEN(3,[dim]), E.TEN(1,[dim])]
        val freshIndex'= length(sizes)
        val body' = createBody(dim, s, freshIndex',alpha', dx, Vid, hid, nid, fid)
        val ein1=mkEin(params', sizes, body')
        val einApp1=mkEinApp(ein1, args')
        val rtn1=(FArg, einApp1)
        in
            code@[rtn1]@rtn0
        end


    (* expandEinOp: code->  code list
    * A this point we only have simple ein ops
    * Looks to see if the expression has a probe. If so, replaces it.
    * Note how we keeps eps expressions so only generate pieces that are used
    *)
   fun expandEinOp (e as (y, DstIL.EINAPP(ein, args)), fieldset) = let
        fun rewriteBody b=(case b
            of  (E.Probe(E.Conv(_,_,_,[]),_))
                => replaceProbe(0,e,b,[])
            | (E.Probe(E.Conv (_,alpha,_,dx),_))
                => liftProbe (0,e,b,[]) (*scans dx for contant*)
            | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
                => replaceProbe(0,e,p, sx)  (*no dx*)
            | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
                => liftProbe (0,e,p,sx) (*scalar field*)
            | (E.Sum(sx,E.Probe p))
                => replaceProbe(0,e,E.Probe p, sx)
            | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))
                => replaceProbe(0,e,E.Probe p,sx)
            | _ => [e]
            (* end case *))
        val (fieldset,var) = (case valnumflag
            of true => einSet.rtnVar(fieldset,y,DstIL.EINAPP(ein,args))
            | _     => (fieldset,NONE)
        (* end case *))
        
        fun matchField b=(case b
            of E.Probe _ => 1
            | E.Sum (_, E.Probe _) => 1
            | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _])) => 1
            | _ => 0
            (* end case *))
        val b = Ein.body ein

        in  (case var
            of NONE => ((rewriteBody(Ein.body ein), fieldset, matchField(Ein.body ein), 0))
            | SOME v => (("\n mapp_replacing"^(P.printerE ein)^":");([(y,DstIL.VAR v)], fieldset, matchField(Ein.body ein), 1))
            (* end case *))
        end

  end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0