Home My Page Projects Code Snippets Project Openings diderot
Summary Activity Tracker Tasks SCM

SCM Repository

[diderot] Diff of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml
ViewVC logotype

Diff of /branches/vis15/src/compiler/mid-to-low/ein-to-low.sml

Parent Directory Parent Directory | Revision Log Revision Log | View Patch Patch

revision 3626, Fri Jan 29 23:10:09 2016 UTC revision 3627, Sun Jan 31 14:15:41 2016 UTC
# Line 18  Line 18 
18  *           Creates Low-IL scalar operators  *           Creates Low-IL scalar operators
19  * Note. The Iter function creates LowIL.CONS and therefore binds the indices in the EIN.body  * Note. The Iter function creates LowIL.CONS and therefore binds the indices in the EIN.body
20  *)  *)
21    
22  structure EinToLow : sig  structure EinToLow : sig
23    
24        val expand : LowIL.var * Ein.ein * LowIL.var list -> LowIL.assign list
25    
26    end = struct    end = struct
27    
28      structure Var = LowIL.Var      structure Var = LowIL.Var
29      structure E = Ein      structure E = Ein
     structure P = Printer  
30      structure Iter = Iter      structure Iter = Iter
31      structure EtoSca = ScaToLow      structure EtoSca = ScaToLow
32      structure EtoVec = VecToLow      structure EtoVec = VecToLow
# Line 34  Line 36 
36      fun iter e = Iter.prodIter e      fun iter e = Iter.prodIter e
37      fun intToReal n = H.intToReal n      fun intToReal n = H.intToReal n
38    
39      (*dropIndex: a list-> int*a*alist    (* `dropIndex alpha` returns the (len, i, alpha') where
40      * alpha::i->returns  length of list-1, i, alpha     * len = length(alpha') and alpha = alpha'@[i].
41      *)      *)
42      fun dropIndex alpha = let      fun dropIndex alpha = let
43          val (e1::es) = List.rev(alpha)            fun drop ([], _, _) = raise Fail "dropIndex[]"
44          in (length alpha-1, e1, List.rev es)              | drop ([idx], n, idxs') = (n, idx, List.rev idxs')
45                | drop (idx::idxs, n, idx') = drop (idxs, n+1, idx::idx')
46              in
47                drop (alpha, 0, [])
48          end          end
49    
50      (*matchLast:E.alpha*int -> (E.alpha) Option      (*matchLast:E.alpha*int -> (E.alpha) Option
51      * Is the last index of alpha E.V n.      * Is the last index of alpha E.V n.
52      * If so, return the rest of the list      * If so, return the rest of the list
53      *)      *)
54      fun matchLast (alpha, n) =      fun matchLast (alpha, n) = (case List.rev alpha
55          case List.rev(alpha)             of (E.V v)::es => if (n = v) then SOME(List.rev es) else NONE
             of (E.V v)::es => (case (n = v)  
                 of true => SOME(List.rev es)  
56                  | _ =>   NONE                  | _ =>   NONE
57                  (*end case*))                  (*end case*))
             | _ => NONE  
58    
59      (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option      (*matchFindLast:E.alpha *int -> E.alpha option* E.mu option
60      * Is the last index of alpha = n.      * Is the last index of alpha = n.
# Line 219  Line 221 
221      * Sigma_{i, j} A_ij B_ij      * Sigma_{i, j} A_ij B_ij
222      *)      *)
223      fun handleSumProd2 (body, index, info) = let      fun handleSumProd2 (body, index, info) = let
224          val E.Sum([(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)], E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])) = body            val E.Sum(
225                    [(E.V v1, lb1, ub1), (E.V v2, lb2, ub2)],
226                    E.Opn(E.Prod, [E.Tensor(id1 , alpha), E.Tensor(id2, beta)])
227                  ) = body
228          fun check(v, ub, sx) = (case (matchFindLast(alpha, v), matchFindLast(beta, v))          fun check(v, ub, sx) = (case (matchFindLast(alpha, v), matchFindLast(beta, v))
229              of ((SOME ix1, NONE), (SOME ix2, NONE)) => let              of ((SOME ix1, NONE), (SOME ix2, NONE)) => let
230                  (*v is the last index of alpha, beta and nowhere else, possible sumProd*)                  (*v is the last index of alpha, beta and nowhere else, possible sumProd*)
# Line 246  Line 251 
251      (*scan:var*E.Ein*Var list * Var list-> Var*LowIL.Assgn list      (*scan:var*E.Ein*Var list * Var list-> Var*LowIL.Assgn list
252      *scans body  for vectorization potential      *scans body  for vectorization potential
253      *)      *)
254      fun scan (y, e:Ein.ein, args:LowIL.var list) = let      fun expand (y, Ein.EIN{params, index, body}, args : LowIL.var list) = let
         val b = Ein.body e  
         val index = Ein.index e  
255          val info = (e, args)          val info = (e, args)
256          val all = (b, index, info)          val all = (b, index, info)
   
257          (*any result type*)          (*any result type*)
258          fun gen() = case b            fun gen () = (case body
259              of E.Sum([_], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) => handleSumProd1 all                   of E.Sum([_], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
260              | E.Sum([_, _], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) => handleSumProd2 all                        handleSumProd1 all
261                      | E.Sum([_, _], E.Opn(E.Prod, [E.Tensor(_ , i::ix), E.Tensor(_, j::jx)])) =>
262                          handleSumProd2 all
263              |  _ => runGeneralCase info              |  _ => runGeneralCase info
264                    (* end case *))
265          (*non scalar result for sure*)          (*non scalar result for sure*)
266          fun nonScalar() = case b            fun nonScalar () = (case body
267              of E.Op2(E.Sub, E.Tensor(_, _::_), E.Tensor(_, _::_)) => handleSub all                   of E.Op2(E.Sub, E.Tensor(_, _::_), E.Tensor(_, _::_)) =>
268              | E.Opn(E.Add, (E.Tensor(_, _::_)::_)) => handleAdd all                        handleSub all
269              | E.Op1(E.Neg, E.Tensor(_ , _::_)) => handleNeg all                    | E.Opn(E.Add, (E.Tensor(_, _::_)::_)) =>
270              | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, j::jx)]) => handleScale(s, v, j::jx, index, info)                        handleAdd all
271              | E.Opn(E.Prod, [E.Tensor(v, j::jx), E.Tensor(s , [])]) => handleScale(s, v, j::jx, index, info)                    | E.Op1(E.Neg, E.Tensor(_ , _::_)) =>
272              | E.Opn(E.Prod, [E.Tensor(_ , _::_), E.Tensor(_, _::_)]) => handleProd all                        handleNeg all
273                      | E.Opn(E.Prod, [E.Tensor(s, []), E.Tensor(v, j::jx)]) =>
274                          handleScale(s, v, j::jx, index, info)
275                      | E.Opn(E.Prod, [E.Tensor(v, j::jx), E.Tensor(s , [])]) =>
276                          handleScale(s, v, j::jx, index, info)
277                      | E.Opn(E.Prod, [E.Tensor(_ , _::_), E.Tensor(_, _::_)]) =>
278                          handleProd all
279              |  _ => gen()              |  _ => gen()
280                    (* end case *))
281            val (avail, _) = case index  (* QUESTION: what is special about "3" here? *)
282              val (avail, _) = (case index
283              of [3, 3] => runGeneralCase info              of [3, 3] => runGeneralCase info
284              | [3, 3, 3] => runGeneralCase info              | [3, 3, 3] => runGeneralCase info
285              | _::_ => nonScalar()              | _::_ => nonScalar()
286              | _ => gen()              | _ => gen()
287          in  case AvailRHS.getAssignments avail                  (* end case *))
288              of (x, A)::rest =>  let (*need to reassign the last assgn*)            val (x, asgn):: rest = AvailRHS.getAssignments avail
289                  val code = List.rev((x, A)::rest)            in
290                  val rtn = code@[(y, A)]              List.revMap LowIL.ASSGN ((y, A)::(x, A)::rest)
                 in List.map (fn e =>LowIL.ASSGN(e)) rtn end  
291          end          end
292    
293      end      end

Legend:
Removed from v.3626  
changed lines
  Added in v.3627

root@smlnj-gforge.cs.uchicago.edu
ViewVC Help
Powered by ViewVC 1.0.0