Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/evalImg-set.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/evalImg-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3624 - (view) (download)

1 : jhr 3624 (* evalImg-set.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 : cchiw 3602 structure EvalImg = struct
10 : cchiw 3444 local
11 :    
12 : cchiw 3602 structure Op = LowOps
13 :     structure Ty = LowILTypes
14 :     structure IL = LowIL
15 : cchiw 3444 structure Var = LowIL.Var
16 :     structure E = Ein
17 : cchiw 3602 structure P = Printer
18 :     structure H = Helper
19 :     structure IMap = IntRedBlackMap
20 : cchiw 3444
21 : cchiw 3602 in
22 : cchiw 3444
23 : cchiw 3602 fun lookup k d = IMap.find (d, k)
24 :     fun insert (k, v) d = IMap.insert (d, k, v)
25 :     fun find e = H.find e
26 :     fun mapIndex e = H.mapIndex e
27 :     fun mkInt n = H.mkInt n
28 :     fun assgn e = H.assignOP e
29 : cchiw 3444 fun indexTensor e = H.indexTensor e
30 : cchiw 3602 fun mkAddInt e = H.mkAddInt e
31 : cchiw 3444 fun mkAddPtr e = H.mkAddPtr e
32 : cchiw 3602 fun mkProdInt e = H.mkProdInt e
33 :     fun err str = raise Fail (str)
34 :     fun psize n = foldl (fn (a, b) => b*a) 1 n
35 :     fun asize n = foldl (fn (a, b) => b+a) 0 n
36 : cchiw 3444
37 : cchiw 3602 (* mkImg:dict*string*E.params*var list*sum_id list* () *image*var*int*int*int
38 : cchiw 3444 * ->var*lowIL.assgn
39 :     * The image "imgarg" is probed at positions
40 :     * Σ_{sx} V_alpha[pos0::px]
41 : cchiw 3553 * sumPos () iterates over the summation indices and creates a mapp for the indicies
42 :     * once mapp (j->2 k->0) is created sumPos () calls createImgVar () to get the addr of Σ_k V_{i}[T_j, T_k]
43 : cchiw 3602 * createImgVar () uses mkpos (), getPosAddr () and getImgAddr () to get imgvar
44 : cchiw 3553 *)
45 : cchiw 3602 fun mkImg (avail, mappOrig, params, args, sx, (_, v_alpha, pos0::px), v, imgarg, lb, range0, range1) = let
46 :     val dim = ImageInfo.dim v
47 :     val ptyTy = Ty.AddrTy v
48 :     val sizes = ImageInfo.sizes v
49 :     val (avail, vBase) = assgn (avail, Op.baseAddr v, [imgarg], "baseAddr", ptyTy)
50 :     (*base address*)
51 :     val (avail, vShapeShift) = mkInt (avail, psize (ImageInfo.voxelShape v))
52 :     (*shift of the image field.*)
53 : cchiw 3444
54 : cchiw 3602 (*Since the image is loaded as a vector
55 : cchiw 3444 * we evaluate the first position just once
56 : cchiw 3602 * Σ_{ij..} V[T+i, T+j...]-> Σ_{j..} V[T+j...]
57 : cchiw 3444 * and we drop the first summation index
58 : cchiw 3602 * Additionally, summation indices are reversed
59 : cchiw 3553 * that inner loop is second (y) axis and outer loop is third (z) axis
60 :     *)
61 : cchiw 3602 val (avail, vPos0, sxx) = let
62 :     val E.Opn (E.Add, [E.Tensor (t1, ix1), _ ]) = pos0
63 :     val (avail, vA) = indexTensor (avail, mappOrig, ( params, args, t1, ix1, Ty.IntTy))
64 :     val (avail, vB) = mkInt (avail, lb)
65 :     val (avail, vC) = mkAddInt (avail, [vA, vB])
66 :     val sxx = List.rev (List.tl(List.map (fn (E.V sid, _, _) => sid) sx))
67 : cchiw 3444 in
68 : cchiw 3602 (avail, vC, sxx)
69 : cchiw 3444 end
70 : cchiw 3602 (* mkpos:ein_exp list*var list*IL.assgn list
71 : cchiw 3444 * transform ein_exp to low-il
72 :     * returns var for the position
73 : cchiw 3553 *)
74 : cchiw 3602 fun mkpos (avail, e, mapp, rest) = (case e
75 :     of [] => (avail, rest)
76 :     | (E.Opn (E.Add, ([E.Tensor (t1, ix1), E.Value v1])) ::es) => let
77 :     val (avail, vA) = indexTensor (avail, mapp, (params, args, t1, ix1, Ty.IntTy))
78 :     val (avail, rest') = (case find (v1, mapp)
79 :     of 0 => (avail, vA)
80 :     | j => let
81 :     val (avail, vB) = mkInt (avail, j)
82 :     val (avail, vC) = mkAddInt (avail, [vA, vB])
83 :     in (avail, vC) end
84 :     (*end case*))
85 :     in mkpos (avail, es, mapp, rest@[rest']) end
86 : cchiw 3553 | e1::_ => raise Fail ("Incorrect pos for Image: "^P.printbody e1)
87 : cchiw 3602 (*end case*))
88 :     (* getPosAddr:var list->var*IL.assgn list
89 : cchiw 3553 * create position addr based on image info's shapeshift, image info's sizes, and args
90 :     * args are the variables for this specific positions. V_ ([x, y])
91 :     * returns vPosAddr, PosAddrcode
92 :     *)
93 : cchiw 3602 fun getPosAddr (avail, args) = case (sizes, args)
94 :     of ([ _ ], [i]) => mkProdInt (avail, [vShapeShift, i]) (*1-d*)
95 :     | ([x, _ ], [i, j]) =>let (*2-d*)
96 :     val (avail, vA) = mkInt (avail, x)
97 :     val (avail, vB) = mkProdInt (avail, [vA, j])
98 :     val (avail, vC) = mkAddInt (avail, [i, vB])
99 :     in mkProdInt (avail, [vShapeShift, vC]) end
100 :     | ([x, y, _], [i, j, k]) =>let (*3-d*)
101 :     val (avail, vA) = mkInt (avail, y)
102 :     val (avail, vB) = mkProdInt (avail, [vA, k])
103 :     val (avail, vC) = mkAddInt (avail, [j, vB])
104 :     val (avail, vD) = mkInt (avail, x)
105 :     val (avail, vE) = mkProdInt (avail, [vD, vC])
106 :     val (avail, vF) = mkAddInt (avail, [i, vE])
107 :     in mkProdInt (avail, [vShapeShift, vF]) end
108 :    
109 :     (* getImgAddr:int list *var->var*IL.assgn list
110 :     * creates image address with ^position address, imgType, and base address
111 : cchiw 3553 * imgType are image specific indices V[0, 1] (_)
112 : cchiw 3602 * ->returns (vImgAddr, ImgAddrcode)
113 : cchiw 3553 *)
114 : cchiw 3602 fun getImgAddr (avail, imgType, vPosAddr) = case imgType
115 :     of [] => mkAddPtr (avail, [vBase, vPosAddr], ptyTy)
116 :     | [0] => mkAddPtr (avail, [vBase, vPosAddr], ptyTy)
117 : cchiw 3444 | [_] => let
118 : cchiw 3602 val (avail, vA) = mkAddPtr (avail, [vBase, vPosAddr], ptyTy)
119 :     val (avail, vB) = mkInt (avail, asize imgType)
120 :     in mkAddPtr (avail, [vB, vA], ptyTy) end
121 : cchiw 3553 | [i, j] => let
122 : cchiw 3602 val [a, b] = ImageInfo.voxelShape v
123 :     val (avail, vA) = mkAddPtr (avail, [vBase, vPosAddr], ptyTy)
124 :     val (avail, vB) = mkInt (avail, (b*j) +i)
125 :     in mkAddPtr (avail, [vB, vA], ptyTy) end
126 : cchiw 3542
127 : cchiw 3444
128 : cchiw 3602 (* createImgVar:dict->var*IL.assgn list
129 : cchiw 3444 * gets low-il var for loading an image address
130 : cchiw 3553 *)
131 : cchiw 3602 fun createImgVar (avail, mapp) = let
132 :     val (avail, vA) = mkpos (avail, px, mapp, []) (*transforms the probed position to low-il*)
133 :     val posArgs = vPos0::vA (*adds intial position to ^*)
134 :     val (avail, vPosAddr) = getPosAddr (avail, posArgs) (*position address*)
135 :     val imgType = List.map (fn (e1) => mapIndex (e1, mapp)) v_alpha (*img specific index*)
136 :     val (avail, vImgAddr) = getImgAddr (avail, imgType, vPosAddr) (*img address*)
137 : cchiw 3444 in
138 : cchiw 3602 assgn (avail, Op.imgLoad (v, dim, range1), [vImgAddr], "imgrng", Ty.tensorTy ([range1]))
139 : cchiw 3444 end
140 :    
141 : cchiw 3602 val range0List = List.tabulate (range0+1, fn e =>e)
142 :     (* sumPos:index_id * var list*lowil.assgn list*dict*int
143 : cchiw 3444 * ->var*lowil.assgn list
144 :     * sumPos iterates over the summation indices and creates mapp
145 : cchiw 3553 *)
146 : cchiw 3602 fun sumPos (avail, [], rest, dict, _) = let
147 :     val (avail, rest') = createImgVar (avail, dict)
148 :     in (avail, rest'::rest) end
149 :     | sumPos (avail, [sid], rest, dict, [r]) = let
150 : cchiw 3553 val n' = lb+r
151 : cchiw 3602 val mapp = insert (sid, n') dict
152 :     val (avail, rest') = createImgVar (avail, mapp)
153 :     in (avail, rest'::rest) end
154 :     | sumPos (avail, [sid], rest, dict, r::es) = let
155 : cchiw 3553 val n' = lb+r
156 : cchiw 3602 val mapp = insert (sid, n') dict
157 :     val (avail, rest') = createImgVar (avail, mapp)
158 :     in sumPos (avail, [sid], rest'::rest, dict, es) end
159 :     | sumPos (avail, sid::sxx, rest, dict, [r]) = let
160 : cchiw 3553 val n' = lb+r
161 : cchiw 3602 val mapp = insert (sid, n') dict
162 :     in
163 :     sumPos (avail, sxx, rest, mapp, range0List)
164 :     end
165 :     | sumPos (avail, sid::sxx, rest, dict, r::es) = let
166 : cchiw 3553 val n' = lb+r
167 : cchiw 3602 val mapp = insert (sid, n') dict
168 :     val (avail, rest') = sumPos (avail, sxx, rest, mapp, range0List)
169 : cchiw 3444 in
170 : cchiw 3602 sumPos (avail, sid::sxx, rest', dict, es)
171 : cchiw 3444 end
172 : cchiw 3602 val (avail,ids) = sumPos (avail, sxx, [], mappOrig, range0List)
173 : cchiw 3444 in
174 : cchiw 3602 (avail, List.rev ids)
175 : cchiw 3444 end
176 :    
177 :    
178 : cchiw 3602 end (* local *)
179 : cchiw 3444
180 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0