Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3579, Tue Jan 12 22:51:51 2016 UTC revision 3580, Tue Jan 12 23:29:57 2016 UTC
# Line 36  Line 36 
36      * img-imginfo about V      * img-imginfo about V
37      *)      *)
38    
 (* FIXME: what are these for? should they be settable from the command-line? *)  
     val valnumflag = true  
     val tsplitvar = true  
     val fieldliftflag = true  
     val detflag = true  
   
     fun transformToIndexSpace e = T.imageToWorld e  
     fun transformToImgSpace e = T.worldToIndex e  
   
39      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
40      fun mkEinApp (rator, args) = DstIR.EINAPP(rator, args)      fun mkEinApp (rator, args) = DstIR.EINAPP(rator, args)
   
41      fun setProd e= E.Opn(E.Prod, e)      fun setProd e= E.Opn(E.Prod, e)
     fun setAdd e= E.Opn(E.Add, e)  
42    
43      fun getRHSDst x  = (case DstIR.Var.binding x      fun getRHSDst x  = (case DstIR.Var.binding x
44             of DstIR.VB_RHS(DstIR.OP(rator, args)) => (rator, args)             of DstIR.VB_RHS(DstIR.OP(rator, args)) => (rator, args)
# Line 85  Line 74 
74            val hArg = List.nth(args, hid)            val hArg = List.nth(args, hid)
75            val newposArg = List.nth(args, tid)            val newposArg = List.nth(args, tid)
76            val (s,img, dim) = getArgsDst(hArg, imgArg, args)            val (s,img, dim) = getArgsDst(hArg, imgArg, args)
77        val (argsT, P, code) = transformToImgSpace{info=img, img=imgArg, pos=newposArg}        val (argsT, P, code) = T.worldToIndex{info = img, img = imgArg, pos = newposArg}
78            in            in
79              (dim, args@argsT, code, s, P)              (dim, args@argsT, code, s, P)
80            end            end
# Line 93  Line 82 
82      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id
83      * expands the body for the probed field      * expands the body for the probed field
84      *)      *)
85      fun createBody (dim, s, sx, alpha, deltas, Vid, hid, nid, fid)=let      fun createBody (dimO, s, sx, alpha, deltas, Vid, hid, nid, fid) = let
86          (*1-d fields*)          (*1-d fields*)
87          fun createKRND1 () = let          fun createKRND1 () = let
88              val sum = sx              val imgpos = [E.Opn(E.Add,[E.Tensor(fid,[]), E.Value(sx)])]
89              val dels = List.map (fn e=>(E.C 0,e)) deltas              val dels = List.map (fn e=>(E.C 0,e)) deltas
90              val pos = [setAdd[E.Tensor(fid,[]), E.Value(sum)]]              val rest = E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value(sx)))
             val rest = E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value(sum)))  
91              in              in
92                 setProd [E.Img(Vid,alpha,pos),rest]                 setProd [E.Img(Vid,alpha,imgpos),rest]
93              end              end
94          (*createKRN Image field and kernels *)          (*createKRN Image field and kernels *)
95        fun createKRN (0, imgpos, rest) = setProd ([E.Img(Vid,alpha,imgpos)] @rest)          fun createKRN (0, imgpos, rest) = E.Opn(E.Prod, E.Img(Vid,alpha,imgpos)::rest)
96          | createKRN (dim, imgpos, rest) = let          | createKRN (d, imgpos, rest) = let
97              val dim' = dim-1              val d' = d-1
98              val sum = sx+dim'              val cx = E.C(d')
99              val dels = List.map (fn e=>(E.C dim',e)) deltas              val Vsum = E.Value(sx+d')
100              val pos = [setAdd[E.Tensor(fid,[E.C dim']), E.Value(sum)]]              val pos0 = E.Opn(E.Add, [E.Tensor(fid,[cx]), Vsum])
101              val rest' =  E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[E.C dim']), E.Value(sum)))              val dels = List.map (fn e =>(cx, e)) deltas
102              in              val rest0 = E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[cx]),  Vsum))
103                  createKRN(dim',pos@imgpos,[rest']@rest)              in
104              end                  createKRN(d',pos0::imgpos,rest0::rest)
105          val exp = (case dim              end
106              of 1 => createKRND1()  
107              | _ => createKRN(dim, [], [])          (*sumIndex creating summation Index for body*)
108              (* end case *))          val esum = List.tabulate(dimO, (fn d =>(E.V d, 1-s, s)))
109          (*sumIndex creating summaiton Index for body*)          val exp = if (dimO=1) then createKRND1() else createKRN(dimO, [], [])
         val slb = 1-s  
         val esum = List.tabulate(dim, (fn dim=>(E.V (dim+sx), slb, s)))  
110      in      in
111          E.Sum(esum, exp)          E.Sum(esum, exp)
112      end      end
# Line 129  Line 115 
115      *get fresh/unused index_id, returns int      *get fresh/unused index_id, returns int
116      *)      *)
117      fun getsumshift ([], n) = n      fun getsumshift ([], n) = n
118      fun getsumshift (sx, n) = let        | getsumshift (sx, n) = let
119          val (E.V v,_,_) = List.hd( List.rev sx)          val (E.V v,_,_) = List.hd( List.rev sx)
120          in          in
121              v+1              v+1
122          end          end
123    
124      (*formBody:ein_exp->ein_exp      (*formBody:ein_exp->ein_exp*)
     *)  
125      fun formBody (E.Sum([],e))=formBody e      fun formBody (E.Sum([],e))=formBody e
126        | formBody (E.Sum(sx,e))= E.Sum(sx,formBody e)        | formBody (E.Sum(sx,e))= E.Sum(sx,formBody e)
127        | formBody (E.Opn(E.Prod, [e]))=e        | formBody (E.Opn(E.Prod, [e]))=e
# Line 146  Line 131 
131      fun multiPs ([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))      fun multiPs ([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
132      (*| multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1]))) *)      (*| multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1]))) *)
133        | multiPs ([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))        | multiPs ([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
134        | multiPs (Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))        | multiPs (Ps,sx,body) = formBody(E.Sum(sx,setProd(body::Ps)))
135    
136        fun arrangeBody(body,probe') = (case body
137                of E.Sum(sx, E.Probe _)                     => E.Sum(sx,probe')
138                | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => E.Sum(sx,E.Opn(E.Prod,[eps0,probe']))
139                | _                                         => probe'
140            (* end case *))
141    
142    
143      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
# Line 156  Line 147 
147      * rewrites body      * rewrites body
148      * replace probe with expanded version      * replace probe with expanded version
149      *)      *)
150       fun replaceProbe (testN, (y, DstIR.EINAPP(Ein.EIN{params = params, index = index, body = body},args)), probe, sx) = let       fun replaceProbe((y, DstIR.EINAPP(Ein.EIN{params = params, index = index, body = body},args)), probe, sx) = let
151          val fid = length(params)          val fid = length(params)
152          val nid = fid+1          val nid = fid+1
153          val Pid = nid+1          val Pid = nid+1
# Line 164  Line 155 
155          val nshift = length(dx)          val nshift = length(dx)
156          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
157          val freshIndex = getsumshift(sx,length(index))          val freshIndex = getsumshift(sx,length(index))
158          val (dx,newsx1,Ps) = transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps) = T.imageToWorld(freshIndex,dim,dx,Pid)
159          (*val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]*)          (*val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]*)
160           val params'=params@[E.TEN(true,[dim]),E.TEN(true,[dim]),E.TEN(true,[dim,dim])]           val params'=params@[E.TEN(true,[dim]),E.TEN(true,[dim]),E.TEN(true,[dim,dim])]
161          val probe' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)          val probe' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
162          val probe' = multiPs(Ps,newsx1,probe')          val probe' = multiPs(Ps,newsx1,probe')
163          val body' = (case body          val body' = arrangeBody(body,probe')
             of E.Sum(sx, E.Probe _)              => E.Sum(sx,probe')  
             | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,probe'])  
             | _                                  => probe'  
             (* end case *))  
164          val einapp=(y,mkEinApp(mkEin(params',index,body'),argsA@[PArg]))          val einapp=(y,mkEinApp(mkEin(params',index,body'),argsA@[PArg]))
165          in          in
166              code@[einapp]              code@[einapp]
# Line 184  Line 171 
171          val tid = 1          val tid = 1
172    
173          (*Assumes body is already clean*)          (*Assumes body is already clean*)
174          val (newdx, newsx, Ps)=transformToIndexSpace(freshIndex, dim, dx, Pid)          val (newdx, newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
175    
176          (*need to rewrite dx*)          (*need to rewrite dx*)
177          val (_, sizes, e as E.Conv(_,alpha',_,dx)) = (case sx@newsx          val (_, sizes, e as E.Conv(_,alpha',_,dx)) = (case sx@newsx
# Line 194  Line 181 
181          val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]          val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]
182          fun filterAlpha []=[]          fun filterAlpha []=[]
183            | filterAlpha (E.C _::es) = filterAlpha es            | filterAlpha (E.C _::es) = filterAlpha es
184            | filterAlpha (e1::es) = [e1]@(filterAlpha es)            | filterAlpha (e1::es) = e1::(filterAlpha es)
185          val tshape = filterAlpha(alpha')@newdx          val tshape = filterAlpha(alpha')@newdx
186          val t = E.Tensor(tid, tshape)          val t = E.Tensor(tid, tshape)
187          val (splitvar, body) = (case originalb          val (splitvar, body) = (case originalb
# Line 208  Line 195 
195              (splitvar, ein0, sizes, dx, alpha')              (splitvar, ein0, sizes, dx, alpha')
196          end          end
197    
198      fun liftProbe (printStrings, (y, DstIR.EINAPP(Ein.EIN{params , index , body }, args)), probe, sx) = let      fun liftProbe ((y, DstIR.EINAPP(Ein.EIN{params , index , body }, args)), probe, sx) = let
199          val fid = length(params)          val fid = length(params)
200          val nid = fid+1          val nid = fid+1
201          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
# Line 247  Line 234 
234      * Looks to see if the expression has a probe. If so, replaces it.      * Looks to see if the expression has a probe. If so, replaces it.
235      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
236      *)      *)
237     fun expandEinOp (e as (y, DstIR.EINAPP(ein as Ein.EIN{params , index , body }, args)), fieldset) = let     fun expandEinOp (e as (y, DstIR.EINAPP(ein as Ein.EIN{params , index , body }, args)), avail) = let
         val avail = AvailRHS.new()  
238          fun matchField()=(case body          fun matchField()=(case body
239              of E.Probe _ => 1              of E.Probe _ => 1
240              | E.Sum (_, E.Probe _) => 1              | E.Sum (_, E.Probe _) => 1
# Line 257  Line 243 
243              (* end case *))              (* end case *))
244          fun rewriteBody()=(case body          fun rewriteBody()=(case body
245              of  (E.Probe(E.Conv(_,_,_,[]),_))              of  (E.Probe(E.Conv(_,_,_,[]),_))
246                  => replaceProbe(0,e,body,[])                 => replaceProbe(e, body, [])
247              | (E.Probe(E.Conv (_,alpha,_,dx),_))              | (E.Probe(E.Conv (_,alpha,_,dx),_))
248                  => liftProbe (0,e,body,[]) (*scans dx for contant*)                 => liftProbe (e, body, []) (*scans dx for contant*)
249              | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))              | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
250                  => replaceProbe(0,e,p, sx)  (*no dx*)                 => replaceProbe(e, p, sx)  (*no dx*)
251              | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))              | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
252                  => liftProbe (0,e,p,sx) (*scalar field*)                 => liftProbe (e, p, sx) (*scalar field*)
253              | (E.Sum(sx,E.Probe p))              | (E.Sum(sx,E.Probe p))
254                  => replaceProbe(0,e,E.Probe p, sx)                 => replaceProbe(e, E.Probe p, sx)
255              | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))              | (E.Sum(sx,E.Opn(E.Prod,[eps,E.Probe p])))
256                  => replaceProbe(0,e,E.Probe p,sx)                 => replaceProbe(e ,E.Probe p, sx)
257              | _ => [e]              | _ => [e]
258              (* end case *))              (* end case *))
259          val (fieldset,var) = if (valnumflag)          in
260              then  einSet.rtnVar(fieldset, y, DstIR.EINAPP(ein, args))              (rewriteBody(), avail, matchField(), 0)
             else (fieldset, NONE)  
         in  (case var  
             of NONE => (rewriteBody(), fieldset, matchField(), 0)  
             | SOME v => ([(y,DstIR.VAR v)], fieldset, matchField(), 1)  
             (* end case *))  
261          end          end
262    
263    end (* ProbeEin *)    end (* ProbeEin *)

Legend:
Removed from v.3579  
changed lines
  Added in v.3580

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0