Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/genH.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/genH.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 2611 - (view) (download)

1 : cchiw 2555 (*hashs Ein Function after substitution*)
2 :     structure gHelper = struct
3 :     local
4 :     structure E = Ein
5 :    
6 :     (* structure genKrn=genKrn*)
7 :    
8 :    
9 :    
10 :     structure DstIL = LowIL
11 :     structure DstTy = LowILTypes
12 :     structure DstOp = LowOps
13 :     structure Var = LowIL.Var
14 :    
15 :     structure SrcIL = MidIL
16 :     structure SrcOp = MidOps
17 :     structure SrcSV = SrcIL.StateVar
18 :     structure SrcTy = MidILTypes
19 :     structure VTbl = SrcIL.Var.Tbl
20 :    
21 :     in
22 :    
23 : cchiw 2608 val testing=0
24 : cchiw 2555
25 :    
26 :     fun insert (key, value) d =fn s =>
27 : cchiw 2608 if s = key then (print(String.concat[Int.toString(key),"=>",Int.toString(value)]);SOME value)
28 : cchiw 2555 else d s
29 :    
30 :     fun lookup k d = d k
31 :     val empty =fn key =>NONE
32 :     fun findDup(list1,list2)=let
33 :     fun current []=NONE
34 :     | current(v::vs)=let
35 :     val m=List.find (fn x => x =v) list2
36 :     in (case m
37 :     of NONE =>current(vs)
38 :     |_=> m
39 :     (*end case*))
40 :     end
41 :     in current list1
42 :     end
43 :    
44 :    
45 :    
46 :     val bV= ref 0
47 :    
48 : cchiw 2605 fun toStringRHS x = (case DstIL.Var.binding x
49 : cchiw 2555 of vb => String.concat[
50 :     "\n Found ", DstIL.vbToString vb,"\n"]
51 :     (* end case *))
52 :    
53 :    
54 :    
55 :    
56 :     fun getKernel x = (case SrcIL.Var.binding x
57 :     of SrcIL.VB_RHS(SrcIL.OP(SrcOp.Kernel(h, _),_))=> h
58 :     | vb => (raise Fail (String.concat["\n -- Not a kernel, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
59 :     (* end case *))
60 :    
61 :    
62 :     fun getImage x = (case SrcIL.Var.binding x
63 :     of SrcIL.VB_RHS(SrcIL.OP(SrcOp.LoadImage(img),_))=> img
64 :     | vb => (raise Fail (String.concat["\n -- Not an image, ", SrcIL.Var.toString x," found ", SrcIL.vbToString vb,"\n"]))
65 :     (* end case *))
66 :    
67 :    
68 :    
69 :    
70 :    
71 : cchiw 2605 fun toStringAssgn(DstIL.ASSGN (x, DstIL.OP(opss,args)))=let
72 :     val A=[(Var.toString x),"==",DstOp.toString opss," : ",(String.concatWith "," (List.map Var.toString args))]
73 :     in String.concat A
74 :     end
75 :     | toStringAssgn(DstIL.ASSGN(x,DstIL.LIT _))= String.concat[Var.toString x,"==...Lit"]
76 :     | toStringAssgn(DstIL.ASSGN(x,DstIL.CONS (_, varl)))= let
77 : cchiw 2555 val y= List.map (fn e1=> Var.toString e1) varl
78 : cchiw 2605 in String.concat[(Var.toString x),"==",(String.concatWith "," y)] end
79 :     | toStringAssgn(DstIL.ASSGN (x, _))=String.concat[Var.toString x,"==","CONS",toStringRHS x]
80 : cchiw 2555
81 :    
82 :    
83 : cchiw 2605 fun toStringTy(DstTy.IntTy )= "int "
84 :     | toStringTy(DstTy.TensorTy [])= "Real "
85 :     | toStringTy(DstTy.TensorTy(dd))=String.concat[
86 : cchiw 2555 "TensorTy[", String.concatWith "," (List.map Int.toString dd), "] "]
87 :    
88 :    
89 :     fun aaV(opss,args,pre,ty)=let
90 :     val a=DstIL.Var.new(pre ,ty)
91 :     val code=DstIL.ASSGN (a,DstIL.OP(opss,args))
92 : cchiw 2605 val _ =(case testing
93 :     of 0=> 1
94 : cchiw 2608 | _ => (print(String.concat([toStringTy ty,"\n", toStringAssgn code] ));1)
95 : cchiw 2605 (* end case *))
96 : cchiw 2555 in
97 :     (a,[code])
98 :     end
99 :    
100 :    
101 :    
102 :    
103 :     fun mkMultiple(list1,rator,ty)=let
104 :     fun add([],_)=raise Fail "no element in addM"
105 :     | add([e1],_)=(e1,[])
106 :     | add([e1,e2],code)=let
107 :     val (vA,A)=aaV(rator,[e1,e2],"MO",ty)
108 :     in (vA,code@A)
109 :     end
110 :     | add(e1::e2::es,code)=let
111 :     val (vA,A)=aaV(rator,[e1,e2],"MO",ty)
112 :     in add(vA::es,code@A)
113 :     end
114 :     in add(list1,[])
115 :     end
116 :    
117 :    
118 :    
119 :     fun mapIndex(e1,mapp)=(case e1
120 :     of E.V e =>let
121 :     val a=lookup e mapp
122 :     in (case a of NONE=> raise Fail "Outside Bound"
123 :     |SOME s => s) end
124 :     | E.C c=> c
125 :     (*end case*))
126 :    
127 : cchiw 2605 (*
128 : cchiw 2555 fun printIndexXX(n,mapp)=let
129 :     val a=lookup n mapp
130 :     in (case a
131 :     of NONE=> print("-\n")
132 : cchiw 2603 |SOME (s) => ((*print(String.concat[Int.toString(n), "==>",Int.toString(s)]);*)printIndexXX(n+1,mapp))
133 : cchiw 2555 (*end case*))
134 :     end
135 : cchiw 2605 *)
136 : cchiw 2555
137 :    
138 :     fun getShape(params, id)=(case List.nth(params,id)
139 :     of E.TEN(3,[shape])=> DstTy.iVecTy(shape) (*FIX HERE*)
140 :     | E.TEN(_,shape)=> DstTy.TensorTy shape
141 :     |_=> raise Fail "NONE Tensor Param")
142 :    
143 :    
144 :    
145 :     fun mkSca(mapp,(id,ix1,(args,params)))= let
146 : cchiw 2611
147 : cchiw 2555 val ix1'=List.map (fn (e1)=> mapIndex(e1,mapp)) ix1
148 : cchiw 2611
149 : cchiw 2555 val nU=List.nth(args,id)
150 : cchiw 2611
151 : cchiw 2555 val i=DstTy.indexTy(ix1')
152 : cchiw 2611
153 : cchiw 2555 val a=getShape(params,id)
154 : cchiw 2611
155 : cchiw 2555 in aaV(DstOp.S(id, i,a),[nU],"S"^Int.toString(id),DstTy.TensorTy([]))
156 :     end
157 :    
158 :    
159 :     fun mkVec(mapp,(id,ix1,last,(args,params)))= let
160 : cchiw 2611
161 : cchiw 2555 val ix1'=List.map (fn (e1)=> mapIndex(e1,mapp)) ix1
162 :     val nU=List.nth(args,id)
163 :     val i=DstTy.indexTy(ix1')
164 :     val a=getShape(params,id)
165 :     in aaV(DstOp.V(id, last, i,a),[nU],"V"^Int.toString(id),DstTy.TensorTy([last])) end
166 :    
167 :     (*Helper functions for addition *)
168 :     fun handleAddVec(mapp,(es,index,last,args))=let
169 : cchiw 2605
170 : cchiw 2555 fun add([],rest,code)=(rest,code)
171 :     (* | add((id1,[])::es,rest,code)=let
172 :     val (vA,A)= mkVec(mapp,(id1,index,args))
173 :     in add(es,rest@[vA],code@A)
174 :     end*)
175 :    
176 :     | add((id1,ix1)::es,rest,code)=let
177 :     val (vA,A)= mkVec(mapp,(id1,ix1,last,args))
178 :     in add(es,rest@[vA],code@A)
179 :     end
180 :    
181 :     val (rest,code)=add(es,[],[])
182 :     val (vA,A)=mkMultiple( rest,DstOp.addVec(last),DstTy.TensorTy([last]))
183 :     in (vA,code@A)
184 :     end
185 :    
186 :    
187 :    
188 :     (*Subtract SCalars*)
189 :     fun mksubSca(mapp,([(id1,ix1),(id2,ix2)],[],args))= let
190 :     val (vA,A)=mkSca(mapp,(id1,ix1,args))
191 :     val (vB, B)=mkSca(mapp,(id2, ix2,args))
192 :     val (vD, D)=aaV(DstOp.subSca,[vA, vB],"SubSca",DstTy.TensorTy([]))
193 :     in (vD, A@B@D)end
194 :    
195 :    
196 :     (*subtract Vectors*)
197 :     fun mksubVec(mapp,([(id1,ix1),(id2,ix2)],[],last,args))= let
198 : cchiw 2605
199 : cchiw 2555 val (vA,A)= mkVec(mapp,(id1,ix1,last,args))
200 :     val (vB, B)= mkVec(mapp,(id2,ix2,last,args))
201 :     val (vD, D)=aaV(DstOp.subVec(last),[vA, vB],"subVec",DstTy.TensorTy([last]))
202 :     in (vD, A@B@D) end
203 :    
204 :    
205 :    
206 :     (*Product functions*)
207 :     (*product of 2 scalars*)
208 :     fun mkprodSca(mapp,([(id1,ix1),(id2,ix2)],[],args))= let
209 :     val (vA,A)=mkSca(mapp,(id1,ix1,args))
210 :     val (vB, B)=mkSca(mapp,(id2, ix2,args))
211 :     val (vD, D)=aaV(DstOp.prodSca,[vA, vB],"prodSca",DstTy.TensorTy([]))
212 :     in (vD, A@B@D)end
213 :     | mkprodSca _= raise Fail "Prod----d---"
214 :    
215 : cchiw 2584 (*
216 :     (*product of 2 scalars*)
217 :     fun mkprodScaR(_,([(id1,ix1),(id2,ix2)],[],args))= let
218 :     val (vA,A)=mkSca(mapp,(id1,ix1,args))
219 :     val (vB, B)=mkSca(mapp,(id2, ix2,args))
220 :     aaV(DstOp.S(id, i,a),[nU],"S"^Int.toString(id),DstTy.TensorTy([]))
221 : cchiw 2555
222 : cchiw 2584 val (vD, D)=aaV(DstOp.prodSca,[vA, vB],"prodSca",DstTy.TensorTy([]))
223 :     in (vD, A@B@D)end
224 :     | mkprodSca _= raise Fail "Prod----d---"
225 :     *)
226 :    
227 : cchiw 2555 (*product of 1 scalars and 1 projection*)
228 :     fun mkprodScaV(mapp,([(id1,ix1),(id2,ix2)],[],last,args))=let
229 : cchiw 2605
230 : cchiw 2555 val (vA,A)=mkSca(mapp,(id1,ix1,args))
231 :     val (vB, B)= mkVec(mapp,(id2,ix2,last,args))
232 :    
233 :     val (vD, D)=aaV(DstOp.prodScaV(last),[vA, vB],"prodScaV",DstTy.TensorTy([last]))
234 :     in (vD,A@B@D) end
235 :    
236 :     (*product of 2 projections*)
237 :     fun mkprodVec(mapp,([(id1,ix1),(id2,ix2)],[],last,args))= let
238 : cchiw 2605
239 : cchiw 2555 val (vA,A)= mkVec(mapp,(id1,ix1,last,args))
240 :     val (vB, B)= mkVec(mapp,(id2,ix2,last,args))
241 :     val (vD, D)=aaV(DstOp.prodVec(last),[vA, vB],"prodV",DstTy.TensorTy([last]))
242 :     in (vD, A@B@D)
243 :     end
244 :     (*error here *)
245 :     (*summation over product of 2 projections*)
246 :     fun mkprodSumVec(mapp,(m,[],i,args))= let
247 : cchiw 2605
248 : cchiw 2555 val i'=i+1
249 :     val (vD,D)=mkprodVec(mapp,(m,[],i',args))
250 :     val (vE, E)=aaV(DstOp.sumVec(i'),[vD],"sumVec",DstTy.realTy)
251 :     in (vE, D @E)
252 :     end
253 :    
254 :     (*product of -1 and 1 projection*)
255 :     fun mkNegV(mapp,((vA,id,ix),[],last,args))=let
256 :     val (vB, B)= mkVec(mapp,(id,ix,last,args))
257 :     val (vD, D)=aaV(DstOp.prodScaV(last),[vA, vB],"prodScaV",DstTy.TensorTy([last]))
258 :     in (vD,B@D) end
259 :    
260 :    
261 :    
262 :     (*Dot Product like summation
263 :     Does Vec x Vec *)
264 :     fun sumDot(a, ( m,sx,last,args))=let
265 :     val [(E.V v,lb,ub)]=sx
266 :     fun sumI(a,0,rest,code)=let
267 :     val mapp =insert(v, 0) a
268 :     (*val mapp=a@[lb]*)
269 :     val (vD,pre)=mkprodVec(mapp,(m,[],last,args))
270 :     val (vE, E)=aaV(DstOp.sumVec(last),[vD],"SumVec",DstTy.TensorTy([]))
271 :     val rest'=[vE]@rest
272 :     val (vF, F)=mkMultiple( rest',DstOp.addSca,DstTy.TensorTy([]))
273 :     in (vF,pre@E@code@F) end
274 :     | sumI(a,sx,rest',code')=let
275 :     (* val mapp=a@[(sx+lb)]*)
276 :     val mapp =insert(v, (sx+lb)) a
277 :     val (vD,pre)=mkprodVec(mapp,(m,[],last,args))
278 :     val (vE, E)=aaV(DstOp.sumVec(last),[vD],"SumVec",DstTy.TensorTy([]))
279 :     in sumI(a,sx-1,[vE]@rest',pre@E@code') end
280 :     in sumI(a, (ub-lb), [],[]) end
281 :    
282 :    
283 :     (*Can do multiple summations *)
284 :     fun sum(a, ( m,sx,args))=let
285 : cchiw 2608 val _ =print "\n in Summation Helper"
286 :     fun sumI1(left,(v,0,lb1),[],rest,code)=let
287 : cchiw 2605
288 : cchiw 2555 val mapp =insert(v, lb1) left
289 :     val (vD,pre)=mkprodSca(mapp,(m,[],args))
290 :     in ([vD]@rest,pre@code)
291 :     end
292 :     | sumI1(left,(v,i,lb1),[],rest,code)=let
293 :     val mapp =insert(v, i+lb1) left
294 :     val (vD,pre)=mkprodSca(mapp,(m,[],args))
295 :     in sumI1(left,(v,i-1,lb1),[],[vD]@rest,pre@code)
296 :     end
297 :     | sumI1(left,(v,0,lb1),(E.V a,lb2,ub)::sx,rest,code)=let
298 :     val mapp =insert(v, lb1) left
299 :     in sumI1(mapp,(a,ub-lb2,lb2),sx,rest,code) end
300 :     | sumI1(left,(v,s,lb1),(E.V v',lb2,ub)::sx,rest,code)=let
301 :     val mapp =insert(v, s+lb1) left
302 :     val (rest',code')=sumI1(mapp,(v',ub-lb2,lb2),sx,rest,code)
303 :     in sumI1(left,(v,s-1,lb1),(E.V v',lb2,ub)::sx,rest',code') end
304 :    
305 :     val (E.V v,lb,ub)=hd(sx)
306 :     val(li, code)=sumI1(empty,(v,ub-lb,lb),tl(sx),[],[])
307 :     val (vF, F)=mkMultiple(li,DstOp.addSca,DstTy.TensorTy([]))
308 :     in (vF,code@F) end
309 :    
310 :    
311 :     fun mkC n= let
312 :     val (vB,B)=aaV(DstOp.C(n),[],"Const",DstTy.TensorTy([]))
313 :     in (vB,B) end
314 :    
315 :    
316 :     fun evalDelta2(a,b,mapp)= let
317 :     val i=mapIndex(a,mapp)
318 :     val j=mapIndex(b,mapp)
319 :     in if(i=j) then mkC 1 else mkC 0
320 :    
321 :     end
322 :    
323 :     (*Field/Kern*)
324 :     fun evalDelta(dels,mapp)=let
325 :     fun m(a,b)=if(a=b) then 1 else 0
326 :     fun ij(i,j)=(case (i,j)
327 :     of (E.V a, E.V b)=>m(mapIndex(i,mapp),mapIndex(j,mapp))
328 :     | (E.C a, E.V b)=>m(a,mapIndex(j,mapp))
329 :     | (E.V a, E.C b)=>m(mapIndex(i,mapp),b)
330 :     | (E.C a, E.C b)=>m(i,j)
331 :     (*end case*))
332 :     val dels'=List.map ij dels
333 :     in
334 :     List.foldl(fn(x,y)=>x+y) 0 dels'
335 :     end
336 :    
337 :    
338 :    
339 :    
340 :     fun evalEps(a,b,c,mapp)=let
341 :     val i=mapIndex(E.V a,mapp)
342 :     val j=mapIndex(E.V b,mapp)
343 :     val k=mapIndex(E.V c,mapp)
344 :     in
345 :     if(i=j orelse j=k orelse i=k) then 0
346 :     else
347 :     if(j>i) then
348 :     if(j>k andalso k>i) then ~1 else 1
349 :     else if(i>k andalso k>j) then 1 else ~1
350 :    
351 :     end
352 :    
353 :    
354 :    
355 :     fun skeleton A=(case A
356 :     of [DstIL.ASSGN(_,DstIL.OP(DstOp.C 0,_))]=>0
357 :     | [DstIL.ASSGN(_,DstIL.OP(DstOp.C 1,_))]=>1
358 :     | [DstIL.ASSGN(_,DstIL.OP(DstOp.C ~1,_))]=> ~1
359 :     | _ => 9
360 :     (*end case*))
361 :     end
362 :    
363 :    
364 :    
365 :    
366 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0