Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/split.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2838 - (view) (download)

1 : cchiw 2838 (* Currently under construction
2 :     *
3 :     * COPYRIGHT (c) 2012 The Diderot Project (http://diderot-language.cs.uchicago.edu)
4 :     * All rights reserved.
5 :     *)
6 :    
7 :    
8 :    
9 :    
10 :     structure Split = struct
11 :    
12 :     local
13 :    
14 :     structure E = Ein
15 :     structure mk= mkOperators
16 :     structure SrcIL = HighIL
17 :     structure SrcTy = HighILTypes
18 :     structure SrcOp = HighOps
19 :     structure SrcSV = SrcIL.StateVar
20 :     structure VTbl = SrcIL.Var.Tbl
21 :     structure DstIL = MidIL
22 :     structure DstTy = MidILTypes
23 :     structure DstOp = MidOps
24 :     structure DstV = DstIL.Var
25 :     structure SrcV = SrcIL.Var
26 :     structure P=Printer
27 :     structure F=Filter
28 :     structure T=TransformEin
29 :     structure Var = MidIL.Var
30 :     structure cleanP=cleanParams
31 :     structure cleanI=cleanIndex
32 :    
33 :     val testing=1
34 :     in
35 :    
36 :    
37 :     fun setEin(params,index,body)=Ein.EIN{params=params, index=index, body=body}
38 :     fun assignEinApp(y,params,index,body,args)= (y,DstIL.EINAPP(setEin(params,index,body),args))
39 :     fun setEinZero(y,params,index,args)= (y,DstIL.EINAPP(setEin(params,index,E.Const 0),args))
40 :     fun cleanParams e =cleanP.cleanParams e
41 :     fun cleanIndex e =cleanI.cleanIndex e
42 :     fun itos i =Int.toString i
43 :     fun err str=raise Fail str
44 :     val cnt = ref 0
45 :     fun genName prefix = let
46 :     val n = !cnt
47 :     in
48 :     cnt := n+1;
49 :     String.concat[prefix, "_", Int.toString n]
50 :     end
51 :     fun testp n=(case testing
52 :     of 0=> 1
53 :     | _ =>(print(String.concat n);1)
54 :     (*end case*))
55 :    
56 :     fun printEINAPP(id, DstIL.EINAPP(rator, args))=let
57 :     val a=String.concatWith " , " (List.map Var.toString args)
58 :     in
59 :     String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> ==",P.printerE rator, a,"\n"])
60 :     end
61 :     | printEINAPP(id, DstIL.OP(rator, args))=let
62 :     val a=String.concatWith " , " (List.map Var.toString args)
63 :     in
64 :     String.concat([(DstTy.toString (Var.ty id)),"<",Var.toString id,"> =",DstOp.toString rator,a,"\n"])
65 :     end
66 :    
67 :     | printEINAPP(id,_)= String.concat([Var.toString id,"<",(DstTy.toString (Var.ty id)),"> non-einapp\n"])
68 :    
69 :    
70 :     (* mkreplacement:params*index*index_list*int list* ein_exp-> ein_exp* params*args*code*
71 :     *creates new param and replacement tensor for the original ein_exp
72 :     *Then cleans params for suebxpression
73 :     *)
74 :     fun mkreplacement(params,args,tshape,sizes,body)=let
75 :     val id=length(params)
76 :     val params'=params@[E.TEN(1,sizes)]
77 :     val e'=E.Tensor(id,tshape)
78 :     val M = DstV.new (genName ("TLifted_"^itos id), DstTy.TensorTy sizes)
79 :     val args'=args@[M]
80 :     val einapp=cleanParams(M,body,params',sizes,args')
81 :     in
82 :     (e',params',args',[einapp])
83 :     end
84 :    
85 :     (* lift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
86 :     *lifts expression and returns replacement tensor
87 :     * cleans the index and params of subexpression
88 :     *)
89 :     fun lift(e,params,index,sx,args)=let
90 :     val (tshape,sizes,body)=cleanIndex(e,index,sx)
91 :     val (Re,Rparams,Rargs,code)=mkreplacement(params,args,tshape,sizes,body)
92 :     in
93 :     (Re,Rparams,Rargs,code)
94 :     end
95 :    
96 :     (* simplelift:ein_app*params*index*sum_id*args-> (ein_exp* params*args*code)
97 :     *lifts expression and returns replacement tensor
98 :     * cleans params of subexpression
99 :     *)
100 :     fun simplelift(e,params,index,args)=(*let
101 :     val tshape=List.map (fn x => E.V x) index
102 :     val(Re,Rparams,Rargs,code)=mkreplacement(params,args,tshape,index,e)
103 :     in
104 :     (Re,Rparams,Rargs,code)
105 :     end
106 :     *)lift(e,params,index,[],args)
107 :    
108 :    
109 :    
110 :     (* isOp: ein->int
111 :     * checks to see if this sub-expression is pulled out or split form original
112 :     * 0-becomes zero,1-remains the same, 2-operator
113 :     *)
114 :     fun isOp e =(case e
115 :     of E.Field _ => 0
116 :     | E.Conv _ => 0
117 :     | E.Apply _ => 0
118 :     | E.Lift _ => 0
119 :     | E.Neg _ => 1
120 :     | E.Add _ => 1
121 :     | E.Sub _ => 1
122 :     | E.Prod _ => 1
123 :     | E.Div _ => 1
124 :     | E.Sum _ => 1
125 :     | E.Probe _ => 1
126 :     | E.Partial _ => err(" Partial used after normalize")
127 :     | E.Krn _ => err("Krn used before expand")
128 :     | E.Value _ => err("Value used before expand")
129 :     | E.Img _ => err("Probe used before expand")
130 :     | _ => 2
131 :     (*end case*))
132 :    
133 :    
134 :     (* simpleOp:ein_exp*params*index*args-> ein_exp*params*args*code
135 :     * If e1 an op then call simplelift() to replace it
136 :     * Otherwise rewrite to 0 or it remains the same
137 :     *)
138 :     fun simpleOp(e1,params,index,args)=(case (isOp e1)
139 :     of 0 => (E.Const 0,params,args,[])
140 :     | 2 => (e1,params,args,[])
141 :     | _ => simplelift(e1,params,index,args)
142 :     (*end*))
143 :    
144 :    
145 :     (* simpleOps:ein_exp list*params*index*args-> ein_exp list*params*args*code
146 :     * calls simpleOp on ein_exp list
147 :     *)
148 :     fun simpleOps(list1,params,index,args)=let
149 :     fun m([],rest,params,args,code)=(rest,params,args,code)
150 :     | m(e1::es,rest,params,args,code)=let
151 :     val (e1',params',args',code')= simpleOp(e1,params,index,args)
152 :     in
153 :     m(es,rest@[e1'],params',args',code@code')
154 :     end
155 :     in
156 :     m(list1,[],params,args,[])
157 :     end
158 :    
159 :     (* prodOps:ein_exp list*params*index*sum_id list*args-> ein_exp list*params*args*code
160 :     * calls lift on ein_exp list
161 :     *)
162 :     fun prodOps(list1,params,index,sx,args)=let
163 :     fun m([],rest,params,args,code)=(rest,params,args,code)
164 :     | m(e1::es,rest,params,args,code)=(case (isOp e1)
165 :     of 0 => m(es,rest@[E.Const 0],params,args,code)
166 :     | 2 => m(es,rest@[e1],params,args,code)
167 :     | 1 => let
168 :     val (e1',params',args',code')= lift(e1,params,index,sx,args)
169 :     in
170 :     m(es,rest@[e1'],params',args',code@code')
171 :     end
172 :     (*end case*))
173 :     in
174 :     m(list1,[],params,args,[])
175 :     end
176 :    
177 :    
178 :     fun isZero(y,body,params,index,args) =(case (cleanI.isZero body)
179 :     of 1=> setEinZero(y,params,[],args)
180 :     | _ => cleanParams(y,body,params,index,args)
181 :     (*end case*))
182 :    
183 :     (* handleNeg:var*ein_exp *params*index*args-> (var*einap)*code
184 :     * calls simpleOp() lift on ein_exp
185 :     *)
186 :     fun handleNeg(y,e1,params,index,args)=let
187 :     val (e1',params',args',code)= simpleOp(e1,params,index,args)
188 :     val body =E.Neg e1'
189 :     val einapp= isZero(y,body,params',index,args')
190 :     in
191 :     (einapp,code)
192 :     end
193 :    
194 :     (* handleSub:var*ein_exp*ein_exp *params*index*args-> (var*einap)*code
195 :     * calls simpleOps() lift on ein_exp
196 :     *)
197 :     fun handleSub(y,e1,e2,params,index,args)=let
198 :     val ([e1',e2'],params',args',code)= simpleOps([e1,e2],params,index,args)
199 :     val body =E.Sub(e1',e2')
200 :     val einapp= isZero(y,body,params',index,args')
201 :     in
202 :     (einapp,code)
203 :     end
204 :    
205 :     (* handleDiv:var*ein_exp *ein_exp*params*index*args-> (var*einap)*code
206 :     * calls simpleOp() lift on ein_exp
207 :     *)
208 :     fun handleDiv(y,e1,e2,params,index,args)=let
209 :     val (e1',params1',args1',code1')=simpleOp(e1,params,index,args)
210 :     val (e2',params2',args2',code2')=simpleOp(e2,params1',[],args1')
211 :     val body =E.Div(e1',e2')
212 :     val einapp= isZero(y,body,params2',index,args2')
213 :     in
214 :     (einapp,code1'@code2')
215 :     end
216 :    
217 :    
218 :     (* handleAdd:var*ein_exp list *params*index*args-> (var*einap)*code
219 :     * calls simpleOps() lift on ein_exp
220 :     *)
221 :     fun handleAdd(y,e1,params,index,args)=let
222 :     val (e1',params',args',code)= simpleOps(e1,params,index,args)
223 :     val body =E.Add e1'
224 :     val einapp= isZero(y,body,params',index,args')
225 :     in
226 :     (einapp,code)
227 :     end
228 :    
229 :     (* handleProd:var*ein_exp list*params*index*args-> (var*einap)*code
230 :     * calls prodOps() lift on ein_exp
231 :     *)
232 :     fun handleProd(y,e1,params,index,args)=let
233 :     val (e1',params',args',code)= prodOps(e1,params,index,[],args)
234 :     val body =E.Prod e1'
235 :     val einapp= isZero(y,body,params',index,args')
236 :     in
237 :     (einapp,code)
238 :     end
239 :    
240 :     (* handleSumProd:var*ein_exp *params*index*args-> (var*einap)*code
241 :     * calls prodOps() lift on ein_exp
242 :     *)
243 :     fun handleSumProd(y,e1,params,index,sx,args)=let
244 :     val _ =List.map (fn (_,_,ub)=> Int.toString ub) sx
245 :     val (e1',params',args',code)= prodOps(e1,params,index,sx,args)
246 :     val body= E.Sum(sx,E.Prod e1')
247 :     val einapp= isZero(y,body,params',index,args')
248 :     in
249 :     (einapp,code)
250 :     end
251 :    
252 :     (* split:var*ein_app-> (var*einap)*code
253 :     * split ein expression into smaller pieces
254 :     *)
255 :     fun split(y,einapp as DstIL.EINAPP(Ein.EIN{params, index, body},args))=let
256 :     val zero= (setEinZero(y,params,[],args),[])
257 :     val default=((y,einapp),[])
258 :     val sumIndex=ref []
259 :     fun rewrite b=(case b
260 :     of E.Probe _ => default
261 :     | E.Conv _ => zero
262 :     | E.Field _ => zero
263 :     | E.Apply _ => zero
264 :     | E.Lift e => zero
265 :     | E.Delta _ => default
266 :     | E.Epsilon _ => default
267 :     | E.Tensor _ => default
268 :     | E.Const _ => default
269 :     | E.Neg e1 => handleNeg(y,e1,params,index,args)
270 :     | E.Sub (e1,e2) => handleSub(y,e1,e2,params,index,args)
271 :     | E.Div (e1,e2) => handleDiv(y,e1,e2,params,index,args)
272 :     | E.Sum(_,E.Probe _) => default
273 :     | E.Sum(_,E.Conv _) => zero
274 :     | E.Sum(sx,E.Prod e1) => handleSumProd(y,e1,params,index,sx,args)
275 :     | E.Sum(sx,E.Neg n) => rewrite (E.Neg(E.Sum(sx,n)))
276 :     | E.Sum(sx,E.Add a) => rewrite (E.Add(List.map (fn e=> E.Sum(sx,e)) a))
277 :     | E.Sum(sx,E.Sub (e1,e2)) => rewrite (E.Sub(E.Sum(sx,e1),E.Sum(sx,e2)))
278 :     | E.Sum(sx,E.Div(e1,e2)) => rewrite(E.Div(E.Sum(sx,e1),E.Sum(sx,e2)))
279 :     | E.Sum(c1, E.Sum (c2,e)) => rewrite (E.Sum (c1@c2,e))
280 :     | E.Sum(sx,_) => default
281 :     | E.Add e1 => handleAdd(y,e1,params,index,args)
282 :     | E.Prod e1 => handleProd(y,e1,params,index,args)
283 :     | E.Partial _ => err(" Partial used after normalize")
284 :     | E.Krn _ => err("Krn used before expand")
285 :     | E.Value _ => err("Value used before expand")
286 :     | E.Img _ => err("Probe used before expand")
287 :     (*end case *))
288 :    
289 :     val (einapp2,newbies) =rewrite body
290 :     in
291 :     (einapp2,newbies)
292 :     end
293 :     |split(y,app) =((y,app),[])
294 :    
295 :    
296 :     (* iterMultiple:code*code=> (code*code)
297 :     * recursively split ein expression into smaller pieces
298 :     *)
299 :     fun iterMultiple(einapp2,newbies2)=let
300 :     fun itercode([],rest,code)=(rest,code)
301 :     | itercode(e1::newbies,rest,code)=let
302 :     val (einapp3,code3) =split(e1)
303 :     val (rest4,code4)=itercode(code3,[],[])
304 :     in itercode(newbies,rest@[einapp3],code4@rest4@code)
305 :     end
306 :     val(rest,code)= itercode(newbies2,[],[])
307 :     in
308 :     (einapp2,code@rest)
309 :     end
310 :    
311 :    
312 :     fun iterSplit(y,einapp)=let
313 :     val (einapp2,newbies2)=split(y,einapp)
314 :     in
315 :     iterMultiple(einapp2,newbies2)
316 :     end
317 :    
318 :    
319 :    
320 :     (* gettest:code*code=> (code*code)
321 :     * print results for splitting einapp
322 :     *)
323 :     fun gettest(einapp)=(case testing
324 :     of 0=>iterSplit(einapp)
325 :     | _=>let
326 :     val star="\n*************\n"
327 :     val _ =print(String.concat[star])
328 :     val (einapp2,newbies)=iterSplit(einapp)
329 :     val a=printEINAPP einapp2
330 :     val b=String.concatWith",\n\t"(List.map printEINAPP newbies)
331 :     val _ =print(String.concat[printEINAPP einapp,"=>",a," newbies\n\t",b, "\n",a,star])
332 :     in
333 :     (einapp2,newbies)
334 :     end
335 :     (*end case*))
336 :    
337 :    
338 :     end; (* local *)
339 :    
340 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0