Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3030 - (view) (download)

1 : cchiw 2845 (* Expands probe ein
2 : cchiw 2606 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :     structure ProbeEin = struct
8 :    
9 :     local
10 :    
11 :     structure E = Ein
12 :     structure DstIL = MidIL
13 :     structure DstOp = MidOps
14 :     structure P=Printer
15 : cchiw 2611 structure T=TransformEin
16 : cchiw 2845 structure MidToS=MidToString
17 : cchiw 2976 structure DstV = DstIL.Var
18 :     structure DstTy = MidILTypes
19 :    
20 : cchiw 2606 in
21 :    
22 : cchiw 2870 (* This file expands probed fields
23 :     * Take a look at ProbeEin tex file for examples
24 :     *Note that the original field is an EIN operator in the form <V_alpha * H^(deltas)>(midIL.var list )
25 :     * Param_ids are used to note the placement of the argument in the midIL.var list
26 :     * Index_ids keep track of the shape of an Image or differentiation.
27 :     * Mu bind Index_id
28 :     * Generally, we will refer to the following
29 :     *dim:dimension of field V
30 :     * s: support of kernel H
31 :     * alpha: The alpha in <V_alpha * H^(deltas)>
32 :     * deltas: The deltas in <V_alpha * H^(deltas)>
33 :     * Vid:param_id for V
34 :     * hid:param_id for H
35 :     * nid: integer position param_id
36 :     * fid :fractional position param_id
37 :     *img-imginfo about V
38 :     *)
39 : cchiw 2843
40 : cchiw 2923 val testing=0
41 : cchiw 2845 val cnt = ref 0
42 : cchiw 2606
43 : cchiw 3030 fun printEINAPP e=MidToString.printEINAPP e
44 : cchiw 2845 fun transformToIndexSpace e=T.transformToIndexSpace e
45 :     fun transformToImgSpace e=T.transformToImgSpace e
46 :     fun testp n=(case testing
47 :     of 0=> 1
48 :     | _ =>(print(String.concat n);1)
49 :     (*end case*))
50 :     fun getRHSDst x = (case DstIL.Var.binding x
51 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
52 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
53 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
54 :     (* end case *))
55 : cchiw 2838
56 : cchiw 2606
57 : cchiw 2845 (* getArgsDst:MidIL.Var* MidIL.Var->int, ImageInfo, int
58 :     uses the Param_ids for the image, kernel,
59 :     and position tensor to get the Mid-IL arguments
60 :     returns the support of ther kernel, and image
61 :     *)
62 :     fun getArgsDst(hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
63 :     of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
64 :     in
65 :     ((Kernel.support h) ,img,ImageInfo.dim img)
66 :     end
67 :     | _ => raise Fail "Expected Image and kernel arguments"
68 :     (*end case*))
69 : cchiw 2606
70 :    
71 : cchiw 2845 (*handleArgs():int*int*int*Mid IL.Var list
72 :     ->int*Mid.ILVars list* code*int* low-il-var
73 :     * uses the Param_ids for the image, kernel, and tensor
74 :     * and gets the mid-IL vars for each.
75 :     *Transforms the position to index space
76 :     *P is the mid-il var for the (transformation matrix)transpose
77 :     *)
78 :     fun handleArgs(Vid,hid,tid,args)=let
79 :     val imgArg=List.nth(args,Vid)
80 :     val hArg=List.nth(args,hid)
81 :     val newposArg=List.nth(args,tid)
82 :     val (s,img,dim) =getArgsDst(hArg,imgArg,args)
83 :     val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
84 : cchiw 2606 in
85 : cchiw 2845 (dim,args@argsT,code, s,P)
86 : cchiw 2606 end
87 : cchiw 2838
88 : cchiw 2845 (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
89 :     * expands the body for the probed field
90 :     *)
91 :     fun createBody(dim, s,sx,alpha,deltas,Vid, hid, nid, fid)=let
92 :     (*1-d fields*)
93 :     fun createKRND1 ()=let
94 :     val sum=sx
95 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
96 :     val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
97 :     val rest= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
98 :     in
99 :     E.Prod [E.Img(Vid,alpha,pos),rest]
100 : cchiw 2843 end
101 : cchiw 2845 (*createKRN Image field and kernels *)
102 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(Vid,alpha,imgpos)] @rest)
103 :     | createKRN(dim,imgpos,rest)=let
104 :     val dim'=dim-1
105 :     val sum=sx+dim'
106 :     val dels=List.map (fn e=>(E.C dim',e)) deltas
107 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
108 :     val rest'= E.Krn(hid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
109 :     in
110 :     createKRN(dim',pos@imgpos,[rest']@rest)
111 :     end
112 :     val exp=(case dim
113 :     of 1 => createKRND1()
114 :     | _=> createKRN(dim, [],[])
115 :     (*end case*))
116 :     (*sumIndex creating summaiton Index for body*)
117 :     val slb=1-s
118 :     val esum=List.tabulate(dim, (fn dim=>(E.V (dim+sx),slb,s)))
119 : cchiw 2843 in
120 : cchiw 2845 E.Sum(esum, exp)
121 : cchiw 2606 end
122 :    
123 : cchiw 2845 (*getsumshift:sum_indexid list* int list-> int
124 :     *get fresh/unused index_id, returns int
125 :     *)
126 :     fun getsumshift(sx,index) =let
127 :     val nsumshift= (case sx
128 :     of []=> length(index)
129 :     | _=>let
130 :     val (E.V v,_,_)=List.hd(List.rev sx)
131 :     in v+1
132 :     end
133 :     (* end case *))
134 :     val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sx
135 :     val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),
136 :     "\nThink nshift is ", Int.toString nsumshift]
137 :     in
138 :     nsumshift
139 :     end
140 : cchiw 2611
141 : cchiw 2845 (*formBody:ein_exp->ein_exp
142 :     *just does a quick rewrite
143 :     *)
144 :     fun formBody(E.Sum([],e))=formBody e
145 :     | formBody(E.Sum(sx,e))= E.Sum(sx,formBody e)
146 :     | formBody(E.Prod [e])=e
147 :     | formBody e=e
148 : cchiw 2606
149 : cchiw 2976 (* silly change in order of the product to match vis branch WorldtoSpace functions*)
150 :     fun multiPs([P0,P1,P2],sx,body)= formBody(E.Sum(sx, E.Prod([P0,P1,P2,body])))
151 :     | multiPs(Ps,sx,body)=formBody(E.Sum(sx, E.Prod([body]@Ps)))
152 :    
153 : cchiw 2845 (* replaceProbe:ein_exp* params *midIL.var list * int list* sum_id list
154 :     -> ein_exp* *code
155 :     * Transforms position to world space
156 :     * transforms result back to index_space
157 :     * rewrites body
158 :     * replace probe with expanded version
159 :     *)
160 :     fun replaceProbe(b,params,args,index, sx)=let
161 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
162 :     val fid=length(params)
163 :     val nid=fid+1
164 :     val Pid=nid+1
165 :     val nshift=length(dx)
166 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
167 :     val freshIndex=getsumshift(sx,index)
168 :     val (dx,newsx1,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
169 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
170 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
171 : cchiw 2976 val body' = multiPs(Ps,newsx1,body')
172 : cchiw 2845 val args'=argsA@[PArg]
173 :     in
174 :     (body',params',args' ,code)
175 :     end
176 : cchiw 2976
177 :     fun mkEin(params,index,body)=E.EIN{params=params, index=index,body=body}
178 :     fun mkEinApp(rator,args)=DstIL.EINAPP(rator,args)
179 :    
180 :     fun createEinApp(alpha,index,freshIndex,dim,dx)= let
181 :     val Pid=0
182 :     val tid=1
183 :    
184 :     val (newdx,newsx,Ps)=transformToIndexSpace(freshIndex,dim,dx,Pid)
185 :     val params=[E.TEN(1,[dim,dim]),E.TEN(1,index)]
186 :     val t=E.Tensor(tid,alpha@newdx)
187 :     val body = multiPs(Ps,newsx,t)
188 :     val rator=mkEin(params,index,body)
189 :     in
190 :     rator
191 :     end
192 :    
193 :    
194 :    
195 :     fun liftProbe(y,b,params,args,index, sx)=let
196 :     val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_))=b
197 :     val fid=length(params)
198 :     val nid=fid+1
199 :     val nshift=length(dx)
200 :     val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
201 :     val freshIndex=getsumshift(sx,index)
202 :    
203 :    
204 :     (*transform T*P*P..Ps*)
205 :     val FArg = DstV.new ("F", DstTy.TensorTy(index))
206 :     val rator=createEinApp(alpha,index,freshIndex,dim,dx)
207 :     val einApp0=mkEinApp(rator,[PArg,FArg])
208 :    
209 :    
210 :     (*lifted probe*)
211 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim])]
212 :     val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
213 :     val args'=argsA
214 :     val ein1=mkEin(params', index,body')
215 :     val einApp1=mkEinApp(ein1,args')
216 :    
217 :     val rtn=code@[(FArg,einApp1),(y,einApp0)]
218 :    
219 : cchiw 3030 val _=print(String.concat["\n lift probe:\n",P.printbody b,"\n\t=>\n\t", printEINAPP (FArg,einApp1), "&\n\t", printEINAPP (y,einApp0)])
220 : cchiw 2976 in
221 :     rtn
222 :     end
223 :    
224 :    
225 : cchiw 2845 (* expandEinOp: code-> code list
226 :     *Looks to see if the expression has a probe. If so, replaces it.
227 :     * Note how we keeps eps expressions so only generate pieces that are used
228 :     *)
229 : cchiw 2923 fun expandEinOp( e as (y, DstIL.EINAPP(ein as Ein.EIN{params, index, body}, args))) = let
230 : cchiw 2845 fun rewriteBody b=(case b
231 : cchiw 2976 of E.Probe(E.Field _,_)=> raise Fail"Poorly formed EIN operator."
232 :     | E.Probe(E.Conv(_,_,_,[]),_) =>let
233 : cchiw 3017
234 : cchiw 2976 val (body',params',args',newbies)=replaceProbe(b,params,args, index, [])
235 : cchiw 2845 val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
236 :     val code=newbies@[einapp]
237 : cchiw 3030 val _=print(String.concat["\n Replace probe:\n",P.printbody b,"\n=>",P.printbody body'])
238 : cchiw 2845 in
239 :     code
240 :     end
241 : cchiw 3030 | E.Probe(E.Conv _,_) =>let
242 :     val _=print(String.concat["\n Lift probe:\n",P.printerE ein])
243 :     in liftProbe(y,b,params,args, index, [])
244 :     end
245 : cchiw 2845 | E.Sum(sx,E.Probe e) =>let
246 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
247 :     val body'=E.Sum(sx,body')
248 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
249 :     val code=newbies@[einapp]
250 : cchiw 3030 val _=print(String.concat["\n Replace probe:\n",P.printerE ein,"\n=>",P.printbody body'])
251 : cchiw 2870
252 : cchiw 2845 in
253 :     code
254 :     end
255 :     | E.Sum(sx,E.Prod[eps,E.Probe e]) =>let
256 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
257 :     val body'=E.Sum(sx,E.Prod[eps,body'])
258 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
259 :     val code=newbies@[einapp]
260 : cchiw 3030 val _=print(String.concat["\n Replace probe:\n",P.printerE ein,"\n=>",P.printbody body'])
261 : cchiw 2870 in
262 : cchiw 2845 code
263 :     end
264 : cchiw 2922 | _=> [(y, DstIL.EINAPP(ein,args))]
265 : cchiw 2845 (* end case *))
266 :     in
267 :     rewriteBody body
268 :     end
269 : cchiw 2843
270 : cchiw 2606 end; (* local *)
271 :    
272 : cchiw 2845 end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0