Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2522 - (view) (download)

1 : cchiw 2498 (* examples.sml
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 : cchiw 2510 (*
9 :     A couple of different approaches.
10 :     One approach is to find all the Probe(Conv). Gerenerate exp for it
11 :     Then use Subst function to sub in. That takes care for index matching and
12 :    
13 :     *)
14 :    
15 :     (*This approach creates probe expanded terms, and adds params to the end. *)
16 :    
17 :    
18 : cchiw 2498 structure Expand = struct
19 :    
20 :     local
21 :    
22 :     structure E = Ein
23 : cchiw 2510 structure mk= mkOperators
24 : cchiw 2498
25 : cchiw 2510
26 :     structure SrcIL = HighIL
27 :     structure SrcTy = HighILTypes
28 :     structure SrcOp = HighOps
29 :     structure SrcSV = SrcIL.StateVar
30 :     structure VTbl = SrcIL.Var.Tbl
31 :     structure DstIL = MidIL
32 :     structure DstTy = MidILTypes
33 :     structure DstOp = MidOps
34 : cchiw 2515 structure DstV = DstIL.Var
35 :     structure SrcV = SrcIL.Var
36 : cchiw 2510 structure P=Printer
37 : cchiw 2502
38 : cchiw 2510
39 : cchiw 2515 datatype peanut= O of DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
40 : cchiw 2522 datatype peanut2= O2 of SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int
41 : cchiw 2498 in
42 :    
43 :    
44 : cchiw 2502 fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
45 :     fun assignEin (x, rator, args) = (x, DstIL.EINAPP(rator, args))
46 : cchiw 2510
47 :     fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args))
48 :     fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args))
49 :    
50 : cchiw 2498 fun insert (key, value) d =fn s =>
51 :     if s = key then SOME value
52 :     else d s
53 :     fun lookup k d = d k
54 :    
55 : cchiw 2499
56 : cchiw 2510
57 : cchiw 2515 fun getRHS x = (case DstIL.Var.binding x
58 : cchiw 2522 of DstIL.VB_RHS(DstIL.OP(rator, args)) => (O rator, args)
59 : cchiw 2515 | DstIL.VB_RHS(DstIL.VAR x') => getRHS x'
60 :     | DstIL.VB_RHS(DstIL.EINAPP (e,args))=>(E e,args)
61 : cchiw 2522 | DstIL.VB_RHS(DstIL.CONS (ty,args))=>(C ty,args)
62 :     | DstIL.VB_NONE=>(S 2,[])
63 : cchiw 2510 | vb => raise Fail(concat[
64 : cchiw 2515 "expected rhs operator for ", DstIL.Var.toString x,
65 :     "but found ", DstIL.vbToString vb])
66 :     (* end case *))
67 : cchiw 2510
68 :    
69 : cchiw 2522 fun getRHS2 x = (case SrcIL.Var.binding x
70 :     of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)
71 :     | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS2 x'
72 :     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)
73 :     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)
74 :     | SrcIL.VB_NONE=>(S2 2,[])
75 :     | vb => raise Fail(concat[
76 :     "expected rhs operator for ", SrcIL.Var.toString x,
77 :     "but found ", SrcIL.vbToString vb])
78 :     (* end case *))
79 : cchiw 2510
80 :    
81 : cchiw 2515
82 : cchiw 2522
83 :     fun PrintBIND x = (case DstIL.Var.binding x
84 :     of vb=> print(String.concat[ "\n expected rhs operator for ", DstIL.Var.toString x,
85 :     " but found ", DstIL.vbToString vb,"\n \n"])
86 :     (* end case *))
87 :    
88 :     fun PrintBIND2 x = (case SrcIL.Var.binding x
89 :     of vb=> print(String.concat[ "\n expected rhs operator for ", SrcIL.Var.toString x,
90 :     " but found ", SrcIL.vbToString vb,"\n \n"])
91 :     (* end case *))
92 :    
93 :    
94 :    
95 :    
96 :    
97 :    
98 : cchiw 2502 (*Create fractional, and integer position vectors*)
99 : cchiw 2515 fun createArgs(dim,v,posx,pos)=let
100 :    
101 : cchiw 2502 val translate=DstOp.Translate v
102 :     val transform=DstOp.Transform v
103 :     val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
104 :     val T = DstV.new ("T", DstTy.vecTy dim) (*translate*)
105 :     val x = DstV.new ("x", DstTy.vecTy dim)
106 :     val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*)
107 :     val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*)
108 :     val n = DstV.new ("n", DstTy.iVecTy dim) (*interger position*)
109 :    
110 :    
111 :     val PosToImgSpace=mk.transform(dim,dim)
112 :     val code=[
113 :     assign(M, transform, []),
114 :     assign(T, translate, []),
115 : cchiw 2510 pos,
116 :     assignEin(x, PosToImgSpace,[M,posx,T]) , (* MX+T*)
117 : cchiw 2502 assign(nd, DstOp.Floor dim, [x]), (*nd *)
118 :     assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*)
119 :     assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
120 :     ]
121 :    
122 : cchiw 2515 in ([f,n],code)
123 : cchiw 2502 end
124 : cchiw 2499
125 :    
126 : cchiw 2522 (*Create fractional, and integer position vectors*)
127 :     fun createArgs2 (dim,v,posx)=let
128 : cchiw 2499
129 : cchiw 2522 val translate=DstOp.Translate v
130 :     val transform=DstOp.Transform v
131 :     val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
132 :     val T = DstV.new ("T", DstTy.vecTy dim) (*translate*)
133 :     val x = DstV.new ("x", DstTy.vecTy dim)
134 :     val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*)
135 :     val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*)
136 :     val n = DstV.new ("n", DstTy.iVecTy dim) (*interger position*)
137 :    
138 :     val PosToImgSpace=mk.transform(dim,dim)
139 :     val code=[
140 :     assign(M, transform, []),
141 :     assign(T, translate, []),
142 :    
143 :     assignEin(x, PosToImgSpace,[M,posx,T]) , (* MX+T*)
144 :     assign(nd, DstOp.Floor dim, [x]), (*nd *)
145 :     assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*)
146 :     assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
147 :     ]
148 :    
149 :     in ([f,n],code)
150 :     end
151 :    
152 :    
153 :     (*Currently can't get rhs of image*)
154 :    
155 :     fun createImg (dict,dim,pos1,imgArg,posArg)=let
156 :     val g=print "\n\n \t\t *** Did not find the proper bindings for img\n\n"
157 :     val info= PrintBIND imgArg
158 :     val info2= PrintBIND posArg
159 :     val a= DstV.new ("pos1F", DstTy.vecTy dim)
160 :     val b= DstV.new ("pos2F", DstTy.vecTy dim)
161 :     val bug=(pos1,pos1+1,dict,[E.TEN 1,E.TEN 1],[a,b],[])
162 :     in bug end
163 :    
164 :     fun Position(V,t,dict,dim,pos1,orig,args)=let
165 : cchiw 2515 val l=lookup t dict
166 : cchiw 2499
167 : cchiw 2515 in (case l
168 :     of NONE =>let
169 : cchiw 2522 val newimgArg=List.nth(args, V)
170 :     val newposArg=List.nth(args, t)
171 :     val imgArg=List.nth(orig,V)
172 :     val posArg=List.nth(orig,t)
173 :     in (case (getRHS2 imgArg,getRHS2 posArg)
174 :     of ((O2(SrcOp.LoadImage img), a1),(C2 ty, a2))=>
175 : cchiw 2499
176 : cchiw 2522 createImg(dict,dim,pos1,newimgArg,newposArg)
177 :     (*How to reassign arguments?*)
178 :     (*let
179 :    
180 :     val (args',code')=createArgs2(dim,img,newposArg)
181 :     val pos2=pos1+1
182 :     val dict'=insert(t,(pos1,pos2)) dict
183 :     val params'=[E.TEN 1,E.TEN 1]
184 :     val code= [assign(newimgArg,DstOp.LoadImage img,[])]
185 :    
186 :     in (pos1,pos2, dict',params',args',code'@code)
187 :     end
188 :     *)
189 :    
190 :    
191 :     (*
192 : cchiw 2515 val posx= DstV.new ("pos", DstTy.vecTy dim)
193 :     val pos= assignEin(posx, e,[])
194 :     val (args',code')=createArgs(dim,img,posx,pos)
195 :     val pos2=pos1+1
196 :     val dict'=insert(t,(pos1,pos2)) dict
197 : cchiw 2521 val params'=[E.TEN 1,E.TEN 1]
198 : cchiw 2515 in (pos1,pos2, dict',params',args',code')
199 :     end
200 : cchiw 2522 | ((O2(Src.LoadImage v,_),(C _,_))=>
201 :     let
202 :     val (args',code')=createArgs2(dim,v,posArg)
203 :     val pos2=pos1+1
204 :     val dict'=insert(t,(pos1,pos2)) dict
205 :     val params'=[E.TEN 1,E.TEN 1]
206 :     in (pos1,pos2, dict',params',args',code')
207 :     end
208 :    
209 :     *)
210 :     |_=>createImg(dict,dim,pos1,newimgArg,newposArg)
211 : cchiw 2515 (*end case*))
212 :     end
213 :    
214 :     | SOME (fid,nid)=>(fid,nid, dict,[],[],[]))
215 :     (*end case*)
216 :     end
217 : cchiw 2499
218 :     (*
219 : cchiw 2498 createDels=> creates the kronecker deltas for each Kernel
220 :     For each dimesnion a, and each index in derivative b create element (a,b)
221 :     *)
222 :     fun createDels([],_)= []
223 : cchiw 2515 | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)
224 : cchiw 2498
225 :    
226 :    
227 : cchiw 2522
228 :     fun kernTransition(k,x)= case getRHS2 k
229 :     of (O2(SrcOp.Kernel(h, i)),arg)=> (Kernel.support h ,[assign (x, DstOp.Kernel(h, i), [])])
230 :     |_ => raise Fail "Not a kernel argument"
231 :    
232 :    
233 :    
234 :    
235 :     fun expandEinProbe((body,(params,index,args,d,code,change)),origargs,sx)=(case body
236 : cchiw 2510 of E.Probe(E.Conv(V,shape,h,deltas),E.Tensor(t,alpha)) =>let
237 :     val E.IMG(dim)=List.nth(params,V)
238 : cchiw 2522
239 : cchiw 2510
240 : cchiw 2522
241 :     (*Kernel arg*)
242 :     val hnew=List.nth(args,h)
243 :     val hnewx=DstIL.Var.new("hnew " ,DstTy.KernelTy)
244 :     val (s,harg)=kernTransition(List.nth(origargs,h),hnewx)
245 :     val ss=print(String.concat["\n Support",Int.toString(s)])
246 :    
247 :    
248 : cchiw 2510 val pnum=length params
249 :    
250 : cchiw 2522 val (fid,nid,d',params',args',code')=Position(V,t,d,dim,pnum,origargs,args)
251 : cchiw 2498 val shift=length index
252 : cchiw 2510
253 : cchiw 2522
254 : cchiw 2515
255 : cchiw 2498 (*sumIndex creating summaiton Index for body*)
256 :     fun sumIndex(0)=[]
257 : cchiw 2515 |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]
258 : cchiw 2498
259 :     (*createKRN Image field and kernels *)
260 : cchiw 2510 fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
261 : cchiw 2515 | createKRN(dim,imgpos,rest)=let
262 : cchiw 2522 val dim'=dim-1
263 :     val sum=dim'+shift
264 :     val dels=createDels(deltas,dim')
265 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
266 :     val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
267 : cchiw 2498 in
268 :     createKRN(dim',pos@imgpos,[rest']@rest)
269 :     end
270 :    
271 :    
272 : cchiw 2510 val exp=createKRN(dim, [],[])
273 :     val esum=sumIndex (dim)
274 : cchiw 2522
275 : cchiw 2510
276 : cchiw 2522 in (E.Sum(esum, exp),(params@params',index,args@args',d',code@harg@code',1)) end
277 : cchiw 2510 (*end case*))
278 : cchiw 2498
279 :    
280 :     (*copied from high-to-mid.sml*)
281 : cchiw 2522 fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let
282 : cchiw 2498
283 :     val dummy=E.Const 0.0 (*tmp variables*)
284 : cchiw 2522 (*Maybe conv->0 sumewhere else *)
285 : cchiw 2510 val sumIndex=ref (length index)
286 :     fun sumI(e)=let
287 :     val (E.V v,_,_)=List.nth(e, length(e)-1)
288 : cchiw 2522 in v end
289 : cchiw 2498
290 :     fun rewriteBody exp= let
291 : cchiw 2510 val (body, data)=exp
292 : cchiw 2498 in (case body
293 :     of E.Const _=>exp
294 :     | E.Tensor _=>exp
295 :     | E.Krn _=>exp
296 :     | E.Delta _=>exp
297 :     | E.Value _ =>exp
298 :     | E.Epsilon _=>exp
299 :     | E.Partial _=>exp
300 :     | E.Img _=> exp
301 : cchiw 2522 | E.Conv _=>(print "\n No Probe, used, can not expand \n" ; (dummy,data))
302 : cchiw 2510 | E.Field _ =>(dummy, data)
303 :     | E.Apply _ =>(dummy, data)
304 : cchiw 2498 | E.Neg e=> let
305 : cchiw 2510 val (body',exp')=rewriteBody (e, data)
306 : cchiw 2498 in
307 : cchiw 2510 (E.Neg(body'),exp')
308 : cchiw 2498 end
309 :     | E.Sum (c,e)=> let
310 : cchiw 2510
311 :     val m=(sumI(c))+1
312 :     val (body',exp')=(sumIndex:=m;rewriteBody (e,data))
313 : cchiw 2498 in
314 : cchiw 2510 ((E.Sum(c, body'), exp'))
315 : cchiw 2498 end
316 : cchiw 2510 | E.Probe(E.Conv _, _) =>let
317 :     val ref x=sumIndex
318 : cchiw 2522 in expandEinProbe(exp,origargs, x) end
319 : cchiw 2498 | E.Sub(a,b)=>let
320 : cchiw 2510 val (bodya,dataa)= rewriteBody(a, data)
321 :     val (bodyb, datab)= rewriteBody(b, dataa)
322 :     in (E.Sub( bodya, bodyb),datab)
323 : cchiw 2498 end
324 :     | E.Div(a,b)=>let
325 : cchiw 2510 val (bodya,dataa)= rewriteBody(a, data)
326 :     val (bodyb, datab)= rewriteBody(b, dataa)
327 :     in (E.Div(bodya, bodyb),datab) end
328 : cchiw 2498 | E.Add es=> let
329 : cchiw 2510 fun filter([], done, data')= (E.Add done, data')
330 :     | filter(e::es, done, data')= let
331 :     val (body', data'')= rewriteBody(e, data')
332 :     in filter(es, done@[body'], data'') end
333 :     in filter(es, [],data) end
334 :    
335 : cchiw 2498 | E.Prod es=> let
336 : cchiw 2510 fun filter([], done, data)= (E.Prod done, data)
337 :     | filter(e::es, done, data)= let
338 :     val (body', data')= rewriteBody(e, data)
339 :     in filter(es, done@[body'], data') end
340 :     in filter(es, [],data) end
341 : cchiw 2498 | E.Probe _=> exp
342 :     (* end case *))
343 : cchiw 2510 end
344 : cchiw 2498
345 : cchiw 2510 val empty =fn key =>NONE
346 : cchiw 2515 val (body',(params', index', args',_,code',change))=rewriteBody(body,(params,index,args,empty,[],0))
347 : cchiw 2510 val newbie=Ein.EIN{params=params', index=index', body=body'}
348 : cchiw 2515 in (change,newbie,args',code') end
349 : cchiw 2498
350 :     end; (* local *)
351 :    
352 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0