Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-il/move-sums.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-il/move-sums.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2838 - (view) (download)

1 : cchiw 2605 (* Funtions push summations downt to necessary expressions *)
2 :    
3 :     structure SummationEin = struct
4 :    
5 :     local
6 :    
7 :     structure E = Ein
8 :     structure P=Printer
9 :     structure F=Filter
10 :    
11 :     in
12 :    
13 :     fun rewriteProd A=(case A
14 :     of [A]=> A
15 :     | A => E.Prod A
16 :     (*end case*))
17 :    
18 :     fun rewriteSum(c,p)= E.Sum(c, rewriteProd p)
19 :    
20 : cchiw 2838
21 : cchiw 2605 fun embed(A, c1, B, c2, C)=let
22 :     val C'=rewriteSum(c2,C)
23 :     in (A, c1,B@[C'])
24 :     end
25 :    
26 : cchiw 2838 fun rewriteProdSum(pre,_,[])=rewriteProd pre
27 :     | rewriteProdSum(pre,outer,post)=rewriteProd(pre@[rewriteSum(outer,post)])
28 : cchiw 2605
29 : cchiw 2838
30 : cchiw 2605 (* return pre, outer Sum, post*)
31 :     fun splitMultipleSum(c1,c2,pre,post)=(case (pre,post)
32 :     of (A, []) => let
33 :     val (pre,post)= F.splitSum(c1,A)
34 :     in (pre,[c1],post)
35 :     end
36 :     | ([],B) => (case F.splitSum(c1,B)
37 :     of ([],D) => ([],[c1]@c2,D)
38 :     | (C,[]) => ([],c2,C)
39 :     | (C,D) => embed([],c2,C,[c1],D)
40 :     (*end case*))
41 :     | (A,B) => (case (F.splitSum(c1,A),F.splitSum(c1,B))
42 :     of ((C,[]),(E,[])) => (C ,c2,E)
43 :     | ((C,D),(E,[])) => (C@[rewriteSum([c1],D)], c2,E)
44 :     | ((C,[]),(E,F)) => embed(C,c2,E,[c1],F)
45 : cchiw 2608 | ((C,D),_)=> embed(C,[c1],D,c2,B)
46 : cchiw 2605 (*end case*))
47 :     (*end case*))
48 :    
49 :     fun shiftSum(sx,e)= let
50 :     val sx'=List.rev(sx)
51 :     val c2=List.hd(sx')
52 :     val (A,B)=F.splitSum(c2,e)
53 :    
54 : cchiw 2838 fun double([], outer, pre, post)= rewriteProdSum(pre,outer,post)
55 : cchiw 2605 | double(c1::cs, outer, pre, post)= let
56 :     val (pre',outer', post')=splitMultipleSum(c1,outer,pre,post)
57 :     in double(cs, outer',pre',post')
58 :     end
59 :     in double(List.tl(sx'),[c2],A,B)
60 :     end
61 :    
62 :     (*Move Summation indices around *)
63 :     fun cleanSummation (Ein.EIN{params, index, body}) = let
64 :     fun rewriteBody body =(case body
65 :     of E.Const _ => body
66 :     | E.Tensor _ => body
67 :     | E.Field _ => body
68 :     | E.Delta _ => body
69 :     | E.Epsilon _ => body
70 :     | E.Conv _ => body
71 :     | E.Partial _ => body
72 :     | E.Krn _ => raise Fail"Krn before Expand"
73 :     | E.Img _ => raise Fail"Img before Expand"
74 :     | E.Value _ => raise Fail"Value before Expand"
75 :     | E.Neg e => E.Neg(rewriteBody e)
76 :     | E.Add es => E.Add(List.map (fn e=> rewriteBody e) es)
77 :     | E.Sub(e1,e2) => E.Sub(rewriteBody e1, rewriteBody e2)
78 :     | E.Prod es => E.Prod(List.map (fn e=> rewriteBody e) es)
79 :     | E.Div(e1,e2) => E.Div(rewriteBody e1, rewriteBody e2)
80 :     | E.Apply(e1,e2) => E.Apply(rewriteBody e1, rewriteBody e2)
81 :     | E.Probe(e1,e2) => E.Probe(e1, rewriteBody e2)
82 :     | E.Lift e => E.Lift(rewriteBody e)
83 : cchiw 2838 | E.Sum (sx,E.Prod[e]) => shiftSum(sx,[e])
84 : cchiw 2611 | E.Sum(sx,E.Prod e) => shiftSum(sx,e)
85 :     | E.Sum (sx,e) => shiftSum(sx,[e])
86 : cchiw 2605 (* end case *))
87 :    
88 :     val b=rewriteBody body
89 :     in (Ein.EIN{params=params, index=index, body=b})
90 :     end
91 :    
92 :     (*
93 :     fun tester e=print( String.concat["tester \n",P.printerE(e),"===>",P.printerE(cleanSummation2 e)])
94 :     val v0=E.V 0
95 :     val v1=E.V 1
96 :     val v2=E.V 2
97 :     val vv0=(v0,0,0)
98 :     val vv1=(v1,0,0)
99 :     val vv2=(v2,0,0)
100 :    
101 :     val t0=E.Tensor(0,[v0])
102 :     val t1=E.Tensor(0,[v1])
103 :     val t2=E.Tensor(0,[v2])
104 :    
105 :     val t01=E.Tensor(0,[v0,v1])
106 :     val t12=E.Tensor(0,[v0,v1,v2])
107 :    
108 :     val A= E.EIN{params = [], index = [],
109 :     body = E.Sum([vv0,vv1],E.Prod[t0])}
110 :    
111 :     val B= E.EIN{params = [], index = [],
112 :     body = E.Sum([vv0,vv1],E.Prod[t1])}
113 :    
114 :     val C= E.EIN{params = [], index = [],
115 :     body = E.Sum([vv0,vv1],E.Prod[t0,t1,t2])}
116 :    
117 :     val D= E.EIN{params = [], index = [],
118 :     body = E.Sum([vv0,vv1],E.Prod[t0,t01])}
119 :    
120 :     val E= E.EIN{params = [], index = [],
121 :     body = E.Sum([vv0,vv1],E.Prod[t1,t01])}
122 :    
123 :     val F= E.EIN{params = [], index = [],
124 :     body = E.Sum([vv0,vv1],E.Prod[t0,t1,t2,t01])}
125 :    
126 :     val G= E.EIN{params = [], index = [],
127 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t01,t2])}
128 :    
129 :     val H= E.EIN{params = [], index = [],
130 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t1,t01,t2])}
131 :    
132 :     val I= E.EIN{params = [], index = [],
133 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t1,t2,t01])}
134 :    
135 :    
136 :     val J= E.EIN{params = [], index = [],
137 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t01,t2,t12])}
138 :    
139 :     val K= E.EIN{params = [], index = [],
140 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t1,t01,t2,t12])}
141 :    
142 :     val L= E.EIN{params = [], index = [],
143 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t1,t2,t01,t12])}
144 :    
145 :    
146 :    
147 :     val M= E.EIN{params = [], index = [],
148 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t2,t12])}
149 :    
150 :     val N= E.EIN{params = [], index = [],
151 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t1,t2,t12])}
152 :    
153 :     val O= E.EIN{params = [], index = [],
154 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t1,t2,t12])}
155 :    
156 :    
157 :    
158 :    
159 :     fun Y _=List.map tester [A,B,C,D,E,F,G,H,I,J,K,L,M,N,O ]
160 :    
161 :     fun cleanSummation e=(print "pre";Y 1;cleanSummation2 e)
162 :     *)
163 :     end
164 :    
165 :    
166 :    
167 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0