Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/ProbeEin.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2838 - (view) (download)

1 : cchiw 2608 (* Currently under construction
2 : cchiw 2606 *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 :     (*
9 :     A couple of different approaches.
10 :     One approach is to find all the Probe(Conv). Gerenerate exp for it
11 :     Then use Subst function to sub in. That takes care for index matching and
12 :    
13 :     *)
14 :    
15 :     (*This approach creates probe expanded terms, and adds params to the end. *)
16 :    
17 :    
18 :     structure ProbeEin = struct
19 :    
20 :     local
21 :    
22 :     structure E = Ein
23 :     structure mk= mkOperators
24 :     structure SrcIL = HighIL
25 :     structure SrcTy = HighILTypes
26 :     structure SrcOp = HighOps
27 :     structure SrcSV = SrcIL.StateVar
28 :     structure VTbl = SrcIL.Var.Tbl
29 :     structure DstIL = MidIL
30 :     structure DstTy = MidILTypes
31 :     structure DstOp = MidOps
32 :     structure DstV = DstIL.Var
33 :     structure SrcV = SrcIL.Var
34 :     structure P=Printer
35 :     structure F=Filter
36 : cchiw 2611 structure T=TransformEin
37 : cchiw 2838 structure split=Split
38 : cchiw 2606
39 : cchiw 2827 val testing=0
40 : cchiw 2606
41 : cchiw 2827
42 : cchiw 2606 in
43 :    
44 : cchiw 2838 val cnt = ref 0
45 :     fun genName prefix = let
46 :     val n = !cnt
47 :     in
48 :     cnt := n+1;
49 :     String.concat[prefix, "_", Int.toString n]
50 :     end
51 : cchiw 2606
52 : cchiw 2838
53 : cchiw 2606 fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
54 :     fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
55 : cchiw 2829 fun testp n=(case testing
56 :     of 0=> 1
57 :     | _ =>(print(String.concat n);1)
58 :     (*end case*))
59 : cchiw 2606
60 : cchiw 2830 (*transform image-space position x to world space position*)
61 : cchiw 2838
62 :     fun getTys 1= (DstTy.intTy,[],[])
63 :     | getTys dim = (DstTy.iVecTy dim,[dim],[dim,dim])
64 :    
65 :    
66 : cchiw 2830 fun WorldToImagespace(dim,v,posx,imgArgDst)=let
67 : cchiw 2838 val translate=DstOp.Translate v
68 :     val transform=DstOp.Transform v
69 :     val (_ ,fty,pty)=getTys dim
70 :     val mty=DstTy.TensorTy pty
71 :     val rty=DstTy.TensorTy fty
72 :    
73 :     val M = DstV.new (genName "M", mty) (*transform dim by dim?*)
74 :     val T = DstV.new (genName "T", rty)
75 :     val x = DstV.new (genName "x", rty) (*Image-Space position*)
76 :     val x0 = DstV.new (genName "x0", rty)
77 :     val (PosToImgSpaceA,PosToImgSpaceB)=(case dim
78 :     of 1=>(mk.prodScalar,mk.addScalar)
79 :     | _ => (mk.transformA(dim,dim) ,mk.transformB(dim))
80 :     (*end case*))
81 :     val code=[
82 :     assign(M, transform, [imgArgDst]),
83 :     assign(T, translate, [imgArgDst]),
84 :     assignEin(x0, PosToImgSpaceA,[M,posx]) , (*xo=MX*)
85 :     assignEin(x, PosToImgSpaceB,[x0,T]) (*x=x0+T*)
86 : cchiw 2830 ]
87 :     in (M,x,code)
88 : cchiw 2838 end
89 : cchiw 2829
90 : cchiw 2830
91 :     (*Create fractional, and integer position vectors*)
92 :     fun transformToImgSpace (dim,v,posx,imgArgDst)=let
93 : cchiw 2838 val (ity,fty,pty)=getTys dim
94 :     val mty=DstTy.TensorTy pty
95 :     val rty=DstTy.TensorTy fty
96 : cchiw 2830
97 : cchiw 2838 val f = DstV.new ("f", rty) (*fractional*)
98 :     val nd = DstV.new ("nd", rty) (*real position*)
99 :     val n = DstV.new ("n", ity) (*integer position*)
100 :     val P = DstV.new ("P",mty) (*transform dim by dim?*)
101 : cchiw 2830
102 :     val (M,x,code1)=WorldToImagespace(dim,v,posx,imgArgDst)
103 : cchiw 2838 val (P,PCode)=(case dim
104 :     of 1=>(M,[])
105 :     | _ =>(P,[assignEin(P, mk.transpose(pty), [M])])
106 :     (*end case*))
107 : cchiw 2830 val code=[
108 : cchiw 2606 assign(nd, DstOp.Floor dim, [x]), (*nd *)
109 : cchiw 2838 assignEin(f, mk.subTen(fty),[x,nd]), (*fractional*)
110 :     assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
111 : cchiw 2608 ]
112 : cchiw 2838 in ([n,f],P,code1@PCode@code)
113 : cchiw 2606 end
114 :    
115 : cchiw 2838
116 :     fun getRHSDst x = (case DstIL.Var.binding x
117 :     of DstIL.VB_RHS(DstIL.OP(rator, args)) => (rator, args)
118 :     | DstIL.VB_RHS(DstIL.VAR x') => getRHSDst x'
119 :     | vb => raise Fail(concat[ "expected rhs operator for ", DstIL.Var.toString x, "but found ", DstIL.vbToString vb])
120 :     (* end case *))
121 : cchiw 2606
122 :    
123 : cchiw 2838 (*Get Img, and Kern Args*)
124 :     fun getArgsDst(hid,hArg,imgArg,args)=(case (getRHSDst hArg,getRHSDst imgArg)
125 :     of ((DstOp.Kernel(h, i), _ ),(DstOp.LoadImage img, _ ))=> let
126 :     in
127 :     ((Kernel.support h) ,img)
128 :     end
129 :     | _ => raise Fail "Expected Image and kernel argument"
130 :     (*end case*))
131 : cchiw 2606
132 : cchiw 2838
133 : cchiw 2830
134 : cchiw 2838 fun handleArgs(V,hid,t,args)=let
135 :     val hArg=List.nth(args,hid)
136 :     val imgArg=List.nth(args,V)
137 :     val newposArg=List.nth(args,t)
138 :     val (s,img) =getArgsDst(hid,hArg,imgArg,args)
139 : cchiw 2830 val dim=ImageInfo.dim img
140 : cchiw 2838 val (argsT,P,code)=transformToImgSpace(dim,img,newposArg,imgArg)
141 :     in (dim,args@argsT,code, s,P)
142 : cchiw 2606 end
143 :    
144 : cchiw 2608
145 : cchiw 2606 (*Created new body for probe*)
146 :     fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let
147 :    
148 :     (*sumIndex creating summaiton Index for body*)
149 : cchiw 2838 fun sumIndex 0=[]
150 : cchiw 2606 |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]
151 :    
152 : cchiw 2838
153 :     fun createKRND1 ()=let
154 :     val sum=sx
155 :     val dels=List.map (fn e=>(E.C 0,e)) deltas
156 :     val pos=[E.Add[E.Tensor(fid,[]),E.Value(sum)]]
157 :     val rest= E.Krn(h,dels,E.Sub(E.Tensor(nid,[]),E.Value(sum)))
158 :     in
159 :     E.Prod [E.Img(V,shape,pos),rest]
160 :    
161 :     end
162 :    
163 : cchiw 2606 (*createKRN Image field and kernels *)
164 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
165 :     | createKRN(dim,imgpos,rest)=let
166 :     val dim'=dim-1
167 :     val sum=sx+dim'
168 : cchiw 2830 val dels=List.map (fn e=>(E.C dim',e)) deltas
169 : cchiw 2606 val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
170 :     val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
171 :     in
172 :     createKRN(dim',pos@imgpos,[rest']@rest)
173 :     end
174 :    
175 : cchiw 2838 val exp=(case dim
176 :     of 1 => createKRND1()
177 :     | _=> createKRN(dim, [],[])
178 :     (*end case*))
179 :    
180 : cchiw 2606 val esum=sumIndex (dim)
181 :     in E.Sum(esum, exp)
182 :     end
183 :    
184 :    
185 :     fun ShapeConv([],n)=[]
186 :     | ShapeConv(E.C c::es, n)=ShapeConv(es, n)
187 :     | ShapeConv(E.V v::es, n)=
188 :     if(n>v) then [E.V v] @ ShapeConv(es, n)
189 :     else ShapeConv(es,n)
190 :    
191 :    
192 :     fun mapIndex([],_)=[]
193 :     | mapIndex(E.V v::es,index) = [List.nth(index, v)]@ mapIndex(es,index)
194 :     | mapIndex(E.C c::es,index) = mapIndex(es,index)
195 :    
196 :    
197 :    
198 : cchiw 2838 (* Expand probe in place eplaceProbe(b,params,args, index, sx,args)*)
199 :     fun replaceProbe(b,params,args,index, sumIndex)=let
200 : cchiw 2606
201 : cchiw 2608 val E.Probe(E.Conv(V,alpha,h,dx),E.Tensor(t,_))=b
202 : cchiw 2606 val fid=length(params)
203 : cchiw 2611 val nid=fid+1
204 : cchiw 2606 val n=length(index)
205 : cchiw 2838
206 : cchiw 2611 val nshift=length(dx)
207 : cchiw 2838 val nsumshift =(case sumIndex
208 :     of []=> n
209 : cchiw 2611 | _=>let
210 : cchiw 2838 val (E.V v,_,_)=List.hd(List.rev sumIndex)
211 : cchiw 2611 in v+1
212 : cchiw 2838 end
213 : cchiw 2611 (* end case *))
214 : cchiw 2838
215 :     val aa=List.map (fn (E.V v,_,_)=>Int.toString v) sumIndex
216 :     val _ =testp["\n", "SumIndex" ,(String.concatWith"," aa),"\nThink nshift is ", Int.toString nsumshift]
217 : cchiw 2611
218 :     (*Outer Index-id Of Probe*)
219 :     val VShape=ShapeConv(alpha, n)
220 :     val HShape=ShapeConv(dx, n)
221 :     val shape=VShape@HShape
222 :     (* Bindings for Shape*)
223 :     val shapebind= mapIndex(shape,index)
224 :     val Vshapebind= mapIndex(VShape,index)
225 :    
226 :    
227 : cchiw 2838 val (dim,argsA,code,s,PArg) = handleArgs(V,h,t,args)
228 : cchiw 2611 val (_,_,dx, _,sxT,restT,_,_) = T.Transform(dx,shapebind,Vshapebind,dim,PArg,nsumshift,1,nid+1)
229 :    
230 :     val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]
231 :     val body'' = createBody(dim, s,nsumshift+nshift,alpha,dx,V, h, nid, fid)
232 :     val body' =(case nshift
233 :     of 0=> body''
234 :     | _ => E.Sum(sxT, E.Prod(restT@[body'']))
235 :     (*end case*))
236 :     val args'=argsA@[PArg]
237 : cchiw 2838 in
238 :     (body',params',args' ,code)
239 : cchiw 2606 end
240 :    
241 : cchiw 2611
242 : cchiw 2606 (* sx-[] then move out, otherwise keep in *)
243 : cchiw 2838 fun expandEinOp( e as (y, DstIL.EINAPP(Ein.EIN{params, index, body}, args))) = let
244 : cchiw 2606
245 :    
246 :     (*b-current body, info-original ein op, data-new assigments*)
247 : cchiw 2838 fun rewriteBody b= let
248 : cchiw 2606 in (case b
249 : cchiw 2838 of E.Probe(E.Conv _, E.Tensor _) =>let
250 :     val (body',params',args',newbies)=replaceProbe(b, params,args,index, [])
251 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
252 :     val code=newbies@[einapp]
253 :     in
254 :     (1,code)
255 : cchiw 2606 end
256 : cchiw 2838 | E.Sum(sx,E.Probe e) =>let
257 :     val (body',params',args',newbies)=replaceProbe(E.Probe e,params,args, index, sx)
258 :     val body'=E.Sum(sx,body')
259 :     val einapp=(y,DstIL.EINAPP(Ein.EIN{params=params', index=index, body=body'},args'))
260 :     val code=newbies@[einapp]
261 : cchiw 2606 in
262 : cchiw 2838 (1,code)
263 : cchiw 2606 end
264 : cchiw 2838 | _=> (0,[e])
265 : cchiw 2606 (* end case *))
266 :     end
267 :    
268 :     val empty =fn key =>NONE
269 : cchiw 2838
270 :     val (c,code)=rewriteBody body
271 :     val b=String.concatWith",\t"(List.map split.printEINAPP code)
272 :     val _ =(case c
273 :     of 1 =>print(String.concat["\nbody",split.printEINAPP e, "\n=>\n",b ])
274 :     | _ =>print(String.concat[""])
275 : cchiw 2606 (*end case*))
276 :     in
277 : cchiw 2838 code
278 : cchiw 2606 end
279 :    
280 :     end; (* local *)
281 :    
282 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0