Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split-einHtM.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/split-einHtM.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2554 - (view) (download)

1 : cchiw 2522 (* Split Functions before code generation process*)
2 :     structure splitHtM = struct
3 :     local
4 :     structure E = Ein
5 :     structure DstIL = MidIL
6 :     structure DstTy = MidILTypes
7 :     structure shift=shiftHtM
8 :     structure P=Printer
9 : cchiw 2525 structure Var = MidIL.Var
10 :     structure HVar = HighIL.Var
11 : cchiw 2522 in
12 :    
13 :    
14 :     fun printA(id,e,arg)=let
15 :     val a=String.concatWith " , " (List.map Var.toString arg)
16 :     in String.concat([(Var.toString id)," ==",P.printerE e, a])
17 :     end
18 :    
19 :     fun printAA(id,e,arg)=let
20 :     val a=String.concatWith " , " (List.map HVar.toString arg)
21 :     in String.concat([(Var.toString id)," ==",P.printerE e, a])
22 :     end
23 :    
24 :    
25 :     fun createEin( params,index, body)=Ein.EIN{params=params, index=index, body=body}
26 :     fun flat xs = List.foldr op@ [] xs
27 :     val counter=ref 0
28 :    
29 :     (*How to create new ein variable*)
30 : cchiw 2525 fun fresh ty=let
31 : cchiw 2522 val ref x=counter
32 :     val m=x+1
33 : cchiw 2525 val x=DstIL.Var.new("Q" ^ Int.toString(m) ,ty)
34 : cchiw 2522 in (counter:=m;x) end
35 :    
36 :     fun createnewb (params,index,args,(id,e))=let
37 :     val (p',b',args)=shift.cleanParams(e,params,args)
38 :     val a=createEin(p',index, b')
39 :     in (id,a,args)
40 :     end
41 :    
42 :     fun createnewP (params,args,(id,e,ix))=let
43 :     val (p',b',args)=shift.cleanParams(e,params,args)
44 :     val a=createEin(p',ix, b')
45 :     in (id,a,args)
46 :     end
47 :    
48 : cchiw 2525 fun findOp e=(case e
49 :     of E.Neg _=>1
50 :     | E.Add _=>1
51 :     | E.Sub _=>1
52 :     | E.Prod _=>1
53 :     | E.Div _=>1
54 :     | E.Sum _ =>1
55 :     | _=>0
56 :     (*end case*))
57 : cchiw 2522
58 :    
59 : cchiw 2525
60 : cchiw 2522 (*Outside Operator is Neg*)
61 : cchiw 2525 fun handleNeg(params, index,e1,args)=let
62 :     val id=ref (length params)
63 :     val n=length index
64 :    
65 :     val ix=List.tabulate (n,fn v=> E.V(v))
66 :     fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
67 :    
68 :     fun divsort(e)= let
69 :     val s=findOp e
70 :     in (case s
71 :     of 0=>(e,[],[],[])
72 :     | _=> let
73 :     val q=fresh(DstTy.TensorTy(index))
74 :     in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end
75 :     (*end case*))
76 :     end
77 :    
78 :     val (lft1, newbies1,params1,args1)=divsort(e1)
79 :     val (p',b',args')= shift.cleanParams(E.Neg(lft1),params@params1, args@args1)
80 :     val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
81 :     in
82 :     (z1,(p',b',args'))
83 :     end
84 :    
85 :    
86 :    
87 :     (*let
88 : cchiw 2522 val id=ref (length params)
89 : cchiw 2525 val n=length index
90 :     val ix=List.tabulate (n,fn v=> E.V(v))
91 : cchiw 2522 fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
92 :    
93 :     fun replace e'= let
94 : cchiw 2525 val q=fresh(DstTy.TensorTy(index))
95 : cchiw 2522 val t=mkTensor 0
96 :     val newbie=createEin( params,index,e')
97 : cchiw 2525 in ([(q,newbie,args)],(params@[E.TEN(1,index)], E.Neg t,[q]))
98 : cchiw 2522 end
99 :    
100 : cchiw 2525 fun sort e1= (case e1
101 :     of E.Add _=> replace e1
102 :     | E.Sub _=> replace e1
103 :     | E.Prod _=> replace e1
104 :     | E.Div _=> replace e1
105 :     | E.Sum _=> replace e1
106 :     | _=>([],params, E.Neg e1, args)
107 :     (*end case*))
108 : cchiw 2522
109 : cchiw 2525 val (newbies1,params1,lft1,args1)=sort e
110 :     val (p',b',args')= shift.cleanParams(lft1,params@params1, args@args1)
111 :     val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
112 :     in
113 :     (z1,(p',b',args'))
114 : cchiw 2522 end
115 :    
116 : cchiw 2525 *)
117 : cchiw 2522
118 :    
119 :    
120 : cchiw 2525
121 : cchiw 2522 (*Outside Operator is Add*)
122 :     fun handleAdd(params, index,list1,args)=let
123 :     val id=ref (length params)
124 :     val n=length index
125 :     val ix=List.tabulate (n,fn v=> E.V(v))
126 :     fun mkTensor _=let val ref idx= id in (id:=(idx+1);[E.Tensor(idx,ix)]) end
127 :    
128 :     fun foundOp(e,es,(lft,newbies,params,args))=let
129 : cchiw 2525 val q=fresh(DstTy.TensorTy(index))
130 :     in (es,(lft@(mkTensor 0), newbies@[(q, e)],params@[E.TEN(1,index)],args@[q]))
131 : cchiw 2522 end
132 :    
133 :    
134 :     fun sort([], m)=m
135 :     | sort(e::es,m)=(case e
136 :     of E.Add p => sort(p@es, m)
137 :     | E.Sub _=>sort (foundOp(e, es,m))
138 :     | E.Prod _=>sort (foundOp(e, es,m))
139 :     | E.Div _=>sort (foundOp(e, es,m))
140 :     | E.Neg _=>sort (foundOp(e, es,m))
141 :     | E.Sum _=>sort (foundOp(e, es,m))
142 :     | _ => let
143 :     val (l,n, p, a)=m
144 :     in sort(es,(l@[e],n,p,a)) end
145 :     (*end case *))
146 :    
147 :     val (lft, newbies,params',args')=sort(list1,([],[],[],[]))
148 :     val (p',b',args')= shift.cleanParams(E.Add(lft),params@params', args@args')
149 :     val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
150 :     in
151 :     (z,(p',b',args'))
152 :     end
153 :    
154 :    
155 :    
156 :    
157 :    
158 : cchiw 2525
159 : cchiw 2522 (*Outside Operator is Sub*)
160 :     fun handleSub(params, index,e1,e2,args)=let
161 :     val id=ref (length params)
162 :     val n=length index
163 :     val ix=List.tabulate (n,fn v=> E.V(v))
164 :    
165 :     fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
166 :    
167 :     fun subsort(e)= let
168 :     val s=findOp e
169 :     in (case s
170 :     of 0=>(e,[],[],[])
171 :    
172 :     | _=> let
173 : cchiw 2525 val q=fresh(DstTy.TensorTy(index))
174 :     in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end
175 : cchiw 2522 (*end case*))
176 :     end
177 :    
178 :     val (lft1, newbies1,params1,args1)=subsort e1
179 :     val (lft2, newbies2,params2,args2)=subsort e2
180 :     val (p',b',args')= shift.cleanParams(E.Sub(lft1,lft2), params@params1@params2, args@args1@args2)
181 :     val newbies=newbies1@newbies2
182 :     val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
183 :     in
184 :     (z,(p',b',args'))
185 :     end
186 :    
187 :     (*Outside Operator is Div *)
188 :     fun handleDiv(params, index,e1,e2,args)=let
189 :     val id=ref (length params)
190 : cchiw 2554 val ix=List.tabulate (n,fn v=> E.V(v))
191 : cchiw 2522 fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
192 :     fun mkSca _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,[])) end
193 :    
194 :     fun divsort(e,nextfn)= let
195 :     val s=findOp e
196 :     in (case s
197 :     of 0=>(e,[],[],[])
198 :    
199 :     | _=> let
200 : cchiw 2525 val q=fresh(DstTy.TensorTy(index))
201 :     in (nextfn 0, [(q, e)],[E.TEN(1,index)],[q]) end
202 : cchiw 2522 (*end case*))
203 :     end
204 :    
205 :     val (lft1, newbies1,params1,args1)=divsort(e1,mkTensor)
206 :     val (lft2, newbies2,params2,args2)=divsort(e2,mkSca)
207 :     val (p',b',args')= shift.cleanParams(E.Div(lft1,lft2),params@params1@params2, args@args1@args2)
208 :    
209 :     val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
210 :     val z2=List.map (fn(e)=> createnewb(params,[],args,e) ) newbies2
211 :     in
212 :     (z1@z2,(p',b',args'))
213 :     end
214 :    
215 :    
216 :    
217 : cchiw 2525
218 : cchiw 2522 fun hProd(params, index,list1,args)=let
219 :     val id=ref (length params)
220 :     val n=length index
221 :    
222 :     fun mkPTensor e=let
223 :     val ref idx= id
224 :     val (ix,index',e')=shift.cleanIndex(e, n, index)
225 :     in (id:=(idx+1);([E.Tensor(idx,ix)],index',e')) end
226 :    
227 :    
228 :     fun foundOp (e,es,(lft,newbies, params, args))=let
229 : cchiw 2525
230 :     val ref idx= id
231 :     val (ix,index',e')=shift.cleanIndex(e, n, index)
232 :     val (p,ix,e')=([E.Tensor(idx,ix)],index',e')
233 :     val q=fresh(DstTy.TensorTy(ix))
234 :     in (es,(lft@p, newbies@[(q, e',ix)],params@[E.TEN(1,ix)],args@[q])) end
235 : cchiw 2522
236 :    
237 :     fun sort([], m)=m
238 :     | sort(e::es,m)=(case e
239 :     of E.Add _ => sort (foundOp(e,es,m))
240 :     | E.Sub _=>sort (foundOp(e, es,m))
241 :     | E.Prod p=>sort (p@es,m)
242 :     | E.Div _=>sort (foundOp(e, es,m))
243 :     | E.Neg _=>sort (foundOp(e, es,m))
244 :     | E.Sum _=>sort (foundOp(e, es,m))
245 :     | E.Probe _=> raise Fail("Probe- Should have been expanded")
246 :     | _ => let
247 :     val (l,n, p, a)=m
248 :     in sort(es,(l@[e],n,p,a)) end
249 :     (*end case *))
250 :    
251 :     in
252 :     sort(list1,([],[],[],[]))
253 :     end
254 :    
255 :     fun handleProd(params, index,list1,args)=let
256 :     val (lft, newbies,params',args')=hProd(params, index,list1,args)
257 :     val (p',b',args')= shift.cleanParams(E.Prod lft, params@params', args@args')
258 :     val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
259 :     in
260 :     (z,(p',b',args'))
261 :     end
262 :    
263 :     fun handleSumProd(params, ind,sx,list1,args)=let
264 :     val id=ref (length params)
265 :     val n=length ind
266 : cchiw 2553 val m=print (String.concat["\n In Sum Prod", "n",Int.toString(n)])
267 : cchiw 2522
268 :    
269 :    
270 : cchiw 2553 fun g(lft,[],_)=(1,lft) (*lft-outer index*)
271 :     | g(lft,(E.V s,0,ub)::es,n')=if(s=n') then (print "match";g(lft@[ub],es,n'+1)) else (0,[])
272 :     | g _ =(0,[]) (*Can't be split, weird bound*)
273 : cchiw 2522
274 :     val (c,index')= g([],sx,n)
275 :    
276 :     in case c
277 :     of 0=> ([],(params,E.Sum(sx, E.Prod(list1)),args))
278 :     |_=>let
279 :     val index=ind@index'
280 :     val (lft, newbies,params',args')=hProd(params, index,list1,args)
281 :     val (p',b',args')= shift.cleanParams(E.Sum(sx,E.Prod lft), params@params', args@args')
282 :     val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
283 :     in
284 :     (z,(p',b',args'))
285 :     end
286 :     end
287 :    
288 :    
289 :    
290 :    
291 :    
292 :     fun genfn(id,Ein.EIN{params, index, body},args)= let
293 :    
294 :     val notDone=([],(params,body,args))
295 :     fun gen body=(case body
296 :     of E.Field _ =>raise Fail(concat["Invalid Field here "] )
297 :     | E.Partial _ =>raise Fail(concat["Invalid Partial here "] )
298 :     | E.Apply _ =>raise Fail(concat["Invalid Apply here "] )
299 :     | E.Probe _ => raise Fail("Probe- Should have been expanded")
300 :     | E.Conv _ =>notDone
301 :     | E.Krn _ =>notDone
302 :     | E.Img _=> notDone
303 :     | E.Const _=> notDone
304 :     | E.Tensor(id,[])=> notDone
305 :     | E.Prod(E.Img _ :: _)=>notDone
306 :     | E.Neg(E.Neg e)=> gen e
307 :     | E.Neg e=> handleNeg(params, index,e,args)
308 : cchiw 2553 | E.Add a => (print "Add";handleAdd(params, index,a,args))
309 : cchiw 2522 | E.Sub(E.Sub(a,b),E.Sub(c,d))=> gen(E.Sub(E.Add[a,d],E.Add[b,c]))
310 :     | E.Sub(E.Sub(a,b),e2)=>gen (E.Sub(a,E.Add[b,e2]))
311 :     | E.Sub(e1,E.Sub(c,d))=>gen(E.Add([E.Sub(e1,c),d]))
312 : cchiw 2553 | E.Sub(e1,e2)=>(print "SUB";handleSub(params, index,e1,e2,args))
313 : cchiw 2522 | E.Div(E.Div(a,b),E.Div(c,d))=> gen(E.Div(E.Prod[a,d],E.Prod[b,c]))
314 :     | E.Div(E.Div(a,b),c)=> gen(E.Div(a, E.Prod[b,c]))
315 :     | E.Div(a,E.Div(b,c))=> gen(E.Div(E.Prod[a,c],b))
316 :     | E.Div(e1,e2)=>handleDiv(params, index,e1,e2,args)
317 : cchiw 2553 | E.Prod e=> (print "PROD ";handleProd(params, index,e,args))
318 : cchiw 2522
319 :     | E.Sum(_,E.Prod(E.Img _ :: _ ))=>notDone
320 : cchiw 2553 | E.Sum(sx,E.Prod e)=>(print "CAT"; handleSumProd(params, index,sx,e,args))
321 : cchiw 2523
322 :    
323 : cchiw 2522 | _=> notDone
324 :     (*end case*))
325 :    
326 :    
327 :     val (newbie,(p,b,arg))= gen body
328 :     val e'=createEin(p,index, b)
329 :    
330 :    
331 :     val f= (id,e',arg)
332 :     in (newbie, f)
333 :     end
334 :    
335 :    
336 :    
337 :     fun splitIt (change,e)=let
338 :     val (newbie, e')= genfn e
339 :     in (case length(newbie)
340 :     of 0=>(change,[e'])
341 :     | _=> let
342 :     val a=List.map (fn(e1)=>splitIt(1,e1)) newbie
343 :     val newbie'=flat(List.map (fn(e1,e2)=>e2) a)
344 :     in (1,newbie'@[e']) end
345 :     (*end case *))
346 :    
347 :     end
348 :    
349 :     fun splitein(id,E.EIN{params,index,body},arg)=let
350 :     val m=print(printA(id,E.EIN{params=params,index=index,body=body},arg))
351 :     val g=print "\n \t changed to =>\n \t"
352 :     val (p',i',b',args')=shiftHtM.clean(params, index, body, arg)
353 :     val einn'=createEin(p',i', b')
354 :     val m=print(printA(id,einn',args'))
355 :    
356 :     in
357 :     splitIt(0,(id,einn',args'))
358 :     end
359 :    
360 :    
361 :    
362 :    
363 :    
364 :     end (* local *)
365 :    
366 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0