Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/step2.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/step2.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2843 - (view) (download)

1 : cchiw 2615 (*general function for scalars*)
2 : cchiw 2612 structure step2 = struct
3 :     local
4 :     structure DstIL = LowIL
5 :     structure DstTy = LowILTypes
6 :     structure DstOp = LowOps
7 :     structure Var = LowIL.Var
8 :     structure E = Ein
9 :     structure S3=step3
10 :     structure genKrn=genKrn
11 : cchiw 2627 structure tS= toStringEin
12 : cchiw 2612
13 :     in
14 :    
15 :    
16 :     fun insert (key, value) d =fn s =>
17 :     if s = key then SOME value
18 :     else d s
19 :    
20 :     fun lookup k d = d k
21 :     val empty =fn key =>NONE
22 :    
23 :     fun err _=raise Fail("Invalid Field Here")
24 :     fun errS str=raise Fail(str)
25 :    
26 :     (*Helpers for scalars*)
27 : cchiw 2624
28 :     fun mkCons(shape, rest)=let
29 :     val ty=DstTy.TensorTy shape
30 :     val a=DstIL.Var.new("Cons" ,ty)
31 :     val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
32 : cchiw 2838 (*val _=print("###"^tS.toStringAll(ty,code))*)
33 : cchiw 2624 in (a, [code])
34 :     end
35 :    
36 :    
37 : cchiw 2612 val Sca=DstTy.TensorTy([])
38 : cchiw 2676 val addR=DstOp.addSca
39 :    
40 :    
41 : cchiw 2789 fun mkProdSca(lhs,rest)=S3.aaV(DstOp.prodSca,rest,lhs^"prodSca",Sca)
42 :     fun mkSubSca(lhs,rest)= S3.aaV(DstOp.subSca,rest,lhs^"subSca",Sca)
43 :     fun mkDivSca(lhs,rest)= S3.aaV(DstOp.divSca,rest,lhs^"divSca",Sca)
44 :     fun mkMultipleSca(info,ids,rator)=S3.mkMultiple(info,ids,rator,Sca)
45 : cchiw 2827 fun mkReal n=S3.mkReal n
46 : cchiw 2612
47 : cchiw 2624
48 :    
49 :    
50 : cchiw 2628
51 : cchiw 2612 fun prodIter(origIndex,index,nextfn,args)=(let
52 :     val index'=List.map (fn (e)=>(e-1)) index
53 :    
54 :     fun get(n,m,mapp)=let
55 :     val mapp =insert(n, m) mapp
56 :     in
57 :     nextfn(mapp,args)
58 :     end
59 :    
60 :     fun Iter(mapp,[],rest,code,shape,_)=let
61 :     val (vF,code')=nextfn(mapp,args)
62 :     in (vF, code'@code)
63 :     end
64 :     | Iter(mapp,[0], rest, code,shape,n)=let
65 :     val (vF,code')= get(n,0,mapp)
66 :     val(vE,E)=mkCons(shape,[vF]@rest)
67 :     in
68 :     (vE, code'@code@E)
69 :     end
70 :     | Iter(mapp,[c],rest,code,shape,n)=let
71 :     (*val (vF,code')= get(n,c,mapp)
72 :     val (vE,E)=nextfn(mapp,args)*)
73 :     val (vE,E)=get(n,c,mapp)
74 :     in
75 :     Iter(mapp, [c-1], [vE]@rest,E@code,shape,n)
76 :     end
77 :     | Iter(mapp,b::c,rest,ccode,s::shape,n)=let
78 :     val n'=n+1
79 :     fun S(0, rest,code)=let
80 :     val mapp =insert(n, 0) mapp
81 :     val (v',code')=Iter(mapp,c,[],[],shape,n')
82 :     val(vA,A)=mkCons(s::shape,[v']@rest)
83 :     in
84 :     (vA, code'@code@A)
85 :     end
86 :     | S(i, rest, code)= let
87 :     val mapp =insert(n, i) mapp
88 :     val (v',code')=Iter(mapp,c,[],[],shape,n')
89 :     in
90 :     S(i-1,[v']@rest,code'@code)
91 :     end
92 :     val (vA,code')=S(b, [],[])
93 :     in
94 :     (vA,code'@ccode)
95 :     end
96 : cchiw 2615 | Iter _=raise Fail"index' is larger than origIndex"
97 : cchiw 2612 in
98 :     Iter(empty,index',[],[],origIndex,0)
99 :     end)
100 :    
101 :     (*Get constant *)
102 :     fun skeleton A=(case A
103 : cchiw 2637 of [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 0))]=>0
104 : cchiw 2669 | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int 1))]=> 1
105 :     | [DstIL.ASSGN(_,DstIL.LIT(Literal.Int ~1))]=> ~1
106 : cchiw 2612 | _ => 9
107 :     (*end case*))
108 :    
109 :     (*Helper Functions for General functions*)
110 :     fun findIX(v, mapp)=(case (lookup v mapp)
111 :     of NONE=> errS( "Outside Bound:"^Int.toString(v))
112 :     |SOME s => s
113 :     (*end case*))
114 :    
115 :    
116 : cchiw 2789 fun NegCheckO(lhs,(vA,A))=(case skeleton A
117 : cchiw 2827 of 0 => mkReal 0
118 :     | ~1 => mkReal 1
119 :     | 1 => mkReal ~1
120 : cchiw 2612 | _=> let
121 : cchiw 2827 val (vB,B)=mkReal ~1
122 : cchiw 2789 val (vD,D)=mkProdSca (lhs,[vB,vA])
123 : cchiw 2612 in (vD,A@B@D) end
124 :     (*end case*))
125 :    
126 :    
127 : cchiw 2789 fun SubcheckO(lhs,(vA,A),(vB,B))=(case((skeleton A),(skeleton B))
128 : cchiw 2827 of (0,0)=> mkReal 0
129 : cchiw 2612 |(0,_)=> let
130 : cchiw 2827 val (vD,D)= mkReal ~1
131 : cchiw 2789 val (vE,E)= mkProdSca(lhs, [vD,vB])
132 : cchiw 2612 in (vE,B@D@E) end
133 :     | (_,0)=> (vA,A)
134 :     | _ => let
135 : cchiw 2789 val (vD,D)= mkSubSca(lhs,[vA,vB])
136 : cchiw 2612 in (vD, A@B@D) end
137 :     (*end case*))
138 :    
139 :    
140 : cchiw 2838
141 : cchiw 2612 (* general expressions-removes zeros*)
142 : cchiw 2843 fun generalfn(dict,(body,origargs,info as (lhs,_,_,_)))=let
143 : cchiw 2612 val mapp=ref dict
144 : cchiw 2613
145 : cchiw 2612
146 :    
147 :     fun gen body=let
148 : cchiw 2827 fun AddcheckO ([],[],[])=let val (vA,A)=mkReal 0 in ([vA],A) end
149 : cchiw 2612 | AddcheckO([],ids,code)=(ids,code)
150 :     | AddcheckO(e1::es,ids,code)=let
151 :     val (a,b)=gen e1
152 :     in (case (skeleton b)
153 :     of 0 => AddcheckO(es,ids,code)
154 :     | _ => AddcheckO(es,ids@[a],code@b)
155 :     (*end case*))
156 :     end
157 : cchiw 2827 fun ProdcheckO ([],[],[])=let val (vA,A)=mkReal 1 in ([vA],A) end
158 : cchiw 2612 | ProdcheckO([],ids,code)=(ids,code)
159 :     | ProdcheckO(e1::es,ids,code)=let
160 :     val (a,b)=gen e1
161 :     in (case (skeleton b)
162 :     of 0 => ([a],b)
163 :     | 1 => ProdcheckO(es,ids,code)
164 :     | _ => ProdcheckO(es,ids@[a],code@b)
165 :     (*end case*))
166 :     end
167 :    
168 :     fun Sumcheck(sumx,e)=let
169 :     fun sumloop mapsum=let
170 :     val _ = mapp:=mapsum
171 :     val(vA,A)=gen e
172 : cchiw 2669 in (case (skeleton A)
173 :     of 0 => ([],A)
174 :     | _ => ([vA],A)
175 :     (*end case*))
176 :     end
177 :    
178 :     (*in ([vA],A) end*)
179 : cchiw 2612 fun sumI1(left,(v,0,lb1),[],rest,code)=let
180 :     val dict=insert(v, lb1) left
181 :     val (vD,pre)= sumloop dict
182 :     in (vD@rest,pre@code) end
183 :     | sumI1(left,(v,i,lb1),[],rest,code)=let
184 :     val dict=insert(v, (i+lb1)) left
185 :     val (vD,pre)=sumloop dict
186 :     in sumI1(dict,( v,i-1,lb1),[],vD@rest,pre@code) end
187 :     | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
188 :     val dict=insert(v, lb1) left
189 :     in sumI1(dict,(a,ub-lb2,lb2),sx,rest,code) end
190 :     | sumI1(left,(v,s,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
191 :     val dict=insert(v, (s+lb1)) left
192 :     val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code)
193 :     in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end
194 : cchiw 2615 | sumI1 _ =raise Fail"None Variable-index in summation"
195 : cchiw 2612 val (E.V v,lb,ub)=hd(sumx)
196 :     in
197 :     sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[])
198 :     end
199 :    
200 : cchiw 2669 fun iterList(e, DstOp.addSca)=(case e
201 : cchiw 2827 of ([],code)=>let val (vA,A)=mkReal 0 in (vA,A) end
202 : cchiw 2669 | ([id1],code) => (id1,code)
203 :     | (ids,code) => let
204 : cchiw 2789 val (vB,B)= mkMultipleSca(info,ids,addR)
205 : cchiw 2669 in (vB,code@B) end
206 :     (*end case*))
207 :    
208 :     | iterList(e,rator)= (case e
209 : cchiw 2612 of ([id1],code) => (id1,code)
210 :     | (ids,code) => let
211 : cchiw 2789 val (vB,B)= mkMultipleSca(info,ids, rator)
212 : cchiw 2612 in (vB,code@B) end
213 :     (*end case*))
214 : cchiw 2669
215 : cchiw 2612 in (case body
216 :     of E.Field _ => err 1
217 :     | E.Partial _ => err 1
218 :     | E.Apply _ => err 1
219 :     | E.Probe _ => err 1
220 :     | E.Conv _ => err 1
221 :     | E.Krn _ => err 1
222 :     | E.Img _ => err 1
223 :     | E.Lift _ => err 1
224 : cchiw 2827 | E.Value v => mkReal(findIX(v,!mapp))
225 :     | E.Const c => mkReal c
226 : cchiw 2612 | E.Epsilon(i,j,k) => S3.evalEps(!mapp,i,j,k)
227 : cchiw 2843 | E.Eps2(i,j) => S3.evalEps2(!mapp,i,j)
228 : cchiw 2612 | E.Delta(i,j) => S3.evalDelta2(!mapp,i,j)
229 : cchiw 2613 | E.Tensor(id,ix) => S3.mkSca(!mapp,(id,ix,info))
230 : cchiw 2789 | E.Neg e => NegCheckO(lhs,gen e)
231 :     | E.Sub (e1,e2) => SubcheckO(lhs,gen e1,gen e2)
232 : cchiw 2843 | E.Div(e1 as E.Tensor(_,[_]),e2 as E.Tensor(_,[]))=>
233 :     gen (E.Prod[E.Div(E.Const 1, e2),e1])
234 : cchiw 2612 | E.Div(e1,e2) => let
235 :     val (vA,A)=gen e1
236 :     in (case (skeleton A)
237 : cchiw 2827 of 0=> mkReal 0
238 : cchiw 2612 | _=> let
239 :     val (vB,B)=gen e2
240 : cchiw 2789 val (vD,D)= mkDivSca(lhs, [vA,vB])
241 : cchiw 2612 in (vD, A@B@D) end
242 :     (*end case*))
243 :     end
244 : cchiw 2676 | E.Add e => (iterList(AddcheckO(e,[],[]),addR))
245 : cchiw 2612 | E.Prod e => iterList(ProdcheckO(e,[],[]),DstOp.prodSca)
246 : cchiw 2829 | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(Hid,del,pos)::es))=>let
247 : cchiw 2612 val harg=List.nth(origargs,Hid)
248 :     val imgarg=List.nth(origargs,Vid)
249 : cchiw 2829 val h=S3.getKernel harg
250 :     val v=S3.getImageSrc imgarg
251 : cchiw 2843 val(_,_,_,args)=info
252 : cchiw 2829 val imgargnew=List.nth(args,Vid)
253 : cchiw 2612 in
254 : cchiw 2829 genKrn.evalField(!mapp,(body,(v,imgargnew),h,info))
255 : cchiw 2612 end
256 : cchiw 2829
257 : cchiw 2676 | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),addR)
258 : cchiw 2612 (*end case*))
259 :     end
260 :    
261 :     in gen body
262 :     end
263 :    
264 :     end (* local *)
265 :    
266 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0