Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3646, Tue Feb 2 14:09:56 2016 UTC revision 3647, Tue Feb 2 15:02:53 2016 UTC
# Line 93  Line 93 
93    (* generate low-IL code for scaling a non-scalar tensor; `sId` is the scalar    (* generate low-IL code for scaling a non-scalar tensor; `sId` is the scalar
94     * parameter's ID and `vId` is the tensor parameter's ID.     * parameter's ID and `vId` is the tensor parameter's ID.
95     *)     *)
96      fun expandScale (sId, vId, params, body, index, args) = let      fun expandScale (sId, vId, shape, params, body, index, args) = let
97            val (n, vecIX, index') = dropIndex index            val (n, vecIX, index') = dropIndex index
98            in            in
99              case matchLast(index, n)              case matchLast(shape, n)
100               of SOME ix => let               of SOME ix => let
                   val avail = AvailRHS.new()  
101                    val vecA = createI (params, args, sId, [])                    val vecA = createI (params, args, sId, [])
102                    val vecB = createP (params, args, vecIX, vId, ix)                    val vecB = createP (params, args, vecIX, vId, ix)
103                    in                    in
104                      unroll (avail, index, index', EinToVec.op2, (vecIX, Op.VScale vecIX, vecA, vecB))                      unroll (
105                          AvailRHS.new(), index, index',
106                          EinToVec.op2, (vecIX, Op.VScale vecIX, vecA, vecB))
107                    end                    end
108                | _ => scalarExpand (params, body, index, args)                | _ => scalarExpand (params, body, index, args)
109              (* end case *)              (* end case *)
110            end            end
111    
112      (*handleSumProd:E.body*int list*info ->Var*LowIR.ASSN list    (* handle potential sum-of-products (i.e., inner products); otherwise fall back to the
113      * info:(string*E.EIN*Var list)     * general scalar case.
     * low-IL code for double dot product  
     * Sigma_{i, j} A_ij B_ij  
114      *)      *)
115      fun handleSumProd2 (params, body, index, args) = let      fun expandInner (params, body, index, args) = (case body
116            val E.Sum(             of E.Sum(
117                    [(E.V v, _, ub)],
118                    E.Opn(E.Prod, [E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)])
119                  ) => (case (matchFindLast(alpha, v), matchFindLast(beta, v))
120                     of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
121                        (* v is the last index of alpha, beta and nowhere else *)
122                          val avail = AvailRHS.new()
123                          val vecIX= ub+1
124                          val vecA = createP(params, args, vecIX, id1, ix1)
125                          val vecB = createP(params, args, vecIX, id2, ix2)
126                          in
127                            unroll (avail, index, index, EinToVec.dotV, (vecIX, vecA, vecB))
128                          end
129                      | _ => scalarExpand (params, body, index, args)
130                    (* end case *))
131                | E.Sum(
132                  [(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],                  [(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],
133                  E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])                  E.Opn(E.Prod, [E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)])
134                ) = body                ) => let
135            fun check (v, ub, sx) = (case (matchFindLast(alpha, v), matchFindLast(beta, v))            fun check (v, ub, sx) = (case (matchFindLast(alpha, v), matchFindLast(beta, v))
136              of ((SOME ix1, NONE), (SOME ix2, NONE)) => let              of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
137                  (*v is the last index of alpha, beta and nowhere else, possible sumProd*)                            (* v is the last index of alpha, beta and nowhere else *)
                 val avail = AvailRHS.new()  
                 (*val nextfnargs = (Ein.params e, args, sx, ub+1, id1, ix1, id2, ix2)*)  
138                  val vecIX= ub+1                  val vecIX= ub+1
139                  val vecA = createP(params, args, vecIX, id1, ix1)                  val vecA = createP(params, args, vecIX, id1, ix1)
140                  val vecB = createP(params, args, vecIX, id2, ix2)                  val vecB = createP(params, args, vecIX, id2, ix2)
141                  in                  in
142                      SOME(unroll (avail, index, index, EinToVec.sumDotV, (sx, vecIX, vecA, vecB)))                                SOME(unroll (
143                                      AvailRHS.new(), index, index,
144                                      EinToVec.sumDotV, (sx, vecIX, vecA, vecB)))
145                  end                  end
146              | _ => NONE              | _ => NONE
147              (* end case *))              (* end case *))
148          in (case check(v1, ub1, (E.V v2, lb2, ub2))                  in
149                      case check(v1, ub1, (E.V v2, lb2, ub2))
150              of SOME e =>e              of SOME e =>e
151              | _ => (case check(v2, ub2, (E.V v1, lb1, ub1))              | _ => (case check(v2, ub2, (E.V v1, lb1, ub1))
152                  of SOME e => e                  of SOME e => e
153                  |_ => scalarExpand (params, body, index, args)                  |_ => scalarExpand (params, body, index, args)
154                  (* end case *))                  (* end case *))
155              (* end case *))                    (* end case *)
156          end          end
157                |  _ => scalarExpand (params, body, index, args)
158              (* end case *))
159    
160    (* expand an Ein expression that has a non-scalar result *)    (* expand an Ein expression that has a non-scalar result *)
161      fun nonScalar (params, body, index, args) = (case body      fun nonScalar (params, body, index, args) = (case body
# Line 223  Line 240 
240                      | _ => scalarExpand (params, body, index, args)                      | _ => scalarExpand (params, body, index, args)
241                    (* end case *)                    (* end case *)
242                  end                  end
243  (* QUESTION: since we are guaranteed a non-scalar result, the only case in gen that can              |  _ => expandInner (params, body, index, args)
  * apply is the default case; right?  
  *)  
 (*  
             |  _ => gen()  
 *)  
             | _ => scalarExpand (params, body, index, args)  
244            (* end case *))            (* end case *))
245    
246    (* scan:var*E.Ein*Var list * Var list-> Var*LowIR.Assgn list    (* scan:var*E.Ein*Var list * Var list-> Var*LowIR.Assgn list
247     * scans body  for vectorization potential     * scans body  for vectorization potential
248     *)     *)
249      fun expand (y, Ein.EIN{params, index, body}, args : LowIR.var list) = let      fun expand (y, Ein.EIN{params, index, body}, args : LowIR.var list) = let
           val info = (e, args)  
           val all = (b, index, info)  
         (* any result type *)  
           fun gen () = (case body  
                  of E.Sum([(E.V v, _, ub)],  
                       E.Opn(E.Prod, [E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)])  
                     ) => (case (matchFindLast(alpha, v), matchFindLast(beta, v))  
                        of ((SOME ix1, NONE), (SOME ix2, NONE)) => let  
                          (* v is the last index of alpha, beta and nowhere else *)  
                             val avail = AvailRHS.new()  
                             val vecIX= ub+1  
                             val vecA = createP(params, args, vecIX, id1, ix1)  
                             val vecB = createP(params, args, vecIX, id2, ix2)  
                             in  
                               unroll (avail, index, index, EinToVec.dotV, (vecIX, vecA, vecB))  
                             end  
                         | _ => scalarExpand (params, body, index, args)  
                       (* end case *))  
                   | E.Sum([_, _], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>  
                       handleSumProd2 all  
                   |  _ => scalarExpand (params, body, index, args)  
                 (* end case *))  
250  (* FIXME: We really only care if the last dimension of a non-scalar shape is "3", since it  (* FIXME: We really only care if the last dimension of a non-scalar shape is "3", since it
251   * causes poor performance in the generated code.  Really, this should be handled by the LowIR   * causes poor performance in the generated code.  Really, this should be handled by the LowIR
252   * to TreeIR transform, where we take into account machine vector widths.   * to TreeIR transform, where we take into account machine vector widths.
# Line 267  Line 256 
256                    | [3, 3, 3] => scalarExpand (params, body, index, args)                    | [3, 3, 3] => scalarExpand (params, body, index, args)
257                    | _::_ => nonScalar (params, body, index, args)                    | _::_ => nonScalar (params, body, index, args)
258  (* TODO: inline gen here, once we've checked that it isn't required in nonScalar. *)  (* TODO: inline gen here, once we've checked that it isn't required in nonScalar. *)
259                    | _ => gen ()                    | _ => expandInner (params, body, index, args)
260                  (* end case *))                  (* end case *))
261            val (x, asgn) :: rest = AvailRHS.getAssignments avail            val (x, asgn) :: rest = AvailRHS.getAssignments avail
262            in            in

Legend:
Removed from v.3646  
changed lines
  Added in v.3647

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0