Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/mid-to-low/field-to-low.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/mid-to-low/field-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3796 - (view) (download)

1 : jhr 3727 (* field-to-low.sml
2 :     *
3 :     * NOTE: this code will need to be changed if we ever want to support different kernels
4 :     * for different axes
5 :     *
6 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
7 :     *
8 :     * COPYRIGHT (c) 2016 The University of Chicago
9 :     * All rights reserved.
10 :     *)
11 :    
12 :     structure FieldToLow : sig
13 :    
14 :     (* expand a MidIR probe to LowIR code. The arguments are:
15 :     *
16 : jhr 3728 * avail -- available LowIR assignments
17 :     * mapp -- mapping from iteration indices to deBruijn indices
18 :     * sx -- summation bounds
19 :     * prod -- body of summation, which is a product of Ein expressions
20 : jhr 3745 * args -- the actual arguments of the enclosing Ein expression
21 : jhr 3727 *)
22 :     val expand : {
23 : jhr 3790 avail : AvailRHS.t,
24 :     mapp : int IntRedBlackMap.map,
25 :     sx : Ein.sumrange list,
26 :     prod : Ein.ein_exp list,
27 :     args : LowIR.var list
28 :     } -> LowIR.var
29 : jhr 3727
30 :     end = struct
31 :    
32 :     structure IR = LowIR
33 :     structure Ty = LowTypes
34 :     structure Op = LowOps
35 :     structure Var = LowIR.Var
36 :     structure E = Ein
37 :     structure Mk = MkLowIR
38 :     structure IMap = IntRedBlackMap
39 :    
40 : jhr 3745 (* Index cons at piece *)
41 :     fun getHolder (args, id, piece) = let
42 : jhr 3790 val t = List.nth(args, id)
43 :     in
44 :     case IR.Var.getDef t
45 : jhr 3745 of IR.CONS(eargs, _) => List.nth(eargs, piece)
46 : jhr 3790 | IR.OP(_, _) => t
47 :     | rhs => raise Fail(String.concat[
48 :     "getHolder found ", Var.name t, "=", IR.RHS.toString rhs,
49 :     " at ", Int.toString id
50 :     ])
51 :     (* end case *)
52 :     end
53 : jhr 3727
54 : jhr 3794 (* FIXME: I think this function can be reimplemented in a more straightforward way *)
55 :     (* evaluate image expression *)
56 : jhr 3795 fun imgToArgs (avail, mapp, sx, E.Img(Vid, alpha, vs, s, _), args) = let
57 : jhr 3794 val vI = List.nth(args, Vid)
58 :     val range1 = 2*s
59 :     val range = List.tabulate (range1, fn e => e)
60 :     val beta = List.map (fn id => Mk.lookupMu(mapp, id)) alpha
61 : cchiw 3796 val dim = length(vs)
62 :    
63 : cchiw 3793 (* Index tensor with image shape and position v_beta[n0,n1,n2]*)
64 :     (* change here depending on data layout created by load voxel*)
65 : cchiw 3796
66 :     fun getIX idxs =
67 :     AvailRHS.addAssign (
68 :     avail, "proj", Ty.TensorTy [range1],
69 :     IR.OP(Op.ProjectLast(IR.Var.ty (vI), beta @ idxs), [vI]))
70 :    
71 : cchiw 3793 (* 2-d case*)
72 : jhr 3794 fun iter2 ([], rest) = rest
73 :     | iter2 (i::es, rest) = iter2(es, getIX [i]::rest)
74 : cchiw 3793 (* 3-d case*)
75 : jhr 3794 fun iter3 ([], [], rest) = rest
76 :     | iter3 ([_], [], rest) = rest
77 :     | iter3 (i::es, [], rest) = iter3(es, range, rest)
78 :     | iter3 (i::es, j::js, rest) = iter3(i::es, js, getIX [i, j]::rest)
79 :     in
80 :     case vs
81 :     of [_] => [getIX []] (* 1-d case *)
82 :     | [_,_] => iter2 (range, []) (* 2-d case*)
83 :     | [_, _, _] => iter3 (range, range, []) (* 3-d case*)
84 : jhr 3795 | _ => raise Fail "unsupported dimension"
85 : jhr 3794 (* end case *)
86 :     end
87 : cchiw 3793
88 : jhr 3745 (* Convolution product of Image and Kernel *)
89 : cchiw 3744 fun prodImgKrn (avail, imgArg, krnArg, s) = let
90 :     (* Number of arguments for Cons *)
91 : jhr 3790 val range1 = 2*s
92 :     val range0 = range1-1
93 :     fun ConsInt args = let
94 :     val ty = Ty.TensorTy [range1]
95 :     val rhs = IR.CONS (args, ty)
96 :     in
97 :     AvailRHS.addAssign (avail,"cons"^"_", ty, rhs)
98 :     end
99 :     fun mkDotVec (a,b) = Mk.vecDot(avail, range1, a, b)
100 :     fun mul2d ([], rest, hy) = ConsInt (List.rev rest)
101 :     | mul2d (e::es, rest, hy) = let
102 :     val vA = mkDotVec (e, hy)
103 :     in
104 :     mul2d (es, vA::rest, hy)
105 :     end
106 :     fun mul3d ([], _ , _, rest, hz) = rest
107 :     | mul3d (e1::es, rest, 0, consrest, hz) = let
108 :     val vA = mkDotVec (hz, e1)
109 :     val vD = ConsInt (rest@[vA])
110 :     in
111 :     mul3d (es, [], range0, consrest@[vD], hz)
112 :     end
113 :     | mul3d (e1::es, rest, n, consrest, hz) = let
114 :     val vA = mkDotVec (hz, e1)
115 :     in
116 :     mul3d (es, rest@[vA], n-1, consrest, hz)
117 :     end
118 : cchiw 3741 (*Create Product by doing case analysis of the dimension*)
119 : jhr 3790 in
120 :     case (krnArg, imgArg)
121 :     of ([h0], [i]) => mkDotVec (i, h0) (*1-D case*)
122 :     | ([h0, h1], _) => let
123 :     val vA = mul2d (imgArg, [], h0)
124 :     in
125 : cchiw 3741 mkDotVec (vA, h1)
126 : jhr 3790 end
127 :     | ([h0, h1, h2], _) => let
128 :     val restZ = mul3d (imgArg, [], range0, [], h0)
129 :     val restY = mul2d (restZ, [], h1)
130 :     in
131 :     mkDotVec (h2, restY)
132 :     end
133 :     | _ => raise Fail "Kernel dimensions not between 1-3"
134 :     (* end case *)
135 :     end
136 : cchiw 3741
137 : jhr 3728 (* expand a MidIR probe to LowIR code. *)
138 : jhr 3745 fun expand {avail, mapp, sx, prod as E.Img e1::krnexps, args} = let
139 : jhr 3794 val imgArgs = imgToArgs(avail, mapp, sx, E.Img e1, args)
140 : cchiw 3741 (* get piece for each kernel *)
141 : jhr 3790 fun getf (E.Krn(id, dels, _)) = let
142 :     (* evaluate dels to integer *)
143 : jhr 3745 val delta = List.foldl (fn((i, j), y) => Mk.evalDelta(mapp, i, j) + y) 0 dels
144 : jhr 3790 in
145 :     getHolder (args, id, delta) (* selects variable in holder/cons list *)
146 :     end
147 : cchiw 3744 (* evaluate kernel expression(s) *)
148 : jhr 3790 val krnArgs = List.map getf krnexps (* doesn't create code *)
149 :     val rtn = prodImgKrn (avail, imgArgs, krnArgs, #4 e1)
150 :     in
151 :     rtn
152 :     end
153 : jhr 3727
154 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0