Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-to-mid/split-einHtM.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-to-mid/split-einHtM.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2584 - (view) (download)

1 : cchiw 2522 (* Split Functions before code generation process*)
2 :     structure splitHtM = struct
3 :     local
4 :     structure E = Ein
5 :     structure DstIL = MidIL
6 :     structure DstTy = MidILTypes
7 :     structure shift=shiftHtM
8 :     structure P=Printer
9 : cchiw 2525 structure Var = MidIL.Var
10 :     structure HVar = HighIL.Var
11 : cchiw 2522 in
12 :    
13 :    
14 :     fun printA(id,e,arg)=let
15 :     val a=String.concatWith " , " (List.map Var.toString arg)
16 :     in String.concat([(Var.toString id)," ==",P.printerE e, a])
17 :     end
18 :    
19 :     fun printAA(id,e,arg)=let
20 :     val a=String.concatWith " , " (List.map HVar.toString arg)
21 :     in String.concat([(Var.toString id)," ==",P.printerE e, a])
22 :     end
23 :    
24 :    
25 :     fun createEin( params,index, body)=Ein.EIN{params=params, index=index, body=body}
26 :     fun flat xs = List.foldr op@ [] xs
27 :     val counter=ref 0
28 :    
29 :     (*How to create new ein variable*)
30 : cchiw 2525 fun fresh ty=let
31 : cchiw 2522 val ref x=counter
32 :     val m=x+1
33 : cchiw 2525 val x=DstIL.Var.new("Q" ^ Int.toString(m) ,ty)
34 : cchiw 2522 in (counter:=m;x) end
35 :    
36 :     fun createnewb (params,index,args,(id,e))=let
37 :     val (p',b',args)=shift.cleanParams(e,params,args)
38 :     val a=createEin(p',index, b')
39 :     in (id,a,args)
40 :     end
41 :    
42 :     fun createnewP (params,args,(id,e,ix))=let
43 :     val (p',b',args)=shift.cleanParams(e,params,args)
44 :     val a=createEin(p',ix, b')
45 :     in (id,a,args)
46 :     end
47 :    
48 : cchiw 2525 fun findOp e=(case e
49 : cchiw 2555 of E.Neg _=>1
50 :     | E.Add _=>1
51 :     | E.Sub _=>1
52 :     | E.Prod _=>1
53 :     | E.Div _=>1
54 :     | E.Sum _ =>1
55 :     | _=>0
56 : cchiw 2525 (*end case*))
57 : cchiw 2522
58 :    
59 : cchiw 2525
60 : cchiw 2522 (*Outside Operator is Neg*)
61 : cchiw 2525 fun handleNeg(params, index,e1,args)=let
62 :     val id=ref (length params)
63 :     val n=length index
64 :    
65 :     val ix=List.tabulate (n,fn v=> E.V(v))
66 :     fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
67 :    
68 :     fun divsort(e)= let
69 :     val s=findOp e
70 :     in (case s
71 :     of 0=>(e,[],[],[])
72 :     | _=> let
73 :     val q=fresh(DstTy.TensorTy(index))
74 :     in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end
75 :     (*end case*))
76 :     end
77 :    
78 :     val (lft1, newbies1,params1,args1)=divsort(e1)
79 :     val (p',b',args')= shift.cleanParams(E.Neg(lft1),params@params1, args@args1)
80 :     val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
81 :     in
82 :     (z1,(p',b',args'))
83 :     end
84 :    
85 :    
86 : cchiw 2522 (*Outside Operator is Add*)
87 :     fun handleAdd(params, index,list1,args)=let
88 : cchiw 2584
89 : cchiw 2522 val id=ref (length params)
90 :     val n=length index
91 :     val ix=List.tabulate (n,fn v=> E.V(v))
92 :     fun mkTensor _=let val ref idx= id in (id:=(idx+1);[E.Tensor(idx,ix)]) end
93 :    
94 :     fun foundOp(e,es,(lft,newbies,params,args))=let
95 : cchiw 2525 val q=fresh(DstTy.TensorTy(index))
96 :     in (es,(lft@(mkTensor 0), newbies@[(q, e)],params@[E.TEN(1,index)],args@[q]))
97 : cchiw 2522 end
98 :    
99 :    
100 :     fun sort([], m)=m
101 :     | sort(e::es,m)=(case e
102 :     of E.Add p => sort(p@es, m)
103 :     | E.Sub _=>sort (foundOp(e, es,m))
104 :     | E.Prod _=>sort (foundOp(e, es,m))
105 :     | E.Div _=>sort (foundOp(e, es,m))
106 :     | E.Neg _=>sort (foundOp(e, es,m))
107 :     | E.Sum _=>sort (foundOp(e, es,m))
108 :     | _ => let
109 :     val (l,n, p, a)=m
110 :     in sort(es,(l@[e],n,p,a)) end
111 :     (*end case *))
112 :    
113 :     val (lft, newbies,params',args')=sort(list1,([],[],[],[]))
114 :     val (p',b',args')= shift.cleanParams(E.Add(lft),params@params', args@args')
115 : cchiw 2555
116 :    
117 : cchiw 2522 val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
118 :     in
119 :     (z,(p',b',args'))
120 :     end
121 :    
122 :    
123 :    
124 :    
125 :    
126 : cchiw 2525
127 : cchiw 2522 (*Outside Operator is Sub*)
128 :     fun handleSub(params, index,e1,e2,args)=let
129 : cchiw 2555 val gg=print "SUBXX"
130 : cchiw 2522 val id=ref (length params)
131 :     val n=length index
132 :     val ix=List.tabulate (n,fn v=> E.V(v))
133 :    
134 :     fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
135 :    
136 :     fun subsort(e)= let
137 :     val s=findOp e
138 :     in (case s
139 :     of 0=>(e,[],[],[])
140 :     | _=> let
141 : cchiw 2525 val q=fresh(DstTy.TensorTy(index))
142 :     in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end
143 : cchiw 2522 (*end case*))
144 :     end
145 :    
146 :     val (lft1, newbies1,params1,args1)=subsort e1
147 :     val (lft2, newbies2,params2,args2)=subsort e2
148 :     val (p',b',args')= shift.cleanParams(E.Sub(lft1,lft2), params@params1@params2, args@args1@args2)
149 :     val newbies=newbies1@newbies2
150 :     val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
151 :     in
152 :     (z,(p',b',args'))
153 :     end
154 :    
155 :     (*Outside Operator is Div *)
156 :     fun handleDiv(params, index,e1,e2,args)=let
157 :     val id=ref (length params)
158 : cchiw 2555 val n=length index
159 : cchiw 2554 val ix=List.tabulate (n,fn v=> E.V(v))
160 : cchiw 2555 fun mkTensor _=let val ref idx= id val _ =id:=(idx+1) in (E.Tensor(idx,ix),E.TEN(1,index)) end
161 :     fun mkSca _= let val ref idx= id val _ =id:=(idx+1) in (E.Tensor(idx,[]),E.TEN(1,[])) end
162 : cchiw 2522
163 :     fun divsort(e,nextfn)= let
164 :     val s=findOp e
165 :     in (case s
166 :     of 0=>(e,[],[],[])
167 :    
168 :     | _=> let
169 : cchiw 2525 val q=fresh(DstTy.TensorTy(index))
170 : cchiw 2555 val (a,b)=nextfn 0
171 :     in (a,[b], [(q, e)],[q]) end
172 : cchiw 2522 (*end case*))
173 :     end
174 :    
175 : cchiw 2555 val (lft1,params1, newbies1,args1)=divsort(e1,mkTensor)
176 :     val (lft2,params2, newbies2,args2)=divsort(e2,mkSca)
177 : cchiw 2522 val (p',b',args')= shift.cleanParams(E.Div(lft1,lft2),params@params1@params2, args@args1@args2)
178 :    
179 :     val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
180 :     val z2=List.map (fn(e)=> createnewb(params,[],args,e) ) newbies2
181 :     in
182 :     (z1@z2,(p',b',args'))
183 :     end
184 :    
185 :    
186 :    
187 : cchiw 2525
188 : cchiw 2522 fun hProd(params, index,list1,args)=let
189 :     val id=ref (length params)
190 :     val n=length index
191 :    
192 : cchiw 2555
193 : cchiw 2522
194 :    
195 :     fun foundOp (e,es,(lft,newbies, params, args))=let
196 : cchiw 2525
197 :     val ref idx= id
198 :     val (ix,index',e')=shift.cleanIndex(e, n, index)
199 :     val (p,ix,e')=([E.Tensor(idx,ix)],index',e')
200 :     val q=fresh(DstTy.TensorTy(ix))
201 : cchiw 2555 in (id:=(idx+1);(es,(lft@p, newbies@[(q, e',ix)],params@[E.TEN(1,ix)],args@[q]))) end
202 :    
203 : cchiw 2522
204 :     fun sort([], m)=m
205 :     | sort(e::es,m)=(case e
206 :     of E.Add _ => sort (foundOp(e,es,m))
207 :     | E.Sub _=>sort (foundOp(e, es,m))
208 : cchiw 2555 | E.Prod p=>(sort (p@es,m))
209 : cchiw 2522 | E.Div _=>sort (foundOp(e, es,m))
210 :     | E.Neg _=>sort (foundOp(e, es,m))
211 :     | E.Sum _=>sort (foundOp(e, es,m))
212 :     | E.Probe _=> raise Fail("Probe- Should have been expanded")
213 :     | _ => let
214 :     val (l,n, p, a)=m
215 :     in sort(es,(l@[e],n,p,a)) end
216 :     (*end case *))
217 :    
218 :     in
219 :     sort(list1,([],[],[],[]))
220 :     end
221 :    
222 :     fun handleProd(params, index,list1,args)=let
223 :     val (lft, newbies,params',args')=hProd(params, index,list1,args)
224 :     val (p',b',args')= shift.cleanParams(E.Prod lft, params@params', args@args')
225 :     val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
226 :     in
227 :     (z,(p',b',args'))
228 :     end
229 :    
230 : cchiw 2584
231 :     (* need to figure this out*)
232 : cchiw 2555 fun handleSumProd(paramsO, indO,sxO,list1O,argsO)=let
233 :     val id=ref (length paramsO)
234 :     val n=length indO
235 : cchiw 2553 val m=print (String.concat["\n In Sum Prod", "n",Int.toString(n)])
236 : cchiw 2555 val (params,ind,E.Sum(sx,E.Prod list1),args)=shiftHtM.clean(paramsO, indO,E.Sum(sxO,E.Prod list1O), argsO)
237 : cchiw 2522
238 : cchiw 2555 fun g(lft,[],_)=(1,lft)
239 :     | g(lft,(E.V s,0,ub)::es,n')=
240 :     if(s=n') then (g(lft@[ub+1],es,n'+1)) else (0,[])
241 : cchiw 2553 | g _ =(0,[]) (*Can't be split, weird bound*)
242 : cchiw 2522 val (c,index')= g([],sx,n)
243 :    
244 :     in case c
245 :     of 0=> ([],(params,E.Sum(sx, E.Prod(list1)),args))
246 :     |_=>let
247 :     val index=ind@index'
248 :     val (lft, newbies,params',args')=hProd(params, index,list1,args)
249 :     val (p',b',args')= shift.cleanParams(E.Sum(sx,E.Prod lft), params@params', args@args')
250 :     val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
251 :     in
252 :     (z,(p',b',args'))
253 :     end
254 :     end
255 :    
256 :    
257 :    
258 :    
259 :    
260 :     fun genfn(id,Ein.EIN{params, index, body},args)= let
261 : cchiw 2584
262 : cchiw 2555 val notDone=([],(params,body,args))
263 : cchiw 2522
264 :     fun gen body=(case body
265 :     of E.Field _ =>raise Fail(concat["Invalid Field here "] )
266 :     | E.Partial _ =>raise Fail(concat["Invalid Partial here "] )
267 :     | E.Apply _ =>raise Fail(concat["Invalid Apply here "] )
268 :     | E.Probe _ => raise Fail("Probe- Should have been expanded")
269 :     | E.Conv _ =>notDone
270 :     | E.Krn _ =>notDone
271 :     | E.Img _=> notDone
272 :     | E.Const _=> notDone
273 :     | E.Tensor(id,[])=> notDone
274 :     | E.Prod(E.Img _ :: _)=>notDone
275 :     | E.Neg(E.Neg e)=> gen e
276 :     | E.Neg e=> handleNeg(params, index,e,args)
277 : cchiw 2555 | E.Add a => (handleAdd(params, index,a,args))
278 : cchiw 2522 | E.Sub(E.Sub(a,b),E.Sub(c,d))=> gen(E.Sub(E.Add[a,d],E.Add[b,c]))
279 :     | E.Sub(E.Sub(a,b),e2)=>gen (E.Sub(a,E.Add[b,e2]))
280 :     | E.Sub(e1,E.Sub(c,d))=>gen(E.Add([E.Sub(e1,c),d]))
281 : cchiw 2555 | E.Sub(e1,e2)=>(handleSub(params, index,e1,e2,args))
282 : cchiw 2522 | E.Div(E.Div(a,b),E.Div(c,d))=> gen(E.Div(E.Prod[a,d],E.Prod[b,c]))
283 :     | E.Div(E.Div(a,b),c)=> gen(E.Div(a, E.Prod[b,c]))
284 :     | E.Div(a,E.Div(b,c))=> gen(E.Div(E.Prod[a,c],b))
285 :     | E.Div(e1,e2)=>handleDiv(params, index,e1,e2,args)
286 : cchiw 2555 | E.Prod e=> (handleProd(params, index,e,args))
287 : cchiw 2522 | E.Sum(_,E.Prod(E.Img _ :: _ ))=>notDone
288 : cchiw 2584 (* | E.Sum(sx,E.Prod e)=>(handleSumProd(params, index,sx,e,args))*)
289 : cchiw 2522 | _=> notDone
290 :     (*end case*))
291 :    
292 :    
293 :     val (newbie,(p,b,arg))= gen body
294 :     val e'=createEin(p,index, b)
295 :    
296 :    
297 :     val f= (id,e',arg)
298 :     in (newbie, f)
299 :     end
300 :    
301 :    
302 :    
303 :     fun splitIt (change,e)=let
304 :     val (newbie, e')= genfn e
305 :     in (case length(newbie)
306 :     of 0=>(change,[e'])
307 :     | _=> let
308 :     val a=List.map (fn(e1)=>splitIt(1,e1)) newbie
309 :     val newbie'=flat(List.map (fn(e1,e2)=>e2) a)
310 :     in (1,newbie'@[e']) end
311 :     (*end case *))
312 :    
313 :     end
314 :    
315 :     fun splitein(id,E.EIN{params,index,body},arg)=let
316 : cchiw 2584 val _=print(String.concat["\n Pre Shift \n", printA(id,E.EIN{params=params,index=index,body=body},arg)])
317 :    
318 : cchiw 2522 val (p',i',b',args')=shiftHtM.clean(params, index, body, arg)
319 :     val einn'=createEin(p',i', b')
320 : cchiw 2584 val _ =print(String.concat[ "\n Post shift-\n ",printA(id,einn',args'),"\n \n"])
321 : cchiw 2522
322 :     in
323 :     splitIt(0,(id,einn',args'))
324 :     end
325 :    
326 :    
327 :    
328 :    
329 :    
330 :     end (* local *)
331 :    
332 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0