Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Annotation of /branches/charisee/src/compiler/mid-to-low/vec-to-low-set.sml
ViewVC logotype

Annotation of /branches/charisee/src/compiler/mid-to-low/vec-to-low-set.sml

Parent Directory Parent Directory | Revision Log Revision Log


Revision 3544 - (view) (download)

1 : cchiw 3444 (*
2 :     * helper functions used for vector operations.
3 :     * At this point every index is bound to an int
4 :     * and we are ready to return low-IL code.
5 :     * using LowIL vectors ops like subVec,addVec,prodVec..
6 :     *
7 :     * EIN->scan->iter->here.
8 :     *)
9 :     structure VecToLowSet = struct
10 :     local
11 :    
12 :     structure DstIL = LowIL
13 :     structure DstTy = LowILTypes
14 :     structure DstOp = LowOps
15 :     structure LowToS= LowToString
16 :     structure Var = LowIL.Var
17 :     structure E = Ein
18 :     structure H=HelperSet
19 :     in
20 :    
21 :     fun insert e=H.insert e
22 :     fun assgn e=H.assignOP e
23 :     fun indexTensor e=H.indexTensor e
24 :     fun mkProdVec e =H.mkProdVec e
25 :     fun mkSumVec e =H.mkSumVec e
26 :     fun mkMultiple e=H.mkMultiple e
27 :     fun getTensorTy e=H.getTensorTy e
28 :     fun mapIndex e = H.mapIndex e
29 :     fun assignOP e= H.assignOP e
30 :     fun assgnCons e= H.assgnCons e
31 :     fun iToss e = H.iToss e
32 :     val realTy=DstTy.TensorTy []
33 : cchiw 3542 val testing=false
34 :     fun testp n =if (testing) then (print(String.concat n);1) else 1
35 : cchiw 3444
36 :     (*projTensor:dict*(string* E.params*Var list*int*E.tensor_id*E.alpha)->Var*code list
37 :     * projects tensor to a vector
38 :     *just used by EintoVecOps but made sense to keep it here
39 :     *)
40 : cchiw 3542 fun projTensor(setA,_,(_,_,args,_,id,[]))= (setA,List.nth(args,id),[])
41 : cchiw 3444 | projTensor(setA,mapp,(lhs,params,args,vecIX,id,ix))= let
42 :     val nU=List.nth(args,id)
43 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
44 : cchiw 3544 val ix'= ixx (*index Ty*)
45 : cchiw 3444 val argTy= getTensorTy(params,id)
46 :     val vecTy=DstTy.TensorTy [vecIX]
47 :     val opp=DstOp.ProjectLast(id,vecIX,ix',argTy)
48 :     val name=String.concat["ProjLast_",iToss ixx,"_"]
49 :     val opset= lowSet.LowSet.empty
50 :     val (setB,vB,B)= assignOP(setA,opp,[nU],name,vecTy)
51 :     in
52 :     (setB,vB,B)
53 :     end
54 :    
55 :     fun projFirst(setA,mapp,(lhs,params,args,vecIX,id,ix))= let
56 :     val nU=List.nth(args,id)
57 :     val argTy= getTensorTy(params,id)
58 :     val vecTy=DstTy.TensorTy [vecIX]
59 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
60 :     fun f cnt = let
61 : cchiw 3544 val ix'=cnt::ixx
62 : cchiw 3444 val opp=DstOp.IndexTensor(id,ix',argTy)
63 :     val name=String.concat["IndexTensor_",iToss ixx,"_"]
64 :     in
65 :     assignOP(setA,opp,[nU],name,realTy)
66 :     end
67 :     fun iter ([],vCs,Cs)=(vCs,Cs)
68 :     | iter((_,vB,B)::es,vCs,Cs)= iter(es,vCs@[vB],Cs@B)
69 :     val ops=List.tabulate( vecIX, fn e=> f e)
70 :     val (vCs,Cs)=iter(ops,[],[])
71 :     val (setD,vD,D)=assgnCons(setA,"projFirstCons", [vecIX], vCs)
72 : cchiw 3542 in
73 :     (setD,vD, Cs@D)
74 :     end
75 : cchiw 3444
76 :     fun projFirst2(setA,mapp,(lhs,params,args,opset,vecIX,id,ix))= let
77 :     val nU=List.nth(args,id)
78 :     val argTy= getTensorTy(params,id)
79 :     val vecTy=DstTy.TensorTy [vecIX]
80 :     val ixx=(List.map (fn (e1)=> mapIndex(e1,mapp)) ix)
81 : cchiw 3544 val ix'=(*DstTy.index Ty*) ixx
82 : cchiw 3444 val opp=DstOp.ProjectFirst(id,vecIX,ix',argTy)
83 :     val name=String.concat["ProjFirst_",iToss ixx,"_"]
84 : cchiw 3542 in
85 :     assignOP(setA,opp,[nU],name,vecTy)
86 :     end
87 : cchiw 3444
88 :    
89 :     (*--------------------Vectorization Helper Functions--------------------*)
90 :     (*negN:dict*string*E.params*Var list * int*Var*E.tensor_id*E.alpha
91 :     * ->Var*LowIL.ASSGN list*)
92 :     fun negV(setT,mapp,(lhs,params,args,vecIX ,vA, id2,ix2))=let
93 :     val (setB,vB,B)= projTensor(setT,mapp,(lhs,params,args,vecIX,id2,ix2))
94 :     val (setC,vC,C)=assgn(setB,DstOp.prodScaV vecIX,[vA, vB],"prodScaV",DstTy.TensorTy [vecIX])
95 :     in
96 :     (setC,vC,B@C)
97 :     end
98 :    
99 :     (*subN:dict*string*E.params*Var list * int*E.tensor_id*E.alpha*E.tensor_id*E.alpha
100 :     * ->Var*LowIL.ASSGN list
101 :     * Vector Subtraction
102 :     *)
103 :     fun subV(setT,mapp,(lhs,params, args,vecIX,id1,ix1,id2,ix2))=let
104 :     val (setA,vA,A)= projTensor(setT, mapp,(lhs,params,args,vecIX,id1,ix1))
105 :     val (setB,vB,B)= projTensor(setA, mapp,(lhs,params,args,vecIX,id2,ix2))
106 :     val (setC,vC,C)= assgn(setB, DstOp.subVec vecIX ,[vA, vB],"subVec",DstTy.TensorTy [vecIX])
107 :     in
108 :     (setC,vC, A@B@C)
109 :     end
110 :    
111 :     (*Vector Addition
112 :     *addN:dict*string*E.params*Var list * int*(E.tensor_id*E.alpha)list
113 :     * ->Var*LowIL.ASSGN list
114 :     *)
115 :     fun addV(setT,mapp,(lhs,params, args,vecIX,rest))=let
116 :     fun add(setB,[],rest,code)=(setB,rest,code)
117 :     | add(setB,(id1,ix1)::es,rest,code)=let
118 :     val (setC,vC,C)= projTensor(setB, mapp,(lhs,params,args,vecIX,id1,ix1))
119 :     in
120 :     add(setC,es,rest@[vC],code@C)
121 :     end
122 :     val (setB,rest,code)=add(setT,rest,[],[])
123 :     val (setC,vC,C)=mkMultiple(setB,rest,DstOp.addVec vecIX,DstTy.TensorTy([vecIX]))
124 :     in
125 :     (setC,vC,code@C)
126 :     end
127 :    
128 :     (*scaleNN:dict*string*E.params*Var list * int*E.tensor_id*E.alpha*E.tensor_id*E.alpha)
129 :     * ->Var*LowIL.ASSGN list
130 :     *Vector Scaling
131 :     *)
132 :     fun scaleV(setT,mapp,(lhs,params , args,vecIX,id1,ix1,id2,ix2))=let
133 :     val (setA,vA,A)= indexTensor(setT,mapp,(lhs,params,args,id1,ix1,realTy))
134 :     val (setB,vB,B)= projTensor(setA, mapp,(lhs,params,args,vecIX,id2,ix2))
135 :     val (setC,vC,C)= assgn(setB, DstOp.prodScaV vecIX,[vA, vB],"prodScaV",DstTy.TensorTy([vecIX]))
136 :     in
137 :     (setC,vC,A@B@C)
138 :     end
139 :    
140 :     (*prodN:dict*string*E.params*Var list * int*E.tensor_id*E.alpha*E.tensor_id*E.alpha)
141 :     * ->Var*LowIL.ASSGN list
142 :     *Vector Product
143 :     *)
144 :     fun prodV(setT,mapp,(lhs,params,args,vecIX,id1,ix1,id2,ix2))= let
145 :     val (setA,vA,A)= projTensor(setT, mapp,(lhs,params,args,vecIX,id1,ix1))
146 :     val (setB,vB,B)= projTensor(setA, mapp,(lhs,params,args,vecIX,id2,ix2))
147 :     val (setC,vC,C) = mkProdVec(setB,vecIX,[vA, vB])
148 :     in
149 : cchiw 3542 (setC,vC, A@B@C)
150 : cchiw 3444 end
151 :    
152 :     (*dotN:dict*string*E.params*Var list * int*E.tensor_id*E.alpha*E.tensor_id*E.alpha)
153 :     * ->Var*LowIL.ASSGN list
154 :     *dot product
155 :     *)
156 :     fun dotV(setT,mapp,(lhs,params, args,vecIX,id1,ix1,id2,ix2))=let
157 :     val (setD,vD,D)= prodV(setT,mapp,(lhs,params,args,vecIX,id1,ix1,id2,ix2))
158 :     val (setE,vE,E)= mkSumVec(setD,vecIX,[vD])
159 :     in
160 :     (setE,vE, D @E)
161 :     end
162 :    
163 :     fun VM(setT,mapp,(lhs,params, args,vecIX,id1,ix1,id2,ix2))=let
164 :     val vA= List.nth(args,id1)
165 :     val (setB,vB,B)= projTensor(setT, mapp,(lhs,params,args,vecIX,id2,ix2))
166 :     val (setD,vD,D)= mkProdVec(setB,vecIX,[vA, vB])
167 :     val (setE,vE,E)= mkSumVec(setD,vecIX,[vD])
168 :     in
169 :     (setE,vE, B@D@E)
170 :     end
171 :    
172 :     fun MM3(setT,mapp,(lhs,params,args, vecIX,id1,ix1,id2,ix2,_))= let
173 :     val (setA,vA,A)= projTensor(setT, mapp,(lhs,params,args,vecIX,id1,ix1))
174 :     val (setB,vB,B)= projFirst(setA,mapp,(lhs,params,args,vecIX,id2,ix2))
175 :     val (setD,vD,D)= mkProdVec(setB,vecIX,[vA, vB])
176 :     val (setE,vE,E)= mkSumVec(setD,vecIX,[vD])
177 :     in
178 :     (setE,vE,A@B@D@E)
179 :     end
180 :    
181 :     (*sumDotN:dict*string*E.params*Var list * (E.nu)*int*E.tensor_id*E.alpha*E.tensor_id*E.alpha)
182 :     * ->Var*LowIL.ASSGN list
183 :     *Sum of dot Product
184 :     *)
185 :     fun sumDotV(setT,mapp,(lhs,params, args,(E.V v,lb,ub),ub1,id1,ix1,id2,ix2))=let
186 :     val nextfnargs=(lhs,params, args,ub1,id1,ix1,id2,ix2)
187 :     fun sumI(setA,a,0,rest,code)=let
188 :     val mapp =insert(v, 0) a
189 :     val (setE,vE,E)=dotV(setA,mapp,nextfnargs)
190 :     val rest'=[vE]@rest
191 :     val (setF,vF, F)=mkMultiple(setE,rest',DstOp.addSca,realTy)
192 :     in
193 :     (setF,vF,E@code@F)
194 :     end
195 :     | sumI(setA,a,sx,rest',code')=let
196 :     val mapp =insert(v, (sx+lb)) a
197 :     val (setE,vE,E)=dotV(setA,mapp,nextfnargs)
198 :     in
199 :     sumI(setE,a,sx-1,[vE]@rest',E@code')
200 :     end
201 :     in
202 :     sumI(setT,mapp, (ub-lb), [],[])
203 :     end
204 :     | sumDotV _= raise Fail "Non-variable index in summation"
205 :    
206 :     end
207 :    
208 :    
209 :     end

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0