[Jython-checkins] jython (merge default -> default): Merged

jim.baker jython-checkins at python.org
Thu May 22 22:02:47 CEST 2014


http://hg.python.org/jython/rev/77dcd0a92088
changeset:   7270:77dcd0a92088
parent:      7264:06161ebf74ee
parent:      7269:374dd2f0a5ef
user:        Jim Baker <jim.baker at rackspace.com>
date:        Thu May 22 13:50:00 2014 -0600
summary:
  Merged https://bitbucket.org/jython/jython/pull-request/39/2088-fix-defaultdict-so-that-derived/

files:
  Lib/test/test_defaultdict.py                                  |   59 ++++-
  Lib/test/test_defaultdict_jy.py                               |  104 +++++++--
  src/org/python/modules/_collections/PyDefaultDict.java        |   27 +-
  src/org/python/modules/_collections/PyDefaultDictDerived.java |    9 +
  src/templates/defaultdict.derived                             |   11 +-
  5 files changed, 155 insertions(+), 55 deletions(-)


diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py
--- a/Lib/test/test_defaultdict.py
+++ b/Lib/test/test_defaultdict.py
@@ -24,21 +24,21 @@
         d1[13]
         d1[14]
         self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
-        self.assert_(d1[12] is not d1[13] is not d1[14])
+        self.assertTrue(d1[12] is not d1[13] is not d1[14])
         d2 = defaultdict(list, foo=1, bar=2)
         self.assertEqual(d2.default_factory, list)
         self.assertEqual(d2, {"foo": 1, "bar": 2})
         self.assertEqual(d2["foo"], 1)
         self.assertEqual(d2["bar"], 2)
         self.assertEqual(d2[42], [])
-        self.assert_("foo" in d2)
-        self.assert_("foo" in d2.keys())
-        self.assert_("bar" in d2)
-        self.assert_("bar" in d2.keys())
-        self.assert_(42 in d2)
-        self.assert_(42 in d2.keys())
-        self.assert_(12 not in d2)
-        self.assert_(12 not in d2.keys())
+        self.assertIn("foo", d2)
+        self.assertIn("foo", d2.keys())
+        self.assertIn("bar", d2)
+        self.assertIn("bar", d2.keys())
+        self.assertIn(42, d2)
+        self.assertIn(42, d2.keys())
+        self.assertNotIn(12, d2)
+        self.assertNotIn(12, d2.keys())
         d2.default_factory = None
         self.assertEqual(d2.default_factory, None)
         try:
@@ -59,6 +59,7 @@
         d1 = defaultdict()
         self.assertEqual(d1.default_factory, None)
         self.assertEqual(repr(d1), "defaultdict(None, {})")
+        self.assertEqual(eval(repr(d1)), d1)
         d1[11] = 41
         self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
         d2 = defaultdict(int)
@@ -67,7 +68,7 @@
         self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
         def foo(): return 43
         d3 = defaultdict(foo)
-        self.assert_(d3.default_factory is foo)
+        self.assertTrue(d3.default_factory is foo)
         d3[13]
         self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
 
@@ -111,6 +112,12 @@
         d4[12]
         self.assertEqual(d4, {42: [], 12: []})
 
+        # Issue 6637: Copy fails for empty default dict
+        d = defaultdict()
+        d['a'] = 42
+        e = d.copy()
+        self.assertEqual(e['a'], 42)
+
     def test_shallow_copy(self):
         d1 = defaultdict(foobar, {1: 1})
         d2 = copy.copy(d1)
@@ -126,7 +133,7 @@
         d2 = copy.deepcopy(d1)
         self.assertEqual(d2.default_factory, foobar)
         self.assertEqual(d2, d1)
-        self.assert_(d1[1] is not d2[1])
+        self.assertTrue(d1[1] is not d2[1])
         d1.default_factory = list
         d2 = copy.deepcopy(d1)
         self.assertEqual(d2.default_factory, list)
@@ -137,12 +144,36 @@
         try:
             d1[(1,)]
         except KeyError, err:
-            # XXX: Exception.message is only supported in Python 2.5
-            #self.assertEqual(err.message, (1,))
-            pass
+            self.assertEqual(err.args[0], (1,))
         else:
             self.fail("expected KeyError")
 
+    @unittest.skipIf(test_support.is_jython, "FIXME: incorrect formatting of default_factory when it's a bound method")
+    def test_recursive_repr(self):
+        # Issue2045: stack overflow when default_factory is a bound method
+        class sub(defaultdict):
+            def __init__(self):
+                self.default_factory = self._factory
+            def _factory(self):
+                return []
+        d = sub()
+        self.assertTrue(repr(d).startswith(
+            "defaultdict(<bound method sub._factory of defaultdict(..."))
+
+        # NOTE: printing a subclass of a builtin type does not call its
+        # tp_print slot. So this part is essentially the same test as above.
+        tfn = tempfile.mktemp()
+        try:
+            f = open(tfn, "w+")
+            try:
+                print >>f, d
+            finally:
+                f.close()
+        finally:
+            os.remove(tfn)
+
+    def test_callable_arg(self):
+        self.assertRaises(TypeError, defaultdict, {})
 
 def test_main():
     test_support.run_unittest(TestDefaultDict)
diff --git a/Lib/test/test_defaultdict_jy.py b/Lib/test/test_defaultdict_jy.py
--- a/Lib/test/test_defaultdict_jy.py
+++ b/Lib/test/test_defaultdict_jy.py
@@ -38,30 +38,29 @@
         for t in threads:
             self.assertFalse(t.isAlive())
 
+    class Counter(object):
+        def __init__(self, initial=0):
+            self.atomic = AtomicInteger(initial)
+             # waiting is important here to ensure that
+             # defaultdict factories can step on each other
+            time.sleep(0.001)
+
+        def decrementAndGet(self):
+            return self.atomic.decrementAndGet()
+
+        def incrementAndGet(self):
+            return self.atomic.incrementAndGet()
+
+        def get(self):
+            return self.atomic.get()
+
+        def __repr__(self):
+            return "Counter<%s>" % (self.atomic.get())
+
     def test_inc_dec(self):
+        counters = defaultdict(ThreadSafetyTestCase.Counter)
+        size = 17
 
-        class Counter(object):
-            def __init__(self):
-                self.atomic = AtomicInteger()
-                 # waiting is important here to ensure that
-                 # defaultdict factories can step on each other
-                time.sleep(0.001)
-
-            def decrementAndGet(self):
-                return self.atomic.decrementAndGet()
-
-            def incrementAndGet(self):
-                return self.atomic.incrementAndGet()
-
-            def get(self):
-                return self.atomic.get()
-
-            def __repr__(self):
-                return "Counter<%s>" % (self.atomic.get())
-
-        counters = defaultdict(Counter)
-        size = 17
-        
         def tester():
             for i in xrange(1000):
                 j = (i + randint(0, size)) % size
@@ -70,10 +69,36 @@
                 counters[j].incrementAndGet()
 
         self.run_threads(tester, 20)
-        
+
         for i in xrange(size):
             self.assertEqual(counters[i].get(), 0, counters)
 
+    def test_derived_inc_dec(self):
+        class DerivedDefaultDict(defaultdict):
+            def __missing__(self, key):
+                if self.default_factory is None:
+                    raise KeyError("Invalid key '{0}' and no default factory was set")
+
+                val = self.default_factory(key)
+
+                self[key] = val
+                return val
+
+        counters = DerivedDefaultDict(lambda key: ThreadSafetyTestCase.Counter(key))
+        size = 17
+
+        def tester():
+            for i in xrange(1000):
+                j = (i + randint(0, size)) % size
+                counters[j].decrementAndGet()
+                time.sleep(0.0001)
+                counters[j].incrementAndGet()
+
+        self.run_threads(tester, 20)
+
+        for i in xrange(size):
+            self.assertEqual(counters[i].get(), i, counters)
+
 class GetVariantsTestCase(unittest.TestCase):
 
     #http://bugs.jython.org/issue2133
@@ -94,8 +119,39 @@
         self.assertEquals(d.items(), [("vivify", [])]) 
 
 
+
+class OverrideMissingTestCase(unittest.TestCase):
+    class KeyDefaultDict(defaultdict):
+        """defaultdict to pass the requested key to factory function."""
+        def __missing__(self, key):
+            if self.default_factory is None:
+                raise KeyError("Invalid key '{0}' and no default factory was set")
+            else:
+                val = self.default_factory(key)
+
+            self[key] = val
+            return val
+
+        @classmethod
+        def double(cls, k):
+            return k + k
+
+    def setUp(self):
+        self.kdd = OverrideMissingTestCase.KeyDefaultDict(OverrideMissingTestCase.KeyDefaultDict.double)
+
+    def test_dont_call_derived_missing(self):
+        self.kdd[3] = 5
+        self.assertEquals(self.kdd[3], 5)
+
+    #http://bugs.jython.org/issue2088
+    def test_override_missing(self):
+        # line below causes KeyError in Jython, ignoring overridden __missing__ method
+        self.assertEquals(self.kdd[3], 6)
+        self.assertEquals(self.kdd['ab'], 'abab')
+
+
 def test_main():
-    test_support.run_unittest(PickleTestCase, ThreadSafetyTestCase, GetVariantsTestCase)
+    test_support.run_unittest(PickleTestCase, ThreadSafetyTestCase, GetVariantsTestCase, OverrideMissingTestCase)
 
 
 if __name__ == '__main__':
diff --git a/src/org/python/modules/_collections/PyDefaultDict.java b/src/org/python/modules/_collections/PyDefaultDict.java
--- a/src/org/python/modules/_collections/PyDefaultDict.java
+++ b/src/org/python/modules/_collections/PyDefaultDict.java
@@ -58,13 +58,9 @@
         backingMap = CacheBuilder.newBuilder().build(
                 new CacheLoader<PyObject, PyObject>() {
                     public PyObject load(PyObject key) {
-                        if (defaultFactory == Py.None) {
-                            throw Py.KeyError(key);
-                        }
-                        return defaultFactory.__call__();
+                        return __missing__(key);
                     }
-                }
-        );
+                });
     }
 
     public PyDefaultDict(PyType subtype, Map<PyObject, PyObject> map) {
@@ -78,7 +74,7 @@
         int nargs = args.length - kwds.length;
         if (nargs != 0) {
             defaultFactory = args[0];
-            if (!defaultFactory.isCallable()) {
+            if (!(defaultFactory == Py.None || defaultFactory.isCallable())) {
                 throw Py.TypeError("first argument must be callable");
             }
             PyObject newargs[] = new PyObject[args.length - 1];
@@ -87,22 +83,21 @@
         }
     }
 
+    public PyObject __missing__(PyObject key) {
+        return defaultdict___missing__(key);
+    }
+
     /**
-     * This method is NOT called by the __getitem__ method of the dict class when the
-     * requested key is not found. It is simply here as an alternative to the atomic
-     * construction of that factory. (We actually inline it in.)
+     * This method does NOT call __setitem__ instead it relies on the fact that it is
+     * called within the context of `CacheLoader#load` to actually insert the value
+     * into the dict.
      */
     @ExposedMethod
     final PyObject defaultdict___missing__(PyObject key) {
         if (defaultFactory == Py.None) {
             throw Py.KeyError(key);
         }
-        PyObject value = defaultFactory.__call__();
-        if (value == null) {
-            return value;
-        }
-        __setitem__(key, value);
-        return value;
+        return defaultFactory.__call__();
     }
 
     @Override
diff --git a/src/org/python/modules/_collections/PyDefaultDictDerived.java b/src/org/python/modules/_collections/PyDefaultDictDerived.java
--- a/src/org/python/modules/_collections/PyDefaultDictDerived.java
+++ b/src/org/python/modules/_collections/PyDefaultDictDerived.java
@@ -1122,6 +1122,15 @@
         return super.__coerce_ex__(o);
     }
 
+    public PyObject __missing__(PyObject key) {
+        PyType self_type=getType();
+        PyObject impl=self_type.lookup("__missing__");
+        if (impl!=null) {
+            return impl.__get__(this,self_type).__call__(key);
+        }
+        return super.__missing__(key);
+    }
+
     public String toString() {
         PyType self_type=getType();
         PyObject impl=self_type.lookup("__repr__");
diff --git a/src/templates/defaultdict.derived b/src/templates/defaultdict.derived
--- a/src/templates/defaultdict.derived
+++ b/src/templates/defaultdict.derived
@@ -1,4 +1,13 @@
 base_class: PyDefaultDict
 want_dict: true
 ctr:
-incl: object
+incl: dict
+rest:
+    public PyObject __missing__(PyObject key) {
+        PyType self_type=getType();
+        PyObject impl=self_type.lookup("__missing__");
+        if (impl!=null) {
+            return impl.__get__(this,self_type).__call__(key);
+        }
+        return super.__missing__(key);
+    }

-- 
Repository URL: http://hg.python.org/jython


More information about the Jython-checkins mailing list