Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml
ViewVC logotype

Diff of /branches/charisee/src/compiler/high-to-mid/expand-integrate.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 2502, Mon Nov 4 21:33:35 2013 UTC revision 2510, Thu Nov 14 20:33:18 2013 UTC
# Line 5  Line 5 
5   *)   *)
6    
7    
8  (*Expand ProebConv to Probe of individual field *)  (*
9    A couple of different approaches.
10    One approach is to find all the Probe(Conv). Gerenerate exp for it
11    Then use Subst function to sub in. That takes care for index matching and
12    
13    *)
14    
15    (*This approach creates probe expanded terms, and adds params to the end. *)
16    
17    
18  structure Expand = struct  structure Expand = struct
19    
20      local      local
21    
22      structure E = Ein      structure E = Ein
23        structure mk= mkOperators
24    
25    
26    structure SrcIL = HighIL
27    structure SrcTy = HighILTypes
28    structure SrcOp = HighOps
29    structure SrcSV = SrcIL.StateVar
30    structure VTbl = SrcIL.Var.Tbl
31      structure DstIL = MidIL      structure DstIL = MidIL
32      structure DstTy = MidILTypes      structure DstTy = MidILTypes
   
33      structure DstOp = MidOps      structure DstOp = MidOps
34      structure DstV = DstIL.Var      structure DstV = DstIL.Var
35      structure mk= mkOperators      structure SrcV = SrcIL.Var
36    structure P=Printer
37    
38    
39    datatype peanut=    O of  SrcOp.rator | E of Ein.ein|C of SrcTy.ty
40    
41      in      in
42    
43    
44  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))  fun assign (x, rator, args) = (x, DstIL.OP(rator, args))
45  fun assignEin (x, rator, args) = (x, DstIL.EINAPP(rator, args))  fun assignEin (x, rator, args) = (x, DstIL.EINAPP(rator, args))
46    
47    fun assign2(x, rator, args) = (x, SrcIL.OP(rator, args))
48    fun assignEin2 (x, rator, args) = (x, SrcIL.EINAPP(rator, args))
49    
50  fun insert (key, value) d =fn s =>  fun insert (key, value) d =fn s =>
51          if s = key then SOME value          if s = key then SOME value
52          else d s          else d s
53  fun lookup k d = d k  fun lookup k d = d k
54    
55    
56    
57    fun getRHS x  = (case SrcIL.Var.binding x
58        of SrcIL.VB_RHS(SrcIL.OP(rator, args)) => (O rator, args)
59        | SrcIL.VB_RHS(SrcIL.VAR x') => getRHS x'
60        | SrcIL.VB_RHS(SrcIL.EINAPP (e,args))=>(E e,args)
61        | SrcIL.VB_RHS(SrcIL.CONS (ty,args))=> (print "cons type";(C ty,args))
62        | vb => raise Fail(concat[
63                "expected rhs operator for ", SrcIL.Var.toString x,
64            "but found ", SrcIL.vbToString vb])
65            (* end case *))
66    
67    
68    
69    
70  (*Create fractional, and integer position vectors*)  (*Create fractional, and integer position vectors*)
71  fun createArgs(dim,v,pos,s)=let  fun createArgs(dim,v,posx,pos)=let
72      val translate=DstOp.Translate v      val translate=DstOp.Translate v
73      val transform=DstOp.Transform v      val transform=DstOp.Transform v
74    
     (* Match EinTypes, or mid-il types?  
     val vecsTy =mk.createVec(2*s)  
     val vecDimTy = mk.createVec(dim)  
     *)  
   
   
75      val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)      val M = DstV.new ("M", DstTy.tensorTy [dim,dim]) (*transform dim by dim?*)
76      val T = DstV.new ("T", DstTy.vecTy dim)          (*translate*)      val T = DstV.new ("T", DstTy.vecTy dim)          (*translate*)
77      val x = DstV.new ("x", DstTy.vecTy dim)      val x = DstV.new ("x", DstTy.vecTy dim)
# Line 53  Line 84 
84      val code=[      val code=[
85              assign(M, transform, []),              assign(M, transform, []),
86              assign(T, translate, []),              assign(T, translate, []),
87              assignEin(x, PosToImgSpace,[M,pos,T]) ,  (* MX+T*)              pos,
88                assignEin(x, PosToImgSpace,[M,posx,T]) ,  (* MX+T*)
89              assign(nd, DstOp.Floor dim, [x]),   (*nd *)              assign(nd, DstOp.Floor dim, [x]),   (*nd *)
90              assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)              assignEin(f, mk.subTen([dim]),[x,nd]),           (*fractional*)
91              assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)              assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)
   
92              ]              ]
93    
94        (*Then f, n are new positions created. add to args list of currrent EinExp*)        (*Then f, n are new positions created. add to args list of currrent EinExp*)
95      val args=[f,n]      val args=[]
96      in (args,code)      in (args,code)
97      end      end
 (*  
   
   
  fun createVec d= S.transform(EinOp.tensorOp,[[d]],[])  
   
   
 (*create fractional, and integer position*)  
 fun createArgs(dim,Kernel?, )=  
   
   
     (*create set of positions*)  
     val s = Kernel.support h  
     val vecsTy =createVec(2*s)  
     val vecDimTy = createVec(dim)  
     val translate=DstOp.Translate v  
     val transform=DstOp.Transform v  
   
   
     (* generate the transform code *)  
     val x = DstV.new ("x", vecDimTy)    (* image-space position *)  
     val f = DstV.new ("f", vecDimTy)  
     val nd = DstV.new ("nd", vecDimTy)  
     val n = DstV.new ("n", DstTy.iVecTy dim)  
     val M = DstV.new ("M", transform)  
     val T = DstV.new ("T", translate)  
   
     val sub= S.transform(EinOp.subTensor,[dim],[])  
98    
99    
         (* M_ij x_i*)  
     val MXop=S.transform(EinOp.innerProduct,[[dim],[],[dim]],[])  
     val MX = DstV.new ("MX", MXop)  
100    
101    (*Notice problem here, args need to be f, n
102            but f and n are mid-il ops.
103            everything else is high-il ops.
104    
105        can't rewrite all args, since some may not be EinApps. Would need to call high.to.midil.expand
106            can't because of circulcar dependcy.
107    
108      val PosToImgSpace=S.transform(EinOp.addTensor,[[dim]],[])          then do we write the whole expression as a high-il? and convert with a different expression?
   
     val toImgSpaceCode = [  
         assignEin(MX, Mxop, [M,pos]),  (*M_{ij}X_i*)  
         assignEin(x, PosToImgSpace,[MX,T])   (* MX+T*)  
         assign(nd, DstOp.Floor dim, [x]),   (*nd *)  
         assignEin(f, sub,[x,nd]),           (*fractional*)  
         assign(n, DstOp.RealToInt dim, [nd]) (*real to Int*)  
     ]  
109    
110  *)  *)
111    
112    
   
113  (*  (*
114  createDels=> creates the kronecker deltas for each Kernel  createDels=> creates the kronecker deltas for each Kernel
115  For each dimesnion a, and each index in derivative b create element (a,b)  For each dimesnion a, and each index in derivative b create element (a,b)
# Line 123  Line 120 
120    
121    
122    
123  fun Position(idt,dict,params,args)=let  
124      val l=lookup idt dict  fun Position(V,t,dict,dim,pos1,args)=let
125        val l=lookup t dict
126      in (case l      in (case l
127          of NONE =>let          of NONE =>let
128          val pos1=length params              val imgArg=List.nth(args, V)
129                val img=(case getRHS imgArg
130                    of (O(SrcOp.LoadImage v), _)=>v
131                    (*end case*))
132    
133    
134                val posArg=List.nth(args, t)
135                in (case getRHS posArg of (E e, arg)=>let
136                    val posx= DstV.new ("pos", DstTy.vecTy dim)
137                    val pos= assignEin(posx, e,[])
138                    val (args',code')=createArgs(dim,img,posx,pos)
139          val pos2=pos1+1          val pos2=pos1+1
140          val dict'=insert(idt,(pos1,pos2)) dict                  val dict'=insert(t,(pos1,pos2)) dict
141          val params'=params@[E.TEN,E.TEN]                  val params'=[E.TEN,E.TEN]
142              (*will call create args*)                  in (pos1,pos2, dict',params',args',code')
143          in (pos1,pos2, dict',params',args)                  end
144                    |_=> (99,77,dict, [],[],[]) (* NOn-einOP tensor, deal with later*)
145                    (*end case*))
146          end          end
147              (*Create new fractional, and n variables,and returns fresh ids*)          | SOME (fid,nid)=>(fid,nid, dict,[],[],[])
         | SOME (fid,nid)=>(fid,nid, dict,params,args)  
148          (*end case*))          (*end case*))
149      end      end
150    
151    
152  fun expandEinProbe(params,body,index,d,args)=(case body  
153      of E.Probe(E.Conv(id,alpha,kid,deltas),E.Tensor(idt,alphat)) =>  fun expandEinProbe((body,(params,index,args,d,code)),sx)=(case body
154          if(id+1>length params) then (print "not enough params" ;(params,body,index,d,args))      of E.Probe(E.Conv(V,shape,h,deltas),E.Tensor(t,alpha)) =>let
155          else (case List.nth(params,id)              val E.IMG(dim)=List.nth(params,V)
             of E.FLD(dim)=>  
             let  
156              val s=2 (*support*)              val s=2 (*support*)
157              val (fid,nid,d',params',args')=Position(idt,d,params,args)  
158                val pnum=length params
159    
160                val (fid,nid,d',params',args',code')=Position(V,t,d,dim,pnum,args)
161              val shift=length index              val shift=length index
162    
163                val m=print (String.concat["got sx", Int.toString(sx)])
164              (*sumIndex creating summaiton Index for body*)              (*sumIndex creating summaiton Index for body*)
165              fun sumIndex(0)=[]              fun sumIndex(0)=[]
166              |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+shift-1),1-s,s)]              |sumIndex(dim)= sumIndex(dim-1)@[(E.V (dim+sx-1),1-s,s)]
167    
168              (*createKRN Image field and kernels *)              (*createKRN Image field and kernels *)
169              fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(id,alpha,imgpos)] @rest)              fun createKRN(0,imgpos,rest)=E.Prod ([E.Img(V,shape,imgpos)] @rest)
170              | createKRN(dim,imgpos,rest)=              | createKRN(dim,imgpos,rest)=let
                 let  
171                  val dim'=dim-1                  val dim'=dim-1
172                  val sum=dim'+shift                  val sum=dim'+shift
173                  val dels=createDels(deltas,dim')                  val dels=createDels(deltas,dim')
174                  val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]                  val pos=[E.Add[E.Tensor(fid,[E.C dim']),E.Value(sum)]]
175                  val rest'= E.Krn(kid,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))                  val rest'= E.Krn(h,dels,E.Sub(E.Tensor(nid,[E.C dim']),E.Value(sum)))
176              in              in
177                  createKRN(dim',pos@imgpos,[rest']@rest)                  createKRN(dim',pos@imgpos,[rest']@rest)
178              end              end
179    
180    
181              val i=createKRN(dim, [],[])              val exp=createKRN(dim, [],[])
182              val esum=sumIndex dim              val esum=sumIndex (dim)
             val index'=index  
183    
             in (params', E.Sum(esum, i),index', d',args') end  
         | _=>(print "err: non field in param spot";(params,E.Const(0.0) ,index,d,args))  
         (*end case*))  
     |_=>(print "unexpected body" ;(params,body,index,d,args))  
     (*end case*))  
184    
185                in (E.Sum(esum, exp),(params@params',index,args@args',d',code@code')) end
186                (*end case*))
187    
188    
189  (*copied from high-to-mid.sml*)  (*copied from high-to-mid.sml*)
190  fun expandEinOp ( Ein.EIN{params, index, body}, args) = let  fun expandEinOp ( Ein.EIN{params, index, body}, args) = let
191    
192      val dummy=E.Const 0.0  (*tmp variables*)      val dummy=E.Const 0.0  (*tmp variables*)
193        val sumIndex=ref (length index)
194        fun sumI(e)=let
195            val (E.V v,_,_)=List.nth(e, length(e)-1)
196        in (print (String.concat["sumi",Int.toString(v)]);v) end
197    
198      fun rewriteBody exp= let      fun rewriteBody exp= let
199          val (p,body,ix,d,args')= exp          val (body, data)=exp
200          in (case body          in (case body
201          of E.Const _=>exp          of E.Const _=>exp
202          | E.Tensor _=>exp          | E.Tensor _=>exp
# Line 196  Line 206 
206          | E.Epsilon _=>exp          | E.Epsilon _=>exp
207          | E.Partial _=>exp          | E.Partial _=>exp
208          | E.Img _=>  exp          | E.Img _=>  exp
209          | E.Conv _=>(p,dummy,ix,d,args')          | E.Conv _=>(dummy, data)
210          | E.Field _ =>(p,dummy,ix,d,args')          | E.Field _ =>(dummy, data)
211          | E.Apply _ =>(p,dummy,ix,d,args')          | E.Apply _ =>(dummy, data)
212          | E.Neg e=> let          | E.Neg e=> let
213              val (p',body',ix',d',args'')=rewriteBody (p,e,ix,d,args')              val (body',exp')=rewriteBody (e, data)
214              in              in
215                  (p',E.Neg body',ix',d',args'')                  (E.Neg(body'),exp')
216              end              end
217          | E.Sum (c,e)=> let          | E.Sum (c,e)=> let
218              val (p',body',ix',d',args'')=rewriteBody (p,e,ix,d,args')  
219                val m=(sumI(c))+1
220                val x=print (String.concat["sumi",Int.toString(m)])
221                val (body',exp')=(sumIndex:=m;rewriteBody (e,data))
222              in              in
223                  (p',E.Sum(c,body'),ix',d',args'')                  ((E.Sum(c, body'), exp'))
224              end              end
225          | E.Probe(E.Conv _, _) =>expandEinProbe exp          | E.Probe(E.Conv _, _) =>let
226                val ref x=sumIndex
227                val y=print(String.concat["\n ref",Int.toString(x),"--"])
228                in expandEinProbe(exp,x) end
229          | E.Sub(a,b)=>let          | E.Sub(a,b)=>let
230              val(pa,a',ax,da,args'')= rewriteBody (p,a,ix,d,args')              val (bodya,dataa)= rewriteBody(a, data)
231              val(pb,b',bx,db,args''')= rewriteBody (pa,b,ax,da,args'')              val (bodyb, datab)= rewriteBody(b, dataa)
232              in (pb,E.Sub( a', b'),bx,db,args''')              in   (E.Sub( bodya, bodyb),datab)
233              end              end
234          | E.Div(a,b)=>let          | E.Div(a,b)=>let
235              val(pa,a',ax,da,args'')= rewriteBody (p,a,ix,d,args')              val (bodya,dataa)= rewriteBody(a, data)
236              val(pb,b',bx,db,args''')= rewriteBody (pa,b,ax,da,args'')              val (bodyb, datab)= rewriteBody(b, dataa)
237              in (pb,E.Div( a', b'),bx,db,args''')              in  (E.Div(bodya, bodyb),datab) end
             end  
238          | E.Add es=> let          | E.Add es=> let
239              fun addFilter(p1,ix1,d1,[],done,args')=(p1,E.Add done, ix1,d1,args')              fun filter([], done, data')= (E.Add done, data')
240              | addFilter(p1,ix1,d1, e::es,done,args')=let                  | filter(e::es, done, data')= let
241                  val(p2,e2,ix2,d2,args'')= rewriteBody(p1,e,ix1,d1,args')                      val (body', data'')= rewriteBody(e, data')
242                  in  addFilter(p2, ix2, d2, es, done@[e2],args'')                        in filter(es, done@[body'], data'') end
243                  end              in filter(es, [],data) end
244              in  
                 addFilter(p,ix,d,es,[],args')  
             end  
245          | E.Prod es=> let          | E.Prod es=> let
246              fun addFilter(p1,ix1,d1,[],done,args')=(p1,E.Prod done, ix1,d1,args')              fun filter([], done, data)= (E.Prod done, data)
247              | addFilter(p1,ix1,d1, e::es,done,args')=let                  | filter(e::es, done, data)= let
248                  val(p2,e2,ix2,d2,args'')= rewriteBody(p1,e,ix1,d1,args')                      val (body', data')= rewriteBody(e, data)
249                  in  addFilter(p2, ix2, d2, es, done@[e2],args'')                      in filter(es, done@[body'], data') end
250                  end                  in filter(es, [],data) end
             in  
             addFilter(p,ix,d,es,[],args')  
             end  
251          | E.Probe _=> exp          | E.Probe _=> exp
   
252          (* end case *))          (* end case *))
253          end          end
254    
   
255      val empty =fn key =>NONE      val empty =fn key =>NONE
256      val (params',body',ix',_,args')=rewriteBody(params,body,index,empty,args)      val (body',(params', index', args',_,code'))=rewriteBody(body,(params,index,args,empty,[]))
257      val newbie=Ein.EIN{params=params', index=ix', body=body'}      val newbie=Ein.EIN{params=params', index=index', body=body'}
258      in (newbie,args') end      in (newbie,args',code') end
259    
260    end; (* local *)    end; (* local *)
261    

Legend:
Removed from v.2502  
changed lines
  Added in v.2510

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0