(* hash-cons-set.sml * * COPYRIGHT (c) 2011 The Fellowship of SML/NJ (http://www.smlnj.org) * All rights reserved. * * This is an implementation of the HASH_CONS_SET signature that uses Red-Black * trees. Eventually, it should be replaced by an implmementation that uses * Patricia trees. *) structure HashConsSet : HASH_CONS_SET = struct structure HC = HashCons datatype color = R | B type 'a obj = 'a HC.obj datatype 'a tree = E | T of (color * 'a tree * 'a obj * 'a tree) datatype 'a set = SET of (int * 'a tree) (* NOTE: we should use the Word.< and = operations instead of Word.compare *) fun cmpObj (a : 'a obj, b : 'a obj) = Word.compare(#tag a, #tag b) fun isEmpty (SET(_, E)) = true | isEmpty _ = false val empty = SET(0, E) fun singleton x = SET(1, T(B, E, x, E)) fun add (SET(nItems, m), x) = let val nItems' = ref nItems fun ins E = (nItems' := nItems+1; T(R, E, x, E)) | ins (s as T(color, a, y, b)) = (case cmpObj(x, y) of LESS => (case a of T(R, c, z, d) => (case cmpObj(x, z) of LESS => (case ins c of T(R, e, w, f) => T(R, T(B,e,w,f), z, T(B,d,y,b)) | c => T(B, T(R,c,z,d), y, b) (* end case *)) | EQUAL => T(color, T(R, c, x, d), y, b) | GREATER => (case ins d of T(R, e, w, f) => T(R, T(B,c,z,e), w, T(B,f,y,b)) | d => T(B, T(R,c,z,d), y, b) (* end case *)) (* end case *)) | _ => T(B, ins a, y, b) (* end case *)) | EQUAL => T(color, a, x, b) | GREATER => (case b of T(R, c, z, d) => (case cmpObj(x, z) of LESS => (case ins c of T(R, e, w, f) => T(R, T(B,a,y,e), w, T(B,f,z,d)) | c => T(B, a, y, T(R,c,z,d)) (* end case *)) | EQUAL => T(color, a, y, T(R, c, x, d)) | GREATER => (case ins d of T(R, e, w, f) => T(R, T(B,a,y,c), z, T(B,e,w,f)) | d => T(B, a, y, T(R,c,z,d)) (* end case *)) (* end case *)) | _ => T(B, a, y, ins b) (* end case *)) (* end case *)) val T(_, a, y, b) = ins m in SET(!nItems', T(B, a, y, b)) end fun add' (x, m) = add (m, x) fun addList (s, []) = s | addList (s, x::r) = addList(add(s, x), r) (* Remove an item. Raises LibBase.NotFound if not found. *) local datatype 'a zipper = TOP | LEFT of (color * 'a obj * 'a tree * 'a zipper) | RIGHT of (color * 'a tree * 'a obj * 'a zipper) in fun delete (SET(nItems, t), k) = let (* zip the zipper *) fun zip (TOP, t) = t | zip (LEFT(color, x, b, p), a) = zip(p, T(color, a, x, b)) | zip (RIGHT(color, a, x, p), b) = zip(p, T(color, a, x, b)) (* zip the zipper while resolving a black deficit *) fun fixupZip (TOP, t) = (true, t) (* case 1 from CLR *) | fixupZip (LEFT(B, x, T(R, a, y, b), p), t) = (case a of T(_, T(R, a11, w, a12), z, a2) => (* case 1L ==> case 3L ==> case 4L *) (false, zip (p, T(B, T(R, T(B, t, x, a11), w, T(B, a12, z, a2)), y, b))) | T(_, a1, z, T(R, a21, w, t22)) => (* case 1L ==> case 4L *) (false, zip (p, T(B, T(R, T(B, t, x, a1), z, T(B, a21, w, t22)), y, b))) | T(_, a1, z, a2) => (* case 1L ==> case 2L; rotate + recolor fixes deficit *) (false, zip (p, T(B, T(B, t, x, T(R, a1, z, a2)), y, b))) | _ => fixupZip (LEFT(R, x, a, LEFT(B, y, b, p)), t) (* end case *)) | fixupZip (RIGHT(B, T(R, a, x, b), y, p), t) = (case b of T(_, b1, z, T(R, b21, w, b22)) => (* case 1R ==> case 3R ==> case 4R *) (false, zip (p, T(B, a, x, T(R, T(B, b1, z, b21), w, T(B, b22, y, t))))) | T(_, T(R, b11, w, b12), z, b2) => (* case 1R ==> case 4R *) (false, zip (p, T(B, a, x, T(R, T(B, b11, w, b12), z, T(B, b2, y, t))))) | T(_, b1, z, b2) => (* case 1L ==> case 2L; rotate + recolor fixes deficit *) (false, zip (p, T(B, a, x, T(B, T(R, b1, z, b2), y, t)))) | _ => fixupZip (RIGHT(R, b, y, RIGHT(B, a, x, p)), t) (* end case *)) (* case 3 from CLR *) | fixupZip (LEFT(color, x, T(B, T(R, a1, y, a2), z, b), p), t) = (* case 3L ==> case 4L *) (false, zip (p, T(color, T(B, t, x, a1), y, T(B, a2, z, b)))) | fixupZip (RIGHT(color, T(B, a, x, T(R, b1, y, b2)), z, p), t) = (* case 3R ==> case 4R; rotate, recolor, plus rotate fixes deficit *) (false, zip (p, T(color, T(B, a, x, b1), y, T(B, b2, z, t)))) (* case 4 from CLR *) | fixupZip (LEFT(color, x, T(B, a, y, T(R, b1, z, b2)), p), t) = (false, zip (p, T(color, T(B, t, x, a), y, T(B, b1, z, b2)))) | fixupZip (RIGHT(color, T(B, T(R, a1, z, a2), x, b), y, p), t) = (false, zip (p, T(color, T(B, a1, z, a2), x, T(B, b, y, t)))) (* case 2 from CLR; note that "a" and "b" are guaranteed to be black, since we did * not match cases 3 or 4. *) | fixupZip (LEFT(R, x, T(B, a, y, b), p), t) = (false, zip (p, T(B, t, x, T(R, a, y, b)))) | fixupZip (LEFT(B, x, T(B, a, y, b), p), t) = fixupZip (p, T(B, t, x, T(R, a, y, b))) | fixupZip (RIGHT(R, T(B, a, x, b), y, p), t) = (false, zip (p, T(B, T(R, a, x, b), y, t))) | fixupZip (RIGHT(B, T(B, a, x, b), y, p), t) = fixupZip (p, T(B, T(R, a, x, b), y, t)) (* push deficit up the tree by recoloring a black node as red *) | fixupZip (LEFT(_, y, E, p), t) = fixupZip (p, T(R, t, y, E)) | fixupZip (RIGHT(_, E, y, p), t) = fixupZip (p, T(R, E, y, t)) (* impossible cases that violate the red invariant *) | fixupZip _ = raise Fail "Red invariant violation" (* delete the minimum value from a non-empty tree, returning a triple * (elem, bd, tr), where elem is the minimum element, tr is the residual * tree with elem removed, and bd is true if tr has a black-depth that is * less than the original tree. *) fun delMin (T(R, E, y, b), p) = (* replace the node by its right subtree (which must be E) *) (y, false, zip(p, b)) | delMin (T(B, E, y, T(R, a', y', b')), p) = (* replace the node with its right child, while recoloring the child black to * preserve the black invariant. *) (y, false, zip (p, T(B, a', y', b'))) | delMin (T(B, E, y, E), p) = let (* delete the node, which reduces the black-depth by one, so we attempt to fix * the deficit on the path back. *) val (blkDeficit, t) = fixupZip (p, E) in (y, blkDeficit, t) end | delMin (T(color, a, y, b), z) = delMin(a, LEFT(color, y, b, z)) | delMin (E, _) = raise Match fun del (E, z) = raise LibBase.NotFound | del (T(color, a, y, b), p) = (case cmpObj(k, y) of LESS => del (a, LEFT(color, y, b, p)) | EQUAL => (case (color, a, b) of (R, E, E) => zip(p, E) | (B, E, E) => #2 (fixupZip (p, E)) | (_, T(_, a', y', b'), E) => (* node is black and left child is red; we replace the node with its * left child recolored to black. *) zip(p, T(B, a', y', b')) | (_, E, T(_, a', y', b')) => (* node is black and right child is red; we replace the node with its * right child recolored to black. *) zip(p, T(B, a', y', b')) | _ => let val (minSucc, blkDeficit, b) = delMin (b, TOP) in if blkDeficit then #2 (fixupZip (RIGHT(color, a, minSucc, p), b)) else zip (p, T(color, a, minSucc, b)) end (* end case *)) | GREATER => del (b, RIGHT(color, a, y, p)) (* end case *)) in case del(t, TOP) of T(R, a, x, b) => SET(nItems-1, T(B, a, x, b)) | t => SET(nItems-1, t) (* end case *) end end (* local *) (* Return true if and only if item is an element in the set *) fun member (SET(_, t), k) = let fun find' E = false | find' (T(_, a, y, b)) = (case cmpObj(k, y) of LESS => find' a | EQUAL => true | GREATER => find' b (* end case *)) in find' t end (* Return the number of items in the map *) fun numItems (SET(n, _)) = n fun fold f = let fun foldf (E, accum) = accum | foldf (T(_, a, x, b), accum) = foldf(b, f(x, foldf(a, accum))) in fn init => fn (SET(_, m)) => foldf(m, init) end val foldl = fold (* DEPRECATED *) val foldr = fold (* DEPRECATED *) (* return an ordered list of the items in the set. *) fun toList s = foldr (fn (x, l) => x::l) [] s (* functions for walking the tree while keeping a stack of parents * to be visited. *) fun next ((t as T(_, _, _, b))::rest) = (t, left(b, rest)) | next _ = (E, []) and left (E, rest) = rest | left (t as T(_, a, _, _), rest) = left(a, t::rest) fun start m = left(m, []) (* Return true if and only if the two sets are equal *) fun equal (SET(_, s1), SET(_, s2)) = let fun cmp (t1, t2) = (case (next t1, next t2) of ((E, _), (E, _)) => true | ((E, _), _) => false | (_, (E, _)) => false | ((T(_, _, x, _), r1), (T(_, _, y, _), r2)) => ( case cmpObj(x, y) of EQUAL => cmp (r1, r2) | _ => false (* end case *)) (* end case *)) in cmp (start s1, start s2) end (* Return the lexical order of two sets *) fun compare (SET(_, s1), SET(_, s2)) = let fun cmp (t1, t2) = (case (next t1, next t2) of ((E, _), (E, _)) => EQUAL | ((E, _), _) => LESS | (_, (E, _)) => GREATER | ((T(_, _, x, _), r1), (T(_, _, y, _), r2)) => ( case cmpObj(x, y) of EQUAL => cmp (r1, r2) | order => order (* end case *)) (* end case *)) in cmp (start s1, start s2) end (* Return true if and only if the first set is a subset of the second *) fun isSubset (SET(_, s1), SET(_, s2)) = let fun cmp (t1, t2) = (case (next t1, next t2) of ((E, _), (E, _)) => true | ((E, _), _) => true | (_, (E, _)) => false | ((T(_, _, x, _), r1), (T(_, _, y, _), r2)) => ( case cmpObj(x, y) of LESS => false | EQUAL => cmp (r1, r2) | GREATER => cmp (t1, r2) (* end case *)) (* end case *)) in cmp (start s1, start s2) end (* Return true if the two sets are disjoint *) fun disjoint (SET(0, _), _) = true | disjoint (_, SET(0, _)) = true | disjoint (SET(_, s1), SET(_, s2)) = let fun walk ((E, _), _) = true | walk (_, (E, _)) = true | walk (t1 as (T(_, _, x, _), r1), t2 as (T(_, _, y, _), r2)) = ( case cmpObj(x, y) of LESS => walk (next r1, t2) | EQUAL => false | GREATER => walk (t1, next r2) (* end case *)) in walk (next (start s1), next (start s2)) end (* support for constructing red-black trees in linear time from increasing * ordered sequences (based on a description by R. Hinze). Note that the * elements in the digits are ordered with the largest on the left, whereas * the elements of the trees are ordered with the largest on the right. *) datatype 'a digit = ZERO | ONE of ('a obj * 'a tree * 'a digit) | TWO of ('a obj * 'a tree * 'a obj * 'a tree * 'a digit) (* add an item that is guaranteed to be larger than any in l *) fun addItem (a, l) = let fun incr (a, t, ZERO) = ONE(a, t, ZERO) | incr (a1, t1, ONE(a2, t2, r)) = TWO(a1, t1, a2, t2, r) | incr (a1, t1, TWO(a2, t2, a3, t3, r)) = ONE(a1, t1, incr(a2, T(B, t3, a3, t2), r)) in incr(a, E, l) end (* link the digits into a tree *) fun linkAll t = let fun link (t, ZERO) = t | link (t1, ONE(a, t2, r)) = link(T(B, t2, a, t1), r) | link (t, TWO(a1, t1, a2, t2, r)) = link(T(B, T(R, t2, a2, t1), a1, t), r) in link (E, t) end (* create a set from a list of items; this function works in linear time if the list * is in increasing order. *) fun fromList [] = empty | fromList (first::rest) = let fun add (prev, x::xs, n, accum) = (case cmpObj(prev, x) of LESS => add(x, xs, n+1, addItem(x, accum)) | _ => (* list not in order, so fall back to addList code *) addList(SET(n, linkAll accum), x::xs) (* end case *)) | add (_, [], n, accum) = SET(n, linkAll accum) in add (first, rest, 1, addItem(first, ZERO)) end (* return the union of the two sets *) fun union (SET(_, s1), SET(_, s2)) = let fun ins ((E, _), n, result) = (n, result) | ins ((T(_, _, x, _), r), n, result) = ins(next r, n+1, addItem(x, result)) fun union' (t1, t2, n, result) = (case (next t1, next t2) of ((E, _), (E, _)) => (n, result) | ((E, _), t2) => ins(t2, n, result) | (t1, (E, _)) => ins(t1, n, result) | ((T(_, _, x, _), r1), (T(_, _, y, _), r2)) => ( case cmpObj(x, y) of LESS => union' (r1, t2, n+1, addItem(x, result)) | EQUAL => union' (r1, r2, n+1, addItem(x, result)) | GREATER => union' (t1, r2, n+1, addItem(y, result)) (* end case *)) (* end case *)) val (n, result) = union' (start s1, start s2, 0, ZERO) in SET(n, linkAll result) end (* return the intersection of the two sets *) fun intersection (SET(_, s1), SET(_, s2)) = let fun intersect (t1, t2, n, result) = (case (next t1, next t2) of ((T(_, _, x, _), r1), (T(_, _, y, _), r2)) => ( case cmpObj(x, y) of LESS => intersect (r1, t2, n, result) | EQUAL => intersect (r1, r2, n+1, addItem(x, result)) | GREATER => intersect (t1, r2, n, result) (* end case *)) | _ => (n, result) (* end case *)) val (n, result) = intersect (start s1, start s2, 0, ZERO) in SET(n, linkAll result) end (* return the set difference *) fun difference (SET(_, s1), SET(_, s2)) = let fun ins ((E, _), n, result) = (n, result) | ins ((T(_, _, x, _), r), n, result) = ins(next r, n+1, addItem(x, result)) fun diff (t1, t2, n, result) = (case (next t1, next t2) of ((E, _), _) => (n, result) | (t1, (E, _)) => ins(t1, n, result) | ((T(_, _, x, _), r1), (T(_, _, y, _), r2)) => ( case cmpObj(x, y) of LESS => diff (r1, t2, n+1, addItem(x, result)) | EQUAL => diff (r1, r2, n, result) | GREATER => diff (t1, r2, n, result) (* end case *)) (* end case *)) val (n, result) = diff (start s1, start s2, 0, ZERO) in SET(n, linkAll result) end fun subtract (s, item) = difference (s, singleton item) fun subtract' (item, s) = subtract (s, item) fun subtractList (l, items) = let val items' = List.foldl (fn (x, set) => add(set, x)) (SET(0, E)) items in difference (l, items') end fun app f = let fun appf E = () | appf (T(_, a, x, b)) = (appf a; f x; appf b) in fn (SET(_, m)) => appf m end fun map f = let fun addf (x, m) = add(m, f x) in foldl addf empty end fun mapPartial f (SET(_, m)) = let fun mapf (E, acc) = acc | mapf (T(_, a, x, b), acc) = (case f x of SOME y => mapf (b, mapf (a, add (acc, y))) | NONE => mapf (b, mapf (a, acc)) (* end case *)) in mapf (m, empty) end (* Filter out those elements of the set that do not satisfy the * predicate. The filtering is done in increasing map order. *) fun filter pred (SET(_, t)) = let fun walk (E, n, result) = (n, result) | walk (T(_, a, x, b), n, result) = let val (n, result) = walk(a, n, result) in if (pred x) then walk(b, n+1, addItem(x, result)) else walk(b, n, result) end val (n, result) = walk (t, 0, ZERO) in SET(n, linkAll result) end fun partition pred (SET(_, t)) = let fun walk (E, n1, result1, n2, result2) = (n1, result1, n2, result2) | walk (T(_, a, x, b), n1, result1, n2, result2) = let val (n1, result1, n2, result2) = walk(a, n1, result1, n2, result2) in if (pred x) then walk(b, n1+1, addItem(x, result1), n2, result2) else walk(b, n1, result1, n2+1, addItem(x, result2)) end val (n1, result1, n2, result2) = walk (t, 0, ZERO, 0, ZERO) in (SET(n1, linkAll result1), SET(n2, linkAll result2)) end fun exists pred = let fun test E = false | test (T(_, a, x, b)) = test a orelse pred x orelse test b in fn (SET(_, t)) => test t end fun all pred = let fun test E = true | test (T(_, a, x, b)) = test a andalso pred x andalso test b in fn (SET(_, t)) => test t end fun find pred = let fun test E = NONE | test (T(_, a, x, b)) = (case test a of NONE => if pred x then SOME x else test b | someItem => someItem (* end case *)) in fn (SET(_, t)) => test t end (* deprecated *) val listItems = toList end