Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3734, Thu Apr 7 22:06:42 2016 UTC revision 3735, Fri Apr 8 15:01:28 2016 UTC
# Line 8  Line 8 
8    
9  structure ProbeEin : sig  structure ProbeEin : sig
10    
11      val expand : AvailRHS.t -> MidIR.assign -> MidIR.assign list      val expand : AvailRHS.t -> MidIR.var * MidIR.rhs -> unit
12    
13    end = struct    end = struct
14    
# Line 39  Line 39 
39      *)      *)
40    
41      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
42        fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
43      fun getRHSDst x = (case IR.Var.getDef x      fun getRHSDst x = (case IR.Var.getDef x
44             of IR.OP(rator, args) => (rator, args)             of IR.OP(rator, args) => (rator, args)
45              | rhs => raise Fail(concat[              | rhs => raise Fail(concat[
# Line 73  Line 73 
73          *P is the mid-il var for the (transformation matrix)transpose          *P is the mid-il var for the (transformation matrix)transpose
74      *)      *)
75      fun handleArgs (avail, Vid, hid, tid, args) = let      fun handleArgs (avail, Vid, hid, tid, args) = let
76            val imgArg = List.nth (args, Vid)            val vI = List.nth (args, Vid)
77            val info = getImageDst imgArg            val info = getImageDst vI
78            val s = Kernel.support (getKernelDst (List.nth(args, hid)))        val vH = List.nth(args, hid)
79            val (n, f, P) = T.worldToIndex{            val (vN, vF, vP) = T.worldToIndex{
80                    avail = avail, info = info, img = imgArg, pos = List.nth(args, tid)                    avail = avail, info = info, img = vI, pos = List.nth(args, tid)
81                  }                  }
82            in            in
83              (ImageInfo.dim info, n, f, s, P)              (ImageInfo.dim info, vN, vH, vF, vP)
84            end            end
85    
86    (* build position vector for EvalKernel; args are support, axis, image dimension, position    (*lifted Kernel expressions
87     * vector    args are axis, ein index_ids that represent differentiation,  image dimension, kernel, fractional position, support
88     *)     *)
89      fun buildPos (s, dir, dim, f) = let      fun liftKrn (avail, dir, dx, dim, h, vF, s) = let
90            val x = V.new ("x", Ty.realTy)          val range = 2*s
           val u = V.new ("kernel_pos", Ty.TensorTy[2*s])  
           val stms = [  
                   IR.ASSGN(x, IR.OP(Op.Index(Ty.TensorTy[dim], dir), [f])),  
                   IR.ASSGN(u, IR.OP(Op.BuildPos(s), [x]))  
                 ]  
           in  
             (u, stms)  
           end  
91    
92    (* apply differentiation *)          (* build position vector for EvalKernel *)
93      fun getKrnDel (s, h, k, u) = let          val vX =
94          val v = V.new ("kernel_del"^Int.toString k, Ty.TensorTy[2*s])              if (dim=1) then vF   (* position is a real type*)
95          in              else AvailRHS.addAssign (avail, "vxindexed_dir"^Int.toString(dir)^"_", Ty.realTy, IR.OP(Op.Index(Ty.TensorTy[dim], dir), [vF]))
96            (v, IR.ASSGN(v, IR.OP(Op.EvalKernel(2*s, h, k), [u])))  
97          end          val vPos =  AvailRHS.addAssign (avail, "kernelpos_dir"^Int.toString(dir)^"_", Ty.TensorTy[range], IR.OP(Op.BuildPos s, [vX]))
98    
   (*lifted Kernel expressions*)  
 (* TODO: match against argsA list? *)  
     fun liftKrn (dx, dir, dim, argsA, hid, fid, s) = let  
         val (posV, stms) = buildPos (s, dir, dim, List.nth(argsA, fid))  
         val h = List.nth(argsA, hid)  
         fun iter (0, vs, stms) = (vs, stms)  
           | iter (n, vs, stms) = let  
               val n = n-1  
               val (v, stm) = getKrnDel(s, h, n, posV)  
               in  
                 iter (n, v::vs, stm::stms)  
               end  
99          val nKernEvals = List.length dx + 1          val nKernEvals = List.length dx + 1
100          val (vs, stms') = iter (nKernEvals, [], [])          fun mkEval k = AvailRHS.addAssign (avail, "mkeval_dir"^Int.toString(dir)^"_del"^Int.toString k,
101                    Ty.TensorTy[range], IR.OP(Op.EvalKernel(range, h, k), [vPos]))
102            val vKs = List.tabulate(nKernEvals, (fn k => mkEval k))
103          in          in
104            case vs            case vKs
105             of [v] => (v, stms'@stms) (* scalar result *)             of [v] => v (* scalar result *)
106              | _ => let              | _ => let
107                  val consTy = Ty.TensorTy[nKernEvals, 2*s]              val consTy = Ty.TensorTy[nKernEvals, range]
                 val resV = V.new ("kernel_cons", consTy)  
                 val stm = IR.ASSGN(resV, IR.CONS(vs, consTy))  
108                  in                  in
109                    (resV, stm :: stms' @ stms)                  AvailRHS.addAssign (avail, "kernelCons_dir_"^Int.toString(dir),  consTy, IR.CONS(vKs, consTy))
110                  end                  end
111            (* end case *)            (* end case *)
112          end          end
113    
114      (*fieldReconstruction:int*int*int,mu list, param_id, param_id, param_id, param_id      (*fieldReconstruction:int*int*int,mu list, param_id, param_id, kernel, position, param_id
115      * expands the body for the probed field      * expands the body for the probed field
116      *)      *)
117      fun fieldReconstruction (dimO, s, sx, alpha, dx, argsA, Vid, hid, n, f) = let      fun fieldReconstruction (avail, dim, sx, alpha, dx,  Vid, nid, vH, vF, kid) = let
118          (*1-d fields*)  
119            fun createKRND1 () = let  
120                  val imgpos = [E.Opn(E.Add, [E.Tensor(fid, []), E.Value sx])]          (* image positions for image body *)
121                  val deltas = List.map (fn e =>(E.C 0,e)) dx          val imgpos =
122                  val rest = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value sx))              if (dim = 1)
123                  in              then [E.Opn(E.Add, [E.Tensor(nid, []), E.Value(sx)])]
124                    E.Opn(E.Prod, [E.Img(Vid,alpha,imgpos),rest])              else List.tabulate(dim, fn dir=> E.Opn(E.Add, [E.Tensor(nid, [E.C dir]), E.Value(dir+sx)]))
125                  end          (* image body *)
126          (*createKRN Image field and kernels *)          val imgexp = E.Img(Vid, alpha, imgpos)
127            fun createKRN (0, imgpos, rest) = E.Opn(E.Prod, E.Img(Vid,alpha,imgpos)::rest)  
128              | createKRN (d, imgpos, rest) = let          val h = getKernelDst vH
129                  val d' = d-1          val s = Kernel.support h
130                  val cx = E.C(d')          (* create kernel body *)
131                  val Vsum = E.Value(sx+d')          fun createKrn (0,  krnexp, vAs) = (krnexp, vAs)
132                  val pos0 = E.Opn(E.Add, [E.Tensor(fid,[cx]), Vsum])            | createKrn (dir, krnexp, vAs) = let
133                  val deltas = List.map (fn e =>(cx, e)) dx              val dir' = dir-1
134                  val rest0 = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[cx]), Vsum))              (* ein expression *)
135  (* WHAT NEXT?? *)              val deltas = List.map (fn e =>(E.C(dir'), e)) dx
136                  in              val kexp0 = E.Krn(kid+dir, deltas, E.Value(dir))
137                    createKRN (d', pos0::imgpos, rest0::rest)              (* evalkernel operators *)
138                  end              val vA = liftKrn (avail, dir, dx, dim, h, vF, s)
139          (*sumIndex creating summation Index for body*)          in
140            val esum = List.tabulate (dimO, fn d => (E.V d, 1-s, s))              createKrn (dir', kexp0::krnexp, vA::vAs)
141            val exp = if (dimO = 1) then createKRND1() else createKRN(dimO, [], [])          end
142    
143          (* creating summation Index *)
144           val esum = List.tabulate (dim, fn i => (E.V i, 1-s, s))
145          (* final ein expression body to represent field reconstruction *)
146          val (krnexp, vAs) = createKrn (dim, [], [])
147          val exp =  E.Sum(esum, E.Opn(E.Prod, imgexp::krnexp))
148            in            in
149              E.Sum(esum, exp)              (exp, vAs)
150            end            end
151    
152     (*getsumshift:sum_indexid list* int list-> int     (*getsumshift:sum_indexid list* int list-> int
# Line 198  Line 184 
184              | _ => raise Fail "impossible"              | _ => raise Fail "impossible"
185            (* end case *))            (* end case *))
186    
187    
188    
189      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
190              -> ein_exp* *code              -> ein_exp* *code
191      * Transforms position to world space      * Transforms position to world space
# Line 206  Line 194 
194      * replace probe with expanded version      * replace probe with expanded version
195      *)      *)
196       fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let       fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
197            val fid = length params  
198            val nid = fid+1       (* tensor ids for position, transform matrix P, and kernel terms*)
199            val Pid = nid+1            val nid = length params
200              val pid = nid+1
201          val kid = pid+1
202    
203            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
204            val (dim, n, f, s, PArg) = handleArgs (avail, Vid, hid, tid, args)            val (dim, vN, vH, vF, vP)  = handleArgs (avail, Vid, hid, tid, args)
205            val freshIndex = getsumshift (sx, length index)            val freshIndex = getsumshift (sx, length index)
206            val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, Pid)            val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, pid)
207            val params' = params @ [E.TEN(true, [dim]), E.TEN(true, [dim]), E.TEN(true, [dim, dim])]        val sxn = freshIndex+length dx' (*next available index id *)
208            val probe' = fieldReconstruction (avail, dim, s, freshIndex+length dx', alpha, dx', Vid, hid, n, f)            val (probe', vKs) = fieldReconstruction (avail, dim, sxn, alpha, dx', Vid, nid, vH, vF, kid)
209          (* add new params position (nid), transformation matrix (Pid), and kernel ids *)
210          val params' = params @ [E.TEN(true, [dim]), E.TEN(true, [dim, dim])] @(List.tabulate(dim,fn _=> E.TEN(true,[])))
211            val (_, body') = arrangeBody (body, Ps, sx', probe')            val (_, body') = arrangeBody (body, Ps, sx', probe')
212            val einapp = (y, IR.EINAPP(mkEin(params', index, body'), args @ [n, f, PArg]))            val einapp = (y, IR.EINAPP(mkEin(params', index, body'), args @ [vN, vP]@vKs))
213          val _ = print(String.concat[ "\nProbe var: ", V.name(y)])
214          (* val _ =  AvailRHS.addAssignToList (avail, einapp)*)(*FIXME: remove this*)
215            in            in
216              AvailRHS.addAssignToList (avail, einapp)              AvailRHS.addAssignToList (avail, einapp)
217            end            end
218    
219      (*transform T*P*P..Ps*)      (*transform T*P*P..Ps*)
220      fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let      fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let
221    
222            val Pid = 0            val Pid = 0
223            val tid = 1            val tid = 1
224            val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)            val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
# Line 243  Line 239 
239            in            in
240              (splitvar, ein0, sizes, dx', alpha')              (splitvar, ein0, sizes, dx', alpha')
241            end            end
242            (*
243    (* floats the reconstructed field term *)    (* floats the reconstructed field term *)
244      fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let      fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
           val fid = length params  
           val nid = fid+1  
245            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
246            val (dim, args', s, PArg) = handleArgs(avail, Vid, hid, tid, args)            val (dim, vN, vH, vF, vP)  = handleArgs(avail, Vid, hid, tid, args)
247            val freshIndex = getsumshift(sx, length(index))            val freshIndex = getsumshift(sx, length(index))
248    
249          (* transform T*P*P..Ps *)          (* transform T*P*P..Ps *)
250            val (splitvar, ein0, sizes, dx, alpha') =            val (splitvar, ein0, sizes, dx, alpha') =
251                  createEinApp (body, alpha, index, freshIndex, dim, dx, sx)                  createEinApp (body, alpha, index, freshIndex, dim, dx, sx)
252            val FArg = V.new ("T", Ty.TensorTy(sizes))            val vF = V.new ("T", Ty.TensorTy(sizes))
253            val einApp0 = IR.EINAPP(ein0, [PArg, FArg])            val einApp0 = IR.EINAPP(ein0, [vP, vF])
254            val rtn0 = if splitvar            val rtn0 = if splitvar
255                  then FloatEin.transform(y, EinSums.transform ein0, [PArg, FArg])                  then FloatEin.transform(y, EinSums.transform ein0, [vP, vF])
256                  else [(y, IR.EINAPP(ein0, [PArg, FArg]))]                  else [(y, IR.EINAPP(ein0, [vP, vF]))]
257    
258          (* reconstruct the lifted probe *)          (* reconstruct the lifted probe *)
259            val params' = params@[E.TEN(true, [dim]), E.TEN(true, [dim])]        (* tensor id for position *)
260            val freshIndex' = length sizes        val nid = length params
261            val body' = fieldReconstruction (dim, s, freshIndex', alpha', dx, Vid, hid, nid, fid)        val kid = nid+1
262            val einApp1 = IR.EINAPP(mkEin(params', sizes, body'), args')        (* add new params position (nid) and kernel ids *)
263          val params' = params@(E.TEN(true, [dim])::(List.tabulate(dim,fn _=> E.TEN(true,[]))))
264              val sxn = length sizes (*next available index id *)
265              val (probe', vKs) = fieldReconstruction (avail, dim, sxn, alpha', dx, Vid, nid, vH, vF, kid)
266          val args' = args@(vN::vKs)
267              val einApp1 = IR.EINAPP(mkEin(params', sizes, probe'), args')
268            in            in
269              code @ (FArg, einApp1) :: rtn0          AvailRHS.addAssignsToList(avail, (vF, einApp1) :: rtn0)
270            end            end
271    *)
272    
273    (* expandEinOp: code->  code list    (* expandEinOp: code->  code list
274     * A this point we only have simple ein ops     * A this point we only have simple ein ops
# Line 275  Line 277 
277      fun expand avail (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body      fun expand avail (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body
278             of (E.Probe(E.Conv(_, _, _, []) ,_)) =>             of (E.Probe(E.Conv(_, _, _, []) ,_)) =>
279                  replaceProbe (avail, e, body, [])                  replaceProbe (avail, e, body, [])
280            | (E.Probe _) =>
281            replaceProbe (avail, e, body, [])
282            (*FIXME: Only use replaceProbe while trying to isolate valnum bug*)
283            (*
284              | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>              | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>
285                  liftProbe (avail, e, body, []) (*scans dx for contant*)                  liftProbe (avail, e, body, []) (*scans dx for contant*)
286            *)
287              | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>              | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>
288                  replaceProbe (avail, e, p, sx)  (*no dx*)                  replaceProbe (avail, e, p, sx)  (*no dx*)
289            (*
290              | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>              | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>
291                  liftProbe (avail, e, p, sx) (*scalar field*)                  liftProbe (avail, e, p, sx) (*scalar field*)
292            *)
293              | (E.Sum(sx, E.Probe p)) =>              | (E.Sum(sx, E.Probe p)) =>
294                  replaceProbe (avail, e, E.Probe p, sx)                  replaceProbe (avail, e, E.Probe p, sx)
295              | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>              | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>
296                  replaceProbe (avail, e, E.Probe p, sx)                  replaceProbe (avail, e, E.Probe p, sx)
297    
298              | _ => AvailRHS.addAssignToList (avail, e)              | _ => AvailRHS.addAssignToList (avail, e)
299            (* end case *))            (* end case *))
300    

Legend:
Removed from v.3734  
changed lines
  Added in v.3735

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0