Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Annotation of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3579 - (view) (download)

1 : jhr 3550 (* probe-ein.sml
2 :     *
3 :     * This code is part of the Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     *
5 :     * COPYRIGHT (c) 2016 The University of Chicago
6 :     * All rights reserved.
7 :     *)
8 :    
9 :     structure ProbeEin : sig
10 :    
11 :     end = struct
12 :    
13 :     structure E = Ein
14 : jhr 3577 structure DstIR = MidIR
15 : jhr 3550 structure DstOp = MidOps
16 : cchiw 3579
17 : jhr 3577 structure DstV = DstIR.Var
18 :     structure DstTy = MidTypes
19 :     structure T = CoordSpaceTransform
20 : jhr 3550
21 :     (* This file expands probed fields
22 :     * Take a look at ProbeEin tex file for examples
23 : jhr 3577 * Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIR.var list )
24 :     * Param_ids are used to note the placement of the argument in the midIR.var list
25 : jhr 3550 * Index_ids keep track of the shape of an Image or differentiation.
26 :     * Mu bind Index_id
27 :     * Generally, we will refer to the following
28 :     * dim:dimension of field V
29 :     * s: support of kernel H
30 :     * alpha: The alpha in <V_alpha * H^(deltas)>
31 :     * deltas: The deltas in <V_alpha * H^(deltas)>
32 :     * Vid:param_id for V
33 :     * hid:param_id for H
34 :     * nid: integer position param_id
35 :     * fid :fractional position param_id
36 :     * img-imginfo about V
37 :     *)
38 :    
39 :     (* FIXME: what are these for? should they be settable from the command-line? *)
40 :     val valnumflag = true
41 :     val tsplitvar = true
42 :     val fieldliftflag = true
43 :     val detflag = true
44 :    
45 : cchiw 3579 fun transformToIndexSpace e = T.imageToWorld e
46 :     fun transformToImgSpace e = T.worldToIndex e
47 : cchiw 3569
48 : jhr 3577 fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
49 :     fun mkEinApp (rator, args) = DstIR.EINAPP(rator, args)
50 : cchiw 3579
51 :     fun setProd e= E.Opn(E.Prod, e)
52 :     fun setAdd e= E.Opn(E.Add, e)
53 : jhr 3550
54 : jhr 3577 fun getRHSDst x = (case DstIR.Var.binding x
55 :     of DstIR.VB_RHS(DstIR.OP(rator, args)) => (rator, args)
56 :     | DstIR.VB_RHS(DstIR.VAR x') => getRHSDst x'
57 : jhr 3550 | vb => raise Fail(concat[
58 : jhr 3577 "expected rhs operator for ", DstIR.Var.toString x,
59 :     " but found ", DstIR.vbToString vb
60 : jhr 3550 ])
61 :     (* end case *))
62 :    
63 : jhr 3577 (* getArgsDst:MidIR.Var* MidIR.Var->int, ImageInfo, int
64 : jhr 3550 uses the Param_ids for the image, kernel,
65 : jhr 3577 and position tensor to get the Mid-IR arguments
66 : jhr 3550 returns the support of ther kernel, and image
67 :     *)
68 :     fun getArgsDst (hArg, imgArg, args) = (case (getRHSDst hArg, getRHSDst imgArg)
69 :     of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ )) =>
70 :     (Kernel.support h, img, ImageInfo.dim img)
71 :     | ((k,_), (i,_)) => raise Fail (String.concat[
72 :     "Expected kernel: ", DstOp.toString k, ", Expected Image: ", DstOp.toString i
73 :     ])
74 :     (* end case *))
75 :    
76 : jhr 3577 (*handleArgs():int*int*int*Mid IR.Var list
77 :     ->int*Mid.IRVars list* code*int* low-il-var
78 : jhr 3550 * uses the Param_ids for the image, kernel, and tensor
79 : jhr 3577 * and gets the mid-IR vars for each.
80 : jhr 3550 *Transforms the position to index space
81 :     *P is the mid-il var for the (transformation matrix)transpose
82 :     *)
83 : cchiw 3569 fun handleArgs (Vid, hid, tid, args) = let
84 : jhr 3551 val imgArg = List.nth(args, Vid)
85 :     val hArg = List.nth(args, hid)
86 :     val newposArg = List.nth(args, tid)
87 :     val (s,img, dim) = getArgsDst(hArg, imgArg, args)
88 : cchiw 3579 val (argsT, P, code) = transformToImgSpace{info=img, img=imgArg, pos=newposArg}
89 : jhr 3551 in
90 :     (dim, args@argsT, code, s, P)
91 :     end
92 : jhr 3550
93 :     (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
94 :     * expands the body for the probed field
95 :     *)
96 : cchiw 3569 fun createBody (dim, s, sx, alpha, deltas, Vid, hid, nid, fid)=let
97 : jhr 3550 (*1-d fields*)
98 : cchiw 3569 fun createKRND1 () = let
99 :     val sum = sx
100 :     val dels = List.map (fn e=>(E.C 0,e)) deltas
101 :     val pos = [setAdd[E.Tensor(fid,[]), E.Value(sum)]]
102 : cchiw 3579 val rest = E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value(sum)))
103 : jhr 3550 in
104 : cchiw 3569 setProd [E.Img(Vid,alpha,pos),rest]
105 : jhr 3550 end
106 :     (*createKRN Image field and kernels *)
107 : cchiw 3569 fun createKRN (0, imgpos, rest) = setProd ([E.Img(Vid,alpha,imgpos)] @rest)
108 :     | createKRN (dim, imgpos, rest) = let
109 :     val dim' = dim-1
110 :     val sum = sx+dim'
111 :     val dels = List.map (fn e=>(E.C dim',e)) deltas
112 :     val pos = [setAdd[E.Tensor(fid,[E.C dim']), E.Value(sum)]]
113 : cchiw 3579 val rest' = E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[E.C dim']), E.Value(sum)))
114 : jhr 3550 in
115 :     createKRN(dim',pos@imgpos,[rest']@rest)
116 :     end
117 : cchiw 3569 val exp = (case dim
118 : jhr 3550 of 1 => createKRND1()
119 : cchiw 3569 | _ => createKRN(dim, [], [])
120 : jhr 3550 (* end case *))
121 :     (*sumIndex creating summaiton Index for body*)
122 : cchiw 3569 val slb = 1-s
123 :     val esum = List.tabulate(dim, (fn dim=>(E.V (dim+sx), slb, s)))
124 : jhr 3550 in
125 :     E.Sum(esum, exp)
126 :     end
127 :    
128 :     (*getsumshift:sum_indexid list* int list-> int
129 :     *get fresh/unused index_id, returns int
130 :     *)
131 : cchiw 3569 fun getsumshift ([], n) = n
132 :     fun getsumshift (sx, n) = let
133 :     val (E.V v,_,_) = List.hd( List.rev sx)
134 : jhr 3550 in
135 : cchiw 3569 v+1
136 :     end
137 : jhr 3550
138 :     (*formBody:ein_exp->ein_exp
139 :     *)
140 : jhr 3551 fun formBody (E.Sum([],e))=formBody e
141 :     | formBody (E.Sum(sx,e))= E.Sum(sx,formBody e)
142 :     | formBody (E.Opn(E.Prod, [e]))=e
143 : jhr 3550 | formBody e=e
144 :    
145 :     (* silly change in order of the product to match vis branch WorldtoSpace functions*)
146 : jhr 3551 fun multiPs ([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
147 : cchiw 3579 (*| multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1]))) *)
148 : jhr 3551 | multiPs ([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
149 :     | multiPs (Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
150 : jhr 3550
151 : cchiw 3579
152 : jhr 3577 (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
153 : jhr 3550 -> ein_exp* *code
154 :     * Transforms position to world space
155 :     * transforms result back to index_space
156 :     * rewrites body
157 :     * replace probe with expanded version
158 :     *)
159 : cchiw 3579 fun replaceProbe (testN, (y, DstIR.EINAPP(Ein.EIN{params = params, index = index, body = body},args)), probe, sx) = let
160 : cchiw 3569 val fid = length(params)
161 :     val nid = fid+1
162 :     val Pid = nid+1
163 : cchiw 3579 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
164 : cchiw 3569 val nshift = length(dx)
165 : jhr 3550 val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
166 : cchiw 3579 val freshIndex = getsumshift(sx,length(index))
167 : cchiw 3569 val (dx,newsx1,Ps) = transformToIndexSpace(freshIndex,dim,dx,Pid)
168 : cchiw 3579 (*val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]*)
169 :     val params'=params@[E.TEN(true,[dim]),E.TEN(true,[dim]),E.TEN(true,[dim,dim])]
170 :     val probe' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
171 :     val probe' = multiPs(Ps,newsx1,probe')
172 :     val body' = (case body
173 :     of E.Sum(sx, E.Probe _) => E.Sum(sx,probe')
174 :     | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => E.Sum(sx,setProd[eps0,probe'])
175 :     | _ => probe'
176 : jhr 3550 (* end case *))
177 : cchiw 3569 val einapp=(y,mkEinApp(mkEin(params',index,body'),argsA@[PArg]))
178 : jhr 3550 in
179 :     code@[einapp]
180 :     end
181 :    
182 : jhr 3551 fun createEinApp (originalb, alpha, index, freshIndex, dim, dx, sx) = let
183 : cchiw 3569 val Pid = 0
184 :     val tid = 1
185 : jhr 3550
186 :     (*Assumes body is already clean*)
187 : cchiw 3569 val (newdx, newsx, Ps)=transformToIndexSpace(freshIndex, dim, dx, Pid)
188 : jhr 3550
189 :     (*need to rewrite dx*)
190 : cchiw 3569 val (_, sizes, e as E.Conv(_,alpha',_,dx)) = (case sx@newsx
191 :     of [] => ([], index, E.Conv(9,alpha,7,newdx))
192 : cchiw 3579 | _ => CleanIndex.clean(E.Conv(9,alpha,7,newdx), index, sx@newsx)
193 : jhr 3550 (* end case *))
194 : cchiw 3579 val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]
195 : jhr 3550 fun filterAlpha []=[]
196 : cchiw 3569 | filterAlpha (E.C _::es) = filterAlpha es
197 :     | filterAlpha (e1::es) = [e1]@(filterAlpha es)
198 :     val tshape = filterAlpha(alpha')@newdx
199 :     val t = E.Tensor(tid, tshape)
200 :     val (splitvar, body) = (case originalb
201 :     of E.Sum(sx, E.Probe _) => (true, multiPs(Ps, sx@newsx,t))
202 :     | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false, E.Sum(sx, setProd[eps0, multiPs(Ps, newsx, t)]))
203 :     | _ => (true, multiPs(Ps, newsx, t))
204 :     (* end case *))
205 : jhr 3550
206 : cchiw 3569 val ein0 = mkEin(params, index, body)
207 : jhr 3550 in
208 : cchiw 3569 (splitvar, ein0, sizes, dx, alpha')
209 : jhr 3550 end
210 :    
211 : cchiw 3579 fun liftProbe (printStrings, (y, DstIR.EINAPP(Ein.EIN{params , index , body }, args)), probe, sx) = let
212 : cchiw 3569 val fid = length(params)
213 :     val nid = fid+1
214 : cchiw 3579 val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
215 : cchiw 3569 val nshift = length(dx)
216 :     val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)
217 :     val freshIndex = getsumshift(sx, length(index))
218 : jhr 3550
219 :     (*transform T*P*P..Ps*)
220 : cchiw 3579 val (splitvar, ein0, sizes, dx, alpha') = createEinApp(body, alpha, index, freshIndex, dim, dx, sx)
221 : jhr 3550 val FArg = DstV.new ("F", DstTy.TensorTy(sizes))
222 : cchiw 3569 val einApp0 = mkEinApp(ein0, [PArg,FArg])
223 :     val rtn0 = (case splitvar
224 :     of false => [(y, mkEinApp(ein0, [PArg,FArg]))]
225 : jhr 3550 | _ => let
226 : cchiw 3579 val bind3 = (y, DstIR.EINAPP(EinSums.transform ein0, [PArg, FArg]))
227 : cchiw 3569 in
228 :     Split.splitEinApp bind3
229 :     end
230 : jhr 3550 (* end case *))
231 :    
232 :     (*lifted probe*)
233 : cchiw 3579 (*val params' = params@[E.TEN(3,[dim]), E.TEN(1,[dim])]*) (*Fixme: will get type error later*)
234 :     val params' = params@[E.TEN(true,[dim]), E.TEN(true,[dim])]
235 : jhr 3550 val freshIndex'= length(sizes)
236 : cchiw 3569 val body' = createBody(dim, s, freshIndex',alpha', dx, Vid, hid, nid, fid)
237 :     val ein1=mkEin(params', sizes, body')
238 :     val einApp1=mkEinApp(ein1, args')
239 : cchiw 3579 val code1=(FArg, einApp1)::rtn0
240 : jhr 3550 in
241 : cchiw 3579 code@code1
242 : jhr 3550 end
243 :    
244 :    
245 :     (* expandEinOp: code-> code list
246 :     * A this point we only have simple ein ops
247 :     * Looks to see if the expression has a probe. If so, replaces it.
248 :     * Note how we keeps eps expressions so only generate pieces that are used
249 :     *)
250 : cchiw 3579 fun expandEinOp (e as (y, DstIR.EINAPP(ein as Ein.EIN{params , index , body }, args)), fieldset) = let
251 :     val avail = AvailRHS.new()
252 :     fun matchField()=(case body
253 :     of E.Probe _ => 1
254 :     | E.Sum (_, E.Probe _) => 1
255 :     | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _])) => 1
256 :     | _ => 0
257 :     (* end case *))
258 :     fun rewriteBody()=(case body
259 : jhr 3550 of (E.Probe(E.Conv(_,_,_,[]),_))
260 : cchiw 3579 => replaceProbe(0,e,body,[])
261 : jhr 3550 | (E.Probe(E.Conv (_,alpha,_,dx),_))
262 : cchiw 3579 => liftProbe (0,e,body,[]) (*scans dx for contant*)
263 : jhr 3550 | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
264 :     => replaceProbe(0,e,p, sx) (*no dx*)
265 :     | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
266 : cchiw 3569 => liftProbe (0,e,p,sx) (*scalar field*)
267 : jhr 3550 | (E.Sum(sx,E.Probe p))
268 :     => replaceProbe(0,e,E.Probe p, sx)
269 :     | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))
270 :     => replaceProbe(0,e,E.Probe p,sx)
271 :     | _ => [e]
272 :     (* end case *))
273 : cchiw 3579 val (fieldset,var) = if (valnumflag)
274 :     then einSet.rtnVar(fieldset, y, DstIR.EINAPP(ein, args))
275 :     else (fieldset, NONE)
276 : jhr 3550 in (case var
277 : cchiw 3579 of NONE => (rewriteBody(), fieldset, matchField(), 0)
278 :     | SOME v => ([(y,DstIR.VAR v)], fieldset, matchField(), 1)
279 : jhr 3550 (* end case *))
280 :     end
281 :    
282 :     end (* ProbeEin *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0