Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/high-to-mid/probe-ein.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3580, Tue Jan 12 23:29:57 2016 UTC revision 3581, Wed Jan 13 03:18:52 2016 UTC
# Line 10  Line 10 
10    
11    end = struct    end = struct
12    
     structure E = Ein  
     structure DstIR = MidIR  
     structure DstOp = MidOps  
13    
14      structure DstV = DstIR.Var      structure IR = MidIR
15      structure DstTy = MidTypes      structure IROp = MidOps
16        structure V = IR.Var
17        structure Ty = MidTypes
18        structure E = Ein
19      structure T = CoordSpaceTransform      structure T = CoordSpaceTransform
20    
21      (* This file expands probed fields      (* This file expands probed fields
# Line 28  Line 28 
28      * dim:dimension of field V      * dim:dimension of field V
29      * s: support of kernel H      * s: support of kernel H
30      * alpha: The alpha in <V_alpha * H^(deltas)>      * alpha: The alpha in <V_alpha * H^(deltas)>
31      * deltas: The deltas in <V_alpha * H^(deltas)>      * dx: The dx in <V_alpha * nabla_dx H>
32        * deltas: The deltas in <V_alpha * h^(deltas) h^(deltas)>
33      * Vid:param_id for V      * Vid:param_id for V
34      * hid:param_id for H      * hid:param_id for H
35      * nid: integer position param_id      * nid: integer position param_id
# Line 37  Line 38 
38      *)      *)
39    
40      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}      fun mkEin (params, index, body) = Ein.EIN{params = params, index = index, body = body}
41      fun mkEinApp (rator, args) = DstIR.EINAPP(rator, args)      fun getRHSDst x = (case IR.Var.binding x
42      fun setProd e = E.Opn(E.Prod, e)             of IR.VB_RHS(IR.OP(rator, args)) => (rator, args)
43                | IR.VB_RHS(IR.VAR x') => getRHSDst x'
     fun getRHSDst x = (case DstIR.Var.binding x  
            of DstIR.VB_RHS(DstIR.OP(rator, args)) => (rator, args)  
             | DstIR.VB_RHS(DstIR.VAR x') => getRHSDst x'  
44              | vb => raise Fail(concat[              | vb => raise Fail(concat[
45                    "expected rhs operator for ", DstIR.Var.toString x,                    "expected rhs operator for ", IR.Var.toString x,
46                    " but found ", DstIR.vbToString vb                    " but found ", IR.vbToString vb
47                  ])                  ])
48            (* end case *))            (* end case *))
49    
50      (* getArgsDst:MidIR.Var* MidIR.Var->int, ImageInfo, int      fun getImageDst (imgArg, args) = (case (getRHSDst imgArg)
51          uses the Param_ids for the image, kernel,          of (IROp.LoadImage(_, _,info),_) => info
52          and position tensor to get the Mid-IR arguments          | (i,_) => raise Fail (String.concat[" Expected Image: ", IROp.toString i])
53      returns the support of ther kernel, and image          (* end case *))
54      *)  
55      fun getArgsDst (hArg, imgArg, args) = (case (getRHSDst hArg, getRHSDst imgArg)      fun getKernelDst (hArg, args) = (case (getRHSDst hArg)
56             of ((DstOp.Kernel(h, i), _ ), (DstOp.LoadImage(_, _, img), _ )) =>          of (IROp.Kernel(h, _), _) => Kernel.support h
57                  (Kernel.support h, img, ImageInfo.dim img)          | (k,_) => raise Fail (String.concat["Expected kernel: ", IROp.toString k])
              | ((k,_), (i,_)) => raise Fail (String.concat[  
                     "Expected kernel: ", DstOp.toString k, ", Expected Image: ", DstOp.toString i  
                   ])  
58            (* end case *))            (* end case *))
59    
60      (*handleArgs():int*int*int*Mid IR.Var list      (*handleArgs():int*int*int*Mid IR.Var list
# Line 71  Line 66 
66      *)      *)
67      fun handleArgs (Vid, hid, tid, args) = let      fun handleArgs (Vid, hid, tid, args) = let
68            val imgArg = List.nth(args, Vid)            val imgArg = List.nth(args, Vid)
69            val hArg = List.nth(args, hid)            val info = getImageDst(imgArg, args)
70            val newposArg = List.nth(args, tid)        val s = getKernelDst( List.nth(args, hid), args)
71            val (s,img, dim) = getArgsDst(hArg, imgArg, args)        val (argsT, P, code) = T.worldToIndex{info = info, img = imgArg, pos = List.nth(args, tid)}
       val (argsT, P, code) = T.worldToIndex{info = img, img = imgArg, pos = newposArg}  
72            in            in
73              (dim, args@argsT, code, s, P)          (ImageInfo.dim info, args@argsT, code, s, P)
74            end            end
75    
76      (*createBody:int*int*int,mu list, param_id, param_id, param_id, param_id      (*fieldreconstruction:int*int*int,mu list, param_id, param_id, param_id, param_id
77      * expands the body for the probed field      * expands the body for the probed field
78      *)      *)
79      fun createBody (dimO, s, sx, alpha, deltas, Vid, hid, nid, fid) = let      fun fieldreconstruction (dimO, s, sx, alpha, dx, Vid, hid, nid, fid) = let
80          (*1-d fields*)          (*1-d fields*)
81          fun createKRND1 () = let          fun createKRND1 () = let
82              val imgpos = [E.Opn(E.Add,[E.Tensor(fid,[]), E.Value(sx)])]              val imgpos = [E.Opn(E.Add,[E.Tensor(fid,[]), E.Value(sx)])]
83              val dels = List.map (fn e =>(E.C 0,e)) deltas              val deltas = List.map (fn e =>(E.C 0,e)) dx
84              val rest = E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value(sx)))              val rest = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[]), E.Value(sx)))
85              in              in
86                 setProd [E.Img(Vid,alpha,imgpos),rest]                E.Opn(E.Prod, [E.Img(Vid,alpha,imgpos),rest])
87              end              end
88          (*createKRN Image field and kernels *)          (*createKRN Image field and kernels *)
89          fun createKRN (0, imgpos, rest) = E.Opn(E.Prod, E.Img(Vid,alpha,imgpos)::rest)          fun createKRN (0, imgpos, rest) = E.Opn(E.Prod, E.Img(Vid,alpha,imgpos)::rest)
# Line 98  Line 92 
92              val cx = E.C(d')              val cx = E.C(d')
93              val Vsum = E.Value(sx+d')              val Vsum = E.Value(sx+d')
94              val pos0 = E.Opn(E.Add, [E.Tensor(fid,[cx]), Vsum])              val pos0 = E.Opn(E.Add, [E.Tensor(fid,[cx]), Vsum])
95              val dels = List.map (fn e =>(cx, e)) deltas              val deltas = List.map (fn e =>(cx, e)) dx
96              val rest0 = E.Krn(hid, dels, E.Op2(E.Sub,E.Tensor(nid,[cx]),  Vsum))              val rest0 = E.Krn(hid, deltas, E.Op2(E.Sub,E.Tensor(nid,[cx]),  Vsum))
97              in              in
98                  createKRN(d',pos0::imgpos,rest0::rest)                  createKRN(d',pos0::imgpos,rest0::rest)
99              end              end
# Line 128  Line 122 
122        | formBody e = e        | formBody e = e
123    
124      (* silly change in order of the product to match vis branch WorldtoSpace functions*)      (* silly change in order of the product to match vis branch WorldtoSpace functions*)
125      fun multiPs ([P0,P1,P2],sx,body) = formBody(E.Sum(sx, setProd[P0,P1,P2,body]))      fun multiPs(Ps,sx,body) = let
126      (*| multiPs([P0,P1],sx,body) = formBody(E.Sum(sx, setProd([P0,body,P1]))) *)          val exp= (case Ps
127        | multiPs ([P0,P1,P2,P3],sx,body) = formBody(E.Sum(sx, setProd[P0,P1,P2,P3,body]))              of [P0,P1,P2] => [P0,P1,P2,body]
128        | multiPs (Ps,sx,body) = formBody(E.Sum(sx,setProd(body::Ps)))              (*| [P0,P1] => [P0,body,P1] *)
129                | [P0,P1,P2,P3] => [P0,P1,P2,P3,body]
130      fun arrangeBody(body,probe') = (case body              | _ => body::Ps
131              of E.Sum(sx, E.Probe _)                     => E.Sum(sx,probe')              (* end case *))
132              | E.Sum(sx,E.Opn(E.Prod,[eps0,E.Probe _ ])) => E.Sum(sx,E.Opn(E.Prod,[eps0,probe']))          in formBody(E.Sum(sx, E.Opn(E.Prod, exp))) end
133              | _                                         => probe'  
134        fun arrangeBody(body, Ps, newsx, exp)=(case body
135                of E.Sum(sx, E.Probe _ ) => (true, multiPs(Ps, sx@newsx,exp))
136                | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false, E.Sum(sx, E.Opn(E.Prod, [eps0, multiPs(Ps, newsx,exp)])))
137                | E.Probe _ => (true, multiPs(Ps, newsx, exp))
138                | _ => raise Fail "impossible"
139          (* end case *))          (* end case *))
140    
   
141      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list      (* replaceProbe:ein_exp* params *midIR.var list * int list* sum_id list
142              -> ein_exp* *code              -> ein_exp* *code
143      * Transforms position to world space      * Transforms position to world space
# Line 147  Line 145 
145      * rewrites body      * rewrites body
146      * replace probe with expanded version      * replace probe with expanded version
147      *)      *)
148       fun replaceProbe((y, DstIR.EINAPP(Ein.EIN{params = params, index = index, body = body},args)), probe, sx) = let       fun replaceProbe((y, IR.EINAPP(Ein.EIN{params = params, index = index, body = body},args)), probe, sx) = let
149          val fid = length(params)          val fid = length(params)
150          val nid = fid+1          val nid = fid+1
151          val Pid = nid+1          val Pid = nid+1
152          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
         val nshift = length(dx)  
153          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)          val (dim,argsA,code,s,PArg) = handleArgs(Vid,hid,tid,args)
154          val freshIndex = getsumshift(sx,length(index))          val freshIndex = getsumshift(sx,length(index))
155          val (dx,newsx1,Ps) = T.imageToWorld(freshIndex,dim,dx,Pid)          val (dx',sx',Ps) = T.imageToWorld(freshIndex,dim,dx,Pid)
156          (*val params' = params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]*)          (*val params' = params@[E.TEN(3,[dim]),E.TEN(1,[dim]),E.TEN(1,[dim,dim])]*)
157           val params' = params@[E.TEN(true,[dim]),E.TEN(true,[dim]),E.TEN(true,[dim,dim])]           val params' = params@[E.TEN(true,[dim]),E.TEN(true,[dim]),E.TEN(true,[dim,dim])]
158          val probe' = createBody(dim, s, freshIndex+nshift, alpha, dx, Vid, hid, nid, fid)          val probe' = fieldreconstruction(dim, s, freshIndex+length(dx'), alpha, dx', Vid, hid, nid, fid)
159          val probe' = multiPs(Ps,newsx1,probe')          val (_, body') = arrangeBody(body, Ps, sx', probe')
160          val body' = arrangeBody(body,probe')          val einapp = (y,IR.EINAPP(mkEin(params',index,body'),argsA@[PArg]))
         val einapp = (y,mkEinApp(mkEin(params',index,body'),argsA@[PArg]))  
161          in          in
162              code@[einapp]              code@[einapp]
163          end          end
164    
165      fun createEinApp (originalb, alpha, index, freshIndex, dim, dx, sx) = let      (*transform T*P*P..Ps*)
166        fun createEinApp (body, alpha, index, freshIndex, dim, dx, sx) = let
167          val Pid = 0          val Pid = 0
168          val tid = 1          val tid = 1
169            val (dx', newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)
         (*Assumes body is already clean*)  
         val (newdx, newsx, Ps) = T.imageToWorld(freshIndex, dim, dx, Pid)  
170    
171          (*need to rewrite dx*)          (*need to rewrite dx*)
172          val (_, sizes, e as E.Conv(_,alpha',_,dx)) = (case sx@newsx          val sxx = sx@newsx
173              of [] => ([], index, E.Conv(9,alpha,7,newdx))          val (_, sizes, E.Conv(_, alpha', _, dx')) = (case sxx
174              | _ => CleanIndex.clean(E.Conv(9,alpha,7,newdx), index, sx@newsx)              of [] => ([], index, E.Conv(9,alpha,7,dx'))
175                | _ => CleanIndex.clean(E.Conv(9,alpha,7,dx'), index, sxx)
176              (* end case *))              (* end case *))
177          val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]          fun filterAlpha [] = dx'
         fun filterAlpha [] = []  
178            | filterAlpha (E.C _::es) = filterAlpha es            | filterAlpha (E.C _::es) = filterAlpha es
179            | filterAlpha (e1::es) = e1::(filterAlpha es)            | filterAlpha (e1::es) = e1::(filterAlpha es)
180          val tshape = filterAlpha(alpha')@newdx          val exp = E.Tensor(tid, filterAlpha(alpha'))
181          val t = E.Tensor(tid, tshape)          val (splitvar, body') = arrangeBody(body, Ps, newsx, exp)
182          val (splitvar, body) = (case originalb          val params = [E.TEN(true,[dim,dim]), E.TEN(true,sizes)]
183              of E.Sum(sx, E.Probe _)             => (true, multiPs(Ps, sx@newsx,t))          val ein0 = mkEin(params, index, body')
             | E.Sum(sx, E.Opn(E.Prod,[eps0,E.Probe _ ])) => (false, E.Sum(sx, setProd[eps0, multiPs(Ps, newsx, t)]))  
             | _ => (true, multiPs(Ps, newsx, t))  
             (* end case *))  
   
         val ein0 = mkEin(params, index, body)  
184          in          in
185              (splitvar, ein0, sizes, dx, alpha')              (splitvar, ein0, sizes, dx', alpha')
186          end          end
187    
188      fun liftProbe ((y, DstIR.EINAPP(Ein.EIN{params , index , body }, args)), probe, sx) = let      (*floats the reconstructed field term*)
189        fun liftProbe ((y, IR.EINAPP(Ein.EIN{params, index, body}, args)), probe, sx) = let
190          val fid = length(params)          val fid = length(params)
191          val nid = fid+1          val nid = fid+1
192          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe          val E.Probe(E.Conv(Vid,alpha,hid,dx),E.Tensor(tid,_)) = probe
         val nshift = length(dx)  
193          val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)          val (dim, args', code, s, PArg) = handleArgs(Vid, hid, tid, args)
194          val freshIndex = getsumshift(sx, length(index))          val freshIndex = getsumshift(sx, length(index))
195    
196          (*transform T*P*P..Ps*)          (*transform T*P*P..Ps*)
197          val (splitvar, ein0, sizes, dx, alpha') = createEinApp(body, alpha, index, freshIndex, dim, dx, sx)          val (splitvar, ein0, sizes, dx, alpha') = createEinApp(body, alpha, index, freshIndex, dim, dx, sx)
198          val FArg = DstV.new ("F", DstTy.TensorTy(sizes))          val FArg = V.new ("T", Ty.TensorTy(sizes))
199          val einApp0 = mkEinApp(ein0, [PArg,FArg])          val einApp0 = IR.EINAPP(ein0, [PArg,FArg])
200          val rtn0 = (case splitvar          val rtn0 = if(splitvar)
201              of false => [(y, mkEinApp(ein0, [PArg,FArg]))]              then List.map (fn IR.ASSGN(e)=>e) (FloatEin.transform(y, EinSums.transform ein0, [PArg, FArg]))
202              | _     => let              else [(y,IR.EINAPP(ein0, [PArg,FArg]))]
                 val bind3 = (y, DstIR.EINAPP(EinSums.transform  ein0, [PArg, FArg]))  
                 in  
                     Split.splitEinApp bind3  
                 end  
             (* end case *))  
203    
204          (*lifted probe*)          (*reconstruct the lifted probe*)
205          (*val params' = params@[E.TEN(3,[dim]), E.TEN(1,[dim])]*) (*Fixme: will get type error later*)          (*val params' = params@[E.TEN(3,[dim]), E.TEN(1,[dim])]*) (*Fixme: will get type error later*)
206          val params' = params@[E.TEN(true,[dim]), E.TEN(true,[dim])]          val params' = params@[E.TEN(true,[dim]), E.TEN(true,[dim])]
207          val freshIndex' = length(sizes)          val freshIndex' = length(sizes)
208          val body' = createBody(dim, s, freshIndex',alpha', dx, Vid, hid, nid, fid)          val body' = fieldreconstruction(dim, s, freshIndex',alpha', dx, Vid, hid, nid, fid)
209          val ein1 = mkEin(params', sizes, body')          val einApp1 = IR.EINAPP(mkEin(params', sizes, body'), args')
         val einApp1 = mkEinApp(ein1, args')  
210          val code1 = (FArg, einApp1)::rtn0          val code1 = (FArg, einApp1)::rtn0
211          in          in
212              code@code1              code@code1
# Line 232  Line 216 
216      (* expandEinOp: code->  code list      (* expandEinOp: code->  code list
217      * A this point we only have simple ein ops      * A this point we only have simple ein ops
218      * Looks to see if the expression has a probe. If so, replaces it.      * Looks to see if the expression has a probe. If so, replaces it.
     * Note how we keeps eps expressions so only generate pieces that are used  
219      *)      *)
220     fun expandEinOp (e as (y, DstIR.EINAPP(ein as Ein.EIN{params , index , body }, args)), avail) = let     fun expandEinOp (e as (_, IR.EINAPP(Ein.EIN{body, ...},_))) =
221          fun matchField() = (case body          (case body
             of E.Probe _ => 1  
             | E.Sum (_, E.Probe _) => 1  
             | E.Sum(_, E.Opn(E.Prod,[ _ ,E.Probe _])) => 1  
             | _ => 0  
             (* end case *))  
         fun rewriteBody() = (case body  
222              of  (E.Probe(E.Conv(_,_,_,[]),_))              of  (E.Probe(E.Conv(_,_,_,[]),_))
223                 => replaceProbe(e, body, [])                 => replaceProbe(e, body, [])
224              | (E.Probe(E.Conv (_,alpha,_,dx),_))              | (E.Probe(E.Conv (_,alpha,_,dx),_))
# Line 256  Line 233 
233                 => replaceProbe(e ,E.Probe p, sx)                 => replaceProbe(e ,E.Probe p, sx)
234              | _ => [e]              | _ => [e]
235              (* end case *))              (* end case *))
         in  
             (rewriteBody(), avail, matchField(), 0)  
         end  
236    
237    end (* ProbeEin *)    end (* ProbeEin *)

Legend:
Removed from v.3580  
changed lines
  Added in v.3581

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0