Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2502 - (view) (download)

1 : cchiw 2498 (* examples.sml
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 :     (*Expand ProebConv to Probe of individual field *)
9 :     structure Expand = struct
10 :    
11 :     local
12 :    
13 :     structure E = Ein
14 : cchiw 2502 structure DstIL = MidIL
15 :     structure DstTy = MidILTypes
16 : cchiw 2498
17 : cchiw 2502 structure DstOp = MidOps
18 :     structure DstV = DstIL.Var
19 :     structure mk= mkOperators
20 :    
21 :    
22 : cchiw 2498 in
23 :    
24 :    
25 : cchiw 2502 fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
26 :     fun assignEin (x, rator, args) = (x, DstIL.EINAPP(rator, args))
27 : cchiw 2498 fun insert (key, value) d =fn s =>
28 :     if s = key then SOME value
29 :     else d s
30 :     fun lookup k d = d k
31 :    
32 : cchiw 2499
33 : cchiw 2502 (*Create fractional, and integer position vectors*)
34 :     fun createArgs(dim,v,pos,s)=let
35 :     val translate=DstOp.Translate v
36 :     val transform=DstOp.Transform v
37 : cchiw 2499
38 : cchiw 2502 (* Match EinTypes, or mid-il types?
39 :     val vecsTy =mk.createVec(2*s)
40 :     val vecDimTy = mk.createVec(dim)
41 :     *)
42 :    
43 :    
44 :     val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
45 :     val T = DstV.new ("T", DstTy.vecTy dim) (*translate*)
46 :     val x = DstV.new ("x", DstTy.vecTy dim)
47 :     val f = DstV.new ("f", DstTy.vecTy dim) (*fractional*)
48 :     val nd = DstV.new ("nd", DstTy.vecTy dim) (*real position*)
49 :     val n = DstV.new ("n", DstTy.iVecTy dim) (*interger position*)
50 :    
51 :    
52 :     val PosToImgSpace=mk.transform(dim,dim)
53 :     val code=[
54 :     assign(M, transform, []),
55 :     assign(T, translate, []),
56 :     assignEin(x, PosToImgSpace,[M,pos,T]) , (* MX+T*)
57 :     assign(nd, DstOp.Floor dim, [x]), (*nd *)
58 :     assignEin(f, mk.subTen([dim]),[x,nd]), (*fractional*)
59 :     assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
60 :    
61 :     ]
62 :    
63 :     (*Then f, n are new positions created. add to args list of currrent EinExp*)
64 :     val args=[f,n]
65 :     in (args,code)
66 :     end
67 : cchiw 2498 (*
68 : cchiw 2499
69 :    
70 :     fun createVec d= S.transform(EinOp.tensorOp,[[d]],[])
71 :    
72 :    
73 :     (*create fractional, and integer position*)
74 :     fun createArgs(dim,Kernel?, )=
75 :    
76 :    
77 :     (*create set of positions*)
78 :     val s = Kernel.support h
79 :     val vecsTy =createVec(2*s)
80 :     val vecDimTy = createVec(dim)
81 :     val translate=DstOp.Translate v
82 :     val transform=DstOp.Transform v
83 :    
84 :    
85 :     (* generate the transform code *)
86 :     val x = DstV.new ("x", vecDimTy) (* image-space position *)
87 :     val f = DstV.new ("f", vecDimTy)
88 :     val nd = DstV.new ("nd", vecDimTy)
89 :     val n = DstV.new ("n", DstTy.iVecTy dim)
90 :     val M = DstV.new ("M", transform)
91 :     val T = DstV.new ("T", translate)
92 :    
93 :     val sub= S.transform(EinOp.subTensor,[dim],[])
94 :    
95 :    
96 :     (* M_ij x_i*)
97 :     val MXop=S.transform(EinOp.innerProduct,[[dim],[],[dim]],[])
98 :     val MX = DstV.new ("MX", MXop)
99 :    
100 :    
101 :    
102 :     val PosToImgSpace=S.transform(EinOp.addTensor,[[dim]],[])
103 :    
104 :     val toImgSpaceCode = [
105 :     assignEin(MX, Mxop, [M,pos]), (*M_{ij}X_i*)
106 :     assignEin(x, PosToImgSpace,[MX,T]) (* MX+T*)
107 :     assign(nd, DstOp.Floor dim, [x]), (*nd *)
108 :     assignEin(f, sub,[x,nd]), (*fractional*)
109 :     assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
110 :     ]
111 :    
112 :     *)
113 :    
114 :    
115 :    
116 :     (*
117 : cchiw 2498 createDels=> creates the kronecker deltas for each Kernel
118 :     For each dimesnion a, and each index in derivative b create element (a,b)
119 :     *)
120 :     fun createDels([],_)= []
121 : cchiw 2499 | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)
122 : cchiw 2498
123 :    
124 :    
125 : cchiw 2499
126 : cchiw 2498 fun Position(idt,dict,params,args)=let
127 :     val l=lookup idt dict
128 :     in (case l
129 :     of NONE =>let
130 : cchiw 2499 val pos1=length params
131 :     val pos2=pos1+1
132 :     val dict'=insert(idt,(pos1,pos2)) dict
133 :     val params'=params@[E.TEN,E.TEN]
134 :     (*will call create args*)
135 :     in (pos1,pos2, dict',params',args)
136 :     end
137 : cchiw 2498 (*Create new fractional, and n variables,and returns fresh ids*)
138 :     | SOME (fid,nid)=>(fid,nid, dict,params,args)
139 :     (*end case*))
140 :     end
141 :    
142 :    
143 :     fun expandEinProbe(params,body,index,d,args)=(case body
144 :     of E.Probe(E.Conv(id,alpha,kid,deltas),E.Tensor(idt,alphat)) =>
145 :     if(id+1>length params) then (print "not enough params" ;(params,body,index,d,args))
146 :     else (case List.nth(params,id)
147 :     of E.FLD(dim)=>
148 :     let
149 :     val s=2 (*support*)
150 :     val (fid,nid,d',params',args')=Position(idt,d,params,args)
151 :     val shift=length index
152 :    
153 :     (*sumIndex creating summaiton Index for body*)
154 :     fun sumIndex(0)=[]
155 : cchiw 2499 |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+shift-1),1-s,s)]
156 : cchiw 2498
157 :     (*createKRN Image field and kernels *)
158 :     fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(id,alpha,imgpos)] @rest)
159 :     | createKRN(dim,imgpos,rest)=
160 :     let
161 :     val dim'=dim-1
162 :     val sum=dim'+shift
163 :     val dels=createDels(deltas,dim')
164 :     val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
165 :     val rest'= E.Krn(kid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
166 :     in
167 :     createKRN(dim',pos@imgpos,[rest']@rest)
168 :     end
169 :    
170 :    
171 :     val i=createKRN(dim, [],[])
172 :     val esum=sumIndex dim
173 : cchiw 2499 val index'=index
174 : cchiw 2498
175 :     in (params', E.Sum(esum, i),index', d',args') end
176 :     | _=>(print "err: non field in param spot";(params,E.Const(0.0) ,index,d,args))
177 :     (*end case*))
178 :     |_=>(print "unexpected body" ;(params,body,index,d,args))
179 :     (*end case*))
180 :    
181 :    
182 :    
183 :     (*copied from high-to-mid.sml*)
184 :     fun expandEinOp ( Ein.EIN{params, index, body}, args) = let
185 :    
186 :     val dummy=E.Const 0.0 (*tmp variables*)
187 :    
188 :     fun rewriteBody exp= let
189 :     val (p,body,ix,d,args')= exp
190 :     in (case body
191 :     of E.Const _=>exp
192 :     | E.Tensor _=>exp
193 :     | E.Krn _=>exp
194 :     | E.Delta _=>exp
195 :     | E.Value _ =>exp
196 :     | E.Epsilon _=>exp
197 :     | E.Partial _=>exp
198 :     | E.Img _=> exp
199 :     | E.Conv _=>(p,dummy,ix,d,args')
200 :     | E.Field _ =>(p,dummy,ix,d,args')
201 :     | E.Apply _ =>(p,dummy,ix,d,args')
202 :     | E.Neg e=> let
203 :     val (p',body',ix',d',args'')=rewriteBody (p,e,ix,d,args')
204 :     in
205 :     (p',E.Neg body',ix',d',args'')
206 :     end
207 :     | E.Sum (c,e)=> let
208 :     val (p',body',ix',d',args'')=rewriteBody (p,e,ix,d,args')
209 :     in
210 :     (p',E.Sum(c,body'),ix',d',args'')
211 :     end
212 :     | E.Probe(E.Conv _, _) =>expandEinProbe exp
213 :     | E.Sub(a,b)=>let
214 :     val(pa,a',ax,da,args'')= rewriteBody (p,a,ix,d,args')
215 :     val(pb,b',bx,db,args''')= rewriteBody (pa,b,ax,da,args'')
216 :     in (pb,E.Sub( a', b'),bx,db,args''')
217 :     end
218 :     | E.Div(a,b)=>let
219 :     val(pa,a',ax,da,args'')= rewriteBody (p,a,ix,d,args')
220 :     val(pb,b',bx,db,args''')= rewriteBody (pa,b,ax,da,args'')
221 :     in (pb,E.Div( a', b'),bx,db,args''')
222 :     end
223 :     | E.Add es=> let
224 :     fun addFilter(p1,ix1,d1,[],done,args')=(p1,E.Add done, ix1,d1,args')
225 :     | addFilter(p1,ix1,d1, e::es,done,args')=let
226 :     val(p2,e2,ix2,d2,args'')= rewriteBody(p1,e,ix1,d1,args')
227 :     in addFilter(p2, ix2, d2, es, done@[e2],args'')
228 :     end
229 :     in
230 :     addFilter(p,ix,d,es,[],args')
231 :     end
232 :     | E.Prod es=> let
233 :     fun addFilter(p1,ix1,d1,[],done,args')=(p1,E.Prod done, ix1,d1,args')
234 :     | addFilter(p1,ix1,d1, e::es,done,args')=let
235 :     val(p2,e2,ix2,d2,args'')= rewriteBody(p1,e,ix1,d1,args')
236 :     in addFilter(p2, ix2, d2, es, done@[e2],args'')
237 :     end
238 :     in
239 :     addFilter(p,ix,d,es,[],args')
240 :     end
241 :     | E.Probe _=> exp
242 :    
243 :     (* end case *))
244 :     end
245 :    
246 :    
247 :     val empty =fn key =>NONE
248 :     val (params',body',ix',_,args')=rewriteBody(params,body,index,empty,args)
249 :     val newbie=Ein.EIN{params=params', index=ix', body=body'}
250 :     in (newbie,args') end
251 :    
252 :     end; (* local *)
253 :    
254 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0