Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/ein/type-ein.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/ein/type-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2498 - (view) (download)

1 : cchiw 2498 structure TypeEin = struct
2 :    
3 :     local
4 :    
5 :     structure E = Ein
6 :     structure P = Printer
7 :    
8 :     in
9 :    
10 :    
11 :    
12 :     datatype greek_type=epsTy|deltaTy
13 :    
14 :     datatype ein_type = realTy
15 :     | G of E.mu list
16 :     | ten of E.mu list
17 :     | fld of E.mu list
18 :     | fldmid of E.mu list
19 :     | imageTy of E.mu list
20 :     | kernelTy of E.mu list
21 :     | partialTy of E.mu list
22 :     | errTy
23 :    
24 :     (*can only be C type*)
25 :     fun position([realTy])=realTy
26 :     | position((realTy)::es)= position(es)
27 :     | position _= errTy
28 :    
29 :    
30 :     fun err(msg)= (errTy)
31 :    
32 :    
33 :     fun printIndex([])=""
34 :     | printIndex(E.C x::ix)= String.concat[Int.toString(x),printIndex(ix)]
35 :     | printIndex(E.V v::ix)= String.concat[Int.toString(v),printIndex(ix)]
36 :    
37 :     fun printTy(ty)= (case ty of realTy=> print "realTy"
38 :     | G _=> print "Greek Type"
39 :     | ten ix=> print (String.concat["TEN-" ,printIndex(ix)])
40 :     | fld ix=> print (String.concat["FLD-HIGH-" ,printIndex(ix)])
41 :     | fldmid ix=> print (String.concat["FLD-MID-" ,printIndex(ix)])
42 :     | imageTy ix=> print (String.concat["Image-" ,printIndex(ix)])
43 :     | kernelTy ix=> print (String.concat["KRN-" ,printIndex(ix)])
44 :     | partialTy ix=> print (String.concat["Partial-" ,printIndex(ix)])
45 :     | errTy=> print "err")
46 :    
47 :    
48 :    
49 :     (*single index no duplicates*)
50 :     fun sortIndex(ilist)=let
51 :     fun sort([],ix)=ix
52 :     | sort (E.C _::es,ix)=sort(es,ix)
53 :     | sort(e::es,ix)=let
54 :     val r=List.find (fn(x) => x=e) ix
55 :     in (case r
56 :     of NONE=>sort(es, ix@[e])
57 :     |_=> sort(es, ix)
58 :     (*end case*))
59 :     end
60 :     in sort(ilist,[]) end
61 :    
62 :    
63 :     (*compares one list of indices to another*)
64 :     fun sortFldIndex(part,ilist)=let
65 :     fun sort([],ix)=ix
66 :     | sort (E.C _::es,ix)=sort(es,ix)
67 :     | sort(e::es,ix)=let
68 :     val r=List.find (fn(x) => x=e) ix
69 :     in (case r
70 :     of NONE=>sort(es, ix@[e])
71 :     |_=> sort(es, ix)
72 :     (*end case*))
73 :     end
74 :     in sort(part,ilist) end
75 :    
76 :    
77 :     fun removeSumIndex(ilist,sum)=let
78 :     fun sort([],ix,rest)=rest
79 :     | sort (E.C _::es,ix,rest)=sort(es,ix,rest)
80 :     | sort(e::es,ix,rest)=let
81 :     val r=List.find (fn(x) => x=e) ix
82 :     in (case r
83 :     of NONE=>sort(es, ix,rest@[e])
84 :     |_=> sort(es, ix,rest)
85 :     (*end case*))
86 :     end
87 :     in sort(ilist,sum,[]) end
88 :    
89 :    
90 :    
91 :     fun evalAdd [fld f]= fld f
92 :     | evalAdd [ten t]= ten t
93 :     | evalAdd [realTy]= realTy
94 :     | evalAdd [imageTy i]= imageTy i
95 :     | evalAdd(realTyy::realTy::el)=evalAdd(realTy::el)
96 :     | evalAdd(fld F1:: fld F2::el)=evalAdd(fld F1::el)
97 :     (*if (F1=F2) then evalAdd(fld F1::el)
98 :     else errTy*)
99 :     | evalAdd(ten T1::ten T2::el)=evalAdd(ten T1::el)
100 :     (* if(T1=T2) then evalAdd(ten T1::el)
101 :     else errTy*)
102 :     | evalAdd(imageTy i::imageTy _::el)=evalAdd(imageTy i::el)
103 :     | evalAdd _= errTy
104 :    
105 :     fun evalProd([])=errTy
106 :     | evalProd(errTy::el)= errTy
107 :     | evalProd([e1])=e1
108 :     | evalProd(realTy::el)=evalProd el
109 :     | evalProd(G g::ten t::es)= evalProd (ten(t@g)::es)
110 :     | evalProd(G g::fld t::es)= evalProd (fld(t@g)::es)
111 :     (*Fields..*)
112 :     | evalProd [fld f,realTy] =fld f
113 :     | evalProd [fld f,G _] =fld f
114 :     | evalProd [fld t, partialTy p]= ten(sortIndex(t@p))
115 :     | evalProd([fld f,_])= err "can not multiply field and other type "
116 :     (* Tensors *)
117 :     | evalProd [ten t ,realTy]= ten t
118 :     | evalProd [ten t,G _]= ten t
119 :     | evalProd [ten t,ten T2]=ten(sortIndex(t@T2))
120 :     | evalProd [ten t, partialTy p]= ten(sortIndex(t@p))
121 :     (*kernels*)
122 :     | evalProd [kernelTy k,realTy]= kernelTy k
123 :     | evalProd [kernelTy k,kernelTy _]= kernelTy k
124 :     | evalProd [kernelTy k,imageTy i]= fldmid(i@k)
125 :     (*Partials*)
126 :     | evalProd [partialTy p,realTy]= partialTy p
127 :     | evalProd [partialTy p,G _]= partialTy p
128 :     | evalProd [partialTy p,ten T]= ten(sortIndex(T@p))
129 :     | evalProd [partialTy p,fld T]= ten(sortIndex(T@p))
130 :     | evalProd [partialTy p,partialTy p2]= partialTy(p@p2)
131 :     (*Image *)
132 :     | evalProd [imageTy i,realTy]= imageTy i
133 :     | evalProd [imageTy i ,G _]= imageTy i
134 :     | evalProd [imageTy i ,imageTy i2]= imageTy(i@i2)
135 :     | evalProd [imageTy i, kernelTy k]= fldmid(i@k)
136 :     | evalProd [a,b]= errTy
137 :     | evalProd (e1::es)= evalProd [e1,evalProd(es)]
138 :    
139 :     fun evalSub(a,b)=(case (a,b)
140 :     of (realTy, realTy)=>realTy
141 :     | (ten T1, ten T2)=>ten T1
142 :     |(fld F1, fld F2)=>fld F1
143 :     | (imageTy i, imageTy _)=> imageTy i
144 :     | (fldmid f, fldmid _)=> fldmid f
145 :     | (fldmid f, realTy)=> fldmid f
146 :     | (realTy, fldmid f)=> fldmid f
147 :     | (fld f, realTy) => fld f
148 :     | (realTy, fld f) => fld f
149 :     | (imageTy i, realTy) => imageTy i
150 :     | (realTy, imageTy i) => imageTy i
151 :     |_=>errTy)
152 :    
153 :     fun evalDiv(a,b)=(case (a,b)
154 :     of(realTy ,realTy)=>realTy
155 :     | (fld f, realTy)=> fld f
156 :     | (ten t, realTy)=>ten t
157 :     | _=>errTy)
158 :    
159 :     fun evalProbe(a,b,phase)=if (phase>1) then err "wrong phase for Probe op"
160 :     else (case (a,b)
161 :     of (fld f,ten _)=>fld f
162 :     | (fld f, realTy)=>fld f
163 :     | (fld f, _)=> err "wrong pos for field probe"
164 :     | _=>err "Not a fieldTy in probe"
165 :     (*end case*))
166 :    
167 :     fun evalKrn(dels,phase)= if (3>phase) then err "wrong phrase for kernel"
168 :     else let
169 :     fun size([])=[]
170 :     | size((i ,j)::dels)= [j]@ size(dels)
171 :     in kernelTy(size(dels)) end
172 :    
173 :     fun evalApply(e1,e2,phase)=
174 :     if (phase>1) then err "wrong phase for apply"
175 :     else (case (e1,e2)
176 :     of (partialTy a, fld b)=>fld(a@b)
177 :     |_ =>errTy
178 :     (*end case*))
179 :    
180 :    
181 :     fun evalSum(sx,m)=(case m
182 :     of ten ix=>(
183 :     let val ix'=removeSumIndex(ix,sx)
184 :     in (case ix' of [] => realTy
185 :     |_ =>ten ix')end
186 :     (*end case*))
187 :     | fld ix=>(
188 :     let val ix'=removeSumIndex(ix,sx)
189 :     in (case ix' of [] => realTy
190 :     |_ =>fld ix')end
191 :     (*end case*))
192 :     | fldmid ix =>m
193 :     | realTy=>realTy
194 :     |_=> errTy
195 :     (*end case*))
196 :    
197 :    
198 :     fun checkTenParam(id,params, ix)=
199 :     if(id>length(params))then (print "in here";errTy)
200 :     else(let
201 :     val p=List.nth(params,id)
202 :     in (case p
203 :     of E.TEN => let
204 :     val m = (sortIndex(ix))
205 :     in(case m
206 :     of []=> realTy
207 :     | _=> ten m)
208 :     end
209 :     | _ =>errTy)
210 :     end)
211 :    
212 :     fun checkFldParam(id, params,ix)=
213 :     if(id>length(params))then (print "in here";errTy)
214 :     else(
215 :     let
216 :     val p=List.nth(params,id)
217 :     in (case p
218 :     of E.FLD _ =>fld ix
219 :     |_=> errTy)
220 :     end)
221 :    
222 :    
223 :     fun checker (Ein.EIN{params, index, body},phase) = let
224 :    
225 :    
226 :    
227 :    
228 :     (*make sure the finished expression is the shape it is suppose to be *)
229 :     fun checkbody term = (case term
230 :     of E.Const r => realTy
231 :     | E.Tensor(id, ix) =>checkTenParam(id,params,ix)
232 :     | E.Delta(i,j)=>G [i,j]
233 :     | E.Value(ix) =>realTy
234 :     | E.Epsilon(i,j,k) => G [E.V i,E.V j,E.V k]
235 :     | E.Sum (sx,e1)=> evalSum(sx,checkbody e1)
236 :    
237 :    
238 :     | E.Neg e1=> checkbody e1
239 :     | E.Partial a => partialTy(sortIndex(a))
240 :     | E.Add es => evalAdd(List.map checkbody es)
241 :     | E.Sub(e1, e2) =>evalSub(checkbody e1 ,checkbody e2)
242 :     | E.Prod el => evalProd(List.map checkbody el)
243 :     | E.Div(e1,e2)=> evalDiv (checkbody e1,checkbody e2)
244 :     (*Phase dependent operators*)
245 :     | E.Field(id, alpha)=>
246 :     if (phase>1) then err "wrong phase for Field"
247 :     else checkFldParam(id,params,alpha)
248 :     | E.Apply(e1, e2)=> evalApply(checkbody e1,checkbody e2,phase)
249 :     | E.Probe (e1,e2)=>evalProbe(checkbody e1, checkbody e2,phase)
250 :     | E.Conv (fid,alpha, tid, beta)=>
251 :     if (phase>1) then err "wrong phase for convolution"
252 :     else (case checkFldParam(fid, params,alpha)
253 :     of fld f=> fld(sortFldIndex(beta,f))
254 :     |_=> errTy
255 :     (*end case*))
256 :    
257 :     (*Phase 2 Mid-IL*)
258 :     | E.Krn (_,dels,_) =>evalKrn(dels,phase)
259 :     | E.Img(id,ix,pos)=>
260 :     if(3 >phase) then errTy
261 :     else (case position(List.map checkbody pos)
262 :     of errTy=> err "Not an image position"
263 :     | _ =>(case checkFldParam(id, params,[])
264 :     of errTy=> errTy
265 :     |_=>imageTy(sortIndex(ix))
266 :     (*end case*))
267 :     (*end case*))
268 :     (* end case *))
269 :     in
270 :     checkbody body
271 :     end
272 :    
273 :    
274 :     end; (* local *)
275 :    
276 :     end (* local *)

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0