[Python-checkins] r55321 - sandbox/trunk/abc/abc.py sandbox/trunk/abc/test_abc.py

guido.van.rossum python-checkins at python.org
Mon May 14 23:04:01 CEST 2007


Author: guido.van.rossum
Date: Mon May 14 23:03:55 2007
New Revision: 55321

Modified:
   sandbox/trunk/abc/abc.py
   sandbox/trunk/abc/test_abc.py
Log:
Add overloading example that works across ABC registrations.
Add mutable sequence and mapping.
Register some built-in types.


Modified: sandbox/trunk/abc/abc.py
==============================================================================
--- sandbox/trunk/abc/abc.py	(original)
+++ sandbox/trunk/abc/abc.py	Mon May 14 23:03:55 2007
@@ -9,6 +9,7 @@
 __author__ = "Guido van Rossum <guido at python.org>"
 
 import sys
+import inspect
 import itertools
 
 
@@ -149,7 +150,7 @@
         elif subclass in cls.__abc_negative_cache__:
             return False
         # Check if it's a direct subclass
-        if cls in subclass.mro():
+        if cls in subclass.__mro__:
             cls.__abc_cache__.add(subclass)
             return True
         # Check if it's a subclass of a registered class (recursive)
@@ -398,15 +399,9 @@
 
 ### MAPPINGS ###
 
+# XXX Get rid of _BasicMapping and view types
 
-class BasicMapping(Container):
-
-    """A basic mapping has __getitem__(), __contains__() and get().
-
-    The idea is that you only need to override __getitem__().
-
-    Other dict methods are not supported.
-    """
+class _BasicMapping(Container, Iterable):
 
     @abstractmethod
     def __getitem__(self, key):
@@ -425,9 +420,6 @@
         except KeyError:
             return False
 
-
-class IterableMapping(BasicMapping, Iterable):
-
     def keys(self):
         return KeysView(self)
 
@@ -484,7 +476,7 @@
             yield self._mapping[key]
 
 
-class Mapping(IterableMapping, Sized):
+class Mapping(_BasicMapping, Sized):
 
     def keys(self):
         return KeysView(self)
@@ -555,6 +547,59 @@
         return False
 
 
+class MutableMapping(Mapping):
+
+    @abstractmethod
+    def __setitem__(self, key):
+        raise NotImplementedError
+
+    @abstractmethod
+    def __delitem__(self, key):
+        raise NotImplementedError
+
+    __marker = object()
+
+    def pop(self, key, default=__marker):
+        try:
+            value = self[key]
+        except KeyError:
+            if default is self.__marker:
+                raise
+            return default
+        else:
+            del self[key]
+            return value
+
+    def popitem(self):
+        try:
+            key = next(iter(self))
+        except StopIteration:
+            raise KeyError
+        value = self[key]
+        del self[key]
+        return key, value
+
+    def clear(self):
+        try:
+            while True:
+                self.popitem()
+        except KeyError:
+            pass
+
+    def update(self, other=(), **kwds):
+        if isinstance(other, Mapping):
+            for key in other:
+                self[key] = other[key]
+        elif hasattr(other, "keys"):
+            for key in other.keys():
+                self[key] = other[key]
+        else:
+            for key, value in other:
+                self[key] = value
+        for key, value in kwds.items():
+            self[key] = value
+
+
 ### SEQUENCES ###
 
 
@@ -708,8 +753,74 @@
         return len(self) <= len(other)
 
 
+class MutableSequence(Sequence):
+
+    @abstractmethod
+    def __setitem__(self, i, value):
+        raise NotImplementedError
+
+    @abstractmethod
+    def __delitem__(self, i, value):
+        raise NotImplementedError
+
+    @abstractmethod
+    def insert(self, i, value):
+        raise NotImplementedError
+
+    def append(self, value):
+        self.insert(len(self), value)
+
+    def reverse(self):
+        n = len(self)
+        for i in range(n//2):
+            j = n-i-1
+            self[i], self[j] = self[j], self[i]
+
+    def extend(self, it):
+        for x in it:
+            self.append(x)
+
+    def pop(self, i=None):
+        if i is None:
+            i = len(self) - 1
+        value = self[i]
+        del self[i]
+        return value
+
+    def remove(self, value):
+        for i in range(len(self)):
+            if self[i] == value:
+                del self[i]
+                return
+        raise ValueError
+
+
+
+### PRE-DEFINED REGISTRATIONS ###
+
+Hashable.register(int)
+Hashable.register(float)
+Hashable.register(complex)
+Hashable.register(basestring)
+Hashable.register(tuple)
+Hashable.register(frozenset)
+Hashable.register(type)
+
+Set.register(frozenset)
+MutableSet.register(set)
+
+MutableMapping.register(dict)
+
+Sequence.register(tuple)
+Sequence.register(basestring)
+MutableSequence.register(list)
+MutableSequence.register(bytes)
+
+
 ### ADAPTERS ###
 
+# This is just an example, not something to go into the stdlib
+
 
 class AdaptToSequence(Sequence):
 
@@ -762,3 +873,130 @@
 
     def __len__(self):
         return len(self.adaptee)
+
+
+### OVERLOADING ###
+
+# This is a modest alternative proposal to PEP 3124.  It uses
+# issubclass() exclusively meaning that any issubclass() overloading
+# automatically works.  If accepted it probably ought to go into a
+# separate module (overloading.py?) as it has nothing to do directly
+# with ABCs.  The code here is an evolution from my earlier attempt in
+# sandbox/overload/overloading.py.
+
+
+class overloadable:
+
+    """An implementation of overloadable functions.
+
+    Usage example:
+
+    @overloadable
+    def flatten(x):
+        yield x
+
+    @flatten.overload
+    def _(it: Iterable):
+        for x in it:
+            yield x
+
+    @flatten.overload
+    def _(x: basestring):
+        yield x
+
+    """
+
+    def __init__(self, default_func):
+        # Decorator to declare new overloaded function.
+        self.registry = {}
+        self.cache = {}
+        self.default_func = default_func
+
+    def __get__(self, obj, cls=None):
+        if obj is None:
+            return self
+        return new.instancemethod(self, obj)
+
+    def overload(self, func):
+        """Decorator to overload a function using its argument annotations."""
+        self.register_func(self.extract_types(func), func)
+        if func.__name__ == self.default_func.__name__:
+            return self
+        else:
+            return func
+
+    def extract_types(self, func):
+        """Helper to extract argument annotations as a tuple of types."""
+        args, varargs, varkw, defaults, kwonlyargs, kwdefaults, annotations = \
+              inspect.getfullargspec(func)
+        return tuple(annotations.get(arg, object) for arg in args)
+
+    def register_func(self, types, func):
+        """Helper to register an implementation."""
+        self.registry[types] = func
+        self.cache = {} # Clear the cache (later we might optimize this).
+
+    def __call__(self, *args):
+        """Call the overloaded function."""
+        types = tuple(arg.__class__ for arg in args)
+        funcs = self.cache.get(types)
+        if funcs is None:
+            self.cache[types] = funcs = list(self.find_funcs(types))
+        return funcs[0](*args)
+
+    def find_funcs(self, types):
+        """Yield the appropriate overloaded functions, in order."""
+        func = self.registry.get(types)
+        if func is not None:
+            # Easy case -- direct hit in registry.
+            yield func
+            return
+
+        candidates = [cand
+                      for cand in self.registry
+                      if self.implies(types, cand)]
+
+        if not candidates:
+            # Easy case -- return the default function
+            yield self.default_func
+            return
+
+        if len(candidates) == 1:
+            # Easy case -- return this and the default function
+            yield self.registry[candidates[0]]
+            yield self.default_func
+            return
+
+##         # Perhaps all candidates have the same implementation?
+##         # XXX What do we care?
+##         funcs = set(self.registry[cand] for cand in candidates)
+##         if len(funcs) == 1:
+##             yield funcs.pop()
+##             yield self.default_func
+##             return
+
+        candidates.sort(self.comparator)  # Sort on a partial ordering!
+        while candidates:
+            cand = candidates.pop(0)
+            if all(self.implies(cand, c) for c in candidates):
+                yield self.registry[cand]
+            else:
+                yield self.raise_ambiguity
+                break
+        else:
+            yield self.default_func
+
+    def comparator(self, xs, ys):
+        return self.implies(ys, xs) - self.implies(xs, ys)
+
+    def implies(self, xs, ys):
+        return len(xs) == len(ys) and all(issubclass(x, y)
+                                          for x, y in zip(xs, ys))
+
+    def raise_ambiguity(self, *args):
+        # XXX Should be more specific
+        raise TypeError("ambiguous signature of overloadable function")
+
+    def raise_exhausted(self, *args):
+        # XXX Should be more specific
+        raise TypeError("no remaining candidates for overloadable function")

Modified: sandbox/trunk/abc/test_abc.py
==============================================================================
--- sandbox/trunk/abc/test_abc.py	(original)
+++ sandbox/trunk/abc/test_abc.py	Mon May 14 23:03:55 2007
@@ -76,6 +76,40 @@
         self.assertEqual(42 in a, False)
         self.assertEqual(len(a), 3)
 
+    def test_overloading(self):
+        # Basic 'flatten' example
+        @abc.overloadable
+        def flatten(x):
+            yield x
+        @flatten.overload
+        def _(x: abc.Iterable):
+            for a in x:
+                for b in flatten(a):
+                    yield b
+        @flatten.overload
+        def _(x: basestring):
+            yield x
+        self.assertEqual(list(flatten([1, 2, 3])), [1, 2, 3])
+        self.assertEqual(list(flatten([1,[2],3])), [1, 2, 3])
+        self.assertEqual(list(flatten([1,[2,3]])), [1, 2, 3])
+        self.assertEqual(list(flatten([1,[2,3]])), [1, 2, 3])
+        self.assertEqual(list(flatten([1,"abc",3])), [1, "abc", 3])
+
+        # Add 2-arg version
+        @flatten.overload
+        def _(t: type, x):
+            return t(flatten(x))
+        self.assertEqual(flatten(tuple, [1, 2, 3]), (1, 2, 3))
+        self.assertEqual(flatten(tuple, [1,[2],3]), (1, 2, 3))
+        self.assertEqual(flatten(tuple, [1,"abc",3]), (1, "abc", 3))
+
+        # Change an overload
+        @flatten.overload
+        def flatten(x: basestring):
+            for c in x:
+                yield c
+        self.assertEqual(list(flatten([1, "abc", 3])), [1, "a", "b", "c", 3])
+
 
 if __name__ == "__main__":
     unittest.main()


More information about the Python-checkins mailing list