[pypy-commit] pypy reflex-support: allow operator== to be defined virtual

wlav noreply at buildbot.pypy.org
Thu Apr 18 23:37:37 CEST 2013


Author: Wim Lavrijsen <WLavrijsen at lbl.gov>
Branch: reflex-support
Changeset: r63491:6025b73044af
Date: 2013-04-18 12:20 -0700
http://bitbucket.org/pypy/pypy/changeset/6025b73044af/

Log:	allow operator== to be defined virtual

diff --git a/pypy/module/cppyy/pythonify.py b/pypy/module/cppyy/pythonify.py
--- a/pypy/module/cppyy/pythonify.py
+++ b/pypy/module/cppyy/pythonify.py
@@ -313,9 +313,11 @@
     if '__eq__' in pyclass.__dict__:
         def __eq__(self, other):
             if other is None: return not self
-            if type(self) is not type(other): return False
             if not self and not other: return True
-            return self._cxx_eq(other)
+            try:
+                return self._cxx_eq(other)
+            except TypeError:
+                return NotImplemented
         pyclass._cxx_eq = pyclass.__dict__['__eq__']
         pyclass.__eq__ = __eq__
 
diff --git a/pypy/module/cppyy/test/operators.cxx b/pypy/module/cppyy/test/operators.cxx
--- a/pypy/module/cppyy/test/operators.cxx
+++ b/pypy/module/cppyy/test/operators.cxx
@@ -1,1 +1,16 @@
 #include "operators.h"
+
+// for testing the case of virtual operator==
+v_opeq_base::v_opeq_base(int val) : m_val(val) {}
+v_opeq_base::~v_opeq_base() {}
+
+bool v_opeq_base::operator==(const v_opeq_base& other) {
+   return m_val == other.m_val;
+}
+
+v_opeq_derived::v_opeq_derived(int val) : v_opeq_base(val) {}
+v_opeq_derived::~v_opeq_derived() {}
+
+bool v_opeq_derived::operator==(const v_opeq_derived& other) {
+   return m_val != other.m_val;
+}
diff --git a/pypy/module/cppyy/test/operators.h b/pypy/module/cppyy/test/operators.h
--- a/pypy/module/cppyy/test/operators.h
+++ b/pypy/module/cppyy/test/operators.h
@@ -93,3 +93,23 @@
    operator float() { return m_float; }
    float m_float;
 };
+
+//----------------------------------------------------------------------------
+class v_opeq_base {
+public:
+   v_opeq_base(int val);
+   virtual ~v_opeq_base();
+
+   virtual bool operator==(const v_opeq_base& other);
+
+protected:
+   int m_val;
+};
+
+class v_opeq_derived : public v_opeq_base {
+public:
+   v_opeq_derived(int val);
+   virtual ~v_opeq_derived();
+
+   virtual bool operator==(const v_opeq_derived& other);
+};
diff --git a/pypy/module/cppyy/test/operators.xml b/pypy/module/cppyy/test/operators.xml
--- a/pypy/module/cppyy/test/operators.xml
+++ b/pypy/module/cppyy/test/operators.xml
@@ -2,5 +2,6 @@
 
   <class name="number" />
   <class pattern="operator_*" />
+  <class pattern="v_opeq_*" />
 
 </lcgdict>
diff --git a/pypy/module/cppyy/test/test_operators.py b/pypy/module/cppyy/test/test_operators.py
--- a/pypy/module/cppyy/test/test_operators.py
+++ b/pypy/module/cppyy/test/test_operators.py
@@ -136,3 +136,43 @@
         o = gbl.operator_float(); o.m_float = 3.14
         assert round(o.m_float - 3.14, 5) == 0.
         assert round(float(o) - 3.14, 5)  == 0.
+
+    def test07_virtual_operator_eq(self):
+        """Test use of virtual bool operator=="""
+
+        import cppyy
+
+        b1  = cppyy.gbl.v_opeq_base(1)
+        b1a = cppyy.gbl.v_opeq_base(1)
+        b2  = cppyy.gbl.v_opeq_base(2)
+        b2a = cppyy.gbl.v_opeq_base(2)
+
+        assert b1 == b1
+        assert b1 == b1a
+        assert not b1 == b2
+        assert not b1 == b2a
+        assert b2 == b2
+        assert b2 == b2a
+
+        d1  = cppyy.gbl.v_opeq_derived(1)
+        d1a = cppyy.gbl.v_opeq_derived(1)
+        d2  = cppyy.gbl.v_opeq_derived(2)
+        d2a = cppyy.gbl.v_opeq_derived(2)
+
+        # derived operator== returns opposite
+        assert not d1 == d1
+        assert not d1 == d1a
+        assert d1 == d2
+        assert d1 == d2a
+        assert not d2 == d2
+        assert not d2 == d2a
+
+        # the following is a wee bit interesting due to python resolution
+        # rules on the one hand, and C++ inheritance on the other: python
+        # will never select the derived comparison b/c the call will fail
+        # to pass a base through a const derived&
+        assert b1 == d1
+        assert d1 == b1
+        assert not b1 == d2
+        assert not d2 == b1
+        


More information about the pypy-commit mailing list