1 : |
cchiw |
2522 |
(* Split Functions before code generation process*)
|
2 : |
|
|
structure splitHtM = struct
|
3 : |
|
|
local
|
4 : |
|
|
structure E = Ein
|
5 : |
|
|
structure DstIL = MidIL
|
6 : |
|
|
structure DstTy = MidILTypes
|
7 : |
|
|
structure shift=shiftHtM
|
8 : |
|
|
structure P=Printer
|
9 : |
cchiw |
2525 |
structure Var = MidIL.Var
|
10 : |
|
|
structure HVar = HighIL.Var
|
11 : |
cchiw |
2522 |
in
|
12 : |
|
|
|
13 : |
|
|
|
14 : |
|
|
fun printA(id,e,arg)=let
|
15 : |
|
|
val a=String.concatWith " , " (List.map Var.toString arg)
|
16 : |
|
|
in String.concat([(Var.toString id)," ==",P.printerE e, a])
|
17 : |
|
|
end
|
18 : |
|
|
|
19 : |
|
|
fun printAA(id,e,arg)=let
|
20 : |
|
|
val a=String.concatWith " , " (List.map HVar.toString arg)
|
21 : |
|
|
in String.concat([(Var.toString id)," ==",P.printerE e, a])
|
22 : |
|
|
end
|
23 : |
|
|
|
24 : |
|
|
|
25 : |
|
|
fun createEin( params,index, body)=Ein.EIN{params=params, index=index, body=body}
|
26 : |
|
|
fun flat xs = List.foldr op@ [] xs
|
27 : |
|
|
val counter=ref 0
|
28 : |
|
|
|
29 : |
|
|
(*How to create new ein variable*)
|
30 : |
cchiw |
2525 |
fun fresh ty=let
|
31 : |
cchiw |
2522 |
val ref x=counter
|
32 : |
|
|
val m=x+1
|
33 : |
cchiw |
2525 |
val x=DstIL.Var.new("Q" ^ Int.toString(m) ,ty)
|
34 : |
cchiw |
2522 |
in (counter:=m;x) end
|
35 : |
|
|
|
36 : |
|
|
fun createnewb (params,index,args,(id,e))=let
|
37 : |
|
|
val (p',b',args)=shift.cleanParams(e,params,args)
|
38 : |
|
|
val a=createEin(p',index, b')
|
39 : |
|
|
in (id,a,args)
|
40 : |
|
|
end
|
41 : |
|
|
|
42 : |
|
|
fun createnewP (params,args,(id,e,ix))=let
|
43 : |
|
|
val (p',b',args)=shift.cleanParams(e,params,args)
|
44 : |
|
|
val a=createEin(p',ix, b')
|
45 : |
|
|
in (id,a,args)
|
46 : |
|
|
end
|
47 : |
|
|
|
48 : |
cchiw |
2525 |
fun findOp e=(case e
|
49 : |
|
|
of E.Neg _=>1
|
50 : |
|
|
| E.Add _=>1
|
51 : |
|
|
| E.Sub _=>1
|
52 : |
|
|
| E.Prod _=>1
|
53 : |
|
|
| E.Div _=>1
|
54 : |
|
|
| E.Sum _ =>1
|
55 : |
|
|
| _=>0
|
56 : |
|
|
(*end case*))
|
57 : |
cchiw |
2522 |
|
58 : |
|
|
|
59 : |
cchiw |
2525 |
|
60 : |
cchiw |
2522 |
(*Outside Operator is Neg*)
|
61 : |
cchiw |
2525 |
fun handleNeg(params, index,e1,args)=let
|
62 : |
|
|
val id=ref (length params)
|
63 : |
|
|
val n=length index
|
64 : |
|
|
|
65 : |
|
|
val ix=List.tabulate (n,fn v=> E.V(v))
|
66 : |
|
|
fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
|
67 : |
|
|
|
68 : |
|
|
fun divsort(e)= let
|
69 : |
|
|
val s=findOp e
|
70 : |
|
|
in (case s
|
71 : |
|
|
of 0=>(e,[],[],[])
|
72 : |
|
|
| _=> let
|
73 : |
|
|
val q=fresh(DstTy.TensorTy(index))
|
74 : |
|
|
in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end
|
75 : |
|
|
(*end case*))
|
76 : |
|
|
end
|
77 : |
|
|
|
78 : |
|
|
val (lft1, newbies1,params1,args1)=divsort(e1)
|
79 : |
|
|
val (p',b',args')= shift.cleanParams(E.Neg(lft1),params@params1, args@args1)
|
80 : |
|
|
val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
|
81 : |
|
|
in
|
82 : |
|
|
(z1,(p',b',args'))
|
83 : |
|
|
end
|
84 : |
|
|
|
85 : |
|
|
|
86 : |
|
|
|
87 : |
|
|
(*let
|
88 : |
cchiw |
2522 |
val id=ref (length params)
|
89 : |
cchiw |
2525 |
val n=length index
|
90 : |
|
|
val ix=List.tabulate (n,fn v=> E.V(v))
|
91 : |
cchiw |
2522 |
fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
|
92 : |
|
|
|
93 : |
|
|
fun replace e'= let
|
94 : |
cchiw |
2525 |
val q=fresh(DstTy.TensorTy(index))
|
95 : |
cchiw |
2522 |
val t=mkTensor 0
|
96 : |
|
|
val newbie=createEin( params,index,e')
|
97 : |
cchiw |
2525 |
in ([(q,newbie,args)],(params@[E.TEN(1,index)], E.Neg t,[q]))
|
98 : |
cchiw |
2522 |
end
|
99 : |
|
|
|
100 : |
cchiw |
2525 |
fun sort e1= (case e1
|
101 : |
|
|
of E.Add _=> replace e1
|
102 : |
|
|
| E.Sub _=> replace e1
|
103 : |
|
|
| E.Prod _=> replace e1
|
104 : |
|
|
| E.Div _=> replace e1
|
105 : |
|
|
| E.Sum _=> replace e1
|
106 : |
|
|
| _=>([],params, E.Neg e1, args)
|
107 : |
|
|
(*end case*))
|
108 : |
cchiw |
2522 |
|
109 : |
cchiw |
2525 |
val (newbies1,params1,lft1,args1)=sort e
|
110 : |
|
|
val (p',b',args')= shift.cleanParams(lft1,params@params1, args@args1)
|
111 : |
|
|
val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
|
112 : |
|
|
in
|
113 : |
|
|
(z1,(p',b',args'))
|
114 : |
cchiw |
2522 |
end
|
115 : |
|
|
|
116 : |
cchiw |
2525 |
*)
|
117 : |
cchiw |
2522 |
|
118 : |
|
|
|
119 : |
|
|
|
120 : |
cchiw |
2525 |
|
121 : |
cchiw |
2522 |
(*Outside Operator is Add*)
|
122 : |
|
|
fun handleAdd(params, index,list1,args)=let
|
123 : |
|
|
val id=ref (length params)
|
124 : |
|
|
val n=length index
|
125 : |
|
|
val ix=List.tabulate (n,fn v=> E.V(v))
|
126 : |
|
|
fun mkTensor _=let val ref idx= id in (id:=(idx+1);[E.Tensor(idx,ix)]) end
|
127 : |
|
|
|
128 : |
|
|
fun foundOp(e,es,(lft,newbies,params,args))=let
|
129 : |
cchiw |
2525 |
val q=fresh(DstTy.TensorTy(index))
|
130 : |
|
|
in (es,(lft@(mkTensor 0), newbies@[(q, e)],params@[E.TEN(1,index)],args@[q]))
|
131 : |
cchiw |
2522 |
end
|
132 : |
|
|
|
133 : |
|
|
|
134 : |
|
|
fun sort([], m)=m
|
135 : |
|
|
| sort(e::es,m)=(case e
|
136 : |
|
|
of E.Add p => sort(p@es, m)
|
137 : |
|
|
| E.Sub _=>sort (foundOp(e, es,m))
|
138 : |
|
|
| E.Prod _=>sort (foundOp(e, es,m))
|
139 : |
|
|
| E.Div _=>sort (foundOp(e, es,m))
|
140 : |
|
|
| E.Neg _=>sort (foundOp(e, es,m))
|
141 : |
|
|
| E.Sum _=>sort (foundOp(e, es,m))
|
142 : |
|
|
| _ => let
|
143 : |
|
|
val (l,n, p, a)=m
|
144 : |
|
|
in sort(es,(l@[e],n,p,a)) end
|
145 : |
|
|
(*end case *))
|
146 : |
|
|
|
147 : |
|
|
val (lft, newbies,params',args')=sort(list1,([],[],[],[]))
|
148 : |
|
|
val (p',b',args')= shift.cleanParams(E.Add(lft),params@params', args@args')
|
149 : |
|
|
val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
|
150 : |
|
|
in
|
151 : |
|
|
(z,(p',b',args'))
|
152 : |
|
|
end
|
153 : |
|
|
|
154 : |
|
|
|
155 : |
|
|
|
156 : |
|
|
|
157 : |
|
|
|
158 : |
cchiw |
2525 |
|
159 : |
cchiw |
2522 |
(*Outside Operator is Sub*)
|
160 : |
|
|
fun handleSub(params, index,e1,e2,args)=let
|
161 : |
|
|
val id=ref (length params)
|
162 : |
|
|
val n=length index
|
163 : |
|
|
val ix=List.tabulate (n,fn v=> E.V(v))
|
164 : |
|
|
|
165 : |
|
|
fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
|
166 : |
|
|
|
167 : |
|
|
fun subsort(e)= let
|
168 : |
|
|
val s=findOp e
|
169 : |
|
|
in (case s
|
170 : |
|
|
of 0=>(e,[],[],[])
|
171 : |
|
|
|
172 : |
|
|
| _=> let
|
173 : |
cchiw |
2525 |
val q=fresh(DstTy.TensorTy(index))
|
174 : |
|
|
in (mkTensor 0, [(q, e)],[E.TEN(1,index)],[q]) end
|
175 : |
cchiw |
2522 |
(*end case*))
|
176 : |
|
|
end
|
177 : |
|
|
|
178 : |
|
|
val (lft1, newbies1,params1,args1)=subsort e1
|
179 : |
|
|
val (lft2, newbies2,params2,args2)=subsort e2
|
180 : |
|
|
val (p',b',args')= shift.cleanParams(E.Sub(lft1,lft2), params@params1@params2, args@args1@args2)
|
181 : |
|
|
val newbies=newbies1@newbies2
|
182 : |
|
|
val z=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies
|
183 : |
|
|
in
|
184 : |
|
|
(z,(p',b',args'))
|
185 : |
|
|
end
|
186 : |
|
|
|
187 : |
|
|
(*Outside Operator is Div *)
|
188 : |
|
|
fun handleDiv(params, index,e1,e2,args)=let
|
189 : |
|
|
val id=ref (length params)
|
190 : |
|
|
val ix=List.map (fn e=> E.V e) index
|
191 : |
|
|
fun mkTensor _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,ix)) end
|
192 : |
|
|
fun mkSca _=let val ref idx= id in (id:=(idx+1);E.Tensor(idx,[])) end
|
193 : |
|
|
|
194 : |
|
|
fun divsort(e,nextfn)= let
|
195 : |
|
|
val s=findOp e
|
196 : |
|
|
in (case s
|
197 : |
|
|
of 0=>(e,[],[],[])
|
198 : |
|
|
|
199 : |
|
|
| _=> let
|
200 : |
cchiw |
2525 |
val q=fresh(DstTy.TensorTy(index))
|
201 : |
|
|
in (nextfn 0, [(q, e)],[E.TEN(1,index)],[q]) end
|
202 : |
cchiw |
2522 |
(*end case*))
|
203 : |
|
|
end
|
204 : |
|
|
|
205 : |
|
|
val (lft1, newbies1,params1,args1)=divsort(e1,mkTensor)
|
206 : |
|
|
val (lft2, newbies2,params2,args2)=divsort(e2,mkSca)
|
207 : |
|
|
val (p',b',args')= shift.cleanParams(E.Div(lft1,lft2),params@params1@params2, args@args1@args2)
|
208 : |
|
|
|
209 : |
|
|
val z1=List.map (fn(e)=> createnewb(params,index,args,e) ) newbies1
|
210 : |
|
|
val z2=List.map (fn(e)=> createnewb(params,[],args,e) ) newbies2
|
211 : |
|
|
in
|
212 : |
|
|
(z1@z2,(p',b',args'))
|
213 : |
|
|
end
|
214 : |
|
|
|
215 : |
|
|
|
216 : |
|
|
|
217 : |
cchiw |
2525 |
|
218 : |
cchiw |
2522 |
fun hProd(params, index,list1,args)=let
|
219 : |
|
|
val id=ref (length params)
|
220 : |
|
|
val n=length index
|
221 : |
|
|
|
222 : |
|
|
fun mkPTensor e=let
|
223 : |
|
|
val ref idx= id
|
224 : |
|
|
val (ix,index',e')=shift.cleanIndex(e, n, index)
|
225 : |
|
|
in (id:=(idx+1);([E.Tensor(idx,ix)],index',e')) end
|
226 : |
|
|
|
227 : |
|
|
|
228 : |
|
|
fun foundOp (e,es,(lft,newbies, params, args))=let
|
229 : |
cchiw |
2525 |
|
230 : |
|
|
val ref idx= id
|
231 : |
|
|
val (ix,index',e')=shift.cleanIndex(e, n, index)
|
232 : |
|
|
val (p,ix,e')=([E.Tensor(idx,ix)],index',e')
|
233 : |
|
|
val q=fresh(DstTy.TensorTy(ix))
|
234 : |
|
|
in (es,(lft@p, newbies@[(q, e',ix)],params@[E.TEN(1,ix)],args@[q])) end
|
235 : |
cchiw |
2522 |
|
236 : |
|
|
|
237 : |
|
|
fun sort([], m)=m
|
238 : |
|
|
| sort(e::es,m)=(case e
|
239 : |
|
|
of E.Add _ => sort (foundOp(e,es,m))
|
240 : |
|
|
| E.Sub _=>sort (foundOp(e, es,m))
|
241 : |
|
|
| E.Prod p=>sort (p@es,m)
|
242 : |
|
|
| E.Div _=>sort (foundOp(e, es,m))
|
243 : |
|
|
| E.Neg _=>sort (foundOp(e, es,m))
|
244 : |
|
|
| E.Sum _=>sort (foundOp(e, es,m))
|
245 : |
|
|
| E.Probe _=> raise Fail("Probe- Should have been expanded")
|
246 : |
|
|
| _ => let
|
247 : |
|
|
val (l,n, p, a)=m
|
248 : |
|
|
in sort(es,(l@[e],n,p,a)) end
|
249 : |
|
|
(*end case *))
|
250 : |
|
|
|
251 : |
|
|
in
|
252 : |
|
|
sort(list1,([],[],[],[]))
|
253 : |
|
|
end
|
254 : |
|
|
|
255 : |
|
|
fun handleProd(params, index,list1,args)=let
|
256 : |
|
|
val (lft, newbies,params',args')=hProd(params, index,list1,args)
|
257 : |
|
|
val (p',b',args')= shift.cleanParams(E.Prod lft, params@params', args@args')
|
258 : |
|
|
val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
|
259 : |
|
|
in
|
260 : |
|
|
(z,(p',b',args'))
|
261 : |
|
|
end
|
262 : |
|
|
|
263 : |
|
|
fun handleSumProd(params, ind,sx,list1,args)=let
|
264 : |
|
|
val id=ref (length params)
|
265 : |
|
|
val n=length ind
|
266 : |
cchiw |
2553 |
val m=print (String.concat["\n In Sum Prod", "n",Int.toString(n)])
|
267 : |
cchiw |
2522 |
|
268 : |
|
|
|
269 : |
|
|
|
270 : |
cchiw |
2553 |
fun g(lft,[],_)=(1,lft) (*lft-outer index*)
|
271 : |
|
|
| g(lft,(E.V s,0,ub)::es,n')=if(s=n') then (print "match";g(lft@[ub],es,n'+1)) else (0,[])
|
272 : |
|
|
| g _ =(0,[]) (*Can't be split, weird bound*)
|
273 : |
cchiw |
2522 |
|
274 : |
|
|
val (c,index')= g([],sx,n)
|
275 : |
|
|
|
276 : |
|
|
in case c
|
277 : |
|
|
of 0=> ([],(params,E.Sum(sx, E.Prod(list1)),args))
|
278 : |
|
|
|_=>let
|
279 : |
|
|
val index=ind@index'
|
280 : |
|
|
val (lft, newbies,params',args')=hProd(params, index,list1,args)
|
281 : |
|
|
val (p',b',args')= shift.cleanParams(E.Sum(sx,E.Prod lft), params@params', args@args')
|
282 : |
|
|
val z=List.map (fn(e)=> createnewP(params,args,e) ) newbies
|
283 : |
|
|
in
|
284 : |
|
|
(z,(p',b',args'))
|
285 : |
|
|
end
|
286 : |
|
|
end
|
287 : |
|
|
|
288 : |
|
|
|
289 : |
|
|
|
290 : |
|
|
|
291 : |
|
|
|
292 : |
|
|
fun genfn(id,Ein.EIN{params, index, body},args)= let
|
293 : |
|
|
|
294 : |
|
|
val notDone=([],(params,body,args))
|
295 : |
|
|
fun gen body=(case body
|
296 : |
|
|
of E.Field _ =>raise Fail(concat["Invalid Field here "] )
|
297 : |
|
|
| E.Partial _ =>raise Fail(concat["Invalid Partial here "] )
|
298 : |
|
|
| E.Apply _ =>raise Fail(concat["Invalid Apply here "] )
|
299 : |
|
|
| E.Probe _ => raise Fail("Probe- Should have been expanded")
|
300 : |
|
|
| E.Conv _ =>notDone
|
301 : |
|
|
| E.Krn _ =>notDone
|
302 : |
|
|
| E.Img _=> notDone
|
303 : |
|
|
| E.Const _=> notDone
|
304 : |
|
|
| E.Tensor(id,[])=> notDone
|
305 : |
|
|
| E.Prod(E.Img _ :: _)=>notDone
|
306 : |
|
|
| E.Neg(E.Neg e)=> gen e
|
307 : |
|
|
| E.Neg e=> handleNeg(params, index,e,args)
|
308 : |
cchiw |
2553 |
| E.Add a => (print "Add";handleAdd(params, index,a,args))
|
309 : |
cchiw |
2522 |
| E.Sub(E.Sub(a,b),E.Sub(c,d))=> gen(E.Sub(E.Add[a,d],E.Add[b,c]))
|
310 : |
|
|
| E.Sub(E.Sub(a,b),e2)=>gen (E.Sub(a,E.Add[b,e2]))
|
311 : |
|
|
| E.Sub(e1,E.Sub(c,d))=>gen(E.Add([E.Sub(e1,c),d]))
|
312 : |
cchiw |
2553 |
| E.Sub(e1,e2)=>(print "SUB";handleSub(params, index,e1,e2,args))
|
313 : |
cchiw |
2522 |
| E.Div(E.Div(a,b),E.Div(c,d))=> gen(E.Div(E.Prod[a,d],E.Prod[b,c]))
|
314 : |
|
|
| E.Div(E.Div(a,b),c)=> gen(E.Div(a, E.Prod[b,c]))
|
315 : |
|
|
| E.Div(a,E.Div(b,c))=> gen(E.Div(E.Prod[a,c],b))
|
316 : |
|
|
| E.Div(e1,e2)=>handleDiv(params, index,e1,e2,args)
|
317 : |
cchiw |
2553 |
| E.Prod e=> (print "PROD ";handleProd(params, index,e,args))
|
318 : |
cchiw |
2522 |
|
319 : |
|
|
| E.Sum(_,E.Prod(E.Img _ :: _ ))=>notDone
|
320 : |
cchiw |
2553 |
| E.Sum(sx,E.Prod e)=>(print "CAT"; handleSumProd(params, index,sx,e,args))
|
321 : |
cchiw |
2523 |
|
322 : |
|
|
|
323 : |
cchiw |
2522 |
| _=> notDone
|
324 : |
|
|
(*end case*))
|
325 : |
|
|
|
326 : |
|
|
|
327 : |
|
|
val (newbie,(p,b,arg))= gen body
|
328 : |
|
|
val e'=createEin(p,index, b)
|
329 : |
|
|
|
330 : |
|
|
|
331 : |
|
|
val f= (id,e',arg)
|
332 : |
|
|
in (newbie, f)
|
333 : |
|
|
end
|
334 : |
|
|
|
335 : |
|
|
|
336 : |
|
|
|
337 : |
|
|
fun splitIt (change,e)=let
|
338 : |
|
|
val (newbie, e')= genfn e
|
339 : |
|
|
in (case length(newbie)
|
340 : |
|
|
of 0=>(change,[e'])
|
341 : |
|
|
| _=> let
|
342 : |
|
|
val a=List.map (fn(e1)=>splitIt(1,e1)) newbie
|
343 : |
|
|
val newbie'=flat(List.map (fn(e1,e2)=>e2) a)
|
344 : |
|
|
in (1,newbie'@[e']) end
|
345 : |
|
|
(*end case *))
|
346 : |
|
|
|
347 : |
|
|
end
|
348 : |
|
|
|
349 : |
|
|
fun splitein(id,E.EIN{params,index,body},arg)=let
|
350 : |
|
|
val m=print(printA(id,E.EIN{params=params,index=index,body=body},arg))
|
351 : |
|
|
val g=print "\n \t changed to =>\n \t"
|
352 : |
|
|
val (p',i',b',args')=shiftHtM.clean(params, index, body, arg)
|
353 : |
|
|
val einn'=createEin(p',i', b')
|
354 : |
|
|
val m=print(printA(id,einn',args'))
|
355 : |
|
|
|
356 : |
|
|
in
|
357 : |
|
|
splitIt(0,(id,einn',args'))
|
358 : |
|
|
end
|
359 : |
|
|
|
360 : |
|
|
|
361 : |
|
|
|
362 : |
|
|
|
363 : |
|
|
|
364 : |
|
|
end (* local *)
|
365 : |
|
|
|
366 : |
|
|
end
|