Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3731, Thu Apr 7 19:25:15 2016 UTC revision 3732, Thu Apr 7 20:56:16 2016 UTC
# Line 8  Line 8 
8    
9  structure ProbeEin : sig  structure ProbeEin : sig
10    
11      val expand : MidIR.assign -> MidIR.assign list      val expand : AvailRHS.t -> MidIR.assign -> MidIR.assign list
12    
13    end = struct    end = struct
14    
# Line 58  Line 58 
58            (* end case *))            (* end case *))
59    
60      fun getKernelDst hArg = (case IR.Var.getDef hArg      fun getKernelDst hArg = (case IR.Var.getDef hArg
61             of IR.OP(Op.Kernel(h, _), _) => Kernel.support h             of IR.OP(Op.Kernel(h, _), _) => h
62              | rhs => raise Fail (String.concat[              | rhs => raise Fail (String.concat[
63                    "expected kernel for ", IR.Var.toString hArg,                    "expected kernel for ", IR.Var.toString hArg,
64                    " but found ", IR.RHS.toString rhs                    " but found ", IR.RHS.toString rhs
# Line 72  Line 72 
72          *Transforms the position to index space          *Transforms the position to index space
73          *P is the mid-il var for the (transformation matrix)transpose          *P is the mid-il var for the (transformation matrix)transpose
74      *)      *)
75      fun handleArgs (Vid, hid, tid, args) = let      fun handleArgs (avail, Vid, hid, tid, args) = let
76            val imgArg = List.nth (args, Vid)            val imgArg = List.nth (args, Vid)
77            val info = getImageDst imgArg            val info = getImageDst imgArg
78            val s = getKernelDst (List.nth(args, hid))            val s = Kernel.support (getKernelDst (List.nth(args, hid)))
79            val (argsT, P, code) = T.worldToIndex{info = info, img = imgArg, pos = List.nth(args, tid)}            val (n, f, P) = T.worldToIndex{
80                      avail = avail, info = info, img = imgArg, pos = List.nth(args, tid)
81                    }
82            in            in
83              (ImageInfo.dim info, args@argsT, code, s, P)              (ImageInfo.dim info, n, f, s, P)
84              end
85    
86      (* build position vector for EvalKernel; args are support, axis, image dimension, position
87       * vector
88       *)
89        fun buildPos (s, dir, dim, f) = let
90              val x = DstV.new ("x", DstTy.realTy)
91              val u = DstV.new ("kernel_pos", DstTy.TensorTy[2*s])
92              val stms = [
93                      IR.ASSGN(x, IR.OP(Op.Index(DstTy.TensorTy[dim], dir), [f])),
94                      IR.ASSGN(u, IR.OP(Op.BuildPos(s), [x]))
95                    ]
96              in
97                (u, stms)
98              end
99    
100      (* apply differentiation *)
101        fun getKrnDel (s, h, k, u) = let
102            val v = DstV.new ("kernel_del"^Int.toString k, DstTy.TensorTy[2*s])
103            in
104              (v, IR.ASSGN(v, IR.OP(Op.EvalKernel(2*s, h, k), [u])))
105            end
106    
107      (*lifted Kernel expressions*)
108    (* TODO: match against argsA list? *)
109        fun liftKrn (dx, dir, dim, argsA, hid, fid, s) = let
110            val (posV, stms) = buildPos (s, dir, dim, List.nth(args, fid))
111            val h = List.nth(argsA, hid)
112            fun iter (0, vs, stms) = (vs, stms)
113              | iter (n, vs, stms) = let
114                  val n = n-1
115                  val (v, stm) = getKrnDel(s, h, n, posV)
116                  in
117                    iter (n, v::vs, stm::stms)
118                  end
119            val nKernEvals = List.length dx + 1
120            val (vs, stms') = iter (nKernEvals, [], [])
121            in
122              case vs
123               of [v] => (v, stms'@stms) (* scalar result *)
124                | _ => let
125                    val consTy = DstTy.TensorTy[nKernEvals, 2*s]
126                    val resV = DstV.new ("kernel_cons", consTy)
127                    val stm = IR.ASSGN(resV, IR.CONS(vs, consTy))
128                    in
129                      (resV, stm :: stms' @ stms)
130                    end
131              (* end case *)
132            end            end
133    
134      (*fieldReconstruction:int*int*int,mu list, param_id, param_id, param_id, param_id      (*fieldReconstruction:int*int*int,mu list, param_id, param_id, param_id, param_id
135      * expands the body for the probed field      * expands the body for the probed field
136      *)      *)
137      fun fieldReconstruction (dimO, s, sx, alpha, dx, Vid, hid, nid, fid) = let      fun fieldReconstruction (dimO, s, sx, alpha, dx, argsA, Vid, hid, n, f) = let
138          (*1-d fields*)          (*1-d fields*)
139            fun createKRND1 () = let            fun createKRND1 () = let
140                  val imgpos = [E.Opn(E.Add,[E.Tensor(fid,[]), E.Value(sx)])]                  val imgpos = [E.Opn(E.Add, [E.Tensor(fid, []), E.Value sx])]
141                  val deltas = List.map (fn e =>(E.C 0,e)) dx                  val deltas = List.map (fn e =>(E.C 0,e)) dx
142                  val rest = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value(sx)))                  val rest = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value sx))
143                  in                  in
144                    E.Opn(E.Prod, [E.Img(Vid,alpha,imgpos),rest])                    E.Opn(E.Prod, [E.Img(Vid,alpha,imgpos),rest])
145                  end                  end
# Line 102  Line 152 
152                  val pos0 = E.Opn(E.Add, [E.Tensor(fid,[cx]), Vsum])                  val pos0 = E.Opn(E.Add, [E.Tensor(fid,[cx]), Vsum])
153                  val deltas = List.map (fn e =>(cx, e)) dx                  val deltas = List.map (fn e =>(cx, e)) dx
154                  val rest0 = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[cx]),  Vsum))                  val rest0 = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[cx]),  Vsum))
155                    val (
156                  in                  in
157                    createKRN (d', pos0::imgpos, rest0::rest)                    createKRN (d', pos0::imgpos, rest0::rest)
158                  end                  end
# Line 154  Line 205 
205      * rewrites body      * rewrites body
206      * replace probe with expanded version      * replace probe with expanded version
207      *)      *)
208       fun replaceProbe ((y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let       fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
209            val fid = length params            val fid = length params
210            val nid = fid+1            val nid = fid+1
211            val Pid = nid+1            val Pid = nid+1
212            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
213            val (dim, argsA, code, s, PArg) = handleArgs (Vid, hid, tid, args)            val (dim, n, f, s, PArg) = handleArgs (avail, Vid, hid, tid, args)
214            val freshIndex = getsumshift (sx, length index)            val freshIndex = getsumshift (sx, length index)
215            val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, Pid)            val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, Pid)
216            val params' = params @ [E.TEN(true, [dim]), E.TEN(true, [dim]), E.TEN(true, [dim, dim])]            val params' = params @ [E.TEN(true, [dim]), E.TEN(true, [dim]), E.TEN(true, [dim, dim])]
217            val probe' = fieldReconstruction (dim, s, freshIndex+length dx', alpha, dx', Vid, hid, nid, fid)            val probe' = fieldReconstruction (avail, dim, s, freshIndex+length dx', alpha, dx', Vid, hid, n, f)
218            val (_, body') = arrangeBody (body, Ps, sx', probe')            val (_, body') = arrangeBody (body, Ps, sx', probe')
219            val einapp = (y, IR.EINAPP(mkEin(params', index, body'), argsA@[PArg]))            val einapp = (y, IR.EINAPP(mkEin(params', index, body'), argsA@[n, f, PArg]))
220            in            in
221              code@[einapp]              AvailRHS.addAssignToList (avail, einapp);
222            end            end
223    
224      (*transform T*P*P..Ps*)      (*transform T*P*P..Ps*)
# Line 194  Line 245 
245            end            end
246    
247    (* floats the reconstructed field term *)    (* floats the reconstructed field term *)
248      fun liftProbe ((y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let      fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
249            val fid = length(params)            val fid = length params
250            val nid = fid+1            val nid = fid+1
251            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
252            val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)            val (dim, args', s, PArg) = handleArgs(avail, Vid, hid, tid, args)
253            val freshIndex = getsumshift(sx, length(index))            val freshIndex = getsumshift(sx, length(index))
254          (* transform T*P*P..Ps *)          (* transform T*P*P..Ps *)
255            val (splitvar, ein0, sizes, dx, alpha') =            val (splitvar, ein0, sizes, dx, alpha') =
# Line 221  Line 272 
272     * A this point we only have simple ein ops     * A this point we only have simple ein ops
273     * Looks to see if the expression has a probe. If so, replaces it.     * Looks to see if the expression has a probe. If so, replaces it.
274     *)     *)
275      fun expand (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body      fun expand avail (e as (_, IR.EINAPP(Ein.EIN{body, ...}, _))) = (case body
276             of (E.Probe(E.Conv(_, _, _, []) ,_)) => replaceProbe(e, body, [])             of (E.Probe(E.Conv(_, _, _, []) ,_)) =>
277              | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) => liftProbe (e, body, []) (*scans dx for contant*)                  replaceProbe (avail, e, body, [])
278              | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) => replaceProbe(e, p, sx)  (*no dx*)              | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>
279              | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) => liftProbe (e, p, sx) (*scalar field*)                  liftProbe (avail, e, body, []) (*scans dx for contant*)
280              | (E.Sum(sx, E.Probe p)) => replaceProbe(e, E.Probe p, sx)              | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>
281              | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) => replaceProbe(e ,E.Probe p, sx)                  replaceProbe (avail, e, p, sx)  (*no dx*)
282              | _ => [e]              | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>
283                    liftProbe (avail, e, p, sx) (*scalar field*)
284                | (E.Sum(sx, E.Probe p)) =>
285                    replaceProbe (avail, e, E.Probe p, sx)
286                | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>
287                    replaceProbe (avail, e, E.Probe p, sx)
288                | _ => addAssignToList (avail, e)
289            (* end case *))            (* end case *))
290    
291    end (* ProbeEin *)    end (* ProbeEin *)

Legend:
Removed from v.3731  
changed lines
  Added in v.3732

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0