1"""Classes to represent arbitrary sets (including sets of sets). 2 3This module implements sets using dictionaries whose values are 4ignored. The usual operations (union, intersection, deletion, etc.) 5are provided as both methods and operators. 6 7Important: sets are not sequences! While they support 'x in s', 8'len(s)', and 'for x in s', none of those operations are unique for 9sequences; for example, mappings support all three as well. The 10characteristic operation for sequences is subscripting with small 11integers: s[i], for i in range(len(s)). Sets don't support 12subscripting at all. Also, sequences allow multiple occurrences and 13their elements have a definite order; sets on the other hand don't 14record multiple occurrences and don't remember the order of element 15insertion (which is why they don't support s[i]). 16 17The following classes are provided: 18 19BaseSet -- All the operations common to both mutable and immutable 20 sets. This is an abstract class, not meant to be directly 21 instantiated. 22 23Set -- Mutable sets, subclass of BaseSet; not hashable. 24 25ImmutableSet -- Immutable sets, subclass of BaseSet; hashable. 26 An iterable argument is mandatory to create an ImmutableSet. 27 28_TemporarilyImmutableSet -- A wrapper around a Set, hashable, 29 giving the same hash value as the immutable set equivalent 30 would have. Do not use this class directly. 31 32Only hashable objects can be added to a Set. In particular, you cannot 33really add a Set as an element to another Set; if you try, what is 34actually added is an ImmutableSet built from it (it compares equal to 35the one you tried adding). 36 37When you ask if `x in y' where x is a Set and y is a Set or 38ImmutableSet, x is wrapped into a _TemporarilyImmutableSet z, and 39what's tested is actually `z in y'. 40 41""" 42 43# Code history: 44# 45# - Greg V. Wilson wrote the first version, using a different approach 46# to the mutable/immutable problem, and inheriting from dict. 47# 48# - Alex Martelli modified Greg's version to implement the current 49# Set/ImmutableSet approach, and make the data an attribute. 50# 51# - Guido van Rossum rewrote much of the code, made some API changes, 52# and cleaned up the docstrings. 53# 54# - Raymond Hettinger added a number of speedups and other 55# improvements. 56 57from __future__ import generators 58try: 59 from itertools import ifilter, ifilterfalse 60except ImportError: 61 # Code to make the module run under Py2.2 62 def ifilter(predicate, iterable): 63 if predicate is None: 64 def predicate(x): 65 return x 66 for x in iterable: 67 if predicate(x): 68 yield x 69 def ifilterfalse(predicate, iterable): 70 if predicate is None: 71 def predicate(x): 72 return x 73 for x in iterable: 74 if not predicate(x): 75 yield x 76 77__all__ = ['BaseSet', 'Set', 'ImmutableSet'] 78 79class BaseSet(object): 80 """Common base class for mutable and immutable sets.""" 81 82 __slots__ = ['_data'] 83 84 # Constructor 85 86 def __init__(self): 87 """This is an abstract class.""" 88 # Don't call this from a concrete subclass! 89 if self.__class__ is BaseSet: 90 raise TypeError("BaseSet is an abstract class. " 91 "Use Set or ImmutableSet.") 92 93 def _getItems(self): 94 """Returns a list of the items in the input order""" 95 #items = self._data.items() 96 items = [] 97 for key, value in self._data.items(): 98 if isinstance(value, tuple): 99 items.extend([(key, v) for v in value]) 100 else: 101 items.append((key, value)) 102 items.sort(key=lambda a: a[1]) 103 return [i[0] for i in items] 104 105 # Standard protocols: __len__, __repr__, __str__, __iter__ 106 107 def __len__(self): 108 """Return the number of elements of a set.""" 109 return len(self._data) 110 111 def __repr__(self): 112 """Return string representation of a set. 113 114 This looks like 'Set([<list of elements>])'. 115 """ 116 return self._repr() 117 118 # __str__ is the same as __repr__ 119 __str__ = __repr__ 120 121 def _repr(self, sorted=False): 122 elements = self._getItems() 123 if sorted: 124 elements.sort() 125 return '%s(%r)' % (self.__class__.__name__, elements) 126 127 def __iter__(self): 128 """Return an iterator over the elements or a set. 129 130 This is the keys iterator for the underlying dict. 131 """ 132 return iter(self._getItems()) 133 134 # Three-way comparison is not supported. However, because __eq__ is 135 # tried before __cmp__, if Set x == Set y, x.__eq__(y) returns True and 136 # then cmp(x, y) returns 0 (Python doesn't actually call __cmp__ in this 137 # case). 138 139 def __cmp__(self, other): 140 raise TypeError("can't compare sets using cmp()") 141 142 # Equality comparisons using the underlying dicts. Mixed-type comparisons 143 # are allowed here, where Set == z for non-Set z always returns False, 144 # and Set != z always True. This allows expressions like "x in y" to 145 # give the expected result when y is a sequence of mixed types, not 146 # raising a pointless TypeError just because y contains a Set, or x is 147 # a Set and y contain's a non-set ("in" invokes only __eq__). 148 # Subtle: it would be nicer if __eq__ and __ne__ could return 149 # NotImplemented instead of True or False. Then the other comparand 150 # would get a chance to determine the result, and if the other comparand 151 # also returned NotImplemented then it would fall back to object address 152 # comparison (which would always return False for __eq__ and always 153 # True for __ne__). However, that doesn't work, because this type 154 # *also* implements __cmp__: if, e.g., __eq__ returns NotImplemented, 155 # Python tries __cmp__ next, and the __cmp__ here then raises TypeError. 156 157 def __eq__(self, other): 158 if isinstance(other, BaseSet): 159 return self._data.keys() == other._data.keys() 160 else: 161 return False 162 163 def __ne__(self, other): 164 if isinstance(other, BaseSet): 165 return self._data.keys() != other._data.keys() 166 else: 167 return True 168 169 # Copying operations 170 171 def copy(self): 172 """Return a shallow copy of a set.""" 173 result = self.__class__() 174 result._data.update(self._data) 175 return result 176 177 __copy__ = copy # For the copy module 178 179 def __deepcopy__(self, memo): 180 """Return a deep copy of a set; used by copy module.""" 181 # This pre-creates the result and inserts it in the memo 182 # early, in case the deep copy recurses into another reference 183 # to this same set. A set can't be an element of itself, but 184 # it can certainly contain an object that has a reference to 185 # itself. 186 from copy import deepcopy 187 result = self.__class__() 188 memo[id(self)] = result 189 data = result._data 190 for e,elt in enumerate(self): 191 data[deepcopy(elt, memo)] = e 192 return result 193 194 # Standard set operations: union, intersection, both differences. 195 # Each has an operator version (e.g. __or__, invoked with |) and a 196 # method version (e.g. union). 197 # Subtle: Each pair requires distinct code so that the outcome is 198 # correct when the type of other isn't suitable. For example, if 199 # we did "union = __or__" instead, then Set().union(3) would return 200 # NotImplemented instead of raising TypeError (albeit that *why* it 201 # raises TypeError as-is is also a bit subtle). 202 203 def __or__(self, other): 204 """Return the union of two sets as a new set. 205 206 (I.e. all elements that are in either set.) 207 """ 208 if not isinstance(other, BaseSet): 209 return NotImplemented 210 return self.union(other) 211 212 def union(self, other): 213 """Return the union of two sets as a new set. 214 215 (I.e. all elements that are in either set.) 216 """ 217 result = self.__class__(self) 218 result._update(other) 219 return result 220 221 def __and__(self, other): 222 """Return the intersection of two sets as a new set. 223 224 (I.e. all elements that are in both sets.) 225 """ 226 if not isinstance(other, BaseSet): 227 return NotImplemented 228 return self.intersection(other) 229 230 def intersection(self, other): 231 """Return the intersection of two sets as a new set. 232 233 (I.e. all elements that are in both sets.) 234 """ 235 if not isinstance(other, BaseSet): 236 other = Set(other) 237 if len(self) <= len(other): 238 little, big = self, other 239 else: 240 little, big = other, self 241 common = ifilter(big._data.has_key, little) 242 return self.__class__(common) 243 244 def __xor__(self, other): 245 """Return the symmetric difference of two sets as a new set. 246 247 (I.e. all elements that are in exactly one of the sets.) 248 """ 249 if not isinstance(other, BaseSet): 250 return NotImplemented 251 return self.symmetric_difference(other) 252 253 def symmetric_difference(self, other): 254 """Return the symmetric difference of two sets as a new set. 255 256 (I.e. all elements that are in exactly one of the sets.) 257 """ 258 result = self.__class__() 259 data = result._data 260 value = 0 261 selfdata = self._data 262 try: 263 otherdata = other._data 264 except AttributeError: 265 otherdata = Set(other)._data 266 for elt in ifilterfalse(otherdata.has_key, selfdata): 267 data[elt] = value 268 value += 1 269 for elt in ifilterfalse(selfdata.has_key, otherdata): 270 data[elt] = value 271 value += 1 272 return result 273 274 def __sub__(self, other): 275 """Return the difference of two sets as a new Set. 276 277 (I.e. all elements that are in this set and not in the other.) 278 """ 279 if not isinstance(other, BaseSet): 280 return NotImplemented 281 return self.difference(other) 282 283 def difference(self, other): 284 """Return the difference of two sets as a new Set. 285 286 (I.e. all elements that are in this set and not in the other.) 287 """ 288 result = self.__class__() 289 data = result._data 290 try: 291 otherdata = other._data 292 except AttributeError: 293 otherdata = Set(other)._data 294 for e, elt in enumerate(ifilterfalse(otherdata.has_key, self)): 295 data[elt] = e 296 return result 297 298 # Membership test 299 300 def __contains__(self, element): 301 """Report whether an element is a member of a set. 302 303 (Called in response to the expression `element in self'.) 304 """ 305 try: 306 return element in self._data 307 except TypeError: 308 transform = getattr(element, "__as_temporarily_immutable__", None) 309 if transform is None: 310 raise # re-raise the TypeError exception we caught 311 return transform() in self._data 312 313 # Subset and superset test 314 315 def issubset(self, other): 316 """Report whether another set contains this set.""" 317 self._binary_sanity_check(other) 318 if len(self) > len(other): # Fast check for obvious cases 319 return False 320 for elt in ifilterfalse(other._data.has_key, self): 321 return False 322 return True 323 324 def issuperset(self, other): 325 """Report whether this set contains another set.""" 326 self._binary_sanity_check(other) 327 if len(self) < len(other): # Fast check for obvious cases 328 return False 329 for elt in ifilterfalse(self._data.has_key, other): 330 return False 331 return True 332 333 # Inequality comparisons using the is-subset relation. 334 __le__ = issubset 335 __ge__ = issuperset 336 337 def __lt__(self, other): 338 self._binary_sanity_check(other) 339 return len(self) < len(other) and self.issubset(other) 340 341 def __gt__(self, other): 342 self._binary_sanity_check(other) 343 return len(self) > len(other) and self.issuperset(other) 344 345 # Assorted helpers 346 347 def _binary_sanity_check(self, other): 348 # Check that the other argument to a binary operation is also 349 # a set, raising a TypeError otherwise. 350 if not isinstance(other, BaseSet): 351 raise TypeError("Binary operation only permitted between sets") 352 353 def _compute_hash(self): 354 # Calculate hash code for a set by xor'ing the hash codes of 355 # the elements. This ensures that the hash code does not depend 356 # on the order in which elements are added to the set. This is 357 # not called __hash__ because a BaseSet should not be hashable; 358 # only an ImmutableSet is hashable. 359 result = 0 360 for elt in self: 361 result ^= hash(elt) 362 return result 363 364 def _update(self, iterable): 365 # The main loop for update() and the subclass __init__() methods. 366 data = self._data 367 368## These would need to have all values incremented 369## # Use the fast update() method when a dictionary is available. 370## if isinstance(iterable, BaseSet): 371## data.update(iterable._data) 372## return 373 374 value = len(self._data) 375 376 if type(iterable) in (list, tuple, xrange): 377 # Optimized: we know that __iter__() and next() can't 378 # raise TypeError, so we can move 'try:' out of the loop. 379 it = iter(iterable) 380 while True: 381 try: 382 for element in it: 383 data[element] = value 384 value += 1 385 return 386 except TypeError: 387 transform = getattr(element, "__as_immutable__", None) 388 if transform is None: 389 raise # re-raise the TypeError exception we caught 390 data[transform()] = value 391 value += 1 392 else: 393 # Safe: only catch TypeError where intended 394 for element in iterable: 395 try: 396 data[element] = value 397 value += 1 398 except TypeError: 399 transform = getattr(element, "__as_immutable__", None) 400 if transform is None: 401 raise # re-raise the TypeError exception we caught 402 data[transform()] = value 403 value += 1 404 405class ImmutableSet(BaseSet): 406 """Immutable set class.""" 407 408 __slots__ = ['_hashcode'] 409 410 # BaseSet + hashing 411 412 def __init__(self, iterable=None): 413 """Construct an immutable set from an optional iterable.""" 414 self._hashcode = None 415 self._data = {} 416 if iterable is not None: 417 self._update(iterable) 418 419 def __hash__(self): 420 if self._hashcode is None: 421 self._hashcode = self._compute_hash() 422 return self._hashcode 423 424 def __getstate__(self): 425 return self._data, self._hashcode 426 427 def __setstate__(self, state): 428 self._data, self._hashcode = state 429 430class Set(BaseSet): 431 """ Mutable set class.""" 432 433 __slots__ = [] 434 435 # BaseSet + operations requiring mutability; no hashing 436 437 def __init__(self, iterable=None): 438 """Construct a set from an optional iterable.""" 439 self._data = {} 440 if iterable is not None: 441 self._update(iterable) 442 443 def __getstate__(self): 444 # getstate's results are ignored if it is not 445 return self._data, 446 447 def __setstate__(self, data): 448 self._data, = data 449 450 def __hash__(self): 451 """A Set cannot be hashed.""" 452 # We inherit object.__hash__, so we must deny this explicitly 453 raise TypeError("Can't hash a Set, only an ImmutableSet.") 454 455 # In-place union, intersection, differences. 456 # Subtle: The xyz_update() functions deliberately return None, 457 # as do all mutating operations on built-in container types. 458 # The __xyz__ spellings have to return self, though. 459 460 def __ior__(self, other): 461 """Update a set with the union of itself and another.""" 462 self._binary_sanity_check(other) 463 self._data.update(other._data) 464 return self 465 466 def union_update(self, other): 467 """Update a set with the union of itself and another.""" 468 self._update(other) 469 470 def __iand__(self, other): 471 """Update a set with the intersection of itself and another.""" 472 self._binary_sanity_check(other) 473 self._data = (self & other)._data 474 return self 475 476 def intersection_update(self, other): 477 """Update a set with the intersection of itself and another.""" 478 if isinstance(other, BaseSet): 479 self &= other 480 else: 481 self._data = (self.intersection(other))._data 482 483 def __ixor__(self, other): 484 """Update a set with the symmetric difference of itself and another.""" 485 self._binary_sanity_check(other) 486 self.symmetric_difference_update(other) 487 return self 488 489 def symmetric_difference_update(self, other): 490 """Update a set with the symmetric difference of itself and another.""" 491 data = self._data 492 value = len(data) 493 if not isinstance(other, BaseSet): 494 other = Set(other) 495 for elt in other: 496 if elt in data: 497 del data[elt] 498 else: 499 data[elt] = value 500 value += 1 501 502 def __isub__(self, other): 503 """Remove all elements of another set from this set.""" 504 self._binary_sanity_check(other) 505 self.difference_update(other) 506 return self 507 508 def difference_update(self, other): 509 """Remove all elements of another set from this set.""" 510 data = self._data 511 if not isinstance(other, BaseSet): 512 other = Set(other) 513 for elt in ifilter(data.has_key, other): 514 del data[elt] 515 516 # Python dict-like mass mutations: update, clear 517 518 def update(self, iterable): 519 """Add all values from an iterable (such as a list or file).""" 520 self._update(iterable) 521 522 def clear(self): 523 """Remove all elements from this set.""" 524 self._data.clear() 525 526 # Single-element mutations: add, remove, discard 527 528 def add(self, element): 529 """Add an element to a set. 530 531 This has no effect if the element is already present. 532 """ 533 try: 534 self._data[element] = len(self._data) 535 except TypeError: 536 transform = getattr(element, "__as_immutable__", None) 537 if transform is None: 538 raise # re-raise the TypeError exception we caught 539 self._data[transform()] = len(self._data) 540 541 def updateDuplicate(self,iterable): 542 '''Adds with possible duplication several elements to the set''' 543 for i in iterable: 544 self.addDuplicate(i) 545 546 def addDuplicate(self, element): 547 """Add an element to the set. 548 549 If the element is already present, it adds the duplicate element. 550 """ 551 try: 552 if element in self._data: 553 pos = self._data[element] 554 if isinstance(pos, tuple): 555 pos = tuple(list(pos)+[len(self._data)]) 556 else: 557 pos = (pos, len(self._data)) 558 self._data[element] = pos 559 else: 560 self._data[element] = len(self._data) 561 except TypeError: 562 transform = getattr(element, "__as_immutable__", None) 563 if transform is None: 564 raise # re-raise the TypeError exception we caught 565 e = transform() 566 if e in self._data: 567 pos = self._data[e] 568 if isinstance(pos, tuple): 569 pos = tuple(list(pos)+[len(self._data)]) 570 else: 571 pos = (pos, len(self._data)) 572 self._data[e] = len(self._data) 573 else: 574 self._data[e] = len(self._data) 575 576 def remove(self, element): 577 """Remove an element from a set; it must be a member. 578 579 If the element is not a member, raise a KeyError. 580 """ 581 try: 582 del self._data[element] 583 except TypeError: 584 transform = getattr(element, "__as_temporarily_immutable__", None) 585 if transform is None: 586 raise # re-raise the TypeError exception we caught 587 del self._data[transform()] 588 589 def discard(self, element): 590 """Remove an element from a set if it is a member. 591 592 If the element is not a member, do nothing. 593 """ 594 try: 595 self.remove(element) 596 except KeyError: 597 pass 598 599 def pop(self): 600 """Remove and return an arbitrary set element.""" 601 #return self._data.popitem()[0] 602 raise RuntimeError('This will not work with our current scheme. We must renumber when this happens') 603 604 def __as_immutable__(self): 605 # Return a copy of self as an immutable set 606 return ImmutableSet(self) 607 608 def __as_temporarily_immutable__(self): 609 # Return self wrapped in a temporarily immutable set 610 return _TemporarilyImmutableSet(self) 611 612class _TemporarilyImmutableSet(BaseSet): 613 # Wrap a mutable set as if it was temporarily immutable. 614 # This only supplies hashing and equality comparisons. 615 616 def __init__(self, set): 617 self._set = set 618 self._data = set._data # Needed by ImmutableSet.__eq__() 619 620 def __hash__(self): 621 return self._set._compute_hash() 622 623if __name__ == '__main__': 624 import unittest 625 626 class SetsTest(unittest.TestCase): 627 def testOrdering(self): 628 '''Verify that sets created in different orders are equivalent''' 629 self.assertEquals(Set([1, 2, 3]), Set([3, 2, 1])) 630 self.assertEquals(Set([1, 2, 3]), Set([2, 3, 1])) 631 return 632 633 def testUniqueness(self): 634 '''Verify that sets created with repeated elements are equivalent''' 635 self.assertEquals(Set([1, 2, 3]), Set([3, 2, 1, 2])) 636 self.assertEquals(Set([1, 2, 3]), Set([2, 3, 3, 3, 1, 2, 3, 1])) 637 return 638 639 def testOrder(self): 640 '''Verify that sets maintain the creation order''' 641 self.assertEquals(str(Set([1, 2, 3])), 'Set([1, 2, 3])') 642 self.assertEquals(str(Set([2, 3, 3, 3, 1, 2, 3, 1])), 'Set([2, 3, 1])') 643 return 644 645 unittest.main() 646