Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2583, Thu Apr 10 19:50:28 2014 UTC revision 2584, Tue Apr 15 03:22:58 2014 UTC
# Line 35  Line 35 
35  structure SrcV = SrcIL.Var  structure SrcV = SrcIL.Var
36  structure P=Printer  structure P=Printer
37  structure shiftHtM=shiftHtM  structure shiftHtM=shiftHtM
38    structure split=splitHtM
39    
40    
41  datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int  datatype peanut=    O of  DstOp.rator | E of Ein.ein|C of DstTy.ty|S of int
# Line 43  Line 44 
44    
45    
46  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
47  fun assignEin (x, rator, args) = (print "assignment";(x, DstIL.EINAPP(rator, args)))  fun assignEin (x, rator, args) = ((x, DstIL.EINAPP(rator, args)))
48    
49  fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args))  fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args))
50  fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args))  fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args))
# Line 94  Line 95 
95    
96    
97  (*Create fractional, and integer position vectors*)  (*Create fractional, and integer position vectors*)
98  fun createArgs2  (dim,v,posx)=let  fun createArgs  (dim,v,posx)=let
99      val zz=print(String.concat["\n XYZ DIM-createArgs2-",Int.toString(dim),"-"])  
100      val translate=DstOp.Translate v      val translate=DstOp.Translate v
101      val transform=DstOp.Transform v      val transform=DstOp.Transform v
102      val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)      val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
# Line 117  Line 118 
118          assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)          assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
119      ]      ]
120    
121      in ([f,n],code)      in ([n,f],code)
122      end      end
123    
124    
 fun Position(img,t,newposArg,dim,pos1,ppos)=let  
     val zz=print(String.concat["\n XYZ DIM-Position-",Int.toString(dim),"-"])  
     val (args',code')=createArgs2(dim,img,newposArg)  
     val pos2=pos1+1  
     in (pos1,pos2, ppos,args',code')  
     end  
125    
126    
127    
# Line 138  Line 133 
133  | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)  | createDels(d::ds,dim)= [( E.C dim,d)]@createDels(ds,dim)
134    
135    
136  fun replaceH(kvar, place,args)=let  
     val l1=List.take(args, place)  
     val l2=List.drop(args,place+1)  
     in l1@[kvar]@l2 end  
137    
138  (*Get Img, and Kern Args*)  (*Get Img, and Kern Args*)
139  fun getArgs(hid,hArg,V,imgArg,args)=  fun getArgs(hid,hArg,V,imgArg,args)=
# Line 149  Line 141 
141          of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let          of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let
142              val hvar=DstV.new ("KNL", DstTy.KernelTy)              val hvar=DstV.new ("KNL", DstTy.KernelTy)
143              val imgvar=DstV.new ("IMG", DstTy.ImageTy img)              val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
144                fun replaceH(kvar, place,args)=let
145                    val l1=List.take(args, place)
146                    val l2=List.drop(args,place+1)
147                    in l1@[kvar]@l2 end
148    
149              val args1=replaceH(hvar, hid,args)              val args1=replaceH(hvar, hid,args)
150              val args2=replaceH(imgvar, V,args1)              val args2=replaceH(imgvar, V,args1)
151              in              in
# Line 158  Line 155 
155          |  _ => raise Fail "Not a kernel argument"          |  _ => raise Fail "Not a kernel argument"
156    
157    
158    (*Get Img, and Krn Args*)
159    fun getArgs3(hid,hArg,V,imgArg,args)=
160        case (getRHS2 hArg,getRHS2 imgArg)
161            of ((O2(SrcOp.Kernel(h, i)),arg),(O2(SrcOp.LoadImage img),_))=> let
162                val hvar=DstV.new ("KNL", DstTy.KernelTy)
163                val imgvar=DstV.new ("IMG", DstTy.ImageTy img)
164                in
165                (Kernel.support h ,img, [assign (hvar, DstOp.Kernel(h, i), []), assign(imgvar,DstOp.LoadImage img,[])],[imgvar, hvar])
166                end
167            | ((O2(SrcOp.Kernel(h, i)),arg),_)=> raise Fail "Not an img Argument"
168            |  _ => raise Fail "Not a kernel argument"
169    
170    
 fun expandEinProbe(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),index, sx,origargs)=let  
   
   
     val E.IMG(dim)=List.nth(params,V)  
     val zz=print(String.concat["\n XYZ DIM-expanEin-",Int.toString(dim),"-"])  
   
     val kArg=List.nth(origargs,h)  
     val imgArg=List.nth(origargs,V)  
     val (s,img,argcode,args2) =getArgs(h,kArg,V,imgArg,args)  
     val newposArg=List.nth(args, t)  
   
     val ppos=[E.TEN(1,[dim]),E.TEN(3,[dim])]  
171    
     val (fid,nid,params',args',code')=Position(img,t,newposArg,dim,(length params),ppos)  
     val shift=(length index)  
172    
173     (* val z=print(String.concat["\n SHIFt SET To",Int.toString(shift),"SX IS ", Int.toString(sx)])*)  fun createBody(dim, s,sx,shape,deltas,V, h, nid, fid)=let
174    
175      (*sumIndex creating summaiton Index for body*)      (*sumIndex creating summaiton Index for body*)
176      fun sumIndex(0)=[]      fun sumIndex(0)=[]
# Line 189  Line 183 
183          val dim'=dim-1          val dim'=dim-1
184          val sum=sx+dim'          val sum=sx+dim'
185          val dels=createDels(deltas,dim')          val dels=createDels(deltas,dim')
         (*val L=print "\n creatWith "  
         val LL=print(Int.toString(dim'))*)  
186          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]          val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
187          val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))          val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
   
188          in          in
189              createKRN(dim',pos@imgpos,[rest']@rest)              createKRN(dim',pos@imgpos,[rest']@rest)
190          end          end
191    
192      val exp=createKRN(dim, [],[])      val exp=createKRN(dim, [],[])
193      val esum=sumIndex (dim)      val esum=sumIndex (dim)
194      in (params@params',E.Sum(esum, exp),args2@args',argcode@code') end      in E.Sum(esum, exp)
195        end
196    
197     (* Expand probe in place *)
198    fun expandEinProbe4(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),index, sx,origargs)=let
199    
200    
201        val E.IMG(dim)=List.nth(params,V)
202        val kArg=List.nth(origargs,h)
203        val imgArg=List.nth(origargs,V)
204        val newposArg=List.nth(args, t)
205    
206        val (s,img,argcode,args2) =getArgs(h,kArg,V,imgArg,args)
207        val params'=[E.TEN(3,[dim]),E.TEN(1,[dim])]
208        val (args',code')=createArgs(dim,img,newposArg)
209        val fid=length params
210        val nid=fid+1
211    
212        val body=createBody(dim, s,sx,shape,deltas,V, h, nid, fid)
213        in (params@params',body,args2@args',argcode@code') end
214    
215    (*Lift Probe*)
216    fun expandEinProbe3(E.Probe(E.Conv(V,shape,h, deltas),E.Tensor(t,alpha)),(params,args),index, sx,origargs)=let
217    
218    
219        val E.IMG(dim)=List.nth(params,V)
220        val kArg=List.nth(origargs,h)
221        val imgArg=List.nth(origargs,V)
222        val newposArg=List.nth(args, t)
223    
224        val (s,img,argcode,argsVH) =getArgs3(h,kArg,V,imgArg,args)
225        val params'=[E.IMG(dim),E.KRN,E.TEN(3,[dim]),E.TEN(1,[dim])]
226        val (argsT,code')=createArgs(dim,img,newposArg)
227    
228        val h=1
229        val fid=2
230        val nid=3
231    
232        val body=createBody(dim, s,sx,shape,deltas,0, h, nid, fid)
233        in (params',body,argsVH@argsT,argcode@code') end
234    
235    
236    
237  fun TS x  = case SrcIL.Var.binding x  fun TS x  = case SrcIL.Var.binding x
# Line 216  Line 247 
247    
248    
249    
250    
251  fun ShapeConv([],n)=[]  fun ShapeConv([],n)=[]
252      | ShapeConv(E.C c::es, n)=ShapeConv(es, n)      | ShapeConv(E.C c::es, n)=ShapeConv(es, n)
253      | ShapeConv(E.V v::es, n)=      | ShapeConv(E.V v::es, n)=
# Line 223  Line 255 
255          else ShapeConv(es,n)          else ShapeConv(es,n)
256    
257    
258    (*Lift probe*)
259  fun foundProbe(b,(params,args),index, sumIndex,origargs)=let  fun foundProbe3(b,(params,args),index, sumIndex,origargs)=let
260        val _=print "\n Lift Probe\n"
261      val E.Probe(E.Conv(_,alpha,_,dx),pos)=b      val E.Probe(E.Conv(_,alpha,_,dx),pos)=b
262    
263      val newId=length(params)      val newId=length(params)
# Line 243  Line 275 
275      val newArg = DstV.new ("PC", DstTy.tensorTy shape')      val newArg = DstV.new ("PC", DstTy.tensorTy shape')
276    
277    
   
278      (*Expand Probe*)      (*Expand Probe*)
279      val ns=length sumIndex      val ns=length sumIndex
280      val (params',body',args',code) =(case ns      val (params',body',args',code) =(case ns
281          of 0=> expandEinProbe(b,(params,args),index, n,origargs)          of 0=> expandEinProbe3(b,(params,args),index, n,origargs)
282          |_=>let          |_=>let
283              val (E.V v,_,_)=List.nth(sumIndex, ns-1)              val (E.V v,_,_)=List.nth(sumIndex, ns-1)
284              val (p,body',a,c)= expandEinProbe(b,(params,args),index, v+1,origargs)              val (p,body',a,c)= expandEinProbe3(b,(params,args),index, v+1,origargs)
285              in  (p,E.Sum(sumIndex ,body'),a,c)              in  (p,E.Sum(sumIndex ,body'),a,c)
286              end              end
287      (* end case *))      (* end case *))
# Line 258  Line 289 
289    
290      val (p',i',b',a')=shiftHtM.clean(params', index,body', args')      val (p',i',b',a')=shiftHtM.clean(params', index,body', args')
291      val newbie'=Ein.EIN{params=p', index=i', body=b'}      val newbie'=Ein.EIN{params=p', index=i', body=b'}
292      val zz=print(P.printerE(newbie'))       val _ = print(String.concat["\n ", split.printA(newArg, newbie', a'),"\n"])
293      val data=assignEin (newArg, newbie', a')      val data=assignEin (newArg, newbie', a')
     val gg=print(String.concat["\n\n ---",P.printerE(newbie'),"\n\n"])  
294      in (newB, (params@[newP],args@[newArg]) ,code@[data])      in (newB, (params@[newP],args@[newArg]) ,code@[data])
295      end      end
296    
297    
298     (* Expand probe in place *)
299     fun foundProbe4(b,(params,args),index, sumIndex,origargs)=let
300        val _=print "\n Don't replace probe \n"
301        val E.Probe(E.Conv(_,alpha,_,dx),pos)=b
302    
303        val newId=length(params)
304        val n=length(index)
305    
306  (*copied from high-to-mid.sml*)      (*Expand Probe*)
307        val ns=length sumIndex
308        val (params',body',args',code) =(case ns
309            of 0=> expandEinProbe4(b,(params,args),index, n,origargs)
310            |_=>let
311                val (E.V v,_,_)=List.nth(sumIndex, ns-1)
312                (* val _ =print(String.concat["\n Found sum sx, last element is ",Int.toString(v), "\n"])*)
313                in expandEinProbe4(b,(params,args),index, v+1,origargs)
314                end
315        (* end case *))
316    
317            val UUU=Ein.EIN{params=params', index=index, body=body'}
318            val _ =print(String.concat["\n $$$ new sub-expression $$$ \n",P.printerE(UUU),"\n"])
319    
320        in (body',(params',args') ,code)
321     end
322    
323    
324    fun flatten []=[]
325        | flatten(e1::es)=e1@(flatten es)
326    
327    
328     (* sx-[] then move out, otherwise keep in *)
329  fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let  fun expandEinOp ( Ein.EIN{params, index, body}, origargs,args) = let
330    
331      val dummy=E.Const 0      val dummy=E.Const 0
# Line 274  Line 333 
333    
334      (*b-current body, info-original ein op, data-new assigments*)      (*b-current body, info-original ein op, data-new assigments*)
335      fun rewriteBody(b,info)= let      fun rewriteBody(b,info)= let
336            val t=print(String.concat["\n\n ***:",Printer.printbody(b),"$ \n"])
337          in (case b          in (case b
338              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let              of  E.Sum(c,  E.Probe(E.Conv v, E.Tensor t)) =>let
339                  val ref sx=sumIndex                  val ref sx=sumIndex
340                  val (t',info',data)=foundProbe(E.Probe(E.Conv v, E.Tensor t ), info,index, sx@c,origargs)                  in (case sx
341                  in (t',info',data)                      of [] => foundProbe3(E.Probe(E.Conv v, E.Tensor t ), info,index, c,origargs)
342                        | _ =>  foundProbe4(E.Probe(E.Conv v, E.Tensor t ), info,index, (flatten sx)@c,origargs)
343                    (* end case*))
344              end              end
345          | E.Probe(E.Conv _, E.Tensor _) =>let          | E.Probe(E.Conv _, E.Tensor _) =>let
346              val ref sx=sumIndex              val ref sx=sumIndex
347              val (t',info',data)=foundProbe(b, info,index, sx,origargs)              in (case sx
348              in (t',info',data)                  of []=> foundProbe3(b, info,index, [],origargs)
349                    | _=> foundProbe4(b, info,index, flatten sx,origargs)
350                 (* end case*))
351              end              end
352          | E.Probe _=> (dummy,info,[])          | E.Probe _=> (dummy,info,[])
353          | E.Conv _=>  (dummy,info,[])          | E.Conv _=>  (dummy,info,[])
# Line 295  Line 359 
359                  (E.Neg(body'),info',data')                  (E.Neg(body'),info',data')
360              end              end
361          | E.Sum (c,e)=> let          | E.Sum (c,e)=> let
362              val ref sx=sumIndex              val ref x=sumIndex
363              val s=sumIndex:=(sx@c)              val c'=[c]@x
364              val (body',info',data')=rewriteBody (e,info)              val (body',info',data')=(sumIndex:=c';rewriteBody (e,info))
365                val ref s=sumIndex
366                val z=hd(s)
367              in              in
368                  (E.Sum(c, body'),info',data')                  (sumIndex:=tl(s);(E.Sum(z, body'),info',data'))
369              end              end
370          | E.Sub(a,b)=>let          | E.Sub(a,b)=>let
371              val (bodyA,infoA,dataA)= rewriteBody(a,info)              val (bodyA,infoA,dataA)= rewriteBody(a,info)
# Line 328  Line 394 
394          end          end
395    
396       val empty =fn key =>NONE       val empty =fn key =>NONE
397      val mm=print "Starting Exapnd"      val mm=print "\n ************************** \n Starting Exapnd"
398      val (body',(params',args'),newbies)=rewriteBody(body,(params,args))      val (body',(params',args'),newbies)=rewriteBody(body,(params,args))
399      val e'=Ein.EIN{params=params', index=index, body=body'}      val e'=Ein.EIN{params=params', index=index, body=body'}
400      val rr=print (String.concat[P.printbody(body),"DONE expand"])  
401        val rr=print (String.concat[P.printerE(e'),"\n DONE expand ************************** \n "])
402      in ((e',args'),newbies) end      in ((e',args'),newbies) end
403    
404    end; (* local *)    end; (* local *)

Legend:
Removed from v.2583  
changed lines
  Added in v.2584

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0