[pypy-commit] cffi default: issue #255: comparing primitive cdatas

arigo pypy.commits at gmail.com
Sun Feb 19 08:37:50 EST 2017


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r2890:91fcbd69bce2
Date: 2017-02-19 14:37 +0100
http://bitbucket.org/cffi/cffi/changeset/91fcbd69bce2/

Log:	issue #255: comparing primitive cdatas

diff --git a/c/_cffi_backend.c b/c/_cffi_backend.c
--- a/c/_cffi_backend.c
+++ b/c/_cffi_backend.c
@@ -2031,47 +2031,97 @@
 
 static PyObject *cdata_richcompare(PyObject *v, PyObject *w, int op)
 {
-    int res;
+    int v_is_ptr, w_is_ptr;
     PyObject *pyres;
-    char *v_cdata, *w_cdata;
 
     assert(CData_Check(v));
-    if (!CData_Check(w)) {
+
+    /* Comparisons involving a primitive cdata work differently than
+     * comparisons involving a struct/array/pointer.
+     *
+     * If v or w is a struct/array/pointer, then the other must be too
+     * (otherwise we return NotImplemented and leave the case to
+     * Python).  If both are, then we compare the addresses.
+     *
+     * If v and/or w is a primitive cdata, then we convert the cdata(s)
+     * to regular Python objects and redo the comparison there.
+     */
+
+    v_is_ptr = !(((CDataObject *)v)->c_type->ct_flags & CT_PRIMITIVE_ANY);
+    w_is_ptr = CData_Check(w) &&
+                  !(((CDataObject *)w)->c_type->ct_flags & CT_PRIMITIVE_ANY);
+
+    if (v_is_ptr && w_is_ptr) {
+        int res;
+        char *v_cdata = ((CDataObject *)v)->c_data;
+        char *w_cdata = ((CDataObject *)w)->c_data;
+
+        switch (op) {
+        case Py_EQ: res = (v_cdata == w_cdata); break;
+        case Py_NE: res = (v_cdata != w_cdata); break;
+        case Py_LT: res = (v_cdata <  w_cdata); break;
+        case Py_LE: res = (v_cdata <= w_cdata); break;
+        case Py_GT: res = (v_cdata >  w_cdata); break;
+        case Py_GE: res = (v_cdata >= w_cdata); break;
+        default: res = -1;
+        }
+        pyres = res ? Py_True : Py_False;
+    }
+    else if (v_is_ptr || w_is_ptr) {
         pyres = Py_NotImplemented;
-        goto done;
-    }
-
-    if ((op != Py_EQ && op != Py_NE) &&
-        ((((CDataObject *)v)->c_type->ct_flags & CT_PRIMITIVE_ANY) ||
-         (((CDataObject *)w)->c_type->ct_flags & CT_PRIMITIVE_ANY)))
-        goto Error;
-
-    v_cdata = ((CDataObject *)v)->c_data;
-    w_cdata = ((CDataObject *)w)->c_data;
-
-    switch (op) {
-    case Py_EQ: res = (v_cdata == w_cdata); break;
-    case Py_NE: res = (v_cdata != w_cdata); break;
-    case Py_LT: res = (v_cdata <  w_cdata); break;
-    case Py_LE: res = (v_cdata <= w_cdata); break;
-    case Py_GT: res = (v_cdata >  w_cdata); break;
-    case Py_GE: res = (v_cdata >= w_cdata); break;
-    default: res = -1;
-    }
-    pyres = res ? Py_True : Py_False;
- done:
+    }
+    else {
+        PyObject *aa[2];
+        int i;
+
+        aa[0] = v; Py_INCREF(v);
+        aa[1] = w; Py_INCREF(w);
+        pyres = NULL;
+
+        for (i = 0; i < 2; i++) {
+            v = aa[i];
+            if (!CData_Check(v))
+                continue;
+            w = convert_to_object(((CDataObject *)v)->c_data,
+                                  ((CDataObject *)v)->c_type);
+            if (w == NULL)
+                goto error;
+            if (CData_Check(w)) {
+                Py_DECREF(w);
+                PyErr_Format(PyExc_NotImplementedError,
+                             "cannot use <cdata '%s'> in a comparison",
+                             ((CDataObject *)v)->c_type->ct_name);
+                goto error;
+            }
+            aa[i] = w;
+            Py_DECREF(v);
+        }
+        pyres = PyObject_RichCompare(aa[0], aa[1], op);
+     error:
+        Py_DECREF(aa[1]);
+        Py_DECREF(aa[0]);
+        return pyres;
+    }
+
     Py_INCREF(pyres);
     return pyres;
-
- Error:
-    PyErr_SetString(PyExc_TypeError,
-                    "cannot do comparison on a primitive cdata");
-    return NULL;
-}
-
-static long cdata_hash(CDataObject *cd)
-{
-    return _Py_HashPointer(cd->c_data);
+}
+
+static long cdata_hash(CDataObject *v)
+{
+    if (((CDataObject *)v)->c_type->ct_flags & CT_PRIMITIVE_ANY) {
+        PyObject *vv = convert_to_object(((CDataObject *)v)->c_data,
+                                         ((CDataObject *)v)->c_type);
+        if (vv == NULL)
+            return -1;
+        if (!CData_Check(vv)) {
+            long hash = PyObject_Hash(vv);
+            Py_DECREF(vv);
+            return hash;
+        }
+        Py_DECREF(vv);
+    }
+    return _Py_HashPointer(v->c_data);
 }
 
 static Py_ssize_t
diff --git a/c/test_c.py b/c/test_c.py
--- a/c/test_c.py
+++ b/c/test_c.py
@@ -27,6 +27,7 @@
                                        .replace(r'\\U', r'\U'))
     u = U()
     str2bytes = str
+    strict_compare = False
 else:
     type_or_class = "class"
     long = int
@@ -38,6 +39,7 @@
     bitem2bchr = bytechr
     u = ""
     str2bytes = lambda s: bytes(s, "ascii")
+    strict_compare = True
 
 def size_of_int():
     BInt = new_primitive_type("int")
@@ -106,11 +108,11 @@
     x = cast(p, -66 + (1<<199)*256)
     assert repr(x) == "<cdata 'signed char' -66>"
     assert int(x) == -66
-    assert (x == cast(p, -66)) is False
-    assert (x != cast(p, -66)) is True
+    assert (x == cast(p, -66)) is True
+    assert (x != cast(p, -66)) is False
     q = new_primitive_type("short")
-    assert (x == cast(q, -66)) is False
-    assert (x != cast(q, -66)) is True
+    assert (x == cast(q, -66)) is True
+    assert (x != cast(q, -66)) is False
 
 def test_sizeof_type():
     py.test.raises(TypeError, sizeof, 42.5)
@@ -175,7 +177,7 @@
             assert float(cast(p, 1.1)) != 1.1     # rounding error
             assert float(cast(p, 1E200)) == INF   # limited range
 
-        assert cast(p, -1.1) != cast(p, -1.1)
+        assert cast(p, -1.1) == cast(p, -1.1)
         assert repr(float(cast(p, -0.0))) == '-0.0'
         assert float(cast(p, b'\x09')) == 9.0
         assert float(cast(p, u+'\x09')) == 9.0
@@ -219,7 +221,7 @@
     p = new_primitive_type("char")
     assert bool(cast(p, 'A')) is True
     assert bool(cast(p, '\x00')) is False    # since 1.7
-    assert cast(p, '\x00') != cast(p, -17*256)
+    assert cast(p, '\x00') == cast(p, -17*256)
     assert int(cast(p, 'A')) == 65
     assert long(cast(p, 'A')) == 65
     assert type(int(cast(p, 'A'))) is int
@@ -382,23 +384,6 @@
     # that it is already loaded too, so it should work
     assert x.load_function(BVoidP, 'sqrt')
 
-def test_hash_differences():
-    BChar = new_primitive_type("char")
-    BInt = new_primitive_type("int")
-    BFloat = new_primitive_type("float")
-    for i in range(1, 20):
-        x1 = cast(BChar, chr(i))
-        x2 = cast(BInt, i)
-        if hash(x1) != hash(x2):
-            break
-    else:
-        raise AssertionError("hashes are equal")
-    for i in range(1, 20):
-        if hash(cast(BFloat, i)) != hash(float(i)):
-            break
-    else:
-        raise AssertionError("hashes are equal")
-
 def test_no_len_on_nonarray():
     p = new_primitive_type("int")
     py.test.raises(TypeError, len, cast(p, 42))
@@ -2261,12 +2246,17 @@
     BVoidP = new_pointer_type(new_void_type())
     p = newp(BIntP, 123)
     q = cast(BInt, 124)
-    py.test.raises(TypeError, "p < q")
-    py.test.raises(TypeError, "p <= q")
     assert (p == q) is False
     assert (p != q) is True
-    py.test.raises(TypeError, "p > q")
-    py.test.raises(TypeError, "p >= q")
+    assert (q == p) is False
+    assert (q != p) is True
+    if strict_compare:
+        py.test.raises(TypeError, "p < q")
+        py.test.raises(TypeError, "p <= q")
+        py.test.raises(TypeError, "q < p")
+        py.test.raises(TypeError, "q <= p")
+        py.test.raises(TypeError, "p > q")
+        py.test.raises(TypeError, "p >= q")
     r = cast(BVoidP, p)
     assert (p <  r) is False
     assert (p <= r) is True
@@ -3840,3 +3830,86 @@
         assert len(w) == 2
     # check that the warnings are associated with lines in this file
     assert w[1].lineno == w[0].lineno + 4
+
+def test_primitive_comparison():
+    def assert_eq(a, b):
+        assert (a == b) is True
+        assert (b == a) is True
+        assert (a != b) is False
+        assert (b != a) is False
+        assert (a < b) is False
+        assert (a <= b) is True
+        assert (a > b) is False
+        assert (a >= b) is True
+        assert (b < a) is False
+        assert (b <= a) is True
+        assert (b > a) is False
+        assert (b >= a) is True
+        assert hash(a) == hash(b)
+    def assert_lt(a, b):
+        assert (a == b) is False
+        assert (b == a) is False
+        assert (a != b) is True
+        assert (b != a) is True
+        assert (a < b) is True
+        assert (a <= b) is True
+        assert (a > b) is False
+        assert (a >= b) is False
+        assert (b < a) is False
+        assert (b <= a) is False
+        assert (b > a) is True
+        assert (b >= a) is True
+        assert hash(a) != hash(b)    # (or at least, it is unlikely)
+    def assert_gt(a, b):
+        assert_lt(b, a)
+    def assert_ne(a, b):
+        assert (a == b) is False
+        assert (b == a) is False
+        assert (a != b) is True
+        assert (b != a) is True
+        if strict_compare:
+            py.test.raises(TypeError, "a < b")
+            py.test.raises(TypeError, "a <= b")
+            py.test.raises(TypeError, "a > b")
+            py.test.raises(TypeError, "a >= b")
+            py.test.raises(TypeError, "b < a")
+            py.test.raises(TypeError, "b <= a")
+            py.test.raises(TypeError, "b > a")
+            py.test.raises(TypeError, "b >= a")
+        elif a < b:
+            assert_lt(a, b)
+        else:
+            assert_lt(b, a)
+    assert_eq(5, 5)
+    assert_lt(3, 5)
+    assert_ne('5', 5)
+    #
+    t1 = new_primitive_type("char")
+    t2 = new_primitive_type("int")
+    t3 = new_primitive_type("unsigned char")
+    t4 = new_primitive_type("unsigned int")
+    t5 = new_primitive_type("float")
+    t6 = new_primitive_type("double")
+    assert_eq(cast(t1, 65), b'A')
+    assert_lt(cast(t1, 64), b'\x99')
+    assert_gt(cast(t1, 200), b'A')
+    assert_ne(cast(t1, 65), 65)
+    assert_eq(cast(t2, -25), -25)
+    assert_lt(cast(t2, -25), -24)
+    assert_gt(cast(t2, -25), -26)
+    assert_eq(cast(t3, 65), 65)
+    assert_ne(cast(t3, 65), b'A')
+    assert_ne(cast(t3, 65), cast(t1, 65))
+    assert_gt(cast(t4, -1), -1)
+    assert_gt(cast(t4, -1), cast(t2, -1))
+    assert_gt(cast(t4, -1), 99999)
+    assert_eq(cast(t4, -1), 256 ** size_of_int() - 1)
+    assert_eq(cast(t5, 3.0), 3)
+    assert_eq(cast(t5, 3.5), 3.5)
+    assert_lt(cast(t5, 3.3), 3.3)   # imperfect rounding
+    assert_eq(cast(t6, 3.3), 3.3)
+    assert_eq(cast(t5, 3.5), cast(t6, 3.5))
+    assert_lt(cast(t5, 3.1), cast(t6, 3.1))   # imperfect rounding
+    assert_eq(cast(t5, 7.0), cast(t3, 7))
+    assert_lt(cast(t5, 3.1), 3.101)
+    assert_gt(cast(t5, 3.1), 3)
diff --git a/cffi/backend_ctypes.py b/cffi/backend_ctypes.py
--- a/cffi/backend_ctypes.py
+++ b/cffi/backend_ctypes.py
@@ -112,11 +112,20 @@
     def _make_cmp(name):
         cmpfunc = getattr(operator, name)
         def cmp(self, other):
-            if isinstance(other, CTypesData):
+            v_is_ptr = not isinstance(self, CTypesGenericPrimitive)
+            w_is_ptr = (isinstance(other, CTypesData) and
+                           not isinstance(other, CTypesGenericPrimitive))
+            if v_is_ptr and w_is_ptr:
                 return cmpfunc(self._convert_to_address(None),
                                other._convert_to_address(None))
+            elif v_is_ptr or w_is_ptr:
+                return NotImplemented
             else:
-                return NotImplemented
+                if isinstance(self, CTypesGenericPrimitive):
+                    self = self._value
+                if isinstance(other, CTypesGenericPrimitive):
+                    other = other._value
+                return cmpfunc(self, other)
         cmp.func_name = name
         return cmp
 
@@ -128,7 +137,7 @@
     __ge__ = _make_cmp('__ge__')
 
     def __hash__(self):
-        return hash(type(self)) ^ hash(self._convert_to_address(None))
+        return hash(self._convert_to_address(None))
 
     def _to_string(self, maxlen):
         raise TypeError("string(): %r" % (self,))
@@ -137,14 +146,8 @@
 class CTypesGenericPrimitive(CTypesData):
     __slots__ = []
 
-    def __eq__(self, other):
-        return self is other
-
-    def __ne__(self, other):
-        return self is not other
-
     def __hash__(self):
-        return object.__hash__(self)
+        return hash(self._value)
 
     def _get_own_repr(self):
         return repr(self._from_ctypes(self._value))
diff --git a/doc/source/ref.rst b/doc/source/ref.rst
--- a/doc/source/ref.rst
+++ b/doc/source/ref.rst
@@ -602,21 +602,21 @@
 |    C type     |   writing into         | reading from     |other operations|
 +===============+========================+==================+================+
 |   integers    | an integer or anything | a Python int or  | int(), bool()  |
-|   and enums   | on which int() works   | long, depending  | `(******)`     |
-|   `(*****)`   | (but not a float!).    | on the type      |                |
+|   and enums   | on which int() works   | long, depending  | `(******)`,    |
+|   `(*****)`   | (but not a float!).    | on the type      | ``<``          |
 |               | Must be within range.  | (ver. 1.10: or a |                |
 |               |                        | bool)            |                |
 +---------------+------------------------+------------------+----------------+
-|   ``char``    | a string of length 1   | a string of      | int(), bool()  |
-|               | or another <cdata char>| length 1         |                |
+|   ``char``    | a string of length 1   | a string of      | int(), bool(), |
+|               | or another <cdata char>| length 1         | ``<``          |
 +---------------+------------------------+------------------+----------------+
 |  ``wchar_t``  | a unicode of length 1  | a unicode of     |                |
-|               | (or maybe 2 if         | length 1         | int(), bool()  |
-|               | surrogates) or         | (or maybe 2 if   |                |
+|               | (or maybe 2 if         | length 1         | int(), bool(), |
+|               | surrogates) or         | (or maybe 2 if   | ``<``          |
 |               | another <cdata wchar_t>| surrogates)      |                |
 +---------------+------------------------+------------------+----------------+
 |  ``float``,   | a float or anything on | a Python float   | float(), int(),|
-|  ``double``   | which float() works    |                  | bool()         |
+|  ``double``   | which float() works    |                  | bool(), ``<``  |
 +---------------+------------------------+------------------+----------------+
 |``long double``| another <cdata> with   | a <cdata>, to    | float(), int(),|
 |               | a ``long double``, or  | avoid loosing    | bool()         |
diff --git a/doc/source/whatsnew.rst b/doc/source/whatsnew.rst
--- a/doc/source/whatsnew.rst
+++ b/doc/source/whatsnew.rst
@@ -50,6 +50,13 @@
   only in out-of-line mode.  This is useful for taking the address of
   global variables.
 
+* Issue #255: ``cdata`` objects of a primitive type (integers, floats,
+  char) are now compared and ordered by value.  For example, ``<cdata
+  'int' 42>`` compares equal to ``42`` and ``<cdata 'char' b'A'>``
+  compares equal to ``b'A'``.  Unlike C, ``<cdata 'int' -1>`` does not
+  compare equal to ``ffi.cast("unsigned int", -1)``: it compares
+  smaller, because ``-1 < 4294967295``.
+
 
 v1.9
 ====
diff --git a/testing/cffi0/backend_tests.py b/testing/cffi0/backend_tests.py
--- a/testing/cffi0/backend_tests.py
+++ b/testing/cffi0/backend_tests.py
@@ -54,7 +54,8 @@
         min = int(min)
         max = int(max)
         p = ffi.cast(c_decl, min)
-        assert p != min       # no __eq__(int)
+        assert p == min
+        assert hash(p) == hash(min)
         assert bool(p) is bool(min)
         assert int(p) == min
         p = ffi.cast(c_decl, max)
@@ -65,9 +66,9 @@
         assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max
         q = ffi.cast(c_decl, long(min - 1))
         assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max
-        assert q != p
+        assert q == p
         assert int(q) == int(p)
-        assert hash(q) != hash(p)   # unlikely
+        assert hash(q) == hash(p)
         c_decl_ptr = '%s *' % c_decl
         py.test.raises(OverflowError, ffi.new, c_decl_ptr, min - 1)
         py.test.raises(OverflowError, ffi.new, c_decl_ptr, max + 1)
@@ -882,9 +883,9 @@
         assert ffi.string(ffi.cast("enum bar", -2)) == "B1"
         assert ffi.string(ffi.cast("enum bar", -1)) == "CC1"
         assert ffi.string(ffi.cast("enum bar", 1)) == "E1"
-        assert ffi.cast("enum bar", -2) != ffi.cast("enum bar", -2)
-        assert ffi.cast("enum foo", 0) != ffi.cast("enum bar", 0)
-        assert ffi.cast("enum bar", 0) != ffi.cast("int", 0)
+        assert ffi.cast("enum bar", -2) == ffi.cast("enum bar", -2)
+        assert ffi.cast("enum foo", 0) == ffi.cast("enum bar", 0)
+        assert ffi.cast("enum bar", 0) == ffi.cast("int", 0)
         assert repr(ffi.cast("enum bar", -1)) == "<cdata 'enum bar' -1: CC1>"
         assert repr(ffi.cast("enum foo", -1)) == (  # enums are unsigned, if
             "<cdata 'enum foo' 4294967295>")        # they contain no neg value
@@ -1113,15 +1114,15 @@
         assert (q == None) is False
         assert (q != None) is True
 
-    def test_no_integer_comparison(self):
+    def test_integer_comparison(self):
         ffi = FFI(backend=self.Backend())
         x = ffi.cast("int", 123)
         y = ffi.cast("int", 456)
-        py.test.raises(TypeError, "x < y")
+        assert x < y
         #
         z = ffi.cast("double", 78.9)
-        py.test.raises(TypeError, "x < z")
-        py.test.raises(TypeError, "z < y")
+        assert x > z
+        assert y > z
 
     def test_ffi_buffer_ptr(self):
         ffi = FFI(backend=self.Backend())
diff --git a/testing/cffi1/test_new_ffi_1.py b/testing/cffi1/test_new_ffi_1.py
--- a/testing/cffi1/test_new_ffi_1.py
+++ b/testing/cffi1/test_new_ffi_1.py
@@ -137,7 +137,7 @@
         min = int(min)
         max = int(max)
         p = ffi.cast(c_decl, min)
-        assert p != min       # no __eq__(int)
+        assert p == min
         assert bool(p) is bool(min)
         assert int(p) == min
         p = ffi.cast(c_decl, max)
@@ -148,9 +148,9 @@
         assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max
         q = ffi.cast(c_decl, long(min - 1))
         assert ffi.typeof(q) is ffi.typeof(p) and int(q) == max
-        assert q != p
+        assert q == p
         assert int(q) == int(p)
-        assert hash(q) != hash(p)   # unlikely
+        assert hash(q) == hash(p)
         c_decl_ptr = '%s *' % c_decl
         py.test.raises(OverflowError, ffi.new, c_decl_ptr, min - 1)
         py.test.raises(OverflowError, ffi.new, c_decl_ptr, max + 1)
@@ -896,9 +896,9 @@
         assert ffi.string(ffi.cast("enum bar", -2)) == "B1"
         assert ffi.string(ffi.cast("enum bar", -1)) == "CC1"
         assert ffi.string(ffi.cast("enum bar", 1)) == "E1"
-        assert ffi.cast("enum bar", -2) != ffi.cast("enum bar", -2)
-        assert ffi.cast("enum foq", 0) != ffi.cast("enum bar", 0)
-        assert ffi.cast("enum bar", 0) != ffi.cast("int", 0)
+        assert ffi.cast("enum bar", -2) == ffi.cast("enum bar", -2)
+        assert ffi.cast("enum foq", 0) == ffi.cast("enum bar", 0)
+        assert ffi.cast("enum bar", 0) == ffi.cast("int", 0)
         assert repr(ffi.cast("enum bar", -1)) == "<cdata 'enum bar' -1: CC1>"
         assert repr(ffi.cast("enum foq", -1)) == (  # enums are unsigned, if
             "<cdata 'enum foq' 4294967295>") or (   # they contain no neg value
@@ -1105,14 +1105,14 @@
         assert (q == None) is False
         assert (q != None) is True
 
-    def test_no_integer_comparison(self):
+    def test_integer_comparison(self):
         x = ffi.cast("int", 123)
         y = ffi.cast("int", 456)
-        py.test.raises(TypeError, "x < y")
+        assert x < y
         #
         z = ffi.cast("double", 78.9)
-        py.test.raises(TypeError, "x < z")
-        py.test.raises(TypeError, "z < y")
+        assert x > z
+        assert y > z
 
     def test_ffi_buffer_ptr(self):
         a = ffi.new("short *", 100)


More information about the pypy-commit mailing list