SCM Repository
Annotation of /branches/charisee/src/compiler/mid-to-low/step2.sml
Parent Directory
|
Revision Log
Revision 2669 - (view) (download)
1 : | cchiw | 2615 | (*general function for scalars*) |
2 : | cchiw | 2612 | structure step2 = struct |
3 : | local | ||
4 : | structure DstIL = LowIL | ||
5 : | structure DstTy = LowILTypes | ||
6 : | structure DstOp = LowOps | ||
7 : | structure Var = LowIL.Var | ||
8 : | structure E = Ein | ||
9 : | structure S3=step3 | ||
10 : | structure genKrn=genKrn | ||
11 : | cchiw | 2627 | structure tS= toStringEin |
12 : | cchiw | 2612 | |
13 : | in | ||
14 : | |||
15 : | |||
16 : | fun insert (key, value) d =fn s => | ||
17 : | if s = key then SOME value | ||
18 : | else d s | ||
19 : | |||
20 : | fun lookup k d = d k | ||
21 : | val empty =fn key =>NONE | ||
22 : | |||
23 : | fun err _=raise Fail("Invalid Field Here") | ||
24 : | fun errS str=raise Fail(str) | ||
25 : | |||
26 : | (*Helpers for scalars*) | ||
27 : | cchiw | 2624 | |
28 : | fun mkCons(shape, rest)=let | ||
29 : | val ty=DstTy.TensorTy shape | ||
30 : | val a=DstIL.Var.new("Cons" ,ty) | ||
31 : | val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest)) | ||
32 : | val _=print("###"^tS.toStringAll(ty,code)) | ||
33 : | in (a, [code]) | ||
34 : | end | ||
35 : | |||
36 : | |||
37 : | cchiw | 2612 | val Sca=DstTy.TensorTy([]) |
38 : | fun mkProdSca rest=S3.aaV(DstOp.prodSca,rest,"prodSca",Sca) | ||
39 : | fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca) | ||
40 : | fun mkDivSca rest= S3.aaV(DstOp.divSca,rest,"divSca",Sca) | ||
41 : | fun mkMultipleSca(ids,rator)=S3.mkMultiple(ids,rator,Sca) | ||
42 : | cchiw | 2637 | fun mkInt n=S3.mkInt n |
43 : | cchiw | 2612 | |
44 : | cchiw | 2624 | |
45 : | |||
46 : | |||
47 : | cchiw | 2628 | |
48 : | cchiw | 2612 | fun prodIter(origIndex,index,nextfn,args)=(let |
49 : | val index'=List.map (fn (e)=>(e-1)) index | ||
50 : | |||
51 : | fun get(n,m,mapp)=let | ||
52 : | val mapp =insert(n, m) mapp | ||
53 : | in | ||
54 : | nextfn(mapp,args) | ||
55 : | end | ||
56 : | |||
57 : | fun Iter(mapp,[],rest,code,shape,_)=let | ||
58 : | val (vF,code')=nextfn(mapp,args) | ||
59 : | in (vF, code'@code) | ||
60 : | end | ||
61 : | | Iter(mapp,[0], rest, code,shape,n)=let | ||
62 : | val (vF,code')= get(n,0,mapp) | ||
63 : | val(vE,E)=mkCons(shape,[vF]@rest) | ||
64 : | in | ||
65 : | (vE, code'@code@E) | ||
66 : | end | ||
67 : | | Iter(mapp,[c],rest,code,shape,n)=let | ||
68 : | (*val (vF,code')= get(n,c,mapp) | ||
69 : | val (vE,E)=nextfn(mapp,args)*) | ||
70 : | val (vE,E)=get(n,c,mapp) | ||
71 : | in | ||
72 : | Iter(mapp, [c-1], [vE]@rest,E@code,shape,n) | ||
73 : | end | ||
74 : | | Iter(mapp,b::c,rest,ccode,s::shape,n)=let | ||
75 : | val n'=n+1 | ||
76 : | fun S(0, rest,code)=let | ||
77 : | val mapp =insert(n, 0) mapp | ||
78 : | val (v',code')=Iter(mapp,c,[],[],shape,n') | ||
79 : | val(vA,A)=mkCons(s::shape,[v']@rest) | ||
80 : | in | ||
81 : | (vA, code'@code@A) | ||
82 : | end | ||
83 : | | S(i, rest, code)= let | ||
84 : | val mapp =insert(n, i) mapp | ||
85 : | val (v',code')=Iter(mapp,c,[],[],shape,n') | ||
86 : | in | ||
87 : | S(i-1,[v']@rest,code'@code) | ||
88 : | end | ||
89 : | val (vA,code')=S(b, [],[]) | ||
90 : | in | ||
91 : | (vA,code'@ccode) | ||
92 : | end | ||
93 : | cchiw | 2615 | | Iter _=raise Fail"index' is larger than origIndex" |
94 : | cchiw | 2612 | in |
95 : | Iter(empty,index',[],[],origIndex,0) | ||
96 : | end) | ||
97 : | |||
98 : | (*Get constant *) | ||
99 : | fun skeleton A=(case A | ||
100 : | cchiw | 2637 | of [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 0))]=>0 |
101 : | cchiw | 2669 | | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 1))]=> 1 |
102 : | | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int ~1))]=> ~1 | ||
103 : | cchiw | 2612 | | _ => 9 |
104 : | (*end case*)) | ||
105 : | |||
106 : | (*Helper Functions for General functions*) | ||
107 : | fun findIX(v, mapp)=(case (lookup v mapp) | ||
108 : | of NONE=> errS( "Outside Bound:"^Int.toString(v)) | ||
109 : | |SOME s => s | ||
110 : | (*end case*)) | ||
111 : | |||
112 : | |||
113 : | fun NegCheckO(vA,A)=(case skeleton A | ||
114 : | of 0 => mkInt 0 | ||
115 : | | ~1 => mkInt 1 | ||
116 : | | 1 => mkInt ~1 | ||
117 : | | _=> let | ||
118 : | val (vB,B)=mkInt ~1 | ||
119 : | val (vD,D)=mkProdSca [vB,vA] | ||
120 : | in (vD,A@B@D) end | ||
121 : | (*end case*)) | ||
122 : | |||
123 : | |||
124 : | fun SubcheckO((vA,A),(vB,B))=(case((skeleton A),(skeleton B)) | ||
125 : | of (0,0)=> mkInt 0 | ||
126 : | |(0,_)=> let | ||
127 : | val (vD,D)= mkInt ~1 | ||
128 : | val (vE,E)= mkProdSca [vD,vB] | ||
129 : | in (vE,B@D@E) end | ||
130 : | | (_,0)=> (vA,A) | ||
131 : | | _ => let | ||
132 : | val (vD,D)= mkSubSca [vA,vB] | ||
133 : | in (vD, A@B@D) end | ||
134 : | (*end case*)) | ||
135 : | (* | ||
136 : | fun printMapp mapp=(case (lookup 0 mapp) | ||
137 : | of NONE=>print(String.concat["\n No zero"]) | ||
138 : | |SOME s => print(String.concat["\n Found 0 =>",Int.toString(s)]) | ||
139 : | (*end case*)) | ||
140 : | |||
141 : | *) | ||
142 : | (* val info=(params,args)*) | ||
143 : | |||
144 : | (* general expressions-removes zeros*) | ||
145 : | fun generalfn(dict,(body,origargs,info))=let | ||
146 : | val mapp=ref dict | ||
147 : | cchiw | 2613 | |
148 : | cchiw | 2612 | |
149 : | |||
150 : | fun gen body=let | ||
151 : | cchiw | 2669 | fun AddcheckO ([],[],[])=let val (vA,A)=mkInt 0 in ([vA],A) end |
152 : | cchiw | 2612 | | AddcheckO([],ids,code)=(ids,code) |
153 : | | AddcheckO(e1::es,ids,code)=let | ||
154 : | val (a,b)=gen e1 | ||
155 : | in (case (skeleton b) | ||
156 : | of 0 => AddcheckO(es,ids,code) | ||
157 : | | _ => AddcheckO(es,ids@[a],code@b) | ||
158 : | (*end case*)) | ||
159 : | end | ||
160 : | fun ProdcheckO ([],[],[])=let val (vA,A)=mkInt 1 in ([vA],A) end | ||
161 : | | ProdcheckO([],ids,code)=(ids,code) | ||
162 : | | ProdcheckO(e1::es,ids,code)=let | ||
163 : | val (a,b)=gen e1 | ||
164 : | in (case (skeleton b) | ||
165 : | of 0 => ([a],b) | ||
166 : | | 1 => ProdcheckO(es,ids,code) | ||
167 : | | _ => ProdcheckO(es,ids@[a],code@b) | ||
168 : | (*end case*)) | ||
169 : | end | ||
170 : | |||
171 : | fun Sumcheck(sumx,e)=let | ||
172 : | fun sumloop mapsum=let | ||
173 : | val _ = mapp:=mapsum | ||
174 : | val(vA,A)=gen e | ||
175 : | cchiw | 2669 | in (case (skeleton A) |
176 : | of 0 => ([],A) | ||
177 : | | _ => ([vA],A) | ||
178 : | (*end case*)) | ||
179 : | end | ||
180 : | |||
181 : | (*in ([vA],A) end*) | ||
182 : | cchiw | 2612 | fun sumI1(left,(v,0,lb1),[],rest,code)=let |
183 : | val dict=insert(v, lb1) left | ||
184 : | val (vD,pre)= sumloop dict | ||
185 : | in (vD@rest,pre@code) end | ||
186 : | | sumI1(left,(v,i,lb1),[],rest,code)=let | ||
187 : | val dict=insert(v, (i+lb1)) left | ||
188 : | val (vD,pre)=sumloop dict | ||
189 : | in sumI1(dict,( v,i-1,lb1),[],vD@rest,pre@code) end | ||
190 : | | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let | ||
191 : | val dict=insert(v, lb1) left | ||
192 : | in sumI1(dict,(a,ub-lb2,lb2),sx,rest,code) end | ||
193 : | | sumI1(left,(v,s,lb1),(E.V a,lb2,ub)::sx,rest,code)=let | ||
194 : | val dict=insert(v, (s+lb1)) left | ||
195 : | val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code) | ||
196 : | in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end | ||
197 : | cchiw | 2615 | | sumI1 _ =raise Fail"None Variable-index in summation" |
198 : | cchiw | 2612 | val (E.V v,lb,ub)=hd(sumx) |
199 : | in | ||
200 : | sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[]) | ||
201 : | end | ||
202 : | |||
203 : | cchiw | 2669 | fun iterList(e, DstOp.addSca)=(case e |
204 : | of ([],code)=>let val (vA,A)=mkInt 0 in (vA,A) end | ||
205 : | | ([id1],code) => (id1,code) | ||
206 : | | (ids,code) => let | ||
207 : | val (vB,B)= mkMultipleSca(ids,DstOp.addSca) | ||
208 : | in (vB,code@B) end | ||
209 : | (*end case*)) | ||
210 : | |||
211 : | | iterList(e,rator)= (case e | ||
212 : | cchiw | 2612 | of ([id1],code) => (id1,code) |
213 : | | (ids,code) => let | ||
214 : | val (vB,B)= mkMultipleSca(ids, rator) | ||
215 : | in (vB,code@B) end | ||
216 : | (*end case*)) | ||
217 : | cchiw | 2669 | |
218 : | cchiw | 2612 | in (case body |
219 : | of E.Field _ => err 1 | ||
220 : | | E.Partial _ => err 1 | ||
221 : | | E.Apply _ => err 1 | ||
222 : | | E.Probe _ => err 1 | ||
223 : | | E.Conv _ => err 1 | ||
224 : | | E.Krn _ => err 1 | ||
225 : | | E.Img _ => err 1 | ||
226 : | | E.Lift _ => err 1 | ||
227 : | | E.Value v => mkInt(findIX(v,!mapp)) | ||
228 : | | E.Const c => mkInt c | ||
229 : | | E.Epsilon(i,j,k) => S3.evalEps(!mapp,i,j,k) | ||
230 : | | E.Delta(i,j) => S3.evalDelta2(!mapp,i,j) | ||
231 : | cchiw | 2613 | | E.Tensor(id,ix) => S3.mkSca(!mapp,(id,ix,info)) |
232 : | cchiw | 2612 | | E.Neg e => NegCheckO(gen e) |
233 : | | E.Sub (e1,e2) => SubcheckO(gen e1,gen e2) | ||
234 : | | E.Div(e1,e2) => let | ||
235 : | val (vA,A)=gen e1 | ||
236 : | in (case (skeleton A) | ||
237 : | of 0=> mkInt 0 | ||
238 : | | _=> let | ||
239 : | val (vB,B)=gen e2 | ||
240 : | val (vD,D)= mkDivSca [vA,vB] | ||
241 : | in (vD, A@B@D) end | ||
242 : | (*end case*)) | ||
243 : | end | ||
244 : | cchiw | 2613 | | E.Add e => (iterList(AddcheckO(e,[],[]),DstOp.addSca)) |
245 : | cchiw | 2612 | | E.Prod e => iterList(ProdcheckO(e,[],[]),DstOp.prodSca) |
246 : | | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(Hid,del,pos)::es))=>let | ||
247 : | val harg=List.nth(origargs,Hid) | ||
248 : | val imgarg=List.nth(origargs,Vid) | ||
249 : | val h=S3.getKernel(harg) | ||
250 : | val v=S3.getImage(imgarg) | ||
251 : | in | ||
252 : | genKrn.evalField(!mapp,(body,v,h,info)) | ||
253 : | end | ||
254 : | |||
255 : | | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),DstOp.addSca) | ||
256 : | (*end case*)) | ||
257 : | end | ||
258 : | |||
259 : | in gen body | ||
260 : | end | ||
261 : | |||
262 : | end (* local *) | ||
263 : | |||
264 : | end |
root@smlnj-gforge.cs.uchicago.edu | ViewVC Help |
Powered by ViewVC 1.0.0 |