Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3632, Sun Jan 31 17:45:58 2016 UTC revision 3641, Mon Feb 1 03:53:00 2016 UTC
# Line 82  Line 82 
82      fun createI (params, args, id, ix) =      fun createI (params, args, id, ix) =
83            VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Indx)            VecToLow.Param(id, List.nth (args, id), H.getTensorTy (params, id), ix, VecToLow.Indx)
84    
85      (*handleAdd:E.body*int list*info ->Var*LowIR.ASSN list    (* generate low-IL code for scaling a non-scalar tensor *)
     * info:(string*E.EIN*Var list)  
     * low-IL code for adding two vectors  
     *)  
     fun handleAdd (E.Opn(E.Add, es), index, info) = let  
         val (n, vecIX, index') = dropIndex index  
         (*check that each tensor in addition list has matching indices*)  
         fun sample ([], rest) = let  
             val avail = AvailRHS.new()  
             val nextfnargs = (vecIX, List.rev rest)  
             in  
                 iter(avail, index, index', VecToLow.addV, nextfnargs)  
             end  
         | sample (E.Tensor(id1, alpha)::ts, rest) = (case (matchLast(alpha, n))  
             of SOME ix1    => sample(ts, createP(params, vecIX, id1, ix1)::rest)  
             | _            => runGeneralCase info  
             (* end case *))  
         | sample _ = runGeneralCase info  
         in  
             sample(es, [])  
         end  
   
     (*handleScale:E.tensor_id*E.tensor_id*E.alpha*int list*info ->Var*LowIR.ASSN list  
     * info:(string*E.EIN*Var list)  
     * low-IL code for adding scaling a vector  
     *)  
86      fun handleScale (id1, id2, alpha2, params, index) = let      fun handleScale (id1, id2, alpha2, params, index) = let
87            val (n, vecIX, index') = dropIndex index            val (n, vecIX, index') = dropIndex index
88            in            in
# Line 124  Line 99 
99              (* end case *)              (* end case *)
100            end            end
101    
     (*handleProd:E.body*int list*info ->Var*LowIR.ASSN list  
     * info:(string*E.EIN*Var list)  
     * low-IL code for vector product  
     *)  
     fun handleProd (E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)]), index, info) = let  
         val (e, args) = info  
         val (n, vecIX, index') = dropIndex index  
         val avail = AvailRHS.new()  
         in case(matchFindLast(alpha, n), matchFindLast(beta, n))  
             of ((SOME ix1, NONE), (SOME ix2, NONE)) => let  
                 (*n is the last index of alpha, beta and nowhere else, possible modulate*)  
                 val vecA = createP(params, args, vecIX, id1, ix1)  
                 val vecB = createP(params, args, vecIX, id2, ix2)  
                 val nextfnargs = (vecIX, Op.VMul vecIX, vecA, vecB)  
                 in  
                     iter(avail, index, index', VecToLow.op2, nextfnargs)  
                 end  
             | ((NONE, NONE), (SOME ix2, NONE)) =>let  
                 (*n is the last index of beta and nowhere else, possible scaleVector*)  
                 val vecA = createI(params, args, id1, alpha)  
                 val vecB = createP(params, args, vecIX, id2, ix2)  
                 val nextfnargs = (vecIX, Op.VScale vecIX, vecA, vecB)  
                 in  
                     iter(avail, index, index', VecToLow.op2 , nextfnargs)  
                 end  
             | ((SOME ix1, NONE), (NONE, NONE)) =>let  
                 (*n is the last index of alpha and nowhere else, ossile scaleVector*)  
                 val vecA = createI(params, args, id2, beta)  
                 val vecB = createP(params, args, vecIX, id1, ix1)  
                 val nextfnargs = (vecIX, Op.VScale vecIX, vecA, vecB)  
   
                 in  
                     iter(avail, index, index', VecToLow.op2, nextfnargs)  
                 end  
             | _ => runGeneralCase info  
         end  
   
102      (*handleSumProd:E.body*int list*info ->Var*LowIR.ASSN list      (*handleSumProd:E.body*int list*info ->Var*LowIR.ASSN list
103      * info:(string*E.EIN*Var list)      * info:(string*E.EIN*Var list)
104      * low-IL code for dot product      * low-IL code for dot product
# Line 185  Line 123 
123      * low-IL code for double dot product      * low-IL code for double dot product
124      * Sigma_{i, j} A_ij B_ij      * Sigma_{i, j} A_ij B_ij
125      *)      *)
126      fun handleSumProd2 (body, index, info) = let      fun handleSumProd2 (params, body, index, args) = let
127            val E.Sum(            val E.Sum(
128                  [(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],                  [(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],
129                  E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])                  E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])
# Line 213  Line 151 
151              (* end case *))              (* end case *))
152          end          end
153    
154    (* scan:var*E.Ein*Var list * Var list-> Var*LowIR.Assgn list    (* expand an Ein expression that has a non-scalar result *)
155     * scans body  for vectorization potential      fun nonScalar (params, body, index, args) = (case body
156     *)             of E.Op2(E.Sub, E.Tensor(id1, alpha as _::_), E.Tensor(id2, beta as _::_)) => let
     fun expand (y, Ein.EIN{params, index, body}, args : LowIR.var list) = let  
           val info = (e, args)  
           val all = (b, index, info)  
         (* any result type *)  
           fun gen () = (case body  
                  of E.Sum([_], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>  
                       handleSumProd1 all  
                   | E.Sum([_, _], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>  
                       handleSumProd2 all  
                   |  _ => runGeneralCase info  
                 (* end case *))  
         (* non scalar result for sure *)  
           fun nonScalar () = (case body  
                  of E.Op2(E.Sub, E.Tensor(id1, alpha as (_::_)), E.Tensor(id2, beta as (_::_))) => let  
157                        val (n, vecIX, index') = dropIndex index                        val (n, vecIX, index') = dropIndex index
158                        in                        in
159                          case (matchLast(alpha, n), matchLast(beta, n))                          case (matchLast(alpha, n), matchLast(beta, n))
# Line 243  Line 167 
167                            | _  => runGeneralCase (index, args)                            | _  => runGeneralCase (index, args)
168                          (* end case *)                          (* end case *)
169                        end                        end
170                    | E.Opn(E.Add, (E.Tensor(_, _::_)::_)) =>              | E.Opn(E.Add, es as E.Tensor(_, _::_)::_) => let
171                        handleAdd all                  val (n, vecIX, index') = dropIndex index
172    (* QUESTION: what does the following comment mean?  What do we do if each tensor has matching indices? *)
173                  (* check that each tensor in addition list has matching indices *)
174                    fun sample ([], rest) = let
175                          val nextfnargs = (vecIX, List.rev rest)
176                          in
177                            iter (AvailRHS.new(), index, index', VecToLow.addV, nextfnargs)
178                          end
179                      | sample (E.Tensor(id1, alpha)::ts, rest) = (case matchLast(alpha, n)
180                           of SOME ix1 => sample(ts, createP(params, args, vecIX, id1, ix1)::rest)
181                            | _ => runGeneralCase (index, args)
182                          (* end case *))
183                      | sample _ = runGeneralCase (index, args)
184                    in
185                      sample es
186                    end
187                    | E.Op1(E.Neg, E.Tensor(id, alpha as (_::_))) => let                    | E.Op1(E.Neg, E.Tensor(id, alpha as (_::_))) => let
188                        val (n, vecIX, index') = dropIndex index                        val (n, vecIX, index') = dropIndex index
189                        in                        in
# Line 260  Line 199 
199                            | _ => runGeneralCase (index, args)                            | _ => runGeneralCase (index, args)
200                          (* end case *)                          (* end case *)
201                        end                        end
202                    | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, j::jx)]) =>              | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, shp as _::_)]) =>
203                        handleScale(s, v, j::jx, index, info)                  handleScale (s, v, shp, index, params, index)
204                    | E.Opn(E.Prod, [E.Tensor(v, j::jx), E.Tensor(s , [])]) =>              | E.Opn(E.Prod, [E.Tensor(v, shp as _::_), E.Tensor(s , [])]) =>
205                        handleScale(s, v, j::jx, index, info)                  handleScale (s, v, j::jx, index, params, index)
206                    | E.Opn(E.Prod, [E.Tensor(_ , _::_), E.Tensor(_, _::_)]) =>              | E.Opn(E.Prod, [E.Tensor(id1 , alpha as _::_), E.Tensor(id2, beta as _::_)]) => let
207                        handleProd all                  val (n, vecIX, index') = dropIndex index
208                    val avail = AvailRHS.new()
209                    in
210                      case (matchFindLast(alpha, n), matchFindLast(beta, n))
211                       of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
212                          (* n is the last index of alpha, beta and nowhere else, possible modulate *)
213                            val vecA = createP(params, args, vecIX, id1, ix1)
214                            val vecB = createP(params, args, vecIX, id2, ix2)
215                            val nextfnargs = (vecIX, Op.VMul vecIX, vecA, vecB)
216                            in
217                              iter (avail, index, index', VecToLow.op2, nextfnargs)
218                            end
219                        | ((NONE, NONE), (SOME ix2, NONE)) => let
220                          (* n is the last index of beta and nowhere else, possible scaleVector *)
221                            val vecA = createI(params, args, id1, alpha)
222                            val vecB = createP(params, args, vecIX, id2, ix2)
223                            val nextfnargs = (vecIX, Op.VScale vecIX, vecA, vecB)
224                            in
225                              iter (avail, index, index', VecToLow.op2 , nextfnargs)
226                            end
227                        | ((SOME ix1, NONE), (NONE, NONE)) => let
228                          (* n is the last index of alpha and nowhere else, ossile scaleVector *)
229                            val vecA = createI(params, args, id2, beta)
230                            val vecB = createP(params, args, vecIX, id1, ix1)
231                            val nextfnargs = (vecIX, Op.VScale vecIX, vecA, vecB)
232                            in
233                              iter (avail, index, index', VecToLow.op2, nextfnargs)
234                            end
235                        | _ => runGeneralCase (index, args)
236                      (* end case *)
237                    end
238    (* QUESTION: since we are guaranteed a non-scalar result, the only case in gen that can
239     * apply is the default case; right?
240     *)
241    (*
242                    |  _ => gen()                    |  _ => gen()
243    *)
244                | _ => runGeneralCase (index, args)
245              (* end case *))
246    
247      (* scan:var*E.Ein*Var list * Var list-> Var*LowIR.Assgn list
248       * scans body  for vectorization potential
249       *)
250        fun expand (y, Ein.EIN{params, index, body}, args : LowIR.var list) = let
251              val info = (e, args)
252              val all = (b, index, info)
253            (* any result type *)
254              fun gen () = (case body
255                     of E.Sum([_], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
256                          handleSumProd1 all
257                      | E.Sum([_, _], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
258                          handleSumProd2 all
259                      |  _ => runGeneralCase info
260                  (* end case *))                  (* end case *))
261  (* QUESTION: what is special about "3" here? *)  (* QUESTION: what is special about "3" here? *)
262            val (avail, _) = (case index            val (avail, _) = (case index
263                   of [3, 3] => runGeneralCase (index, args)                   of [3, 3] => runGeneralCase (index, args)
264                    | [3, 3, 3] => runGeneralCase (index, args)                    | [3, 3, 3] => runGeneralCase (index, args)
265                    | _::_ => nonScalar ()                    | _::_ => nonScalar (params, body, index, args)
266    (* TODO: inline gen here, once we've checked that it isn't required in nonScalar. *)
267                    | _ => gen ()                    | _ => gen ()
268                  (* end case *))                  (* end case *))
269            val (x, asgn) :: rest = AvailRHS.getAssignments avail            val (x, asgn) :: rest = AvailRHS.getAssignments avail
270            in            in
271    (* QUESTION: should we have A twice here? *)
272              List.revMap LowIR.ASSGN ((y, A)::(x, A)::rest)              List.revMap LowIR.ASSGN ((y, A)::(x, A)::rest)
273            end            end
274    

Legend:
Removed from v.3632  
changed lines
  Added in v.3641

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0