Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/high-il/move-sums.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/high-il/move-sums.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2608 - (view) (download)

1 : cchiw 2605 (* Funtions push summations downt to necessary expressions *)
2 :    
3 :     structure SummationEin = struct
4 :    
5 :     local
6 :    
7 :     structure E = Ein
8 :     structure P=Printer
9 :     structure F=Filter
10 :    
11 :     in
12 :    
13 :     fun rewriteProd A=(case A
14 :     of [A]=> A
15 :     | A => E.Prod A
16 :     (*end case*))
17 :    
18 :     fun rewriteSum(c,p)= E.Sum(c, rewriteProd p)
19 :    
20 :     fun embed(A, c1, B, c2, C)=let
21 :     val C'=rewriteSum(c2,C)
22 :     in (A, c1,B@[C'])
23 :     end
24 :    
25 :    
26 :     (* return pre, outer Sum, post*)
27 :     fun splitMultipleSum(c1,c2,pre,post)=(case (pre,post)
28 :     of (A, []) => let
29 :     val (pre,post)= F.splitSum(c1,A)
30 :     in (pre,[c1],post)
31 :     end
32 :     | ([],B) => (case F.splitSum(c1,B)
33 :     of ([],D) => ([],[c1]@c2,D)
34 :     | (C,[]) => ([],c2,C)
35 :     | (C,D) => embed([],c2,C,[c1],D)
36 :     (*end case*))
37 :     | (A,B) => (case (F.splitSum(c1,A),F.splitSum(c1,B))
38 :     of ((C,[]),(E,[])) => (C ,c2,E)
39 :     | ((C,D),(E,[])) => (C@[rewriteSum([c1],D)], c2,E)
40 :     | ((C,[]),(E,F)) => embed(C,c2,E,[c1],F)
41 : cchiw 2608 | ((C,D),_)=> embed(C,[c1],D,c2,B)
42 : cchiw 2605 (*end case*))
43 :     (*end case*))
44 :    
45 :     fun shiftSum(sx,e)= let
46 :     val sx'=List.rev(sx)
47 :     val c2=List.hd(sx')
48 :     val (A,B)=F.splitSum(c2,e)
49 :    
50 :     fun double([], outer, pre, post)= rewriteProd(pre@[rewriteSum(outer,post)])
51 :     | double(c1::cs, outer, pre, post)= let
52 :     val (pre',outer', post')=splitMultipleSum(c1,outer,pre,post)
53 :     in double(cs, outer',pre',post')
54 :     end
55 :     in double(List.tl(sx'),[c2],A,B)
56 :     end
57 :    
58 :     (*Move Summation indices around *)
59 :     fun cleanSummation (Ein.EIN{params, index, body}) = let
60 :     fun rewriteBody body =(case body
61 :     of E.Const _ => body
62 :     | E.Tensor _ => body
63 :     | E.Field _ => body
64 :     | E.Delta _ => body
65 :     | E.Epsilon _ => body
66 :     | E.Conv _ => body
67 :     | E.Partial _ => body
68 :     | E.Krn _ => raise Fail"Krn before Expand"
69 :     | E.Img _ => raise Fail"Img before Expand"
70 :     | E.Value _ => raise Fail"Value before Expand"
71 :     | E.Neg e => E.Neg(rewriteBody e)
72 :     | E.Add es => E.Add(List.map (fn e=> rewriteBody e) es)
73 :     | E.Sub(e1,e2) => E.Sub(rewriteBody e1, rewriteBody e2)
74 :     | E.Prod es => E.Prod(List.map (fn e=> rewriteBody e) es)
75 :     | E.Div(e1,e2) => E.Div(rewriteBody e1, rewriteBody e2)
76 :     | E.Apply(e1,e2) => E.Apply(rewriteBody e1, rewriteBody e2)
77 :     | E.Probe(e1,e2) => E.Probe(e1, rewriteBody e2)
78 :     | E.Lift e => E.Lift(rewriteBody e)
79 :     | E.Sum(sx,E.Prod e) =>shiftSum(sx,e)
80 :     | E.Sum _ => body
81 :     (* end case *))
82 :    
83 :     val b=rewriteBody body
84 :     in (Ein.EIN{params=params, index=index, body=b})
85 :     end
86 :    
87 :     (*
88 :     fun tester e=print( String.concat["tester \n",P.printerE(e),"===>",P.printerE(cleanSummation2 e)])
89 :     val v0=E.V 0
90 :     val v1=E.V 1
91 :     val v2=E.V 2
92 :     val vv0=(v0,0,0)
93 :     val vv1=(v1,0,0)
94 :     val vv2=(v2,0,0)
95 :    
96 :     val t0=E.Tensor(0,[v0])
97 :     val t1=E.Tensor(0,[v1])
98 :     val t2=E.Tensor(0,[v2])
99 :    
100 :     val t01=E.Tensor(0,[v0,v1])
101 :     val t12=E.Tensor(0,[v0,v1,v2])
102 :    
103 :     val A= E.EIN{params = [], index = [],
104 :     body = E.Sum([vv0,vv1],E.Prod[t0])}
105 :    
106 :     val B= E.EIN{params = [], index = [],
107 :     body = E.Sum([vv0,vv1],E.Prod[t1])}
108 :    
109 :     val C= E.EIN{params = [], index = [],
110 :     body = E.Sum([vv0,vv1],E.Prod[t0,t1,t2])}
111 :    
112 :     val D= E.EIN{params = [], index = [],
113 :     body = E.Sum([vv0,vv1],E.Prod[t0,t01])}
114 :    
115 :     val E= E.EIN{params = [], index = [],
116 :     body = E.Sum([vv0,vv1],E.Prod[t1,t01])}
117 :    
118 :     val F= E.EIN{params = [], index = [],
119 :     body = E.Sum([vv0,vv1],E.Prod[t0,t1,t2,t01])}
120 :    
121 :     val G= E.EIN{params = [], index = [],
122 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t01,t2])}
123 :    
124 :     val H= E.EIN{params = [], index = [],
125 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t1,t01,t2])}
126 :    
127 :     val I= E.EIN{params = [], index = [],
128 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t1,t2,t01])}
129 :    
130 :    
131 :     val J= E.EIN{params = [], index = [],
132 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t01,t2,t12])}
133 :    
134 :     val K= E.EIN{params = [], index = [],
135 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t1,t01,t2,t12])}
136 :    
137 :     val L= E.EIN{params = [], index = [],
138 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t1,t2,t01,t12])}
139 :    
140 :    
141 :    
142 :     val M= E.EIN{params = [], index = [],
143 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t2,t12])}
144 :    
145 :     val N= E.EIN{params = [], index = [],
146 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t1,t2,t12])}
147 :    
148 :     val O= E.EIN{params = [], index = [],
149 :     body = E.Sum([vv0,vv1,vv2],E.Prod[t0,t1,t2,t12])}
150 :    
151 :    
152 :    
153 :    
154 :     fun Y _=List.map tester [A,B,C,D,E,F,G,H,I,J,K,L,M,N,O ]
155 :    
156 :     fun cleanSummation e=(print "pre";Y 1;cleanSummation2 e)
157 :     *)
158 :     end
159 :    
160 :    
161 :    
162 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0