Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/step2.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/step2.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2627 - (view) (download)

1 : cchiw 2615 (*general function for scalars*)
2 : cchiw 2612 structure step2 = struct
3 :     local
4 :     structure DstIL = LowIL
5 :     structure DstTy = LowILTypes
6 :     structure DstOp = LowOps
7 :     structure Var = LowIL.Var
8 :     structure E = Ein
9 :     structure S3=step3
10 :     structure genKrn=genKrn
11 : cchiw 2627 structure tS= toStringEin
12 : cchiw 2612
13 :     in
14 :    
15 :    
16 :     fun insert (key, value) d =fn s =>
17 :     if s = key then SOME value
18 :     else d s
19 :    
20 :     fun lookup k d = d k
21 :     val empty =fn key =>NONE
22 :    
23 :     fun err _=raise Fail("Invalid Field Here")
24 :     fun errS str=raise Fail(str)
25 :    
26 :     (*Helpers for scalars*)
27 : cchiw 2624
28 :     fun mkCons(shape, rest)=let
29 :     val ty=DstTy.TensorTy shape
30 :     val a=DstIL.Var.new("Cons" ,ty)
31 :     val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
32 :     val _=print("###"^tS.toStringAll(ty,code))
33 :     in (a, [code])
34 :     end
35 :    
36 :    
37 : cchiw 2612 val Sca=DstTy.TensorTy([])
38 :     fun mkProdSca rest=S3.aaV(DstOp.prodSca,rest,"prodSca",Sca)
39 :     fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca)
40 :     fun mkDivSca rest= S3.aaV(DstOp.divSca,rest,"divSca",Sca)
41 :     fun mkMultipleSca(ids,rator)=S3.mkMultiple(ids,rator,Sca)
42 : cchiw 2620 fun mkInt n= S3.aaV(DstOp.C n,[],"Int",Sca)
43 : cchiw 2612
44 : cchiw 2624 (*
45 :    
46 :     fun mkInt n=let
47 :     val a=DstIL.Var.new("Int" ,Sca)
48 :     val code=DstIL.ASSGN (a,DstIL.LIT(Literal.Int n))
49 :     in
50 :     (a,[code])
51 :     end
52 :     *)
53 :    
54 :    
55 : cchiw 2612 fun prodIter(origIndex,index,nextfn,args)=(let
56 :     val index'=List.map (fn (e)=>(e-1)) index
57 :    
58 :     fun get(n,m,mapp)=let
59 :     val mapp =insert(n, m) mapp
60 :     in
61 :     nextfn(mapp,args)
62 :     end
63 :    
64 :     fun Iter(mapp,[],rest,code,shape,_)=let
65 :     val (vF,code')=nextfn(mapp,args)
66 :     in (vF, code'@code)
67 :     end
68 :     | Iter(mapp,[0], rest, code,shape,n)=let
69 :     val (vF,code')= get(n,0,mapp)
70 :     val(vE,E)=mkCons(shape,[vF]@rest)
71 :     in
72 :     (vE, code'@code@E)
73 :     end
74 :     | Iter(mapp,[c],rest,code,shape,n)=let
75 :     (*val (vF,code')= get(n,c,mapp)
76 :     val (vE,E)=nextfn(mapp,args)*)
77 :     val (vE,E)=get(n,c,mapp)
78 :     in
79 :     Iter(mapp, [c-1], [vE]@rest,E@code,shape,n)
80 :     end
81 :     | Iter(mapp,b::c,rest,ccode,s::shape,n)=let
82 :     val n'=n+1
83 :     fun S(0, rest,code)=let
84 :     val mapp =insert(n, 0) mapp
85 :     val (v',code')=Iter(mapp,c,[],[],shape,n')
86 :     val(vA,A)=mkCons(s::shape,[v']@rest)
87 :     in
88 :     (vA, code'@code@A)
89 :     end
90 :     | S(i, rest, code)= let
91 :     val mapp =insert(n, i) mapp
92 :     val (v',code')=Iter(mapp,c,[],[],shape,n')
93 :     in
94 :     S(i-1,[v']@rest,code'@code)
95 :     end
96 :     val (vA,code')=S(b, [],[])
97 :     in
98 :     (vA,code'@ccode)
99 :     end
100 : cchiw 2615 | Iter _=raise Fail"index' is larger than origIndex"
101 : cchiw 2612 in
102 :     Iter(empty,index',[],[],origIndex,0)
103 :     end)
104 :    
105 :     (*Get constant *)
106 :     fun skeleton A=(case A
107 :     of [DstIL.ASSGN(_,DstIL.OP(DstOp.C c,_))]=>c
108 :     | _ => 9
109 :     (*end case*))
110 :    
111 :     (*Helper Functions for General functions*)
112 :     fun findIX(v, mapp)=(case (lookup v mapp)
113 :     of NONE=> errS( "Outside Bound:"^Int.toString(v))
114 :     |SOME s => s
115 :     (*end case*))
116 :    
117 :    
118 :     fun NegCheckO(vA,A)=(case skeleton A
119 :     of 0 => mkInt 0
120 :     | ~1 => mkInt 1
121 :     | 1 => mkInt ~1
122 :     | _=> let
123 :     val (vB,B)=mkInt ~1
124 :     val (vD,D)=mkProdSca [vB,vA]
125 :     in (vD,A@B@D) end
126 :     (*end case*))
127 :    
128 :    
129 :     fun SubcheckO((vA,A),(vB,B))=(case((skeleton A),(skeleton B))
130 :     of (0,0)=> mkInt 0
131 :     |(0,_)=> let
132 :     val (vD,D)= mkInt ~1
133 :     val (vE,E)= mkProdSca [vD,vB]
134 :     in (vE,B@D@E) end
135 :     | (_,0)=> (vA,A)
136 :     | _ => let
137 :     val (vD,D)= mkSubSca [vA,vB]
138 :     in (vD, A@B@D) end
139 :     (*end case*))
140 :     (*
141 :     fun printMapp mapp=(case (lookup 0 mapp)
142 :     of NONE=>print(String.concat["\n No zero"])
143 :     |SOME s => print(String.concat["\n Found 0 =>",Int.toString(s)])
144 :     (*end case*))
145 :    
146 :     *)
147 :     (* val info=(params,args)*)
148 :    
149 :     (* general expressions-removes zeros*)
150 :     fun generalfn(dict,(body,origargs,info))=let
151 :     val mapp=ref dict
152 : cchiw 2613
153 : cchiw 2612
154 :    
155 :     fun gen body=let
156 :     fun AddcheckO ([],[],[])=let val (vA,A)=mkInt 1 in ([vA],A) end
157 :     | AddcheckO([],ids,code)=(ids,code)
158 :     | AddcheckO(e1::es,ids,code)=let
159 :     val (a,b)=gen e1
160 :     in (case (skeleton b)
161 :     of 0 => AddcheckO(es,ids,code)
162 :     | _ => AddcheckO(es,ids@[a],code@b)
163 :     (*end case*))
164 :     end
165 :     fun ProdcheckO ([],[],[])=let val (vA,A)=mkInt 1 in ([vA],A) end
166 :     | ProdcheckO([],ids,code)=(ids,code)
167 :     | ProdcheckO(e1::es,ids,code)=let
168 :     val (a,b)=gen e1
169 :     in (case (skeleton b)
170 :     of 0 => ([a],b)
171 :     | 1 => ProdcheckO(es,ids,code)
172 :     | _ => ProdcheckO(es,ids@[a],code@b)
173 :     (*end case*))
174 :     end
175 :    
176 :     fun Sumcheck(sumx,e)=let
177 :     fun sumloop mapsum=let
178 :     val _ = mapp:=mapsum
179 :     val(vA,A)=gen e
180 :     in ([vA],A) end
181 :     fun sumI1(left,(v,0,lb1),[],rest,code)=let
182 :     val dict=insert(v, lb1) left
183 :     val (vD,pre)= sumloop dict
184 :     in (vD@rest,pre@code) end
185 :     | sumI1(left,(v,i,lb1),[],rest,code)=let
186 :     val dict=insert(v, (i+lb1)) left
187 :     val (vD,pre)=sumloop dict
188 :     in sumI1(dict,( v,i-1,lb1),[],vD@rest,pre@code) end
189 :     | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
190 :     val dict=insert(v, lb1) left
191 :     in sumI1(dict,(a,ub-lb2,lb2),sx,rest,code) end
192 :     | sumI1(left,(v,s,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
193 :     val dict=insert(v, (s+lb1)) left
194 :     val (rest',code')=sumI1(dict,(a,ub-lb2,lb2),sx,rest,code)
195 :     in sumI1(dict,(v,s-1,lb1),(E.V a,lb2,ub)::sx,rest',code') end
196 : cchiw 2615 | sumI1 _ =raise Fail"None Variable-index in summation"
197 : cchiw 2612 val (E.V v,lb,ub)=hd(sumx)
198 :     in
199 :     sumI1(!mapp,(v,ub-lb,lb),tl(sumx),[],[])
200 :     end
201 :    
202 :     fun iterList(e,rator)= (case e
203 :     of ([id1],code) => (id1,code)
204 :     | (ids,code) => let
205 :     val (vB,B)= mkMultipleSca(ids, rator)
206 :     in (vB,code@B) end
207 :     (*end case*))
208 :     in (case body
209 :     of E.Field _ => err 1
210 :     | E.Partial _ => err 1
211 :     | E.Apply _ => err 1
212 :     | E.Probe _ => err 1
213 :     | E.Conv _ => err 1
214 :     | E.Krn _ => err 1
215 :     | E.Img _ => err 1
216 :     | E.Lift _ => err 1
217 :     | E.Value v => mkInt(findIX(v,!mapp))
218 :     | E.Const c => mkInt c
219 :     | E.Epsilon(i,j,k) => S3.evalEps(!mapp,i,j,k)
220 :     | E.Delta(i,j) => S3.evalDelta2(!mapp,i,j)
221 : cchiw 2613 | E.Tensor(id,ix) => S3.mkSca(!mapp,(id,ix,info))
222 : cchiw 2612 | E.Neg e => NegCheckO(gen e)
223 :     | E.Sub (e1,e2) => SubcheckO(gen e1,gen e2)
224 :     | E.Div(e1,e2) => let
225 :     val (vA,A)=gen e1
226 :     in (case (skeleton A)
227 :     of 0=> mkInt 0
228 :     | _=> let
229 :     val (vB,B)=gen e2
230 :     val (vD,D)= mkDivSca [vA,vB]
231 :     in (vD, A@B@D) end
232 :     (*end case*))
233 :     end
234 : cchiw 2613 | E.Add e => (iterList(AddcheckO(e,[],[]),DstOp.addSca))
235 : cchiw 2612 | E.Prod e => iterList(ProdcheckO(e,[],[]),DstOp.prodSca)
236 :     | E.Sum(sx,E.Prod(E.Img (Vid,_,_)::E.Krn(Hid,del,pos)::es))=>let
237 :     val harg=List.nth(origargs,Hid)
238 :     val imgarg=List.nth(origargs,Vid)
239 :     val h=S3.getKernel(harg)
240 :     val v=S3.getImage(imgarg)
241 :     in
242 :     genKrn.evalField(!mapp,(body,v,h,info))
243 :     end
244 :    
245 :     | E.Sum(sumx, e)=>iterList(Sumcheck(sumx,e),DstOp.addSca)
246 :     (*end case*))
247 :     end
248 :    
249 :     in gen body
250 :     end
251 :    
252 :     end (* local *)
253 :    
254 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0