Home My Page Projects Code Snippets Project Openings SML/NJ
 Summary Activity Forums Tracker Lists Tasks Docs Surveys News SCM Files

# SCM Repository

[smlnj] Diff of /sml/trunk/src/MLRISC/Tools/MatchCompiler/match-compiler.sml
 [smlnj] / sml / trunk / src / MLRISC / Tools / MatchCompiler / match-compiler.sml

# Diff of /sml/trunk/src/MLRISC/Tools/MatchCompiler/match-compiler.sml

revision 751, Fri Dec 8 21:04:14 2000 UTC revision 752, Fri Dec 8 23:32:37 2000 UTC
# Line 52  Line 52
52     fun listify (l,s,r) list =     fun listify (l,s,r) list =
53         l^List.foldr (fn (x,"") => x | (x,y) => x^s^y) "" list^r         l^List.foldr (fn (x,"") => x | (x,y) => x^s^y) "" list^r
54
55       (* ListPair.all has the wrong semantics! *)
56       fun forall f ([], []) = true
57         | forall f (x::xs, y::ys) = f(x,y) andalso forall f (xs, ys)
58         | forall f _ = false
59
60     datatype index = INT of int | LABEL of Var.var     datatype index = INT of int | LABEL of Var.var
61
62     datatype path  = PATH of index list     datatype path  = PATH of index list
63
(* Internal rep of pattern after every variable has been renamed *)
datatype pat =
WILD                    (* wild card *)
| APP of decon * pat list
| TUPLE of pat list
| RECORD of (Var.var * pat) list

and decon = CON of Con.con
| LIT of Literal.literal

exception MatchCompiler of string

fun error msg = raise MatchCompiler msg
fun bug msg   = error("bug: "^msg)

structure Con     = Con
structure Action  = Action
structure Literal = Literal
structure Guard   = Guard
structure Var     = Var

64     structure Index =     structure Index =
65     struct     struct
66        fun compare(INT i, INT j) = Int.compare(i,j)        fun compare(INT i, INT j) = Int.compare(i,j)
# Line 111  Line 95
95            "v_"^List.foldr (fn (i,"") => Index.toString i            "v_"^List.foldr (fn (i,"") => Index.toString i
96                              | (i,s) => Index.toString i^"_"^s) "" p                              | (i,s) => Index.toString i^"_"^s) "" p
97        structure Map = RedBlackMapFn(type ord_key = path val compare = compare)        structure Map = RedBlackMapFn(type ord_key = path val compare = compare)
structure Set = RedBlackSetFn(type ord_key = path val compare = compare)
98     end     end
99
100       datatype name = VAR of Var.var | PVAR of path
101
102       structure Name =
103       struct
104          fun toString(VAR v)  = Var.toString v
105            | toString(PVAR p) = Path.toString p
106          fun compare(VAR x,VAR y) = Var.compare(x,y)
107            | compare(PVAR x,PVAR y) = Path.compare(x,y)
108            | compare(VAR  _, PVAR _) = LESS
109            | compare(PVAR  _, VAR _) = GREATER
110          fun equal(x,y) = compare(x,y) = EQUAL
111          structure Set = RedBlackSetFn(type ord_key = name val compare = compare)
112       end
113
114       structure VarSet = RedBlackSetFn(type ord_key = Var.var val compare = Var.compare)
115       structure Subst = RedBlackMapFn(type ord_key = Var.var val compare = Var.compare)
116       type subst = name Subst.map
117
118       (* Internal rep of pattern after every variable has been renamed *)
119       datatype pat =
120         WILD                               (* wild card *)
121       | APP of decon * pat list            (* constructor *)
122       | TUPLE of pat list                  (* tupling *)
123       | RECORD of (Var.var * pat) list     (* record *)
124       | OR of (subst * pat) list           (* disjunction *)
125
126       and decon = CON of Con.con
127                 | LIT of Literal.literal
128
129       exception MatchCompiler of string
130
131       fun error msg = raise MatchCompiler msg
132       fun bug msg   = error("bug: "^msg)
133
134       structure Con     = Con
135       structure Action  = Action
136       structure Literal = Literal
137       structure Guard   = Guard
138       structure Var     = Var
139
140     structure Decon =     structure Decon =
141     struct     struct
142        fun kind(CON _) = 0        fun kind(CON _) = 0
# Line 144  Line 167
167          | toString(RECORD lps) = listify("{",",","}")          | toString(RECORD lps) = listify("{",",","}")
168                                   (map (fn (l,p) =>                                   (map (fn (l,p) =>
169                                        Var.toString l^"="^toString p) lps)                                        Var.toString l^"="^toString p) lps)
170            | toString(OR ps) = listify("(","|",")") (map toString' ps)
171          and toString'(subst,p) = toString p
172
173        fun kind(WILD) = 0        fun kind(WILD) = 0
174          | kind(APP _) = 1          | kind(APP _) = 1
175          | kind(TUPLE _) = 2          | kind(TUPLE _) = 2
176          | kind(RECORD _) = 3          | kind(RECORD _) = 3
177            | kind(OR _) = 4
178
179        and compareList([], []) = EQUAL        and compareList([], []) = EQUAL
180          | compareList([], _)  = LESS          | compareList([], _)  = LESS
# Line 182  Line 208
208                    )                    )
209            in  loop(xs, ys)            in  loop(xs, ys)
210            end            end
211            | compare(OR xs, OR ys) = compareList(map #2 xs, map #2 ys)
212          | compare(x, y) = Int.compare(kind x, kind y)          | compare(x, y) = Int.compare(kind x, kind y)
213
fun isRefutable WILD = false
| isRefutable (APP(f,_)) = true
| isRefutable (TUPLE ps) = List.exists isRefutable ps
| isRefutable (RECORD ps) = List.exists (fn (l,p) => isRefutable p) ps

214        structure Set = RedBlackSetFn(type ord_key = pat val compare = compare)        structure Set = RedBlackSetFn(type ord_key = pat val compare = compare)
215        fun equal(p1,p2) = compare(p1,p2) = EQUAL        fun equal(p1,p2) = compare(p1,p2) = EQUAL
216     end     end
217
218     type subst = Var.var Path.Map.map     type rule_no = int
219
220     datatype dfa =     datatype dfa =
221         DFA of         DFA of
222         { stamp    : int,              (* unique dfa stamp *)         { stamp    : int,              (* unique dfa stamp *)
223           freeVars : Path.Set.set ref, (* free variables *)           freeVars : Name.Set.set ref, (* free variables *)
224           refCount : int ref,          (* reference count *)           refCount : int ref,          (* reference count *)
225           generated: bool ref,         (* has code been generated? *)           generated: bool ref,         (* has code been generated? *)
226             height   : int ref,          (* dag height *)
227           test     : test              (* type of tests *)           test     : test              (* type of tests *)
228         }         }
229
230     and test =     and test =
231           CASE  of path * (decon * path list * dfa) list *           CASE  of path * (decon * path list * dfa) list *
232                    dfa option (* multiway *)                    dfa option (* multiway *)
233         | WHERE of subst * guard * dfa * dfa              (* if test *)         | WHERE  of guard * dfa * dfa                  (* if test *)
234         | OK    of subst * Action.action                  (* final dfa *)         | OK     of rule_no * Action.action            (* final dfa *)
235         | BIND  of path * (path * index) list * dfa       (* bind *)         | BIND   of subst * dfa                        (* bind *)
236           | SELECT of path * (path * index) list * dfa   (* projections *)
237         | FAIL                                            (* error dfa *)         | FAIL                                            (* error dfa *)
238         | ROOT  of Path.Set.set * dfa  (* root *)
239       and compiled_dfa  =
240              ROOT of {dfa        : dfa,
241                       used       : Name.Set.set,
242                       exhaustive : bool,
243                       redundant  : IntListSet.set
244                      }
245
246     and guard = GUARD of Guard.guard ref     and guard = GUARD of Guard.guard ref
247
# Line 221  Line 251
251           paths : path list                       (* path (per column) *)           paths : path list                       (* path (per column) *)
252         }         }
253
254
255     withtype row =     withtype row =
256                {pats  : pat list,                {pats  : pat list,
257                 guard : (subst * guard) option,                 guard : (subst * guard) option,
258                 dfa   : dfa                 dfa   : dfa
259                }                }
260     type compiled_rule = pat list * guard option * subst * Action.action         and compiled_rule =
261                 rule_no * pat list * guard option * subst * Action.action
262
263           and compiled_pat = pat * subst
264
265     (* Utilities for dfas *)     (* Utilities for dfas *)
266     structure DFA =     structure DFA =
267     struct     struct
268        val itow = Word.fromInt        val itow = Word.fromInt
269
270          fun h(DFA{stamp, ...}) = itow stamp
271        fun hash(DFA{stamp, test, ...}) =        fun hash(DFA{stamp, test, ...}) =
272            (case test of            (case test of
273              FAIL    => 0w0              FAIL    => 0w0
274            | OK _    => 0w123 + itow stamp            | OK _    => 0w123 + itow stamp
275            | CASE(path, cases, default) => 0w1234            | CASE(path, cases, default) => 0w1234 +
276            | BIND(_, _, dfa) => 0w2313 + hash dfa                 foldr (fn ((_,_,x),y) => h x + y)
277            | WHERE(_, g, yes, no) => 0w2343                       (case default of SOME x => h x | NONE => 0w0) cases
278            | ROOT _ => 0w3414            | SELECT(_, _, dfa) => 0w2313 + hash dfa
279              | WHERE(g, yes, no) => 0w2343 + h yes + h no
280              | BIND(_, dfa) => 0w23234 + h dfa
281            )            )
282
283        (* pointer equality *)        (* pointer equality *)
# Line 255  Line 292
292               (case (t1, t2) of               (case (t1, t2) of
293                  (FAIL, FAIL) => true                  (FAIL, FAIL) => true
294                | (OK _, OK _) => s1 = s2                | (OK _, OK _) => s1 = s2
295                | (BIND(p1, b1, x), BIND(p2, b2, y)) =>                | (SELECT(p1, b1, x), SELECT(p2, b2, y)) =>
296                   Path.equal(p1,p2) andalso eq(x,y) andalso                   Path.equal(p1,p2) andalso eq(x,y) andalso
297                   ListPair.all(fn ((px,ix),(py,iy)) =>                   forall(fn ((px,ix),(py,iy)) =>
298                      Path.equal(px,py) andalso Index.equal(ix,iy))                      Path.equal(px,py) andalso Index.equal(ix,iy))
299                       (b1,b2)                       (b1,b2)
300                | (CASE(p1,c1,o1), CASE(p2,c2,o2)) =>                | (CASE(p1,c1,o1), CASE(p2,c2,o2)) =>
301                    Path.equal(p1,p2) andalso                    Path.equal(p1,p2) andalso
302                    ListPair.all                    forall
303                       (fn ((u,_,x),(v,_,y)) =>                       (fn ((u,_,x),(v,_,y)) =>
304                            Decon.equal(u,v) andalso eq(x,y))                            Decon.equal(u,v) andalso eq(x,y))
305                          (c1,c2) andalso                          (c1,c2) andalso
306                    eqOpt(o1,o2)                    eqOpt(o1,o2)
307                | (WHERE(_, GUARD(g1), y1, n1),                | (WHERE(GUARD(g1), y1, n1),
308                   WHERE(_, GUARD(g2), y2, n2)) =>                   WHERE(GUARD(g2), y2, n2)) =>
309                    g1 = g2 andalso eq(y1,y2) andalso eq(n1,n2)                    g1 = g2 andalso eq(y1,y2) andalso eq(n1,n2)
310                  | (BIND(s1, x), BIND(s2, y)) =>
311                      eq(x,y) andalso
312                        forall (fn ((p,x),(q,y)) =>
313                                 Var.compare(p,q) = EQUAL andalso
314                                 Name.equal(x,y))
315                          (Subst.listItemsi s1, Subst.listItemsi s2)
316                | _ => false                | _ => false
317               )               )
318
# Line 279  Line 322
322                       val hashVal = hash                       val hashVal = hash
323                      )                      )
324
325        fun toString dfa =        fun toString(ROOT{dfa, ...}) =
326        let exception NotVisited        let exception NotVisited
327            val visited = IntHashTable.mkTable(32, NotVisited)            val visited = IntHashTable.mkTable(32, NotVisited)
328            fun mark stamp = IntHashTable.insert visited (stamp, true)            fun mark stamp = IntHashTable.insert visited (stamp, true)
# Line 290  Line 333
333            fun prArgs [] = nop            fun prArgs [] = nop
334              | prArgs ps = seq(!!"(",!!",",!!")") (map (! o Path.toString) ps)              | prArgs ps = seq(!!"(",!!",",!!")") (map (! o Path.toString) ps)
335            fun walk(DFA{stamp, test=FAIL, ...}) = ! "fail"            fun walk(DFA{stamp, test=FAIL, ...}) = ! "fail"
| walk(DFA{test=ROOT(_,dfa), ...}) = walk dfa
336              | walk(DFA{stamp, test, refCount=ref n, ...}) =              | walk(DFA{stamp, test, refCount=ref n, ...}) =
337                if isVisited stamp then !"goto" ++ int stamp                if isVisited stamp then !"goto" ++ int stamp
338                else (mark stamp;                else (mark stamp;
# Line 299  Line 341
341                      (case test of                      (case test of
342                        OK(_,a) => !"Ok" ++ !(Action.toString a)                        OK(_,a) => !"Ok" ++ !(Action.toString a)
343                      | FAIL => !"Fail"                      | FAIL => !"Fail"
344                      | BIND(root,bindings,body) =>                      | SELECT(root,bindings,body) =>
345                        line(!"Let") ++                        line(!"Let") ++
346                        block(seq (nop,nl,nop)                        block(seq (nop,nl,nop)
347                                (map (fn (p,i) =>                                (map (fn (p,i) =>
# Line 325  Line 367
367                               )                               )
368                            )                            )
369                         )                         )
370                      | WHERE(_,GUARD(ref g),y,n) =>                      | WHERE(GUARD(ref g),y,n) =>
371                        line(!"If" ++ !(Guard.toString g)) ++                        line(!"If" ++ !(Guard.toString g)) ++
372                        block(tab ++ ! "then" ++ walk y ++ nl ++                        block(tab ++ ! "then" ++ walk y ++ nl ++
373                              tab ++ ! "else" ++ walk n)                              tab ++ ! "else" ++ walk n)
374                      | _ => bug "walk"                      | BIND(subst, x) =>
375                          line(Subst.foldri (fn (v,n,pp) =>
376                               tab ++ !(Var.toString v) ++ !!"<-" ++
377                                      !(Name.toString n) ++ pp)
378                                   nop subst) ++
379                               walk x
380                      )                      )
381                     )                     )
382        in  PP.text(walk dfa ++ nl)        in  PP.text(walk dfa ++ nl)
# Line 339  Line 386
386     (* Utilities for the pattern matrix *)     (* Utilities for the pattern matrix *)
387     structure Matrix =     structure Matrix =
388     struct     struct
datatype expandType = SWITCH of (decon * path list * matrix) list
* matrix option
| REBIND of path * (path * index) list * matrix

389         fun row(MATRIX{rows, ...}, i) = List.nth(rows,i)         fun row(MATRIX{rows, ...}, i) = List.nth(rows,i)
390         fun col(MATRIX{rows, ...}, i) =         fun col(MATRIX{rows, ...}, i) =
391               List.map (fn {pats, ...} => List.nth(pats, i)) rows               List.map (fn {pats, ...} => List.nth(pats, i)) rows
# Line 392  Line 435
435                 | _  =>                 | _  =>
436                   let val (cons, score) =                   let val (cons, score) =
437                      (* count distinct constructors; skip refutable cards                      (* count distinct constructors; skip refutable cards
438                       * Give records or tuples high scores so that                       * Give records, tuples and or pats, high scores so that
439                       * they are immediately expanded                       * they are immediately expanded
440                       *)                       *)
441                         List.foldr (fn (WILD, (S, n)) => (S, n)                         List.foldr (fn (WILD, (S, n)) => (S, n)
442                                   | (TUPLE _, (S, n)) => (S, 100000)                                   | (TUPLE _, (S, n)) => (S, 100000)
443                                   | (RECORD _, (S, n)) => (S, 100000)                                   | (RECORD _, (S, n)) => (S, 100000)
444                                     | (OR _, (S, n)) => (S, 100000)
445                                   | (pat, (S, n)) => (Pat.Set.add(S, pat), n))                                   | (pat, (S, n)) => (Pat.Set.add(S, pat), n))
446                             (Pat.Set.empty, 0) pats_i                             (Pat.Set.empty, 0) pats_i
447                   in score + Pat.Set.numItems cons end                   in score + Pat.Set.numItems cons end
# Line 422  Line 466
466             | NONE => NONE             | NONE => NONE
467         end         end
468
469       end (* Matrix *)
470
471       val toString = DFA.toString
472
473      (*
474        * Rename user pattern into internal pattern.
475        * The path business is hidden from the client.
476        *)
477       fun rename doIt (rule_no, pats, guard, action) : compiled_rule =
478       let val empty = Subst.empty
479
480           fun bind(subst, v, p) =
481               case Subst.find(subst, v) of
482                 SOME _ => error("duplicated pattern variable "^Var.toString v)
483               | NONE => Subst.insert(subst, v, PVAR p)
484
485           fun process(path, subst:subst, pat) : compiled_pat =
486           let fun idPat id = (WILD, bind(subst, id, path))
487               fun asPat(id, p) =
488               let val (p, subst) = process(path, subst, p)
489               in  (p, bind(subst, id, path))
490               end
491               fun wildPat() = (WILD, subst)
492               fun litPat(lit) = (APP(LIT lit, []), subst)
493
494               fun processPats(pats) =
495               let fun loop([], _, ps', subst) = (rev ps', subst)
496                     | loop(p::ps, i, ps', subst) =
497                       let val path' = Path.dot(path, INT i)
498                           val (p, subst) = process(path', subst, p)
499                       in  loop(ps, i+1, p::ps', subst)
500                       end
501               in  loop(pats, 0, [], subst) end
502
503               fun processLPats(lpats) =
504               let fun loop([], ps', subst) = (rev ps', subst)
505                     | loop((l,p)::ps, ps', subst) =
506                       let val path' = Path.dot(path, LABEL l)
507                           val (p, subst) = process(path', subst, p)
508                       in  loop(ps, (l,p)::ps', subst)
509                       end
510               in  loop(lpats, [], subst) end
511
512               fun consPat(c,args) : compiled_pat =
513               let val (pats, subst) = processPats(args)
514                   val n = case c of
515                             LIT _ => 0
516                           | CON c => Con.arity c
517               in  (* arity check *)
518                   if n <> length args
519                   then error("arity mismatch "^Decon.toString c)
520                   else ();
521                   (APP(c, pats), subst)
522               end
523
524               fun tuplePat(pats) : compiled_pat =
525               let val (pats, subst) = processPats(pats)
526               in  (TUPLE pats, subst) end
527
528               fun recordPat(lpats) : compiled_pat =
529               let val (lpats, subst) = processLPats(lpats)
530               in  (RECORD lpats, subst) end
531
532               (* Or patterns are tricky because the same variable name
533                * may be bound to different components.  We handle this by renaming
534                * all variables to some canonical set of paths,
535                * then rename all variables to these paths.
536                *)
537               fun orPat([])   = error "empty or pattern"
538                 | orPat(pats) : compiled_pat =
539               let val results  = map (fn p => process(path, empty, p)) pats
540                   val ps       = map #1 results
541                   val orSubsts = map #2 results
542                   fun sameVars([], s') = true
543                     | sameVars(s::ss, s') =
544                       forall (fn (x,y) => Var.compare(x,y) = EQUAL)
545                          (Subst.listKeys s, s') andalso
546                            sameVars(ss, s')
547                   (* make sure all patterns use the same set of
548                    * variable names
549                    *)
550                   val orNames = Subst.listKeys(hd orSubsts)
551                   val _ = if sameVars(tl orSubsts, orNames) then ()
552                           else error "not all disjuncts have the same variable bindings"
553                   val duplicated =
554                        VarSet.listItems(
555                         VarSet.intersection
558                   val _ = case duplicated of
559                             [] => ()
560                           | _ => error("duplicated pattern variables: "^
561                                       listify("",",","")
562                                         (map Var.toString duplicated))
563                   (* build the new substitution to include all names in the
564                    * or patterns.
565                    *)
566
567                   val subst = Subst.foldri
568                                (fn (v, _, subst) => Subst.insert(subst,v,VAR v)
569                                ) subst (hd orSubsts)
570               in  (OR(ListPair.zip(orSubsts,ps)), subst)
571               end
572
573
574           in  doIt {idPat=idPat,
575                     asPat=asPat,
576                     wildPat=wildPat,
577                     consPat=consPat,
578                     tuplePat=tuplePat,
579                     recordPat=recordPat,
580                     litPat=litPat,
581                     orPat=orPat
582                    } pat
583           end
584
585           fun processAllPats(i, [], subst, ps') = (rev ps', subst)
586             | processAllPats(i, p::ps, subst, ps') =
587               let val (p, subst) = process(PATH[INT i], subst, p)
588               in  processAllPats(i+1, ps, subst, p::ps')  end
589
590           val (pats, subst) = processAllPats(0, pats, empty, [])
591       in  (rule_no, pats, Option.map (GUARD o ref) guard, subst, action)
592       end
593
594       structure DFAMap =
595          RedBlackMapFn(type ord_key = dfa
596                        fun st(DFA{stamp, ...}) = stamp
597                        fun compare(x,y) = Int.compare(st x, st y)
598                       )
599
600       (*
601        * Give the arguments to case, factor out the common case and make it
602        * the default.
603        *)
604       fun factorCase(p, cases, d as SOME _) = (p, cases, d)
605         | factorCase(p, cases, NONE) =
606           let fun count(m,dfa) = getOpt(DFAMap.find(m,dfa),0)
607               fun inc((_,_,dfa),m) = DFAMap.insert(m, dfa, 1 + count(m, dfa))
608               val m = foldr inc DFAMap.empty cases
609               val best = DFAMap.foldri
610                       (fn (dfa,c,NONE) => SOME(dfa,c)
611                         | (dfa,c,best as SOME(_,c')) =>
612                           if c > c' then SOME(dfa,c) else best)
613                          NONE m
614               fun neq(DFA{stamp=x, ...},DFA{stamp=y,...}) = x<>y
615           in  case best of
616                 NONE => (p, cases, NONE)
617               | SOME(_,1) => (p, cases, NONE)
618               | SOME(defaultCase,n) =>
619                 let val others = List.filter(fn (_,_,x) => neq(x,defaultCase))
620                                    cases
621                 in  (p, others, SOME defaultCase)
622                 end
623           end
624
625         structure LSet = RedBlackSetFn(type ord_key = Var.var         structure LSet = RedBlackSetFn(type ord_key = Var.var
626                                        val compare = Var.compare)                                        val compare = Var.compare)
627
628         (*         (*
629        * The main pattern matching compiler.
630        * The dfa states are constructed with hash consing at the same time
631        * so no separate DFA minimization step is needed.
632        *)
633       fun compile{compiled_rules, compress} =
634       let exception NoSuchState
635
636           datatype expandType = SWITCH of (decon * path list * matrix) list
637                                         * matrix option
638                               | PROJECT of path * (path * index) list * matrix
639
640           fun simp x = if compress then factorCase x else x
641
642           (* Table for hash consing *)
643           val dfaTable = DFA.HashTable.mkTable(32,NoSuchState) :
644                               dfa DFA.HashTable.hash_table
645           val lookupState = DFA.HashTable.lookup dfaTable
646           val insertState = DFA.HashTable.insert dfaTable
647
648           val stampCounter = ref 0
649
650           fun mkState(test) =
651           let val stamp = !stampCounter
652           in  stampCounter := stamp + 1;
653               DFA{stamp=stamp, freeVars=ref Name.Set.empty,
654                   height=ref 0, refCount=ref 0, generated=ref false, test=test}
655           end
656
657           fun newState test =
658           let val s = mkState(test)
659           in  lookupState s handle NoSuchState => (insertState(s, s); s)
660           end
661
662           (* State constructors *)
663           val fail = newState(FAIL)
664           fun Ok x = newState(OK x)
665           fun Case(_, [], SOME x) = x
666             | Case(_, [], NONE) = fail
667             | Case(p, cases as (_,_,c)::cs, default) =
668               if List.all(fn (_,_,c') => DFA.eq(c,c')) cs andalso
669                  (case default of
670                     SOME x => DFA.eq(c,x)
671                   | NONE => true
672                  )
673               then c
674               else newState(CASE(simp(p, cases, default)))
675           fun Select(x) = newState(SELECT(x))
676           fun Where(g, yes, no) =
677               if DFA.eq(yes,no) then yes else newState(WHERE(g, yes, no))
678           fun Bind(subst, x) =
679               if Subst.numItems subst = 0 then x else newState(BIND(subst, x))
680
681           (*
682          * Expand column i,          * Expand column i,
683          * Return a new list of matrixes indexed by the deconstructors.          * Return a new list of matrixes indexed by the deconstructors.
684          *)          *)
685         fun expandColumn(m as MATRIX{rows, paths, ...}, i) =         fun expandColumn(m as MATRIX{rows, paths, ...}, i) =
686         let val ithCol = col(m, i)         let val ithCol = Matrix.col(m, i)
687             val path_i = pathOf(m, i)             val path_i = Matrix.pathOf(m, i)
688             val _ = if debug then             val _ = if debug then
689                        (print("Expanding column "^i2s i^"\n"))                        (print("Expanding column "^i2s i^"\n"))
690                     else ()                     else ()
# Line 443  Line 696
696                   | loop _ = bug "split_i"                   | loop _ = bug "split_i"
697             in  loop(0, ps, []) end             in  loop(0, ps, []) end
698
699             (* find first non-wild card constructor *)             (* If the ith column cfind out what to expand *)
700             val firstCon = List.find (fn WILD => false | _ => true) ithCol             fun expand(WILD::ps, this) = expand(ps, this)
701                 | expand((p as OR _)::ps, this) = SOME p
702                 | expand((p as TUPLE _)::ps, this) = expand(ps, SOME p)
703                 | expand((p as RECORD _)::ps, this) = expand(ps, SOME p)
704                 | expand((p as APP _)::ps, this) = expand(ps, SOME p)
705                 | expand([], this) = this
706
707              (* Split the paths *)              (* Split the paths *)
708             val (prevPaths, _, nextPaths) = split_i paths             val (prevPaths, _, nextPaths) = split_i paths
709
710         in  case firstCon of         in  case expand(ithCol, NONE) of
711               SOME(TUPLE pats) => (* expand a tuple along all the columns *)               SOME(OR _) =>
712                    (* if we have or patterns then expand all rows
713                     * with or pattern
714                     *)
715                 let fun expandOr(row as {pats, dfa, guard}) =
716                     let val (prev, pat_i, next) = split_i(pats)
717                     in  case pat_i of
718                           OR ps =>
719                             map (fn (subst,p) =>
720                                  {pats=prev@[p]@next, dfa=Bind(subst,dfa), guard=guard})
721                                 ps
722                         | _ => [row]
723                     end
724                     val newMatrix =
725                          MATRIX{rows  = List.concat (map expandOr rows),
726                                 paths = paths
727                                }
728                 in  expandColumn(newMatrix, i)
729                 end
730               | SOME(TUPLE pats) => (* expand a tuple along all the columns *)
731               let val arity = length pats               let val arity = length pats
732                   val wilds = map (fn _ => WILD) pats                   val wilds = map (fn _ => WILD) pats
733                   fun processRow{pats, dfa, guard} =                   fun processRow{pats, dfa, guard} =
# Line 470  Line 747
747                   val paths = prevPaths @ path_i' @ nextPaths                   val paths = prevPaths @ path_i' @ nextPaths
748                   val bindings = List.tabulate (arity, fn i =>                   val bindings = List.tabulate (arity, fn i =>
749                                         (Path.dot(path_i, INT i), INT i))                                         (Path.dot(path_i, INT i), INT i))
750               in  REBIND(path_i,bindings,               in  PROJECT(path_i,bindings,
751                          MATRIX{rows=rows, paths=paths}                          MATRIX{rows=rows, paths=paths}
752                         )                         )
753               end               end
# Line 533  Line 810
810                   val bindings = map (fn l =>                   val bindings = map (fn l =>
811                                         (Path.dot(path_i, LABEL l), LABEL l))                                         (Path.dot(path_i, LABEL l), LABEL l))
812                                       labels                                       labels
813               in  REBIND(path_i,bindings,               in  PROJECT(path_i,bindings,
814                          MATRIX{rows=rows, paths=paths}                          MATRIX{rows=rows, paths=paths}
815                         )                         )
816               end               end
# Line 632  Line 909
909               end               end
910             | _ => error "expandColumn.1"             | _ => error "expandColumn.1"
911         end (* expandColumn *)         end (* expandColumn *)
end (* Matrix *)

val toString = DFA.toString

(*
* Rename user pattern into internal pattern.
* The path business is hidden from the client.
*)
fun rename doIt (pats, guard, exp) =
let val empty = Path.Map.empty
val bind  = Path.Map.insert

fun process(path, subst, pat) =
let fun idPat id = (WILD, bind(subst, path, id))
fun asPat(id, p) =
let val (p, subst) = process(path, subst, p)
in  (p, bind(subst, path, id))
end
fun wildPat() = (WILD, subst)
fun litPat(lit) = (APP(LIT lit, []), subst)

fun processPats(pats) =
let fun loop([], _, ps', subst) = (rev ps', subst)
| loop(p::ps, i, ps', subst) =
let val path' = Path.dot(path, INT i)
val (p, subst) = process(path', subst, p)
in  loop(ps, i+1, p::ps', subst)
end
in  loop(pats, 0, [], subst) end

fun processLPats(lpats) =
let fun loop([], ps', subst) = (rev ps', subst)
| loop((l,p)::ps, ps', subst) =
let val path' = Path.dot(path, LABEL l)
val (p, subst) = process(path', subst, p)
in  loop(ps, (l,p)::ps', subst)
end
in  loop(lpats, [], subst) end

fun consPat(c,args) =
let val (pats, subst) = processPats(args)
in  (APP(c, pats), subst) end

fun tuplePat(pats) =
let val (pats, subst) = processPats(pats)
in  (TUPLE pats, subst) end

fun recordPat(lpats) =
let val (lpats, subst) = processLPats(lpats)
in  (RECORD lpats, subst) end

in  doIt {idPat=idPat,
asPat=asPat,
wildPat=wildPat,
consPat=consPat,
tuplePat=tuplePat,
recordPat=recordPat,
litPat=litPat
} pat
end

fun processAllPats(i, [], subst, ps') = (rev ps', subst)
| processAllPats(i, p::ps, subst, ps') =
let val (p, subst) = process(PATH[INT i], subst, p)
in  processAllPats(i+1, ps, subst, p::ps')  end

val (pats, subst) = processAllPats(0, pats, empty, [])
in  (pats, Option.map (GUARD o ref) guard, subst, exp)
end

structure DFAMap =
RedBlackMapFn(type ord_key = dfa
fun st(DFA{stamp, ...}) = stamp
fun compare(x,y) = Int.compare(st x, st y)
)

(*
* Give the arguments to case, factor out the common case and make it
* the default.
*)
fun factorCase(p, cases, d as SOME _) = (p, cases, d)
| factorCase(p, cases, NONE) =
let fun count(m,dfa) = getOpt(DFAMap.find(m,dfa),0)
fun inc((_,_,dfa),m) = DFAMap.insert(m, dfa, 1 + count(m, dfa))
val m = foldr inc DFAMap.empty cases
val best = DFAMap.foldri
(fn (dfa,c,NONE) => SOME(dfa,c)
| (dfa,c,best as SOME(_,c')) =>
if c > c' then SOME(dfa,c) else best)
NONE m
fun neq(DFA{stamp=x, ...},DFA{stamp=y,...}) = x<>y
in  case best of
NONE => (p, cases, NONE)
| SOME(_,1) => (p, cases, NONE)
| SOME(defaultCase,n) =>
let val others = List.filter(fn (_,_,x) => neq(x,defaultCase))
cases
in  (p, others, SOME defaultCase)
end
end

(*
* The main pattern matching compiler.
* The dfa states are constructed with hash consing at the same time
* so no separate DFA minimization step is needed.
*)
fun compile{compiled_rules, compress} =
let exception NoSuchState

fun simp x = if compress then factorCase x else x

(* Table for hash consing *)
val dfaTable = DFA.HashTable.mkTable(32,NoSuchState) :
dfa DFA.HashTable.hash_table
val lookupState = DFA.HashTable.lookup dfaTable
val insertState = DFA.HashTable.insert dfaTable

val stampCounter = ref 0

fun mkState(test) =
let val stamp = !stampCounter
in  stampCounter := stamp + 1;
DFA{stamp=stamp, freeVars=ref Path.Set.empty,
refCount=ref 0, generated=ref false, test=test}
end

fun newState test =
let val s = mkState(test)
in  lookupState s handle NoSuchState => (insertState(s, s); s)
end

(* State constructors *)
fun Fail() = newState(FAIL)
fun Ok(subst, action) = newState(OK(subst, action))
fun Case(_, [], SOME x) = x
| Case(_, [], NONE) = Fail()
| Case(p, cases as (_,_,c)::cs, default) =
(* if List.all(fn (_,_,c') => DFA.eq(c,c')) cs andalso
(case default of
SOME x => DFA.eq(c,x)
| NONE => true
)
then c
else *) newState(CASE(simp(p, cases, default)))
fun Root(used, dfa) = newState(ROOT(used, dfa))
fun Bind(p, bindings, x) = newState(BIND(p, bindings, x))
fun Where(subst, g, yes, no) =
if DFA.eq(yes,no) then yes else newState(WHERE(subst, g, yes, no))
912
913         (*         (*
914          * Generate the DFA          * Generate the DFA
915          *)          *)
916         fun match matrix =         fun match matrix =
917             if Matrix.isEmpty matrix then Fail()             if Matrix.isEmpty matrix then fail
918             else             else
919             case Matrix.findBestMatchColumn matrix of             case Matrix.findBestMatchColumn matrix of
920               NONE =>   (* first row is all wild cards *)               NONE =>   (* first row is all wild cards *)
921                 (case Matrix.row(matrix, 0) of                 (case Matrix.row(matrix, 0) of
922                   {guard=SOME(subst, g), dfa, ...} => (* generate guard *)                   {guard=SOME(subst, g), dfa, ...} => (* generate guard *)
923                     Where(subst, g, dfa,                     Bind(subst,
924                           match(Matrix.removeFirstRow matrix                         Where(g, dfa,
925                          ))                               match(Matrix.removeFirstRow matrix)))
926                 | {guard=NONE, dfa, ...} => dfa                 | {guard=NONE, dfa, ...} => dfa
927                 )                 )
928             | SOME i =>             | SOME i =>
929                (* mixture rule; split at column i *)                (* mixture rule; split at column i *)
930               (case Matrix.expandColumn(matrix, i) of               (case expandColumn(matrix, i) of
931                 (* splitting a constructor *)                 (* splitting a constructor *)
932                 Matrix.SWITCH(cases, default) =>                 SWITCH(cases, default) =>
933                 let val cases = map (fn (c,p,m) => (c,p,match m)) cases                 let val cases = map (fn (c,p,m) => (c,p,match m)) cases
934                 in  Case(Matrix.pathOf(matrix, i), cases,                 in  Case(Matrix.pathOf(matrix, i), cases,
935                          Option.map match default)                          Option.map match default)
# Line 808  Line 937
937                 (* splitting a tuple or record;                 (* splitting a tuple or record;
938                  * recompute new bindings.                  * recompute new bindings.
939                  *)                  *)
940               | Matrix.REBIND(p,bindings,m) => Bind(p, bindings, match m)               | PROJECT(p,bindings,m) => Select(p, bindings, match m)
941               )               )
942
943         fun makeMatrix rules =         fun makeMatrix rules =
944         let val (pats0, _, _, _) = hd rules         let val (_, pats0, _, _, _) = hd rules
945             val arity = length pats0             val arity = length pats0
946             fun makeRow(pats, NONE, subst, action) =             fun makeRow(r, pats, NONE, subst, action) =
947                 {pats=pats, guard=NONE, dfa=Ok(subst, action)}                 {pats=pats, guard=NONE, dfa=Bind(subst, Ok(r, action))}
948               | makeRow(pats, SOME g, subst, action) =               | makeRow(r, pats, SOME g, subst, action) =
949                 {pats=pats, guard=SOME(subst,g), dfa=Ok(Path.Map.empty, action)}                 {pats=pats, guard=SOME(subst,g),
950                    dfa=Ok(r, action)}
951
952         in  MATRIX{rows  = map makeRow rules,         in  MATRIX{rows  = map makeRow rules,
953                    paths = List.tabulate(arity, fn i => PATH[INT i])                    paths = List.tabulate(arity, fn i => PATH[INT i])
954                   }                   }
# Line 825  Line 956
956
957         val dfa = match(makeMatrix compiled_rules)         val dfa = match(makeMatrix compiled_rules)
958
959           val rule_nos = map #1 compiled_rules
960
961         (*         (*
962          * 1. Update the reference counts.          * 1. Update the reference counts.
963          * 2. Compute the set of free path variables at each state.          * 2. Compute the set of free path variables at each state.
964          * 3. Compute the set of path variables that are actually used.          * 3. Compute the set of path variables that are actually used.
965            * 4. Compute the height of each node.
966          *)          *)
967         exception NotVisited         exception NotVisited
968         val visited = IntHashTable.mkTable (32, NotVisited)         val visited = IntHashTable.mkTable (32, NotVisited)
# Line 836  Line 970
970         fun isVisited s = getOpt(IntHashTable.find visited s, false)         fun isVisited s = getOpt(IntHashTable.find visited s, false)
971
972         fun set(fv, s) = (fv := s; s)         fun set(fv, s) = (fv := s; s)
973         val union = Path.Set.union         fun setH(height, h) = (height := h; h)
976           val empty = Name.Set.empty
977         val used = ref Path.Set.empty
978         fun occurs s = used := Path.Set.union(!used,s)         val used = ref Name.Set.empty
979           fun occurs s = used := Name.Set.union(!used,s)
980           val redundant = ref(IntListSet.addList(IntListSet.empty, rule_nos))
981           fun ruleUsed r = redundant := IntListSet.delete(!redundant, r)
982
983         fun vars subst = Path.Set.addList(empty,Path.Map.listKeys subst)         fun vars subst = Name.Set.addList(empty,Subst.listItems subst)
984
985         fun visit(DFA{stamp, refCount, test, freeVars, ...}) =         fun visit(DFA{stamp, refCount, test, freeVars, height, ...}) =
986             (refCount := !refCount + 1;             (refCount := !refCount + 1;
987              if isVisited stamp then !freeVars              if isVisited stamp then (!freeVars, !height)
988              else (mark stamp;              else (mark stamp;
989                    case test of                    case test of
990                      FAIL => empty                      FAIL => (empty, 0)
991                    | OK(subst, _) =>                    | BIND(subst, dfa) =>
992                      let val s = vars subst                      let val (s, h) = visit dfa
993                      in  occurs s; set(freeVars, s) end                          val s = union(s, vars subst)
994                        in  occurs s;
995                            (set(freeVars, s), setH(height, h + 1))
996                        end
997                      | OK(rule_no, _) => (ruleUsed rule_no; (empty, 0))
998                    | CASE(p, cases, opt) =>                    | CASE(p, cases, opt) =>
999                      let val fvs =                      let val (fvs, h) =
1000                           List.foldr(fn ((_,_,x),s) => union(visit x,s))                           List.foldr (fn ((_,_,x),(s, h)) =>
1001                               empty cases                               let val (fv,h') = visit x
1002                          val fvs =                               in  (union(fv,s), Int.max(h,h'))
1003                              case opt of NONE => fvs                               end)
1004                                        | SOME x => union(visit x,fvs)                               (empty, 0) cases
1005                          val fvs = add(fvs, p)                          val (fvs, h) =
1006                      in  occurs fvs; set(freeVars, fvs)                              case opt of NONE => (fvs, h)
1007                      end                                        | SOME x =>
1008                    | WHERE(subst, _, y, n) =>                                          let val (fv, h') = visit x
1009                      let val s = union(vars subst, union(visit y, visit n))                                          in  (union(fvs,fv), Int.max(h,h'))
1010                      in  occurs s; set(freeVars, s) end                                          end
1011                    | BIND(p, bs, x) =>                          val fvs = add(fvs, PVAR p)
1012                      let val s = add(visit x, p)                      in  occurs fvs;
1013                          val bs = foldr (fn ((p,_),S) => add(S,p)) s bs                          (set(freeVars, fvs), setH(height, h+1))
1014                      in  occurs bs; set(freeVars, s) end                      end
1015                    | ROOT _ => bug "visit"                    | WHERE(_, y, n) =>
1016                        let val (sy, hy) = visit y
1017                            val (sn, hn) = visit n
1018                            val s = union(sy, sn)
1019                            val h = Int.max(hy,hn) + 1
1020                        in  occurs s;
1021                            (set(freeVars, s), setH(height, h))
1022                        end
1023                      | SELECT(p, bs, x) =>
1024                        let val (s, h) = visit x
1025                            val s  = add(s, PVAR p)
1026                            val bs = foldr (fn ((p,_),S) => add(S,PVAR p)) s bs
1027                        in  occurs bs;
1028                            (set(freeVars, s), setH(height,h+1))
1029                        end
1030                   )                   )
1031             )             )
1032     in  visit dfa; Root(!used, dfa)         val _ = visit dfa;
1033           val DFA{refCount=failCount, ...} = fail
1034       in  ROOT{used = !used,
1035                dfa = dfa,
1036                exhaustive= !failCount = 0,
1037                redundant= !redundant
1038               }
1039     end     end
1040
1041       fun exhaustive(ROOT{exhaustive, ...}) = exhaustive
1042       fun redundant(ROOT{redundant, ...}) = redundant
1043
1044     (*     (*
1045      * Generate final code for pattern matching.      * Generate final code for pattern matching.
1046      *)      *)
# Line 894  Line 1058
1058            genVal  : Var.var * 'exp -> 'decl            genVal  : Var.var * 'exp -> 'decl
1059          } (root, dfa) =          } (root, dfa) =
1060     let     let
1061         val DFA{test=ROOT(used,dfa), ...} = dfa         val ROOT{dfa, used, ...} = dfa
1062
1063         fun arg p = if Path.Set.member(used, p) then SOME p else NONE         fun genPat p = if Name.Set.member(used, PVAR p) then SOME p else NONE
1064         (* fun arg p = SOME p *)         (* fun arg p = SOME p *)
1065
1066         fun mkVars freeVarSet =         fun mkVars freeVarSet =
1067             map genVar (Path.Set.listItems (!freeVarSet))             map (fn PVAR p => genVar p
1068                     | VAR v  => v
1069                   ) (Name.Set.listItems (!freeVarSet))
1070
1071         fun enque((F,B),x) = (F,x::B)         fun enque(dfa,(F,B)) = (F,dfa::B)
1072           val emptyQueue = ([], [])
1073
1074         (* Walk a state, if it is shared then just generate a goto to the         (* Walk a state, if it is shared then just generate a goto to the
1075          * state function; otherwise expand it          * state function; otherwise expand it
# Line 914  Line 1080
1080                (* just generate a goto *)                (* just generate a goto *)
1081                let val code = genGoto(stamp, mkVars freeVars)                let val code = genGoto(stamp, mkVars freeVars)
1082                in  if !generated then (code, workList)                in  if !generated then (code, workList)
1083                    else (generated := true; (code, enque(workList, dfa)))                    else (generated := true; (code, enque(dfa,workList)))
1084                end                end
1085             else             else
1086                expandDfa(dfa, workList)                expandDfa(dfa, workList)
1087
1088             (* generate a new function definition *)             (* generate a new function definition *)
1089         and genNewFun(dfa as DFA{stamp, freeVars, ...}, workList) =         and genNewFun(dfa as DFA{stamp, freeVars, height, ...}, workList) =
1090             let val (body, workList) = expandDfa(dfa, workList)             let val (body, workList) = expandDfa(dfa, workList)
1091             in  (genFun(stamp, mkVars freeVars, body), workList)             in  ((!height,genFun(stamp, mkVars freeVars, body)), workList)
1092             end             end
1093
1094         and expandYesNo(yes, no, workList) =         and expandYesNo(yes, no, workList) =
# Line 935  Line 1101
1101         and expandDfa(DFA{stamp, test, ...}, workList) =         and expandDfa(DFA{stamp, test, ...}, workList) =
1102                (case test of                (case test of
1103                  (* action *)                  (* action *)
1104                  OK(subst,action) =>                  OK(rule_no, action) => (genOk(action), workList)
let val bindings =
Path.Map.foldri (fn (p,v,b) => (v,p)::b) [] subst
in  (genLet(genBind bindings, genOk(action)), workList)
end
1105                  (* failure *)                  (* failure *)
1106                | FAIL => (genFail(), workList)                | FAIL => (genFail(), workList)
1107                  (* guard *)                  (* guard *)
1108                | WHERE(subst, GUARD(ref g), yes, no) =>                | BIND(subst, dfa) =>
1109                  let val (yes, no, workList) = expandYesNo(yes, no, workList)                  let val (code, workList) = walk(dfa, workList)
1110                      val bindings =                      val bindings =
1111                         Path.Map.foldri (fn (p,v,b) => (v,p)::b) [] subst                         Subst.foldri
1112                  in  (genLet(genBind bindings, genIf(g, yes, no)), workList)                         (fn (v,PVAR p,b) => (v,p)::b
1113                             | (v,VAR v',b) => b
1114                             (* | (p,PVAR p',b) => (genVar p',p)::b *)
1115                           ) [] subst
1116                    in  (genLet(genBind bindings, code), workList)
1117                    end
1118                  | WHERE(GUARD(ref g), yes, no) =>
1119                    let val (yes, no, workList) = expandYesNo(yes, no, workList)
1120                    in  (genIf(g, yes, no), workList)
1121                  end                  end
1122                  (* case *)                  (* case *)
1123                | CASE(path, cases, default) =>                | CASE(path, cases, default) =>
# Line 955  Line 1125
1125                        List.foldr                        List.foldr
1126                        (fn ((con, paths, dfa), (cases, workList)) =>                        (fn ((con, paths, dfa), (cases, workList)) =>
1127                             let val (code, workList) = walk(dfa, workList)                             let val (code, workList) = walk(dfa, workList)
1128                             in  ((con, map arg paths, code)::cases, workList)                             in  ((con, map genPat paths, code)::cases, workList)
1129                             end                             end
1130                        ) ([], workList) cases                        ) ([], workList) cases
1131
# Line 970  Line 1140
1140
1141                  in  (genCase(genVar path, cases, default), workList)                  in  (genCase(genVar path, cases, default), workList)
1142                  end                  end
1143                | BIND(path, bindings, body) =>                | SELECT(path, bindings, body) =>
1144                  let val (body, workList) = walk(body, workList)                  let val (body, workList) = walk(body, workList)
1145                      val bindings = map (fn (p,v) => (SOME p,v)) bindings                      val bindings = map (fn (p,v) => (SOME p,v)) bindings
1146                  in  (genLet([genProj(path, bindings)], body), workList)                  in  (genLet([genProj(path, bindings)], body), workList)
1147                  end                  end
| ROOT _ => bug "expandDfa"
1148                )                )
1149
1150             (* Generate code for the dfa; accumulate all the auxiliary             (* Generate code for the dfa; accumulate all the auxiliary
1151              * functions together and generate a let.              * functions together and generate a let.
1152              *)              *)
1153         fun genAll(root,dfa) =         fun genAll(root,dfa) =
1154             let val (exp, workList) = walk(dfa, ([],[]))             let val (exp, workList) = walk(dfa, emptyQueue)
1155                 fun genAuxFunctions(([],[]), funs) = funs                 fun genAuxFunctions(([],[]), funs) = funs
1156                   | genAuxFunctions(([],B), funs) =                   | genAuxFunctions(([],B), funs) =
1157                     genAuxFunctions((rev B,[]), funs)                     genAuxFunctions((rev B,[]), funs)
1158                   | genAuxFunctions((dfa::F,B), funs) =                   | genAuxFunctions((dfa::F,B), funs) =
1159                     let val (fun1, workList) = genNewFun(dfa, (F,B))                     let val (newFun, workList) = genNewFun(dfa, (F, B))
1160                     in  genAuxFunctions(workList, fun1::funs)                     in  genAuxFunctions(workList, newFun :: funs)
1161                     end                     end
1162                 val rootDecl = genVal(genVar(PATH [INT 0]), root)                 val rootDecl = genVal(genVar(PATH [INT 0]), root)
1163                 val decls = genAuxFunctions(workList, [])                 val funs = genAuxFunctions(workList, [])
1164             in  genLet(rootDecl::decls, exp)                 (* order the functions by dependencies; sort by lowest height *)
1165                   val funs = ListMergeSort.sort
1166                               (fn ((h,_),(h',_)) => h > h') funs
1167                   val funs = map #2 funs
1168               in  genLet(rootDecl::funs, exp)
1169             end             end
1170     in  genAll(root,dfa)     in  genAll(root,dfa)
1171     end     end

Legend:
 Removed from v.751 changed lines Added in v.752