Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/mid-to-low/gen-kernel.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2827, Tue Nov 11 00:18:38 2014 UTC revision 2828, Wed Nov 12 06:30:58 2014 UTC
# Line 10  Line 10 
10      structure S3=step3      structure S3=step3
11   structure tS= toStringEin   structure tS= toStringEin
12  structure P=Printer  structure P=Printer
13       structure Var = LowIL.Var
14    
15      in      in
16    
17  val testing=0  val testing=1
18  val Sca=DstTy.TensorTy []  val Sca=DstTy.TensorTy []
19  val iTy=DstTy.IntTy  val iTy=DstTy.IntTy
20  val addR=DstOp.addSca  val addR=DstOp.addSca
# Line 40  Line 41 
41  fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca)  fun mkSubSca rest= S3.aaV(DstOp.subSca,rest,"subSca",Sca)
42  fun mkAddSca rest= S3.aaV(addR,rest,"addSca",Sca)  fun mkAddSca rest= S3.aaV(addR,rest,"addSca",Sca)
43  fun mkAddInt rest= S3.aaV(addR,rest,"addInt",iTy)  fun mkAddInt rest= S3.aaV(addR,rest,"addInt",iTy)
44    fun mkAddPtr (rest,ty)= S3.aaV(addR,rest,"addPtr",ty)
45    fun mkProdInt rest= S3.aaV(DstOp.prodSca,rest,"prodInt",iTy)
46    
47    fun printMap (n,lb)=testp["\n from ",Int.toString n," to ",Int.toString lb]
48    
49  fun mkCons(shape, rest)=let  fun mkCons(shape, rest)=let
50      val ty=DstTy.TensorTy shape      val ty=DstTy.TensorTy shape
# Line 51  Line 56 
56      end      end
57    
58    
 (*Image positions*)  
 fun mkSimpleImgOp(mapp,e,args)=(case e  
     of  E.Add[ E.Tensor(t1,ix1),E.Value v1]=> let  
         val (vA,A)=S3.mkIntAsn(mapp,(t1,ix1,args))  
         val (vB,B)= mkInt (find(v1,mapp))  
         val (vD,D)=mkAddInt[vA,vB]  
         in (vD,A@B@D) end  
     | _ => raise Fail"Probed position is not subtraction or addition"  
     (* end case*))  
   
   
59  (*Add,Subtract Scalars*)  (*Add,Subtract Scalars*)
60  fun mkSimpleOp(mapp,e,args)=let  fun mkSimpleOp(mapp,e,args)=let
61      fun subP e=(case e      fun subP e=(case e
# Line 114  Line 108 
108  fun mkImg(mappOrig,sx,[(fid,ix,px)],v,vNew,info,sid,lb,ub,top, R)=let  fun mkImg(mappOrig,sx,[(fid,ix,px)],v,vNew,info,sid,lb,ub,top, R)=let
109    
110      val dim=length(px)      val dim=length(px)
111      val (vlb,BB)=  mkInt lb  (*last summaiton index*)      val (vlb,BB)=  mkInt lb  (*last summation index*)
112      val ptyTy=DstTy.AddrTy v      val ptyTy=DstTy.AddrTy v
   
113      val (vBase,base)=S3.aaV(DstOp.baseAddr v,[vNew],"baseAddr",ptyTy)      val (vBase,base)=S3.aaV(DstOp.baseAddr v,[vNew],"baseAddr",ptyTy)
     val px=List.rev px  
114    
     fun createImgVar mapp=let  
         val _ =print "\n createImgVar"  
         fun mkpos(e,rest,code)=(case e  
             of [E.Add[E.Tensor(t1,ix1),_]]=> let  
                 val _ =print "\n last mkpos 1"  
115    
116                  val (vA,A)=S3.mkIntAsn(mapp,(t1,ix1,info))      val E.Add[E.Tensor(t1,ix1),_ ]=List.hd(px)
117        val (vA,A)=S3.mkIntAsn(mappOrig,(t1,ix1,info))
118                  val (vC,C)=mkAddInt [vA,vlb]                  val (vC,C)=mkAddInt [vA,vlb]
119                  in      val px=List.drop(px,1)
                     (rest@[vC],code@A@C)  
                 end  
             | pos1::es => let  
                 val _ =print "\n multiple mkpos "  
                 val _=print(P.printbody(pos1))  
                 val (vF,code1)=mkSimpleImgOp(mapp,pos1,info)  
120    
121                  in  
122                      mkpos(es,rest@[vF],code@code1)      (*gets img address*)
123                  end      fun breakImgAddr (imgType,args) = let
124              | _=> raise Fail "Non-addition in Image Position"          val (vX,X)=(case ( ImageInfo.sizes v,args)
125                of ([x, _ ],[a,b]) =>let
126                    val (vA,A)=  mkInt x
127                    val(vB,B)=mkProdInt [vA,a]
128                    val (vC,C)=mkAddInt [b,vB]
129                    in (vC,A@B@C) end
130                | ([x,y,z],[a,b,c])  =>let
131                    val (vA,A)=  mkInt y
132                    val(vB,B)=mkProdInt [vA,a]
133                    val (vC,C)=mkAddInt [b,vB]
134                    val (vD,D)=  mkInt x
135                    val(vE,E)=mkProdInt [vD,vC]
136                    val (vF,F)=mkAddInt [c,vE]
137                    in (vF,A@B@C@D@E@F) end
138          (*end case*))          (*end case*))
139    
140          val (vF,F)= mkpos(px,[],[])          (*Index Image Specific Index*)
141          val imgType=DstTy.imgIndex(List.map (fn (e1)=> S3.mapIndex(e1,mapp)) ix)          val (vC,C)=(case imgType
142          val (vA,A)=S3.aaV(DstOp.imgAddr(v,imgType,dim),[vBase]@vF,"Imageaddress",ptyTy)              of [] => mkAddPtr([vBase,vX],ptyTy)
143          val (vB,B)=S3.aaV(DstOp.imgLoad(v,dim,R),[vA],"imgLoad",DstTy.tensorTy([R]))              | [0] => mkAddPtr([vBase,vX],ptyTy)
144          in              | [i]=> let
145              (vB,F@A@B)                  val (vA,A)= mkAddPtr([vBase,vX],ptyTy)
146                    val (vB,B)=  mkInt i
147                    val (vC,C)= mkAddPtr([vB, vA],ptyTy)
148                    in (vC,A@B@C)end
149            (*end case*))
150            in (vC,X@C) end
151    
152    
153        fun createImgVar mapp=let
154            fun mkpos(e,rest,code)=(case e
155                of [] => (rest@[vC],code)
156                | (E.Add[ E.Tensor(t1,ix1),E.Value v1]::es)=> let
157                    val (vA,A)=S3.mkIntAsn(mapp,(t1,ix1,info))
158                    val (vB,B)= mkInt (find(v1,mapp))
159                    val (vD,D)=mkAddInt[vA,vB]
160                    in mkpos(es,rest@[vD],code@A@B@D) end
161                | _ => raise Fail"Incorrect pos for Image"
162            (*end case*))
163    
164            val (vA,A)= mkpos(px,[],[])
165            val imgType=List.map (fn (e1)=> S3.mapIndex(e1,mapp)) ix
166            val (vC,C)=breakImgAddr (imgType,vA)
167            val (vD,D)=S3.aaV(DstOp.imgLoad(v,dim,R),[vC],"imgLoad",DstTy.tensorTy([R]))
168            in
169                (vD,A@C@D)
170          end          end
171    
172      fun printMap (n,lb)=testp["\n from ",Int.toString n," to ",Int.toString lb]      val R'=R-1
173      fun sumI1(lft,dict,0,0,code,n')=let      fun sumPos([sid],lft,code,dict,0)=let
174          val mapp=insert (n', lb) dict          val n'=lb
175          val _ =printMap(n',lb)          val mapp=insert (sid, n') dict
176          val (lft', code')= createImgVar mapp          val (lft', code')= createImgVar mapp
177          in ([lft']@lft,code'@code) end          in ([lft']@lft,code'@code) end
178      |  sumI1(lft,dict,i,0,code,n')=let      | sumPos([sid],lft,code,dict,r)=let
179          val mapp=insert (n', i-1) dict          val n'=lb+r
180           val _ =printMap(n',i-1)          val mapp=insert (sid, n') dict
181          val (lft', code')=createImgVar mapp          val (lft', code')=createImgVar mapp
182          in sumI1([lft']@lft,dict,i-1,0,code'@code,n') end          in sumPos([sid],[lft']@lft,code'@code,dict,r-1) end
183      | sumI1(lft,dict,0,n,code,n')=let      | sumPos(sid::sxx,lft,code,dict,0)=let
184          val mapp=insert (n', lb) dict          val n'=lb
185           val _ =printMap(n',lb)          val mapp=insert (sid, n') dict
186          in sumI1(lft,mapp,top,n-1,code,n'+1) end          val (lft',code')=sumPos(sxx,lft,[],mapp,R')
187      | sumI1(lft,dict,i,n, code,n')=let          in (lft',code'@code) end
188           val mapp=insert (n', i+lb) dict  
189           val _ =printMap(n',i+lb)      | sumPos(sid::sxx,lft,code,dict,r)=let
190          val (lft',code')=sumI1(lft,mapp,top,n-1,code,n'+1)          val n'=lb+r
191          in sumI1(lft',dict,i-1,n,code',n') end          val mapp=insert (sid, n') dict
192            val (lft',code')=sumPos(sxx,lft,[],mapp,R')
193      val _ = testp["\n top", itos top, "dim", itos dim, "sid", itos sid ]          in
194      val(lft,code)=sumI1([],mappOrig,top,dim-2,[],sid)              sumPos(sid::sxx,lft',code'@code,dict,r-1)
195            end
196        val sxx= List.map(fn (E.V sid,_,_)=> sid) sx
197        val(lft,code)=sumPos(List.drop(sxx,1),[],[],mappOrig,R')
198      in      in
199          (lft,base@BB@code)          (lft,base@BB@A@C@code)
200      end      end
201    
202    
# Line 200  Line 222 
222              in  mkpos(k,fin,l@[l'],dict,i-1,code@code',n)              in  mkpos(k,fin,l@[l'],dict,i-1,code@code',n)
223              end              end
224          (*end case*))          (*end case*))
225      fun evalK([],[],_,code,newId)=(newId,code)      fun evalK([],[],_,_,code,newId)=(newId,code)
226        | evalK(kn::kns,x::xs,n,code,newId)=let        | evalK(kn::kns,x::xs,n,direction,code,newId)=let
227          val (_,dk,_) =kn          val (i,dk,_) =kn
228          val (id,kcode)= evalKrn.expandEvalKernel (R,h, dk, x,n)          val (id,kcode)= evalKrn.expandEvalKernel ("h"^Int.toString(direction),R,h, dk, x,n)
229          in          in
230              evalK(kns,xs,n+1,code@kcode,newId@[id])              evalK(kns,xs,n+1,direction-1,code@kcode,newId@[id])
231          end          end
232       |evalK _ =raise Fail "Non-equal variable list, error in mkKrns"       |evalK _ =raise Fail "Non-equal variable list, error in mkKrns"
233    
234      val(lftkrn,code)=mkpos(newdels,[],[],mappOrig,top,[],sid)      val(lftkrn,code)=mkpos(newdels,[],[],mappOrig,top,[],sid)
235      val (lft,code')=consfn((lftkrn),[],[],top)      val (lft,code')=consfn((lftkrn),[],[],top)
236      val (ids, evalKcode)=evalK(newdels,lft,0,[],[])      val (ids, evalKcode)=evalK(newdels,lft,0,length(newdels)-1,[],[])
237      in      in
238          (List.rev ids, code@code'@evalKcode)          (List.rev ids, code@code'@evalKcode)
239      end      end
# Line 224  Line 246 
246          val ty=DstTy.TensorTy [R]          val ty=DstTy.TensorTy [R]
247          val a=DstIL.Var.new("Cons"  ,ty)          val a=DstIL.Var.new("Cons"  ,ty)
248          val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))          val code=DstIL.ASSGN (a,DstIL.CONS(ty ,rest))
249            val _ =testp[tS.toStringAll(ty,code)]
250          in (a, [code])          in (a, [code])
251          end          end
252    
253    
254      fun dhz([],conslist,rest,code,_,_)=(conslist,code)      fun get3d(imgArg, F0,F1,F2)=let
255        | dhz(e::es,conslist,rest,code,hz,0)=let          val _ =print"\n ____________________DX____________________________"
256          val (vA,A)=sumP(e,hz,R)          val r=R-1
257            fun dhz([] , _ ,code,_,rest)=(rest,code)
258            |dhz(e1::es,rest,code,0,consrest)=let
259                val (vA,A)=sumP(F0,e1,R)
260          val (vD,D)=ConsInt(R,rest@[vA])          val (vD,D)=ConsInt(R,rest@[vA])
261          in dhz(es,conslist@[vD],[],code@A@D,hz,R-1)              in
262                dhz(es,[],code@A@D,r,consrest@[vD])
263                end
264            |dhz(e1::es,rest,code,n,consrest)=let
265                val (vA,A)=sumP(F0,e1,R)
266                in
267                    dhz(es,rest@[vA],code@A,n-1,consrest)
268                end
269            fun dhy([],_,code,_,rest)=(rest,code)
270            |dhy(e1::es,rest,code,0,consrest)=let
271                val (vA,A)=sumP(F1,e1,R)
272                val (vD,D)=ConsInt(R,rest@[vA])
273                in
274                dhy(es,[],code@A@D,r,consrest@[vD])
275          end          end
276        | dhz(e::es,conslist,rest,code,hz,r)=let          |dhy(e1::es,rest,code,n,consrest)=let
277          val (vA,A)=sumP(e,hz,R)              val (vA,A)=sumP(F1,e1,R)
278          in dhz(es,conslist,rest@[vA],code@A,hz,r-1)              in
279                dhy(es,rest@[vA],code@A,n-1,consrest)
280                end
281    
282                  val _ =print(Int.toString(List.length imgArg))
283            val (restZ,codeZ)=dhz(imgArg,[],[],r,[])
284             val _ =print"\n _______________DY__________________________________"
285            val _ =print(Int.toString(List.length restZ))
286    
287    
288            val ([restY],codeY)=dhy(restZ,[],[],r,[])
289                 val _ =print"\n _______________DZ__________________________________"
290    
291             val (vA,A)=sumP(F2,restY,R)
292    
293            in
294                (vA,codeZ@codeY@A)
295      end      end
296    
297    
# Line 260  Line 315 
315              in              in
316                  (vE,code@E)                  (vE,code@E)
317              end              end
318      |   [F0,F1,F2]=>let      |   [F0,F1,F2]=>
319            get3d(imgArg, F0,F1,F2)
320    (*let
321    
322          val (vZ,codeZ)=dhz(imgArg,[],[],[],F0,R-1)          val (vZ,codeZ)=dhz(imgArg,[],[],[],F0,R-1)
323          val (vY,codeY)=dhy(vZ,[],[],F1)          val (vY,codeY)=dhy(vZ,[],[],F1)
324          val (vE,E)=sumP(vY,F2,R)          val (vE,E)=sumP(vY,F2,R)
325          in          in
326              (vE,codeZ@codeY@E)              (vE,codeZ@codeY@E)
327          end          end*)
328      | _ => raise Fail "Kernel dimensions not between 1-3"      | _ => raise Fail "Kernel dimensions not between 1-3"
329          (*end case*))          (*end case*))
330      end      end
# Line 277  Line 335 
335      val (img1,k1)=sortK([],[],e)      val (img1,k1)=sortK([],[],e)
336    
337      val (E.V sid,lb,ub)=  hd(List.rev(sx))      val (E.V sid,lb,ub)=  hd(List.rev(sx))
338           (* val (E.V sid,lb,ub)=hd(sx)*)
339      val top=(ub-lb)      val top=(ub-lb)
340      val R=top+1      val R=top+1
   
   
341      val (imgArg,imgCode)= mkImg(mapp,sx,img1,v,vNew,info,sid,lb,ub,top,R)      val (imgArg,imgCode)= mkImg(mapp,sx,img1,v,vNew,info,sid,lb,ub,top,R)
   
342      val (E.V sid,lb,ub)=hd(sx)      val (E.V sid,lb,ub)=hd(sx)
343      val (krnArg, krnCode)= mkkrns(mapp,sx,k1,h,info,sid,lb,ub,top,R)      val (krnArg, krnCode)= mkkrns(mapp,sx,k1,h,info,sid,lb,ub,top,R)
344      val (vA,A)=prodImgKrn(imgArg,krnArg,R)      val (vA,A)=prodImgKrn(imgArg,krnArg,R)

Legend:
Removed from v.2827  
changed lines
  Added in v.2828

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0