Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3785, Wed Apr 27 20:24:39 2016 UTC revision 3787, Thu Apr 28 15:26:21 2016 UTC
# Line 6  Line 6 
6   * All rights reserved.   * All rights reserved.
7   *)   *)
8    
 structure ProbeEin : sig  
   
     val expand : AvailRHS.t -> MidIR.var * MidIR.rhs -> unit  
   
   end = struct  
   
     structure IR = MidIR  
     structure Op = MidOps  
     structure V = IR.Var  
     structure Ty = MidTypes  
     structure E = Ein  
     structure T = CoordSpaceTransform  
   
9     (* This file expands probed fields     (* This file expands probed fields
10      * Take a look at ProbeEin tex file for examples      * Take a look at ProbeEin tex file for examples
11      * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )      * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )
# Line 38  Line 25 
25      * img-imginfo about V      * img-imginfo about V
26      *)      *)
27    
28    structure ProbeEin : sig
29    
30        val expand : AvailRHS.t -> MidIR.var * MidIR.rhs -> unit
31    
32      end = struct
33    
34        structure IR = MidIR
35        structure Op = MidOps
36        structure V = IR.Var
37        structure Ty = MidTypes
38        structure E = Ein
39        structure T = CoordSpaceTransform
40    
41      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
42      fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)      fun incUse (IR.V{useCnt, ...}) = (useCnt := !useCnt + 1)
43      fun getRHSDst x = (case IR.Var.getDef x      fun getRHSDst x = (case IR.Var.getDef x
# Line 54  Line 54 
54                    "expected image for ", IR.Var.toString imgArg,                    "expected image for ", IR.Var.toString imgArg,
55                    " but found ", IR.RHS.toString rhs                    " but found ", IR.RHS.toString rhs
56                  ])                  ])
57          )            (* end case *))
58    
59      fun getImagInfo e = (case IR.Var.getDef e      fun getImagInfo e = (case IR.Var.getDef e
60          of IR.OP(Op.LoadImage(Ty.ImageTy info, _), []) => (e, info, NONE)          of IR.OP(Op.LoadImage(Ty.ImageTy info, _), []) => (e, info, NONE)
61          | IR.OP(Op.BorderCtlDefault info, [imgArg])    => (imgArg, info, raise Fail "Default boarder control")              | IR.OP(Op.BorderCtlDefault info, [imgArg]) =>
62                    (imgArg, info, raise Fail "Default boarder control")
63          | IR.OP(Op.BorderCtlClamp info, [imgArg])      => (imgArg, info, SOME IndexCtl.Clamp)          | IR.OP(Op.BorderCtlClamp info, [imgArg])      => (imgArg, info, SOME IndexCtl.Clamp)
64          | IR.OP(Op.BorderCtlMirror info, [imgArg])     => (imgArg, info, SOME IndexCtl.Mirror)          | IR.OP(Op.BorderCtlMirror info, [imgArg])     => (imgArg, info, SOME IndexCtl.Mirror)
65          | IR.OP(Op.BorderCtlWrap info, [imgArg])       => (imgArg, info, SOME IndexCtl.Wrap)          | IR.OP(Op.BorderCtlWrap info, [imgArg])       => (imgArg, info, SOME IndexCtl.Wrap)
66          | rhs => raise Fail (String.concat[          | rhs => raise Fail (String.concat[
67          "expected image for ", IR.Var.toString e,                    "expected image for ", IR.Var.toString e, " but found ", IR.RHS.toString rhs
         " but found ", IR.RHS.toString rhs  
68          ])          ])
69          (* end case *))          (* end case *))
70    
# Line 89  Line 89 
89            val (vN, vF, vP) = T.worldToIndex{            val (vN, vF, vP) = T.worldToIndex{
90                    avail = avail, info = info, img = vI, pos = List.nth(args, tid)                    avail = avail, info = info, img = vI, pos = List.nth(args, tid)
91                  }                  }
       val dim = ImageInfo.dim info  
92            in            in
93              (vI, vH, vN, vF, vP, info, border, dim)              (vI, vH, vN, vF, vP, info, border, ImageInfo.dim info)
94            end            end
95    
96    (*lifted Kernel expressions    (*lifted Kernel expressions
# Line 99  Line 98 
98    *)    *)
99      fun liftKrn (avail, dir, dx, dim, h, vF, s) = let      fun liftKrn (avail, dir, dx, dim, h, vF, s) = let
100          val range = 2*s          val range = 2*s
   
101          (* build position vector for EvalKernel *)          (* build position vector for EvalKernel *)
102          val vX =            val vX = if (dim = 1)
103              if (dim=1) then vF   (* position is a real type*)                  then vF   (* position is a real type*)
104              else AvailRHS.addAssign (avail, "vxindexed_dir"^Int.toString(dir)^"_", Ty.realTy, IR.OP(Op.Index(Ty.TensorTy[dim], dir), [vF]))                  else AvailRHS.addAssign (
105                      avail, concat["vxindexed_dir", Int.toString dir, "_"],
106          val vPos =  AvailRHS.addAssign (avail, "kernelpos_dir"^Int.toString(dir)^"_", Ty.TensorTy[range], IR.OP(Op.BuildPos s, [vX]))                    Ty.realTy, IR.OP(Op.Index(Ty.TensorTy[dim], dir), [vF]))
107              val vPos =  AvailRHS.addAssign (
108                    avail, concat["kernpos_dir", Int.toString dir, "_"],
109                    Ty.TensorTy[range], IR.OP(Op.BuildPos s, [vX]))
110          val nKernEvals = List.length dx + 1          val nKernEvals = List.length dx + 1
111          fun mkEval k = AvailRHS.addAssign (avail, "mkeval_dir"^Int.toString(dir)^"_del"^Int.toString k,            fun mkEval k = AvailRHS.addAssign (
112                    avail, concat["mkeval_dir", Int.toString dir, "_del", Int.toString k],
113                  Ty.TensorTy[range], IR.OP(Op.EvalKernel(range, h, k), [vPos]))                  Ty.TensorTy[range], IR.OP(Op.EvalKernel(range, h, k), [vPos]))
114          val vKs = List.tabulate(nKernEvals, (fn k => mkEval k))            val vKs = List.tabulate(nKernEvals, fn k => mkEval k)
115          in          in
116            case vKs            case vKs
117             of [v] => v (* scalar result *)             of [v] => v (* scalar result *)
118              | _ => let              | _ => let
119              val consTy = Ty.TensorTy[nKernEvals, range]              val consTy = Ty.TensorTy[nKernEvals, range]
120              in              in
121                  AvailRHS.addAssign (avail, "kernelCons_dir_"^Int.toString(dir),  consTy, IR.CONS(vKs, consTy))                      AvailRHS.addAssign (
122                          avail, concat["kernelCons_dir_", Int.toString dir, "_"],
123                          consTy, IR.CONS(vKs, consTy))
124              end              end
125            (* end case *)            (* end case *)
126          end          end
127    
128      (* FIXME: what does this do??? *)
129      fun mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border) = let      fun mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border) = let
130         (* creates lb int *)         (* creates lb int *)
131         val vLb = AvailRHS.addAssign (avail, "lit", Ty.intTy,  IR.LIT(Literal.Int (1-(IntInf.fromInt s))))         val vLb = AvailRHS.addAssign (avail, "lit", Ty.intTy,  IR.LIT(Literal.Int (1-(IntInf.fromInt s))))
   
132         (*created n_0 +lb, n_1+lb*)         (*created n_0 +lb, n_1+lb*)
133         fun f i =           fun f i = let
134          let                val vA = AvailRHS.addAssign (
135              val vA = AvailRHS.addAssign (avail, "lit", Ty.intTy,  IR.LIT(Literal.Int (IntInf.fromInt  i)))                      avail, "idx", Ty.intTy,  IR.LIT(Literal.Int (IntInf.fromInt i)))
136              val vB = AvailRHS.addAssign (avail, "subscript", Ty.intTy, IR.OP(Op.Subscript(Ty.TensorTy[dim]), [vN, vA]))                val vB = AvailRHS.addAssign (
137                        avail, "subscript",
138                        Ty.intTy, IR.OP(Op.Subscript(Ty.SeqTy(Ty.intTy, SOME dim)), [vN, vA]))
139              in              in
140                  AvailRHS.addAssign (avail, "add", Ty.intTy, IR.OP(Op.IAdd, [vB, vLb]))                  AvailRHS.addAssign (avail, "add", Ty.intTy, IR.OP(Op.IAdd, [vB, vLb]))
141              end              end
   
142          (* image positions *)          (* image positions *)
143          val s'= 2*s          val s'= 2*s
144          val supportshape =  List.tabulate(dim, fn _ => s')          val supportshape =  List.tabulate(dim, fn _ => s')
145          val ldty = Ty.TensorTy (shape@supportshape)          val ldty = Ty.TensorTy (shape@supportshape)
146          val vNs = List.tabulate( dim, fn n => f n)          val vNs = List.tabulate( dim, fn n => f n)
147          val vSq = AvailRHS.addAssign (avail, "seq", Ty.TensorTy[9], IR.SEQ(vNs, MidTypes.SeqTy(MidTypes.IntTy, SOME dim)))            val vSq = AvailRHS.addAssign (
148                    avail, "seq",
149    (* FIXME: where does this "9" come from??? *)
150                    Ty.TensorTy[9], IR.SEQ(vNs, MidTypes.SeqTy(MidTypes.IntTy, SOME dim)))
151          val op1 = (case border          val op1 = (case border
152              of NONE => Op.LoadVoxels (info, s)              of NONE => Op.LoadVoxels (info, s)
153              | SOME b =>  Op.LoadVoxelsWithCtl (info, s, b)              | SOME b =>  Op.LoadVoxelsWithCtl (info, s, b)
# Line 151  Line 156 
156              AvailRHS.addAssign (avail, "ldvox", ldty, IR.OP(op1, [vI, vSq]))              AvailRHS.addAssign (avail, "ldvox", ldty, IR.OP(op1, [vI, vSq]))
157          end          end
158    
   
   
159      (*fieldReconstruction expands the body for the probed field*)      (*fieldReconstruction expands the body for the probed field*)
160        fun fieldReconstruction (avail, sx, alpha, shape, dx,  Vid, Vidnew, kid, hid, tid, args) = let        fun fieldReconstruction (avail, sx, alpha, shape, dx,  Vid, Vidnew, kid, hid, tid, args) = let
161          val  (vI, vH, vN, vF, vP, info, border, dim) = handleArgs (avail, Vid, hid, tid, args)          val  (vI, vH, vN, vF, vP, info, border, dim) = handleArgs (avail, Vid, hid, tid, args)
162          val h = getKernelDst vH          val h = getKernelDst vH
163          val s = Kernel.support h          val s = Kernel.support h
   
164          (* creating summation Index *)          (* creating summation Index *)
165          val vs = List.tabulate (dim, fn i => (i +sx))          val vs = List.tabulate (dim, fn i => (i +sx))
166          val esum = List.map (fn i => (E.V i, 1-s, s)) vs          val esum = List.map (fn i => (E.V i, 1-s, s)) vs
   
167          (*represent image in ein expression with tensor*)          (*represent image in ein expression with tensor*)
168          val imgexp= E.Img(Vidnew, alpha, List.map (fn i=> E.Value i)  vs, s, E.None)          val imgexp= E.Img(Vidnew, alpha, List.map (fn i=> E.Value i)  vs, s, E.None)
         (*val imgexp = E.Tensor (Vidnew, alpha@(List.map (fn i => E.V i) vs))*)  
   
169          (* create load voxel operator for image *)          (* create load voxel operator for image *)
170          val vLd = mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border)          val vLd = mkLdVoxel (avail, vI, vN, info, alpha, shape, dim, s, border)
   
171          (* create kernel body *)          (* create kernel body *)
172          fun createKrn (0,  krnexp, vAs) = (krnexp, vAs)          fun createKrn (0,  krnexp, vAs) = (krnexp, vAs)
173            | createKrn (dir, krnexp, vAs) = let            | createKrn (dir, krnexp, vAs) = let
# Line 182  Line 180 
180          in          in
181              createKrn (dir', kexp0::krnexp, vA::vAs)              createKrn (dir', kexp0::krnexp, vA::vAs)
182          end          end
   
183        (* final ein expression body to represent field reconstruction *)        (* final ein expression body to represent field reconstruction *)
184        val (krnexp, vKs) = createKrn (dim, [], [])        val (krnexp, vKs) = createKrn (dim, [], [])
185        val exp =  E.Sum(esum, E.Opn(E.Prod, imgexp::krnexp))        val exp =  E.Sum(esum, E.Opn(E.Prod, imgexp::krnexp))
# Line 217  Line 214 
214              formBody(E.Sum(sx, E.Opn(E.Prod, exp)))              formBody(E.Sum(sx, E.Opn(E.Prod, exp)))
215            end            end
216    
217      (* FIXME: what does this do??? *)
218      fun arrangeBody (body, Ps, newsx, exp) = (case body      fun arrangeBody (body, Ps, newsx, exp) = (case body
219             of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))             of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
220              | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) =>              | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) =>
# Line 225  Line 223 
223              | _ => raise Fail "impossible"              | _ => raise Fail "impossible"
224            (* end case *))            (* end case *))
225    
   
   
226      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
227              -> ein_exp* *code              -> ein_exp* *code
228      * Transforms position to world space      * Transforms position to world space
# Line 235  Line 231 
231      * replace probe with expanded version      * replace probe with expanded version
232      *)      *)
233       fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let       fun replaceProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
   
234       (* tensor ids for position, transform matrix P, and kernel terms*)       (* tensor ids for position, transform matrix P, and kernel terms*)
235            val pid = length params            val pid = length params
236        val Vidnew = pid+1        val Vidnew = pid+1
237        val kid = Vidnew        val kid = Vidnew
   
238            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
239        val E.IMG(dim, shape) = List.nth(params, Vid)        val E.IMG(dim, shape) = List.nth(params, Vid)
240            val freshIndex = getsumshift (sx, length index)            val freshIndex = getsumshift (sx, length index)
241            val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, pid)            val (dx', sx', Ps) = T.imageToWorld (freshIndex, dim, dx, pid)
242        val sxn = freshIndex+length dx' (*next available index id *)        val sxn = freshIndex+length dx' (*next available index id *)
243        val (args', vP, probe') = fieldReconstruction (avail, sxn, alpha, shape, dx',  Vid, Vidnew, kid, hid, tid, args)            val (args', vP, probe') = fieldReconstruction (
244                    avail, sxn, alpha, shape, dx', Vid, Vidnew, kid, hid, tid, args)
   
245        (* add new params transformation matrix (Pid), image param, and kernel ids *)        (* add new params transformation matrix (Pid), image param, and kernel ids *)
246        val pP = E.TEN(true, [dim, dim])        val pP = E.TEN(true, [dim, dim])
247        val pV = List.nth(params, Vid)        val pV = List.nth(params, Vid)
# Line 260  Line 253 
253              AvailRHS.addAssignToList (avail, einapp)              AvailRHS.addAssignToList (avail, einapp)
254            end            end
255    
   
256      (*transform T*P*P..Ps*)      (*transform T*P*P..Ps*)
257      fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let      fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let
   
258            val Pid = 0            val Pid = 0
259            val tid = 1            val tid = 1
260            val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)            val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
261          (*need to rewrite dx*)          (*need to rewrite dx*)
262            val sxx = sx@newsx            val sxx = sx@newsx
   
263            val (_, sizes, E.Conv(_, alpha', _, dx)) = (case sxx            val (_, sizes, E.Conv(_, alpha', _, dx)) = (case sxx
264  (* QUESTION: what is the significance of "9" and "7" in this code? *)  (* QUESTION: what is the significance of "9" and "7" in this code? *)
265                   of [] => ([], index, E.Conv(9, alpha, 7, dx'))                   of [] => ([], index, E.Conv(9, alpha, 7, dx'))
266                    | _ => CleanIndex.clean(E.Conv(9, alpha, 7, dx'), index, sxx)                    | _ => CleanIndex.clean(E.Conv(9, alpha, 7, dx'), index, sxx)
267                  (* end case *))                  (* end case *))
   
268            fun filterAlpha [] = dx'            fun filterAlpha [] = dx'
269              | filterAlpha (E.C _::es) = filterAlpha es              | filterAlpha (E.C _::es) = filterAlpha es
270              | filterAlpha (e1::es) = e1::(filterAlpha es)              | filterAlpha (e1::es) = e1::(filterAlpha es)
   
271        val exp = E.Tensor(tid, filterAlpha alpha')        val exp = E.Tensor(tid, filterAlpha alpha')
   
272            val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)            val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)
273            val params = [E.TEN(true, [dim,dim]), E.TEN(true, sizes)]            val params = [E.TEN(true, [dim,dim]), E.TEN(true, sizes)]
274            val ein0 = mkEin(params, index, body')            val ein0 = mkEin(params, index, body')
   
275            in            in
276              (splitvar, ein0, sizes, dx, alpha')              (splitvar, ein0, sizes, dx, alpha')
277            end            end
278    
   
279    (* floats the reconstructed field term *)    (* floats the reconstructed field term *)
280      fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let      fun liftProbe (avail, (y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
281            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe            val E.Probe(E.Conv(Vid, alpha, hid, dx), E.Tensor(tid, _)) = probe
282            val freshIndex = getsumshift(sx, length(index))            val freshIndex = getsumshift(sx, length(index))
283        val E.IMG(dim, shape) = List.nth(params, Vid)        val E.IMG(dim, shape) = List.nth(params, Vid)
   
284                      (* transform T*P*P..Ps *)                      (* transform T*P*P..Ps *)
285            val (splitvar, ein0, sizes, dx, alpha') =            val (splitvar, ein0, sizes, dx, alpha') =
286                  createEinApp (body, alpha, index, freshIndex, dim, dx, sx)                  createEinApp (body, alpha, index, freshIndex, dim, dx, sx)
287        val vT = V.new ("TPP", Ty.TensorTy(sizes))        val vT = V.new ("TPP", Ty.TensorTy(sizes))
   
288                      (* reconstruct the lifted probe *)                      (* reconstruct the lifted probe *)
289        (* making params args: image, position, and kernel ids *)        (* making params args: image, position, and kernel ids *)
290        val kid = 0 (* params used *)        val kid = 0 (* params used *)
291        val params' = List.nth(params,Vid)::(List.tabulate(dim,fn _=> E.KRN))        val params' = List.nth(params,Vid)::(List.tabulate(dim,fn _=> E.KRN))
292        (* create body for ein expression *)        (* create body for ein expression *)
293        val sxn = length sizes (*next available index id *)        val sxn = length sizes (*next available index id *)
294        val (args', vP, probe') = fieldReconstruction (avail, sxn, alpha', shape, dx,  Vid, Vid, kid, hid, tid, args)            val (args', vP, probe') =
295                    fieldReconstruction (avail, sxn, alpha', shape, dx,  Vid, Vid, kid, hid, tid, args)
296            val einApp1 = IR.EINAPP(mkEin(params', sizes, probe'), args')            val einApp1 = IR.EINAPP(mkEin(params', sizes, probe'), args')
   
297                    (* transform T*P*P..Ps *)                    (* transform T*P*P..Ps *)
298        val rtn0 = if splitvar        val rtn0 = if splitvar
299        then FloatEin.transform(y, EinSums.transform ein0, [vP, vT])        then FloatEin.transform(y, EinSums.transform ein0, [vP, vT])
300        else [(y, IR.EINAPP(ein0, [vP, vT]))]        else [(y, IR.EINAPP(ein0, [vP, vT]))]
   
301            in            in
302        List.app (fn e => AvailRHS.addAssignToList(avail, e)) (((vT, einApp1)::(rtn0)))        List.app (fn e => AvailRHS.addAssignToList(avail, e)) (((vT, einApp1)::(rtn0)))
303            end            end
304    
   
305    (* expandEinOp: code->  code list    (* expandEinOp: code->  code list
306     * A this point we only have simple ein ops     * A this point we only have simple ein ops
307     * Looks to see if the expression has a probe. If so, replaces it.     * Looks to see if the expression has a probe. If so, replaces it.
# Line 330  Line 311 
311                  replaceProbe (avail, e, body, [])                  replaceProbe (avail, e, body, [])
312              | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>              | (E.Probe(E.Conv(_, alpha, _, dx) ,_)) =>
313                  liftProbe (avail, e, body, []) (*scans dx for contant*)                  liftProbe (avail, e, body, []) (*scans dx for contant*)
   
314              | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>              | (E.Sum(sx, p as E.Probe(E.Conv(_, _, _, []), _))) =>
315                  replaceProbe (avail, e, p, sx)  (*no dx*)                  replaceProbe (avail, e, p, sx)  (*no dx*)
   
316              | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>              | (E.Sum(sx, p as E.Probe(E.Conv(_, [], _, dx), _))) =>
317                  liftProbe (avail, e, p, sx) (*scalar field*)                  liftProbe (avail, e, p, sx) (*scalar field*)
   
318              | (E.Sum(sx, E.Probe p)) =>              | (E.Sum(sx, E.Probe p)) =>
319                  replaceProbe (avail, e, E.Probe p, sx)                  replaceProbe (avail, e, E.Probe p, sx)
320              | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>              | (E.Sum(sx, E.Opn(E.Prod, [eps, E.Probe p]))) =>
321                  replaceProbe (avail, e, E.Probe p, sx)                  replaceProbe (avail, e, E.Probe p, sx)
   
322              | _ => AvailRHS.addAssignToList (avail, e)              | _ => AvailRHS.addAssignToList (avail, e)
323            (* end case *))            (* end case *))
324    

Legend:
Removed from v.3785  
changed lines
  Added in v.3787

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0