Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2525 - (view) (download)

1 : cchiw 2498 (* examples.sml
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 : cchiw 2510 (*
9 :     A couple of different approaches.
10 :     One approach is to find all the Probe(Conv). Gerenerate exp for it
11 :     Then use Subst function to sub in. That takes care for index matching and
12 :    
13 :     *)
14 :    
15 :     (*This approach creates probe expanded terms, and adds params to the end. *)
16 :    
17 :    
18 : cchiw 2498 structure Expand = struct
19 :    
20 :     local
21 :    
22 :     structure E = Ein
23 : cchiw 2510 structure mk= mkOperators
24 : cchiw 2498
25 : cchiw 2510
26 :     structure SrcIL = HighIL
27 :     structure SrcTy = HighILTypes
28 :     structure SrcOp = HighOps
29 :     structure SrcSV = SrcIL.StateVar
30 :     structure VTbl = SrcIL.Var.Tbl
31 :     structure DstIL = MidIL
32 :     structure DstTy = MidILTypes
33 :     structure DstOp = MidOps
34 : cchiw 2515 structure DstV = DstIL.Var
35 :     structure SrcV = SrcIL.Var
36 : cchiw 2510 structure P=Printer
37 : cchiw 2502
38 : cchiw 2510
39 : cchiw 2515 datatype peanut= O of DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
40 : cchiw 2522 datatype peanut2= O2 of SrcOp.rator | E2 of Ein.ein|C2 of SrcTy.ty|S2 of int
41 : cchiw 2498 in
42 :    
43 :    
44 : cchiw 2502 fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
45 :     fun assignEin (x, rator, args) = (x, DstIL.EINAPP(rator, args))
46 : cchiw 2510
47 :     fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args))
48 :     fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args))
49 :    
50 : cchiw 2498 fun insert (key, value) d =fn s =>
51 :     if s = key then SOME value
52 :     else d s
53 :     fun lookup k d = d k
54 :    
55 : cchiw 2499
56 : cchiw 2510
57 : cchiw 2515 fun getRHS x = (case DstIL.Var.binding x
58 : cchiw 2522 of DstIL.VB_RHS(DstIL.OP(rator, args)) => (O rator, args)
59 : cchiw 2515 | DstIL.VB_RHS(DstIL.VAR x') => getRHS x'
60 :     | DstIL.VB_RHS(DstIL.EINAPP (e,args))=>(E e,args)
61 : cchiw 2522 | DstIL.VB_RHS(DstIL.CONS (ty,args))=>(C ty,args)
62 :     | DstIL.VB_NONE=>(S 2,[])
63 : cchiw 2510 | vb => raise Fail(concat[
64 : cchiw 2515 "expected rhs operator for ", DstIL.Var.toString x,
65 :     "but found ", DstIL.vbToString vb])
66 :     (* end case *))
67 : cchiw 2510
68 :    
69 : cchiw 2522 fun getRHS2 x = (case SrcIL.Var.binding x
70 :     of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O2 rator, args)
71 :     | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS2 x'
72 :     | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E2 e,args)
73 :     | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=>(C2 ty,args)
74 :     | SrcIL.VB_NONE=>(S2 2,[])
75 :     | vb => raise Fail(concat[
76 :     "expected rhs operator for ", SrcIL.Var.toString x,
77 :     "but found ", SrcIL.vbToString vb])
78 :     (* end case *))
79 : cchiw 2510
80 :    
81 : cchiw 2515
82 : cchiw 2522
83 :     fun PrintBIND x = (case DstIL.Var.binding x
84 :     of vb=> print(String.concat[ "\n expected rhs operator for ", DstIL.Var.toString x,
85 :     " but found ", DstIL.vbToString vb,"\n \n"])
86 :     (* end case *))
87 :    
88 :     fun PrintBIND2 x = (case SrcIL.Var.binding x
89 :     of vb=> print(String.concat[ "\n expected rhs operator for ", SrcIL.Var.toString x,
90 :     " but found ", SrcIL.vbToString vb,"\n \n"])
91 :     (* end case *))
92 :    
93 :    
94 :    
95 :    
96 :    
97 :    
98 : cchiw 2502 (*Create fractional, and integer position vectors*)
99 : cchiw 2515 fun createArgs(dim,v,posx,pos)=let
100 :    
101 : cchiw 2502 val translate=DstOp.Translate v
102 :     val transform=DstOp.Transform v
103 :     val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
104 :     val T = DstV.new ("T", DstTy.vecTy dim) (*translate*)
105 :     val x = DstV.new ("x", DstTy.vecTy dim)
106 :     val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*)
107 :     val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*)
108 :     val n = DstV.new ("n", DstTy.iVecTy dim) (*interger position*)
109 :    
110 :    
111 :     val PosToImgSpace=mk.transform(dim,dim)
112 :     val code=[
113 :     assign(M, transform, []),
114 :     assign(T, translate, []),
115 : cchiw 2510 pos,
116 :     assignEin(x, PosToImgSpace,[M,posx,T]) , (* MX+T*)
117 : cchiw 2502 assign(nd, DstOp.Floor dim, [x]), (*nd *)
118 :     assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*)
119 :     assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
120 :     ]
121 :    
122 : cchiw 2515 in ([f,n],code)
123 : cchiw 2502 end
124 : cchiw 2499
125 :    
126 : cchiw 2522 (*Create fractional, and integer position vectors*)
127 :     fun createArgs2 (dim,v,posx)=let
128 : cchiw 2499
129 : cchiw 2522 val translate=DstOp.Translate v
130 :     val transform=DstOp.Transform v
131 :     val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
132 : cchiw 2523 val T = DstV.new ("T", DstTy.tensorTy [dim,dim]) (*translate*)
133 : cchiw 2522 val x = DstV.new ("x", DstTy.vecTy dim)
134 :     val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*)
135 :     val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*)
136 :     val n = DstV.new ("n", DstTy.iVecTy dim) (*interger position*)
137 : cchiw 2525
138 : cchiw 2522
139 :     val PosToImgSpace=mk.transform(dim,dim)
140 :     val code=[
141 :     assign(M, transform, []),
142 :     assign(T, translate, []),
143 :    
144 :     assignEin(x, PosToImgSpace,[M,posx,T]) , (* MX+T*)
145 :     assign(nd, DstOp.Floor dim, [x]), (*nd *)
146 :     assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*)
147 :     assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
148 :     ]
149 :    
150 :     in ([f,n],code)
151 :     end
152 :    
153 :    
154 :    
155 : cchiw 2525 fun Position(img,t,newposArg,dict,dim,pos1,ppos)=let
156 :     val l=lookup t dict
157 : cchiw 2522
158 : cchiw 2515 in (case l
159 :     of NONE =>let
160 : cchiw 2525
161 : cchiw 2523
162 : cchiw 2525 val (args',code')=createArgs2(dim,img,newposArg)
163 :     val pos2=pos1+1
164 :     val dict'=insert(t,(pos1,pos2)) dict
165 :    
166 :    
167 :    
168 :     in (pos1,pos2, dict',ppos,args',code')
169 :     end
170 : cchiw 2523
171 : cchiw 2499
172 : cchiw 2515 | SOME (fid,nid)=>(fid,nid, dict,[],[],[]))
173 :     (*end case*)
174 :     end
175 : cchiw 2499
176 :     (*
177 : cchiw 2498 createDels=> creates the kronecker deltas for each Kernel
178 :     For each dimesnion a, and each index in derivative b create element (a,b)
179 :     *)
180 :     fun createDels([],_)= []
181 : cchiw 2515 | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)
182 : cchiw 2498
183 :    
184 : cchiw 2525 fun replaceH(kvar, place,args)=let
185 :     val l1=List.take(args, place)
186 :     val l2=List.drop(args,place+1)
187 :     in l1@[kvar]@l2 end
188 : cchiw 2498
189 : cchiw 2525 (*Get Img, and Kern Args*)
190 :     fun getArgs(hid,hArg,V,imgArg,args)=
191 :     case (getRHS2 hArg,getRHS2 imgArg)
192 :     of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let
193 :     val hvar=DstV.new ("KNL", DstTy.KernelTy)
194 :     val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
195 :     val args1=replaceH(hvar, hid,args)
196 :     val args2=replaceH(imgvar, V,args1)
197 :     in
198 :     (Kernel.support h ,img, [assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])],args2)
199 :     end
200 :     | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"
201 :     | _ => raise Fail "Not a kernel argument"
202 : cchiw 2522
203 :    
204 :    
205 : cchiw 2525 fun expandEinProbe((body,(params,index,args,d,code,change)),origargs,sx)=(case body
206 :     of E.Probe(E.Conv(V,shape,h,deltas),E.Tensor(t,alpha)) =>let
207 : cchiw 2522
208 : cchiw 2525 val m=print "IN EXPAND\n"
209 : cchiw 2510 val E.IMG(dim)=List.nth(params,V)
210 :    
211 : cchiw 2525
212 :     val kArg=List.nth(origargs,h)
213 :     val imgArg=List.nth(origargs,V)
214 :     val (s,img,argcode,args2) =getArgs(h,kArg,V,imgArg,args)
215 :    
216 :     val newposArg=List.nth(args, t)
217 : cchiw 2522
218 : cchiw 2525 val ppos=[E.TEN(1,[dim]),E.TEN(3,[dim])]
219 :     val (fid,nid,d',params',args',code')=Position(img,t,newposArg,d,dim,(length params),ppos)
220 :     val shift=(length index)
221 : cchiw 2510
222 : cchiw 2525 val z=print(String.concat["\n SHIFt SET To",Int.toString(shift),"SX IS ", Int.toString(sx)])
223 : cchiw 2515
224 : cchiw 2498 (*sumIndex creating summaiton Index for body*)
225 :     fun sumIndex(0)=[]
226 : cchiw 2515 |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]
227 : cchiw 2498
228 :     (*createKRN Image field and kernels *)
229 : cchiw 2510 fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
230 : cchiw 2515 | createKRN(dim,imgpos,rest)=let
231 : cchiw 2522 val dim'=dim-1
232 : cchiw 2525 val sum=sx+dim'
233 : cchiw 2522 val dels=createDels(deltas,dim')
234 : cchiw 2525 val L=print "\n creatWith "
235 :     val LL=print(Int.toString(dim'))
236 : cchiw 2522 val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
237 :     val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
238 : cchiw 2525
239 : cchiw 2498 in
240 :     createKRN(dim',pos@imgpos,[rest']@rest)
241 :     end
242 :    
243 :    
244 : cchiw 2510 val exp=createKRN(dim, [],[])
245 :     val esum=sumIndex (dim)
246 : cchiw 2522
247 : cchiw 2525
248 :     in (E.Sum(esum, exp),(params@params',index,args2@args',d',argcode@code@code',1)) end
249 : cchiw 2510 (*end case*))
250 : cchiw 2498
251 :    
252 : cchiw 2525 fun TS x = case SrcIL.Var.binding x
253 :     of vb => String.concat[SrcIL.Var.toString x,"\n Found ", SrcIL.vbToString vb,"\n"]
254 :     (* end case *)
255 :    
256 :     fun TT x = case DstIL.Var.binding x
257 :     of vb => String.concat[DstIL.Var.toString x,"\n Found ", DstIL.vbToString vb,"\n"]
258 :     (* end case *)
259 :    
260 :    
261 :    
262 : cchiw 2498 (*copied from high-to-mid.sml*)
263 : cchiw 2522 fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let
264 : cchiw 2525 (*
265 : cchiw 2523 val g=print "expand is called"
266 : cchiw 2525 val gg1=print(String.concatWith ","(List.map TS origargs))
267 :     val gg2=print "\n ----------- newbie --------- \n"
268 :     val gg3=print(String.concatWith ","(List.map TT args))
269 :     *)
270 : cchiw 2498 val dummy=E.Const 0.0 (*tmp variables*)
271 : cchiw 2522 (*Maybe conv->0 sumewhere else *)
272 : cchiw 2510 val sumIndex=ref (length index)
273 :     fun sumI(e)=let
274 :     val (E.V v,_,_)=List.nth(e, length(e)-1)
275 : cchiw 2525 val ref x=sumIndex
276 :     val v'=v+1
277 :     in if(x>v') then x else v'
278 :     end
279 : cchiw 2498
280 :     fun rewriteBody exp= let
281 : cchiw 2510 val (body, data)=exp
282 : cchiw 2498 in (case body
283 :     of E.Const _=>exp
284 :     | E.Tensor _=>exp
285 :     | E.Krn _=>exp
286 :     | E.Delta _=>exp
287 :     | E.Value _ =>exp
288 :     | E.Epsilon _=>exp
289 :     | E.Partial _=>exp
290 :     | E.Img _=> exp
291 : cchiw 2522 | E.Conv _=>(print "\n No Probe, used, can not expand \n" ; (dummy,data))
292 : cchiw 2510 | E.Field _ =>(dummy, data)
293 :     | E.Apply _ =>(dummy, data)
294 : cchiw 2498 | E.Neg e=> let
295 : cchiw 2510 val (body',exp')=rewriteBody (e, data)
296 : cchiw 2498 in
297 : cchiw 2510 (E.Neg(body'),exp')
298 : cchiw 2498 end
299 :     | E.Sum (c,e)=> let
300 : cchiw 2510
301 : cchiw 2525 val m=(sumI(c))
302 : cchiw 2510 val (body',exp')=(sumIndex:=m;rewriteBody (e,data))
303 : cchiw 2498 in
304 : cchiw 2510 ((E.Sum(c, body'), exp'))
305 : cchiw 2498 end
306 : cchiw 2510 | E.Probe(E.Conv _, _) =>let
307 :     val ref x=sumIndex
308 : cchiw 2522 in expandEinProbe(exp,origargs, x) end
309 : cchiw 2498 | E.Sub(a,b)=>let
310 : cchiw 2510 val (bodya,dataa)= rewriteBody(a, data)
311 :     val (bodyb, datab)= rewriteBody(b, dataa)
312 :     in (E.Sub( bodya, bodyb),datab)
313 : cchiw 2498 end
314 :     | E.Div(a,b)=>let
315 : cchiw 2510 val (bodya,dataa)= rewriteBody(a, data)
316 :     val (bodyb, datab)= rewriteBody(b, dataa)
317 :     in (E.Div(bodya, bodyb),datab) end
318 : cchiw 2498 | E.Add es=> let
319 : cchiw 2510 fun filter([], done, data')= (E.Add done, data')
320 :     | filter(e::es, done, data')= let
321 :     val (body', data'')= rewriteBody(e, data')
322 :     in filter(es, done@[body'], data'') end
323 :     in filter(es, [],data) end
324 :    
325 : cchiw 2498 | E.Prod es=> let
326 : cchiw 2510 fun filter([], done, data)= (E.Prod done, data)
327 :     | filter(e::es, done, data)= let
328 :     val (body', data')= rewriteBody(e, data)
329 :     in filter(es, done@[body'], data') end
330 :     in filter(es, [],data) end
331 : cchiw 2498 | E.Probe _=> exp
332 :     (* end case *))
333 : cchiw 2510 end
334 : cchiw 2498
335 : cchiw 2510 val empty =fn key =>NONE
336 : cchiw 2515 val (body',(params', index', args',_,code',change))=rewriteBody(body,(params,index,args,empty,[],0))
337 : cchiw 2510 val newbie=Ein.EIN{params=params', index=index', body=body'}
338 : cchiw 2515 in (change,newbie,args',code') end
339 : cchiw 2498
340 :     end; (* local *)
341 :    
342 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0