Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3578, Tue Jan 12 19:54:13 2016 UTC revision 3579, Tue Jan 12 22:51:51 2016 UTC
# Line 13  Line 13 
13      structure E = Ein      structure E = Ein
14      structure DstIR = MidIR      structure DstIR = MidIR
15      structure DstOp = MidOps      structure DstOp = MidOps
16      structure T = TransformEin  
17      structure DstV = DstIR.Var      structure DstV = DstIR.Var
18      structure DstTy = MidTypes      structure DstTy = MidTypes
19      structure T = CoordSpaceTransform      structure T = CoordSpaceTransform
# Line 42  Line 42 
42      val fieldliftflag = true      val fieldliftflag = true
43      val detflag = true      val detflag = true
44    
45      fun transformToIndexSpace e = T.transformToIndexSpace e      fun transformToIndexSpace e = T.imageToWorld e
46      fun transformToImgSpace e = T.transformToImgSpace  e      fun transformToImgSpace e = T.worldToIndex e
47    
48      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
49      fun mkEinApp (rator, args) = DstIR.EINAPP(rator, args)      fun mkEinApp (rator, args) = DstIR.EINAPP(rator, args)
50      fun setConst e = E.setConst e  
51      fun setNeg e = E.setNeg e      fun setProd e= E.Opn(E.Prod, e)
52      fun setExp e = E.setExp e      fun setAdd e= E.Opn(E.Add, e)
     fun setDiv e= E.setDiv e  
     fun setSub e= E.setSub e  
     fun setProd e= E.setProd e  
     fun setAdd e= E.setAdd e  
53    
54      fun getRHSDst x  = (case DstIR.Var.binding x      fun getRHSDst x  = (case DstIR.Var.binding x
55             of DstIR.VB_RHS(DstIR.OP(rator, args)) => (rator, args)             of DstIR.VB_RHS(DstIR.OP(rator, args)) => (rator, args)
# Line 89  Line 85 
85            val hArg = List.nth(args, hid)            val hArg = List.nth(args, hid)
86            val newposArg = List.nth(args, tid)            val newposArg = List.nth(args, tid)
87            val (s,img, dim) = getArgsDst(hArg, imgArg, args)            val (s,img, dim) = getArgsDst(hArg, imgArg, args)
88            val (argsT, P, code) = transformToImgSpace(dim, img, newposArg, imgArg)        val (argsT, P, code) = transformToImgSpace{info=img, img=imgArg, pos=newposArg}
89            in            in
90              (dim, args@argsT, code, s, P)              (dim, args@argsT, code, s, P)
91            end            end
# Line 103  Line 99 
99              val sum = sx              val sum = sx
100              val dels = List.map (fn e=>(E.C 0,e)) deltas              val dels = List.map (fn e=>(E.C 0,e)) deltas
101              val pos = [setAdd[E.Tensor(fid,[]), E.Value(sum)]]              val pos = [setAdd[E.Tensor(fid,[]), E.Value(sum)]]
102              val rest = E.Krn(hid, dels, setSub(E.Tensor(nid,[]), E.Value(sum)))              val rest = E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value(sum)))
103              in              in
104                 setProd [E.Img(Vid,alpha,pos),rest]                 setProd [E.Img(Vid,alpha,pos),rest]
105              end              end
# Line 114  Line 110 
110              val sum = sx+dim'              val sum = sx+dim'
111              val dels = List.map (fn e=>(E.C dim',e)) deltas              val dels = List.map (fn e=>(E.C dim',e)) deltas
112              val pos = [setAdd[E.Tensor(fid,[E.C dim']), E.Value(sum)]]              val pos = [setAdd[E.Tensor(fid,[E.C dim']), E.Value(sum)]]
113              val rest' =  E.Krn(hid, dels, setSub(E.Tensor(nid,[E.C dim']), E.Value(sum)))              val rest' =  E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[E.C dim']), E.Value(sum)))
114              in              in
115                  createKRN(dim',pos@imgpos,[rest']@rest)                  createKRN(dim',pos@imgpos,[rest']@rest)
116              end              end
# Line 148  Line 144 
144    
145      (* silly change in order of the product to match vis branch WorldtoSpace functions*)      (* silly change in order of the product to match vis branch WorldtoSpace functions*)
146      fun multiPs ([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))      fun multiPs ([P0,P1,P2],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,body]))
147      (*      (*| multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1]))) *)
       | multiPs([P0,P1],sx,body)=formBody(E.Sum(sx, setProd([P0,body,P1])))  
       *)  
148        | multiPs ([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))        | multiPs ([P0,P1,P2,P3],sx,body)= formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))
149        | multiPs (Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))        | multiPs (Ps,sx,body)=formBody(E.Sum(sx,setProd([body]@Ps)))
150    
151    
     fun multiMergePs ([P0, P1], [sx0, sx1], body) = E.Sum([sx0],setProd[P0,E.Sum([sx1],setProd[P1,body])])  
       | multiMergePs e = multiPs e  
   
   
152      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
153              -> ein_exp* *code              -> ein_exp* *code
154      * Transforms position to world space      * Transforms position to world space
# Line 166  Line 156 
156      * rewrites body      * rewrites body
157      * replace probe with expanded version      * replace probe with expanded version
158      *)      *)
159       fun replaceProbe (testN, (y, DstIR.EINAPP(e,args)), p, sx) = let       fun replaceProbe (testN, (y, DstIR.EINAPP(Ein.EIN{params = params, index = index, body = body},args)), probe, sx) = let
         val params = Ein.params e  
160          val fid = length(params)          val fid = length(params)
161          val nid = fid+1          val nid = fid+1
162          val Pid = nid+1          val Pid = nid+1
163          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
164          val nshift = length(dx)          val nshift = length(dx)
165          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
166          val freshIndex = getsumshift(sx,length(Ein.index e))          val freshIndex = getsumshift(sx,length(index))
167          val (dx,newsx1,Ps) = transformToIndexSpace(freshIndex,dim,dx,Pid)          val (dx,newsx1,Ps) = transformToIndexSpace(freshIndex,dim,dx,Pid)
168          val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]          (*val params'=params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]*)
169          val body' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)           val params'=params@[E.TEN(true,[dim]),E.TEN(true,[dim]),E.TEN(true,[dim,dim])]
170          val body' = multiPs(Ps,newsx1,body')          val probe' = createBody(dim, s,freshIndex+nshift,alpha,dx,Vid, hid, nid, fid)
171          val body'= (case (Ein.body e)          val probe' = multiPs(Ps,newsx1,probe')
172              of E.Sum(sx, E.Probe _)              => E.Sum(sx,body')          val body' = (case body
173              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,body'])              of E.Sum(sx, E.Probe _)              => E.Sum(sx,probe')
174              | _                                  => body'              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ]))  => E.Sum(sx,setProd[eps0,probe'])
175                | _                                  => probe'
176              (* end case *))              (* end case *))
177          val einapp=(y,mkEinApp(mkEin(params',index,body'),argsA@[PArg]))          val einapp=(y,mkEinApp(mkEin(params',index,body'),argsA@[PArg]))
178          in          in
# Line 199  Line 189 
189          (*need to rewrite dx*)          (*need to rewrite dx*)
190          val (_, sizes, e as E.Conv(_,alpha',_,dx)) = (case sx@newsx          val (_, sizes, e as E.Conv(_,alpha',_,dx)) = (case sx@newsx
191              of [] => ([], index, E.Conv(9,alpha,7,newdx))              of [] => ([], index, E.Conv(9,alpha,7,newdx))
192              | _ => cleanIndex.cleanIndex(E.Conv(9,alpha,7,newdx), index, sx@newsx)              | _ => CleanIndex.clean(E.Conv(9,alpha,7,newdx), index, sx@newsx)
193              (* end case *))              (* end case *))
194          val params = [E.TEN(1,[dim,dim]), E.TEN(1,sizes)]          val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]
195          fun filterAlpha []=[]          fun filterAlpha []=[]
196            | filterAlpha (E.C _::es) = filterAlpha es            | filterAlpha (E.C _::es) = filterAlpha es
197            | filterAlpha (e1::es) = [e1]@(filterAlpha es)            | filterAlpha (e1::es) = [e1]@(filterAlpha es)
# Line 218  Line 208 
208              (splitvar, ein0, sizes, dx, alpha')              (splitvar, ein0, sizes, dx, alpha')
209          end          end
210    
211      fun liftProbe (printStrings, (y, DstIR.EINAPP(e, args)), p, sx) = let      fun liftProbe (printStrings, (y, DstIR.EINAPP(Ein.EIN{params , index , body }, args)), probe, sx) = let
   
         val params = Ein.params e  
         val index = Ein.index e  
212          val fid = length(params)          val fid = length(params)
213          val nid = fid+1          val nid = fid+1
214          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = p          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
215          val nshift = length(dx)          val nshift = length(dx)
216          val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)          val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)
217          val freshIndex = getsumshift(sx, length(index))          val freshIndex = getsumshift(sx, length(index))
218    
219          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
220          val (splitvar, ein0, sizes, dx, alpha') = createEinApp(Ein.body e, alpha, index, freshIndex, dim, dx, sx)          val (splitvar, ein0, sizes, dx, alpha') = createEinApp(body, alpha, index, freshIndex, dim, dx, sx)
221          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg  = DstV.new ("F", DstTy.TensorTy(sizes))
222          val einApp0 = mkEinApp(ein0, [PArg,FArg])          val einApp0 = mkEinApp(ein0, [PArg,FArg])
223          val rtn0 = (case splitvar          val rtn0 = (case splitvar
224              of false => [(y, mkEinApp(ein0, [PArg,FArg]))]              of false => [(y, mkEinApp(ein0, [PArg,FArg]))]
225              | _      => let              | _      => let
226                  val bind3 = (y, DstIR.EINAPP(SummationEin.main ein0, [PArg, FArg]))                  val bind3 = (y, DstIR.EINAPP(EinSums.transform  ein0, [PArg, FArg]))
227                  in                  in
228                      Split.splitEinApp bind3                      Split.splitEinApp bind3
229                  end                  end
230              (* end case *))              (* end case *))
231    
232          (*lifted probe*)          (*lifted probe*)
233          val params' = params@[E.TEN(3,[dim]), E.TEN(1,[dim])]          (*val params' = params@[E.TEN(3,[dim]), E.TEN(1,[dim])]*) (*Fixme: will get type error later*)
234            val params' = params@[E.TEN(true,[dim]), E.TEN(true,[dim])]
235          val freshIndex'= length(sizes)          val freshIndex'= length(sizes)
236          val body' = createBody(dim, s, freshIndex',alpha', dx, Vid, hid, nid, fid)          val body' = createBody(dim, s, freshIndex',alpha', dx, Vid, hid, nid, fid)
237          val ein1=mkEin(params', sizes, body')          val ein1=mkEin(params', sizes, body')
238          val einApp1=mkEinApp(ein1, args')          val einApp1=mkEinApp(ein1, args')
239          val rtn1=(FArg, einApp1)          val code1=(FArg, einApp1)::rtn0
240          in          in
241              code@[rtn1]@rtn0              code@code1
242          end          end
243    
244    
# Line 259  Line 247 
247      * Looks to see if the expression has a probe. If so, replaces it.      * Looks to see if the expression has a probe. If so, replaces it.
248      * Note how we keeps eps expressions so only generate pieces that are used      * Note how we keeps eps expressions so only generate pieces that are used
249      *)      *)
250     fun expandEinOp (e as (y, DstIR.EINAPP(ein, args)), fieldset) = let     fun expandEinOp (e as (y, DstIR.EINAPP(ein as Ein.EIN{params , index , body }, args)), fieldset) = let
251          fun rewriteBody b=(case b          val avail = AvailRHS.new()
252            fun matchField()=(case body
253                of E.Probe _ => 1
254                | E.Sum (_, E.Probe _) => 1
255                | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _])) => 1
256                | _ => 0
257                (* end case *))
258            fun rewriteBody()=(case body
259              of  (E.Probe(E.Conv(_,_,_,[]),_))              of  (E.Probe(E.Conv(_,_,_,[]),_))
260                  => replaceProbe(0,e,b,[])                  => replaceProbe(0,e,body,[])
261              | (E.Probe(E.Conv (_,alpha,_,dx),_))              | (E.Probe(E.Conv (_,alpha,_,dx),_))
262                  => liftProbe (0,e,b,[]) (*scans dx for contant*)                  => liftProbe (0,e,body,[]) (*scans dx for contant*)
263              | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))              | (E.Sum(sx,p as E.Probe(E.Conv(_,_,_,[]),_)))
264                  => replaceProbe(0,e,p, sx)  (*no dx*)                  => replaceProbe(0,e,p, sx)  (*no dx*)
265              | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))              | (E.Sum(sx,p as E.Probe(E.Conv(_,[],_,dx),_)))
# Line 275  Line 270 
270                  => replaceProbe(0,e,E.Probe p,sx)                  => replaceProbe(0,e,E.Probe p,sx)
271              | _ => [e]              | _ => [e]
272              (* end case *))              (* end case *))
273          val (fieldset,var) = (case valnumflag          val (fieldset,var) = if (valnumflag)
274              of true => einSet.rtnVar(fieldset,y,DstIR.EINAPP(ein,args))              then  einSet.rtnVar(fieldset, y, DstIR.EINAPP(ein, args))
275              | _     => (fieldset,NONE)              else (fieldset, NONE)
         (* end case *))  
   
         fun matchField b=(case b  
             of E.Probe _ => 1  
             | E.Sum (_, E.Probe _) => 1  
             | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _])) => 1  
             | _ => 0  
             (* end case *))  
         val b = Ein.body ein  
   
276          in  (case var          in  (case var
277              of NONE => ((rewriteBody(Ein.body ein), fieldset, matchField(Ein.body ein), 0))              of NONE => (rewriteBody(), fieldset, matchField(), 0)
278              | SOME v => (("\n mapp_replacing"^(P.printerE ein)^":");([(y,DstIR.VAR v)], fieldset, matchField(Ein.body ein), 1))              | SOME v => ([(y,DstIR.VAR v)], fieldset, matchField(), 1)
279              (* end case *))              (* end case *))
280          end          end
281    

Legend:
Removed from v.3578  
changed lines
  Added in v.3579

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0