Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3582, Wed Jan 13 22:14:05 2016 UTC revision 3583, Wed Jan 13 22:28:29 2016 UTC
# Line 39  Line 39 
39      *)      *)
40    
41      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
42    
43      fun getRHSDst x = (case IR.Var.binding x      fun getRHSDst x = (case IR.Var.binding x
44             of IR.VB_RHS(IR.OP(rator, args)) => (rator, args)             of IR.VB_RHS(IR.OP(rator, args)) => (rator, args)
45              | IR.VB_RHS(IR.VAR x') => getRHSDst x'              | IR.VB_RHS(IR.VAR x') => getRHSDst x'
# Line 75  Line 76 
76          (ImageInfo.dim info, args@argsT, code, s, P)          (ImageInfo.dim info, args@argsT, code, s, P)
77            end            end
78    
79      (*fieldreconstruction:int*int*int,mu list, param_id, param_id, param_id, param_id      (*fieldReconstruction:int*int*int,mu list, param_id, param_id, param_id, param_id
80      * expands the body for the probed field      * expands the body for the probed field
81      *)      *)
82      fun fieldreconstruction (dimO, s, sx, alpha, dx, Vid, hid, nid, fid) = let      fun fieldReconstruction (dimO, s, sx, alpha, dx, Vid, hid, nid, fid) = let
83          (*1-d fields*)          (*1-d fields*)
84          fun createKRND1 () = let          fun createKRND1 () = let
85              val imgpos = [E.Opn(E.Add,[E.Tensor(fid,[]), E.Value(sx)])]              val imgpos = [E.Opn(E.Add,[E.Tensor(fid,[]), E.Value(sx)])]
# Line 99  Line 100 
100              in              in
101                  createKRN(d',pos0::imgpos,rest0::rest)                  createKRN(d',pos0::imgpos,rest0::rest)
102              end              end
   
103          (*sumIndex creating summation Index for body*)          (*sumIndex creating summation Index for body*)
104          val esum = List.tabulate(dimO, (fn d =>(E.V d, 1-s, s)))            val esum = List.tabulate (dimO, fn d => (E.V d, 1-s, s))
105          val exp = if (dimO=1) then createKRND1() else createKRN(dimO, [], [])          val exp = if (dimO=1) then createKRND1() else createKRN(dimO, [], [])
106      in      in
107          E.Sum(esum, exp)          E.Sum(esum, exp)
# Line 127  Line 127 
127      fun multiPs(Ps,sx,body) = let      fun multiPs(Ps,sx,body) = let
128          val exp= (case Ps          val exp= (case Ps
129              of [P0,P1,P2] => [P0,P1,P2,body]              of [P0,P1,P2] => [P0,P1,P2,body]
             (*| [P0,P1] => [P0,body,P1] *)  
130              | [P0,P1,P2,P3] => [P0,P1,P2,P3,body]              | [P0,P1,P2,P3] => [P0,P1,P2,P3,body]
131              | _ => body::Ps              | _ => body::Ps
132              (* end case *))              (* end case *))
133          in formBody(E.Sum(sx, E.Opn(E.Prod, exp))) end            in
134                formBody(E.Sum(sx, E.Opn(E.Prod, exp)))
135              end
136    
137      fun arrangeBody(body, Ps, newsx, exp)=(case body      fun arrangeBody(body, Ps, newsx, exp)=(case body
138              of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))              of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
139              | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))              | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) =>
140                    (false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))
141              | E.Probe _ => (true, multiPs(Ps, newsx, exp))              | E.Probe _ => (true, multiPs(Ps, newsx, exp))
142              | _ => raise Fail "impossible"              | _ => raise Fail "impossible"
143              (* end case *))              (* end case *))
# Line 147  Line 149 
149      * rewrites body      * rewrites body
150      * replace probe with expanded version      * replace probe with expanded version
151      *)      *)
152       fun replaceProbe((y, IR.EINAPP(Ein.EIN{params = params, index = index, body = body},args)), probe, sx) = let       fun replaceProbe ((y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
153          val fid = length(params)            val fid = length params
154          val nid = fid+1          val nid = fid+1
155          val Pid = nid+1          val Pid = nid+1
156          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
157          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
158          val freshIndex = getsumshift(sx,length(index))            val freshIndex = getsumshift (sx, length index)
159          val (dx',sx',Ps) = T.imageToWorld(freshIndex,dim,dx,Pid)          val (dx',sx',Ps) = T.imageToWorld(freshIndex,dim,dx,Pid)
         (*val params' = params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]*)  
160          val params' = params@[E.TEN(true,[dim]),E.TEN(true,[dim]),E.TEN(true,[dim,dim])]          val params' = params@[E.TEN(true,[dim]),E.TEN(true,[dim]),E.TEN(true,[dim,dim])]
161          val probe' = fieldreconstruction(dim, s, freshIndex+length(dx'), alpha, dx', Vid, hid, nid, fid)            val probe' = fieldReconstruction (dim, s, freshIndex+length dx', alpha, dx', Vid, hid, nid, fid)
162          val (_, body') = arrangeBody(body, Ps, sx', probe')          val (_, body') = arrangeBody(body, Ps, sx', probe')
163          val einapp = (y,IR.EINAPP(mkEin(params',index,body'),argsA@[PArg]))          val einapp = (y,IR.EINAPP(mkEin(params',index,body'),argsA@[PArg]))
164          in          in
# Line 169  Line 170 
170          val Pid = 0          val Pid = 0
171          val tid = 1          val tid = 1
172          val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)          val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
   
173          (*need to rewrite dx*)          (*need to rewrite dx*)
174          val sxx = sx@newsx          val sxx = sx@newsx
175          val (_, sizes, E.Conv(_, alpha', _, dx')) = (case sxx          val (_, sizes, E.Conv(_, alpha', _, dx')) = (case sxx
176    (* QUESTION: what is the significance of "9" and "7" in this code? *)
177              of [] => ([], index, E.Conv(9,alpha,7,dx'))              of [] => ([], index, E.Conv(9,alpha,7,dx'))
178              | _ => CleanIndex.clean(E.Conv(9,alpha,7,dx'), index, sxx)              | _ => CleanIndex.clean(E.Conv(9,alpha,7,dx'), index, sxx)
179              (* end case *))              (* end case *))
180          fun filterAlpha [] = dx'          fun filterAlpha [] = dx'
181            | filterAlpha (E.C _::es) = filterAlpha es            | filterAlpha (E.C _::es) = filterAlpha es
182            | filterAlpha (e1::es) = e1::(filterAlpha es)            | filterAlpha (e1::es) = e1::(filterAlpha es)
183          val exp = E.Tensor(tid, filterAlpha(alpha'))            val exp = E.Tensor(tid, filterAlpha alpha')
184          val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)          val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)
185          val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]          val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]
186          val ein0 = mkEin(params, index, body')          val ein0 = mkEin(params, index, body')
# Line 194  Line 195 
195          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
196          val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)          val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)
197          val freshIndex = getsumshift(sx, length(index))          val freshIndex = getsumshift(sx, length(index))
   
198          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
199          val (splitvar, ein0, sizes, dx, alpha') = createEinApp(body, alpha, index, freshIndex, dim, dx, sx)            val (splitvar, ein0, sizes, dx, alpha') =
200                    createEinApp (body, alpha, index, freshIndex, dim, dx, sx)
201          val FArg = V.new ("T", Ty.TensorTy(sizes))          val FArg = V.new ("T", Ty.TensorTy(sizes))
202          val einApp0 = IR.EINAPP(ein0, [PArg,FArg])          val einApp0 = IR.EINAPP(ein0, [PArg,FArg])
203          val rtn0 = if(splitvar)            val rtn0 = if splitvar
204              then FloatEin.transform(y, EinSums.transform ein0, [PArg, FArg])              then FloatEin.transform(y, EinSums.transform ein0, [PArg, FArg])
205              else [(y,IR.EINAPP(ein0, [PArg,FArg]))]              else [(y,IR.EINAPP(ein0, [PArg,FArg]))]
   
206          (*reconstruct the lifted probe*)          (*reconstruct the lifted probe*)
         (*val params' = params@[E.TEN(3,[dim]), E.TEN(1,[dim])]*) (*Fixme: will get type error later*)  
207          val params' = params@[E.TEN(true,[dim]), E.TEN(true,[dim])]          val params' = params@[E.TEN(true,[dim]), E.TEN(true,[dim])]
208          val freshIndex' = length(sizes)            val freshIndex' = length sizes
209          val body' = fieldreconstruction(dim, s, freshIndex',alpha', dx, Vid, hid, nid, fid)            val body' = fieldReconstruction (dim, s, freshIndex', alpha', dx, Vid, hid, nid, fid)
210          val einApp1 = IR.EINAPP(mkEin(params', sizes, body'), args')          val einApp1 = IR.EINAPP(mkEin(params', sizes, body'), args')
         val code1 = (FArg, einApp1)::rtn0  
211          in          in
212              code@code1              code @ (FArg, einApp1) :: rtn0
213          end          end
214    
   
215      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
216      * A this point we only have simple ein ops      * A this point we only have simple ein ops
217      * Looks to see if the expression has a probe. If so, replaces it.      * Looks to see if the expression has a probe. If so, replaces it.

Legend:
Removed from v.3582  
changed lines
  Added in v.3583

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0