Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis12/src/compiler/high-to-mid/probe.sml
ViewVC logotype

Diff of /branches/vis12/src/compiler/high-to-mid/probe.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 334, Thu Aug 19 20:53:07 2010 UTC revision 349, Fri Sep 24 00:24:20 2010 UTC
# Line 8  Line 8 
8    
9  structure Probe : sig  structure Probe : sig
10    
11      val expand : MidIL.var * FieldDef.field_def * HighIL.var -> HighIL.assign list      val expand : MidIL.var * FieldDef.field_def * MidIL.var -> MidIL.assign list
12    
13    end = struct    end = struct
14    
# Line 16  Line 16 
16      structure SrcOp = HighOps      structure SrcOp = HighOps
17      structure DstIL = MidIL      structure DstIL = MidIL
18      structure DstOp = MidOps      structure DstOp = MidOps
19        structure DstV = DstIL.Var
20      structure VMap = SrcIL.Var.Map      structure VMap = SrcIL.Var.Map
21        structure IT = Shape
22    
23    (* a tree representation of nested iterations over the image space, where the    (* generate a new variable indexed by dimension *)
24     * height of the tree corresponds to the number of dimensions and at each node      fun newVar_dim (prefix, d) =
25     * we have as many children as there are iterations.            DstV.new (prefix ^ Partials.axisToString(Partials.axis d))
26     *)  
27      structure IT =      fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
28        struct      fun cons (x, args) = (x, DstIL.CONS args)
29          datatype ('nd, 'lf) iter_tree      fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt i)))
30            = LF of 'lf      fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i)))
           | ND of ('nd * ('nd, 'lf) iter_tree list)  
31    
32          fun create (depth, width, ndAttr, f, lfAttr, init) = let    (* generate code for a evaluating a single element of a probe operation *)
33                fun mk (d, i, arg) = if (d < depth)      fun probeElem {
34                      then ND(ndAttr arg, List.tabulate(width, fn j => mk(d+1, j, f(j, arg))))              dim,        (* dimension of space *)
35                      else LF(lfAttr arg)              h, s,       (* kernel h with support s *)
36                n, f,       (* Dst vars for integer and fractional components of position *)
37                voxIter     (* iterator over voxels *)
38              } (result, pdOp) = let
39            (* generate the variables that hold the convolution coefficients *)
40              val convCoeffs = let
41                    val Partials.D l = pdOp
42                    fun mkVar 0 = newVar_dim("h", dim)
43                      | mkVar 1 = newVar_dim("dh", dim)
44                      | mkVar i = newVar_dim(concat["d", Int.toString i, "h"], dim)
45                    in
46                      List.map mkVar l
47                    end
48            (* for each dimension, we evaluate the kernel at the coordinates for that axis *)
49              val coeffCode = let
50                    fun gen (x, k, (d, code)) = let
51                          val d = d-1
52                          val fd = newVar_dim ("f", d)
53                          val a = DstV.new "a"
54                          val tmps = List.tabulate(2*s, fn i => (DstV.new("t"^Int.toString i), s - (i+1)))
55                          fun mkArg ((t, n), code) = let
56                                val t' = DstV.new "r"
57                in                in
58                  mk (0, 0, init)                                realLit (t', n) ::
59                                  assign (t, DstOp.Add DstOp.realTy, [fd, t']) ::
60                                  code
61                end                end
62                        (* code in reverse order *)
63          fun map (nd, lf) t = let                        val code =
64                fun mapf (LF x) = LF(lf x)                              assign(x, DstOp.EvalKernel(2*s, h, k), [a]) ::
65                  | mapf (ND(i, kids)) = ND(nd i, List.map mapf kids)                              assign(fd, DstOp.Select(dim, d), [f]) ::
66                                cons(a, List.map #1 tmps) ::
67                                  (List.foldl mkArg [] tmps)
68                in                in
69                  mapf t                          (d, List.rev code)
70                end                end
71                    val Partials.D l = pdOp
         fun foldr f init t = let  
               fun fold (LF x, acc) = f(x, acc)  
                 | fold (ND(_, kids), acc) = List.foldr fold acc kids  
72                in                in
73                  fold t                    #2 (ListPair.foldr gen (dim, []) (convCoeffs, l))
74                end                end
75            (* generate the reduction code *)
76              fun genReduce (result, [hh], IT.ND(_, kids), code) = let
77                  (* the kids will all be leaves *)
78                    val vv = DstV.new "vv"
79                    fun getVox (IT.LF{vox, offsets}) = vox
80                    in
81                      cons (vv, List.map getVox kids) ::
82                      assign (result, DstOp.Dot(2*s), [hh, vv]) :: code
83        end        end
84    (* generate a new variable indexed by dimension *)              | genReduce (result, hh::r, IT.ND(_, kids), code) = let
85      local                  val tv = DstV.new "tv"
86        val dimNames = Vector.fromList[ "x", "y", "z" ];                  val tmps = List.tabulate(2*s, fn i => DstV.new("t"^Int.toString i))
87                    fun lp ([], [], code) = code
88                      | lp (t::ts, kid::kids, code) = genReduce(t, r, kid, lp(ts, kids, code))
89                    val code = cons(tv, tmps) :: assign(result, DstOp.Dot(2*s), [hh, tv]) :: code
90      in      in
91      fun newVar_dim (prefix, d) =                    lp (tmps, kids, code)
92            DstIL.Var.new (prefix ^ Vector.sub(dimNames, d))                  end
93              val reduceCode = genReduce (result, convCoeffs, voxIter, [])
94      fun assign (x, rator, args) = (x, DstIL.OP(rator, args))            in
95      fun cons (x, args) = (x, DstIL.CONS args)              coeffCode @ reduceCode
96      fun realLit (x, i) = (x, DstIL.LIT(Literal.Float(FloatLit.fromInt i)))            end
     fun intLit (x, i) = (x, DstIL.LIT(Literal.Int(IntInf.fromInt i)))  
97    
98    (* generate code for probing the field (D^k (v * h)) at pos *)    (* generate code for probing the field (D^k (v * h)) at pos *)
99      fun probe (result, (k, v, h), pos) = let      fun probe (result, (k, v, h), pos) = let
100            val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v            val ImageInfo.ImgInfo{dim, ty=([], ty), ...} = v
101            val dimTy = DstOp.VecTy dim            val dimTy = DstOp.VecTy dim
102            val s = Kernel.support h            val s = Kernel.support h
           val sTy = DstOp.VecTy(2*s)  
103          (* generate the transform code *)          (* generate the transform code *)
104            val x = DstIL.Var.new "x"     (* image-space position *)            val x = DstV.new "x"  (* image-space position *)
105            val f = DstIL.Var.new "f"            val f = DstV.new "f"
106            val nd = DstIL.Var.new "nd"            val nd = DstV.new "nd"
107            val n = DstIL.Var.new "n"            val n = DstV.new "n"
108            val transformCode = [            val transformCode = [
109                    assign(x, DstOp.Transform v, [pos]),                    assign(x, DstOp.Transform v, [pos]),
110                    assign(nd, DstOp.Floor dim, [x]),                    assign(nd, DstOp.Floor dim, [x]),
111                    assign(f, DstOp.Sub dimTy, [x, nd]),                    assign(f, DstOp.Sub dimTy, [x, nd]),
112                    assign(n, DstOp.TruncToInt dim, [nd])                    assign(n, DstOp.TruncToInt dim, [nd])
113                  ]                  ]
114            (* generate the shape of the differentiation tensor with variables representing
115             * the elements
116             *)
117              val diffIter = let
118                    val partial = Partials.partial dim
119                    fun f (i, axes) = Partials.axis i :: axes
120                    fun g axes =
121                          (DstV.new(String.concat("y" :: List.map Partials.axisToString axes)), partial axes)
122                    in
123                      IT.create (k-1, dim, fn _ => (), f, g, [])
124                    end
125          (* generate code to load the voxel data *)          (* generate code to load the voxel data *)
126            val voxIter = let            val voxIter = let
127                  fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id)                  fun f (i, (offsets, id)) = (i - (s - 1) :: offsets, i::id)
128                  fun g (offsets, id) = {                  fun g (offsets, id) = {
129                          offsets = offsets,                          offsets = offsets,
130                          vox = DstIL.Var.new(String.concat("v" :: List.map Int.toString id))                          vox = DstV.new(String.concat("v" :: List.map Int.toString id))
131                        }                        }
132                  in                  in
133                    IT.create (dim-1, 2*s, fn _ => (), f, g, ([], []))                    IT.create (dim-1, 2*s, fn _ => (), f, g, ([], []))
# Line 95  Line 137 
137                        fun computeIndices (_, []) = ([], [])                        fun computeIndices (_, []) = ([], [])
138                          | computeIndices (i, offset::offsets) = let                          | computeIndices (i, offset::offsets) = let
139                              val index = newVar_dim("i", i)                              val index = newVar_dim("i", i)
140                              val t1 = DstIL.Var.new "t1"                              val t1 = DstV.new "t1"
141                              val t2 = DstIL.Var.new "t2"                              val t2 = DstV.new "t2"
142                              val (indices, code) = computeIndices (i+1, offsets)                              val (indices, code) = computeIndices (i+1, offsets)
143                              val code =                              val code =
144                                    intLit(t1, offset) ::                                    intLit(t1, offset) ::
145                                    assign(t2, DstOp.Select i, [n]) ::                                    assign(t2, DstOp.Select(2*s, i), [n]) ::
146                                    assign(index, DstOp.Add(DstOp.IntTy), [t1, t2]) ::                                    assign(index, DstOp.Add(DstOp.IntTy), [t1, t2]) ::
147                                    code                                    code
148                              val indices = index::indices                              val indices = index::indices
# Line 108  Line 150 
150                                (indices, code)                                (indices, code)
151                              end                              end
152                        val (indices, indicesCode) = computeIndices (0, ~(s-1) :: offsets)                        val (indices, indicesCode) = computeIndices (0, ~(s-1) :: offsets)
153                        val a = DstIL.Var.new "a"                        val a = DstV.new "a"
154                        in                        in
155                          indicesCode :: [                          indicesCode @ [
156                              assign(a, DstOp.VoxelAddress v, indices),                              assign(a, DstOp.VoxelAddress v, indices),
157                              assign(vox, DstOp.LoadVoxels(ty, 2*s))                              assign(vox, DstOp.LoadVoxels(ty, 2*s), [a])
158                            ] @ code                            ] @ code
159                        end                        end
160                  in                  in
161                    IT.foldr genCode [] voxIter                    IT.foldr genCode [] voxIter
162                  end                  end
163            val voxVars = IT.foldr (fn ({vox, ...}, vs) => vox::vs) [] voxIter          (* generate code to evaluate and construct the result tensor *)
164          (* generate the code for computing the convolution coefficients *)            val probeElem = probeElem {dim = dim, h = h, s = s, n = n, f = f, voxIter = voxIter}
165            val convCoeffs = Vector.tabulate (dim,            fun genProbe (result, IT.ND(_, kids as (IT.LF _)::_), code) = let
166                  fn d => Vector.tabulate (k+1,                (* the kids will all be leaves *)
167                    fn 0 => newVar_dim("h", d)                  fun genProbeCode (IT.LF arg, code) = probeElem arg @ code
168                     | 1 => newVar_dim("dh", d)                  fun getProbeVar (IT.LF(t, _)) = t
                    | i => newVar_dim(concat["d", Int.toString i, "h"], d)))  
           fun coefficient (d, k) = Vector.sub(Vector.sub(convCoeffs, d), k)  
           fun genCoeffCode d = if (d < dim)  
                 then let  
                   val code = genCoeffCode (d+1)  
                   val fd = newVar_dim "f"  
                   val a = DstIL.Var.new "a"  
                   val tmps = List.tabulate(2*s,  
                         fn i => (DstIL.Var.new("t"^Int.toString i), s - (i+1)))  
                   fun mkArg ((t, n), code) = let  
                         val t' = DstIL.Var.new "r"  
169                          in                          in
170                            realLit (t', n) ::                    cons (result, List.map getProbeVar kids) :: List.foldr genProbeCode code kids
                           assign (t, DstOp.Add DstOp.realTy, [fd, t']) ::  
                           code  
171                          end                          end
172                    val coeffCode =              | genProbe (result, IT.ND(_, kids), code) = let
173                          cons(a, List.map #1 tmps) ::                  val tmps = List.tabulate(dim, fn i => DstV.new("t"^Int.toString i))
174                          List.tabulate (k+1, fn i =>                  fun lp ([], [], code) = code
175                            assign(coefficient(d, i), DstOp.EvalKernel(2*s, h, i), [a]))                    | lp (t::ts, kid::kids, code) = genProbe(t, kid, lp(ts, kids, code))
176                    val code = cons(result, tmps) :: code
177                    in                    in
178                      assign(fd, DstOp.Select d, f) ::                    lp (tmps, kids, code)
                     (List.foldr mkArg (coeffCode @ code) tmps)  
179                    end                    end
180                  else []            val probeCode = genProbe (result, diffIter, [])
           val coeffCode = genCoeffCode  
         (* generate the reduction code *)  
           fun genReduce (d, IT.ND(kids), code) =  
                 if (d < dim)  
                   then List.foldr (fn (nd, code) => genReduce(d+1, nd, code)) code kids  
                   else let (* the kids will all be leaves *)  
                     val vv = DstIL.Var.new "vv"  
                     fun getVox (IT.LF{vox, offsets}) = vox  
                     val hh = coefficient (d, 0) (* FIXME: what is the right value for k? *)  
181                      in                      in
182                        cons (vv, List.map getVox kids) ::              transformCode @ loadCode @ probeCode
                       assign (t, DstOp.Dot, [hh, vv]) :: code  
                     end  
           val reduceCode = genReduce (1, voxIter, [])  
           in  
             transformCode @ loadCode @ coeffCode @ reduceCode  
183            end            end
184    
185      end      fun expand (result, fld, pos) = let
   
     fun expand (result, FieldDef.CONV(0, img, h), pos) = let  
186            fun expand' (result, FieldDef.CONV(k, v, h)) = let            fun expand' (result, FieldDef.CONV(k, v, h)) = let
187                  val x = DstIL.Var.new "x"                  val x = DstV.new "x"
188                  val xformStm = (x, DstIL.OP(DstOp.Transform img, [pos']))                  val xformStm = (x, DstIL.OP(DstOp.Transform v, [pos]))
189                  in                  in
190                    probe (result, (k, v, h), x) @ [xformStm]                    probe (result, (k, v, h), x) @ [xformStm]
191                  end                  end
192    (* should push negation down to probe operation
193              | expand' (result, FieldDef.NEG fld) = let              | expand' (result, FieldDef.NEG fld) = let
194                  val r = DstIL.Var.new "value"                  val r = DstV.new "value"
195                  val stms = expand' (r, fld)                  val stms = expand' (r, fld)
196                    val ty = ??
197                  in                  in
198                    expand' (r, fld) @ [assign(r, DstOp.Neg ty, [r])]                    expand' (r, fld) @ [assign(r, DstOp.Neg ty, [r])]
199                  end                  end
200    *)
201              | expand' (result, FieldDef.SUM(fld1, dlf2)) = raise Fail "expandInside: SUM"              | expand' (result, FieldDef.SUM(fld1, dlf2)) = raise Fail "expandInside: SUM"
202            in            in
203              List.rev (expand (result, fld))              List.rev (expand' (result, fld))
204            end            end
205    
206    end    end

Legend:
Removed from v.334  
changed lines
  Added in v.349

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0