[Python-checkins] gh-104223: Fix issues with inheriting from buffer classes (#104227)

JelleZijlstra webhook-mailer at python.org
Mon May 8 12:52:49 EDT 2023


https://github.com/python/cpython/commit/405eacc1b87a42e19fd176131e70537f0539e05e
commit: 405eacc1b87a42e19fd176131e70537f0539e05e
branch: main
author: Jelle Zijlstra <jelle.zijlstra at gmail.com>
committer: JelleZijlstra <jelle.zijlstra at gmail.com>
date: 2023-05-08T09:52:41-07:00
summary:

gh-104223: Fix issues with inheriting from buffer classes (#104227)

Co-authored-by: Kumar Aditya <59607654+kumaraditya303 at users.noreply.github.com>

files:
M Include/cpython/memoryobject.h
M Include/internal/pycore_memoryobject.h
M Lib/test/test_buffer.py
M Objects/bytearrayobject.c
M Objects/memoryobject.c
M Objects/typeobject.c

diff --git a/Include/cpython/memoryobject.h b/Include/cpython/memoryobject.h
index deab3cc89f72..3837fa8c6ab5 100644
--- a/Include/cpython/memoryobject.h
+++ b/Include/cpython/memoryobject.h
@@ -24,6 +24,7 @@ typedef struct {
 #define _Py_MEMORYVIEW_FORTRAN     0x004  /* Fortran contiguous layout */
 #define _Py_MEMORYVIEW_SCALAR      0x008  /* scalar: ndim = 0 */
 #define _Py_MEMORYVIEW_PIL         0x010  /* PIL-style layout */
+#define _Py_MEMORYVIEW_RESTRICTED  0x020  /* Disallow new references to the memoryview's buffer */
 
 typedef struct {
     PyObject_VAR_HEAD
diff --git a/Include/internal/pycore_memoryobject.h b/Include/internal/pycore_memoryobject.h
index acc12c927517..fe19e3f9611a 100644
--- a/Include/internal/pycore_memoryobject.h
+++ b/Include/internal/pycore_memoryobject.h
@@ -9,7 +9,8 @@ extern "C" {
 #endif
 
 PyObject *
-PyMemoryView_FromObjectAndFlags(PyObject *v, int flags);
+_PyMemoryView_FromBufferProc(PyObject *v, int flags,
+                             getbufferproc bufferproc);
 
 #ifdef __cplusplus
 }
diff --git a/Lib/test/test_buffer.py b/Lib/test/test_buffer.py
index b6e82ad4db26..2c65ae811481 100644
--- a/Lib/test/test_buffer.py
+++ b/Lib/test/test_buffer.py
@@ -4579,6 +4579,176 @@ def test_c_buffer(self):
             buf.__release_buffer__(mv)
         self.assertEqual(buf.references, 0)
 
+    def test_inheritance(self):
+        class A(bytearray):
+            def __buffer__(self, flags):
+                return super().__buffer__(flags)
+
+        a = A(b"hello")
+        mv = memoryview(a)
+        self.assertEqual(mv.tobytes(), b"hello")
+
+    def test_inheritance_releasebuffer(self):
+        rb_call_count = 0
+        class B(bytearray):
+            def __buffer__(self, flags):
+                return super().__buffer__(flags)
+            def __release_buffer__(self, view):
+                nonlocal rb_call_count
+                rb_call_count += 1
+                super().__release_buffer__(view)
+
+        b = B(b"hello")
+        with memoryview(b) as mv:
+            self.assertEqual(mv.tobytes(), b"hello")
+            self.assertEqual(rb_call_count, 0)
+        self.assertEqual(rb_call_count, 1)
+
+    def test_inherit_but_return_something_else(self):
+        class A(bytearray):
+            def __buffer__(self, flags):
+                return memoryview(b"hello")
+
+        a = A(b"hello")
+        with memoryview(a) as mv:
+            self.assertEqual(mv.tobytes(), b"hello")
+
+        rb_call_count = 0
+        rb_raised = False
+        class B(bytearray):
+            def __buffer__(self, flags):
+                return memoryview(b"hello")
+            def __release_buffer__(self, view):
+                nonlocal rb_call_count
+                rb_call_count += 1
+                try:
+                    super().__release_buffer__(view)
+                except ValueError:
+                    nonlocal rb_raised
+                    rb_raised = True
+
+        b = B(b"hello")
+        with memoryview(b) as mv:
+            self.assertEqual(mv.tobytes(), b"hello")
+            self.assertEqual(rb_call_count, 0)
+        self.assertEqual(rb_call_count, 1)
+        self.assertIs(rb_raised, True)
+
+    def test_override_only_release(self):
+        class C(bytearray):
+            def __release_buffer__(self, buffer):
+                super().__release_buffer__(buffer)
+
+        c = C(b"hello")
+        with memoryview(c) as mv:
+            self.assertEqual(mv.tobytes(), b"hello")
+
+    def test_release_saves_reference(self):
+        smuggled_buffer = None
+
+        class C(bytearray):
+            def __release_buffer__(s, buffer: memoryview):
+                with self.assertRaises(ValueError):
+                    memoryview(buffer)
+                with self.assertRaises(ValueError):
+                    buffer.cast("b")
+                with self.assertRaises(ValueError):
+                    buffer.toreadonly()
+                with self.assertRaises(ValueError):
+                    buffer[:1]
+                with self.assertRaises(ValueError):
+                    buffer.__buffer__(0)
+                nonlocal smuggled_buffer
+                smuggled_buffer = buffer
+                self.assertEqual(buffer.tobytes(), b"hello")
+                super().__release_buffer__(buffer)
+
+        c = C(b"hello")
+        with memoryview(c) as mv:
+            self.assertEqual(mv.tobytes(), b"hello")
+        c.clear()
+        with self.assertRaises(ValueError):
+            smuggled_buffer.tobytes()
+
+    def test_release_saves_reference_no_subclassing(self):
+        ba = bytearray(b"hello")
+
+        class C:
+            def __buffer__(self, flags):
+                return memoryview(ba)
+
+            def __release_buffer__(self, buffer):
+                self.buffer = buffer
+
+        c = C()
+        with memoryview(c) as mv:
+            self.assertEqual(mv.tobytes(), b"hello")
+        self.assertEqual(c.buffer.tobytes(), b"hello")
+
+        with self.assertRaises(BufferError):
+            ba.clear()
+        c.buffer.release()
+        ba.clear()
+
+    def test_multiple_inheritance_buffer_last(self):
+        class A:
+            def __buffer__(self, flags):
+                return memoryview(b"hello A")
+
+        class B(A, bytearray):
+            def __buffer__(self, flags):
+                return super().__buffer__(flags)
+
+        b = B(b"hello")
+        with memoryview(b) as mv:
+            self.assertEqual(mv.tobytes(), b"hello A")
+
+        class Releaser:
+            def __release_buffer__(self, buffer):
+                self.buffer = buffer
+
+        class C(Releaser, bytearray):
+            def __buffer__(self, flags):
+                return super().__buffer__(flags)
+
+        c = C(b"hello C")
+        with memoryview(c) as mv:
+            self.assertEqual(mv.tobytes(), b"hello C")
+        c.clear()
+        with self.assertRaises(ValueError):
+            c.buffer.tobytes()
+
+    def test_multiple_inheritance_buffer_last(self):
+        class A:
+            def __buffer__(self, flags):
+                raise RuntimeError("should not be called")
+
+            def __release_buffer__(self, buffer):
+                raise RuntimeError("should not be called")
+
+        class B(bytearray, A):
+            def __buffer__(self, flags):
+                return super().__buffer__(flags)
+
+        b = B(b"hello")
+        with memoryview(b) as mv:
+            self.assertEqual(mv.tobytes(), b"hello")
+
+        class Releaser:
+            buffer = None
+            def __release_buffer__(self, buffer):
+                self.buffer = buffer
+
+        class C(bytearray, Releaser):
+            def __buffer__(self, flags):
+                return super().__buffer__(flags)
+
+        c = C(b"hello")
+        with memoryview(c) as mv:
+            self.assertEqual(mv.tobytes(), b"hello")
+        c.clear()
+        self.assertIs(c.buffer, None)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/Objects/bytearrayobject.c b/Objects/bytearrayobject.c
index 49d4dd524696..c36db59baaa1 100644
--- a/Objects/bytearrayobject.c
+++ b/Objects/bytearrayobject.c
@@ -61,6 +61,7 @@ static void
 bytearray_releasebuffer(PyByteArrayObject *obj, Py_buffer *view)
 {
     obj->ob_exports--;
+    assert(obj->ob_exports >= 0);
 }
 
 static int
diff --git a/Objects/memoryobject.c b/Objects/memoryobject.c
index f008a8cc3e04..b0168044d9f8 100644
--- a/Objects/memoryobject.c
+++ b/Objects/memoryobject.c
@@ -193,6 +193,20 @@ PyTypeObject _PyManagedBuffer_Type = {
         return -1;                                                \
     }
 
+#define CHECK_RESTRICTED(mv) \
+    if (((PyMemoryViewObject *)(mv))->flags & _Py_MEMORYVIEW_RESTRICTED) { \
+        PyErr_SetString(PyExc_ValueError,                                  \
+            "cannot create new view on restricted memoryview");            \
+        return NULL;                                                       \
+    }
+
+#define CHECK_RESTRICTED_INT(mv) \
+    if (((PyMemoryViewObject *)(mv))->flags & _Py_MEMORYVIEW_RESTRICTED) { \
+        PyErr_SetString(PyExc_ValueError,                                  \
+            "cannot create new view on restricted memoryview");            \
+        return -1;                                                       \
+    }
+
 /* See gh-92888. These macros signal that we need to check the memoryview
    again due to possible read after frees. */
 #define CHECK_RELEASED_AGAIN(mv) CHECK_RELEASED(mv)
@@ -781,7 +795,7 @@ PyMemoryView_FromBuffer(const Py_buffer *info)
    using the given flags.
    If the object is a memoryview, the new memoryview must be registered
    with the same managed buffer. Otherwise, a new managed buffer is created. */
-PyObject *
+static PyObject *
 PyMemoryView_FromObjectAndFlags(PyObject *v, int flags)
 {
     _PyManagedBufferObject *mbuf;
@@ -789,6 +803,7 @@ PyMemoryView_FromObjectAndFlags(PyObject *v, int flags)
     if (PyMemoryView_Check(v)) {
         PyMemoryViewObject *mv = (PyMemoryViewObject *)v;
         CHECK_RELEASED(mv);
+        CHECK_RESTRICTED(mv);
         return mbuf_add_view(mv->mbuf, &mv->view);
     }
     else if (PyObject_CheckBuffer(v)) {
@@ -806,6 +821,30 @@ PyMemoryView_FromObjectAndFlags(PyObject *v, int flags)
         Py_TYPE(v)->tp_name);
     return NULL;
 }
+
+/* Create a memoryview from an object that implements the buffer protocol,
+   using the given flags.
+   If the object is a memoryview, the new memoryview must be registered
+   with the same managed buffer. Otherwise, a new managed buffer is created. */
+PyObject *
+_PyMemoryView_FromBufferProc(PyObject *v, int flags, getbufferproc bufferproc)
+{
+    _PyManagedBufferObject *mbuf = mbuf_alloc();
+    if (mbuf == NULL)
+        return NULL;
+
+    int res = bufferproc(v, &mbuf->master, flags);
+    if (res < 0) {
+        mbuf->master.obj = NULL;
+        Py_DECREF(mbuf);
+        return NULL;
+    }
+
+    PyObject *ret = mbuf_add_view(mbuf, NULL);
+    Py_DECREF(mbuf);
+    return ret;
+}
+
 /* Create a memoryview from an object that implements the buffer protocol.
    If the object is a memoryview, the new memoryview must be registered
    with the same managed buffer. Otherwise, a new managed buffer is created. */
@@ -1397,6 +1436,7 @@ memoryview_cast_impl(PyMemoryViewObject *self, PyObject *format,
     Py_ssize_t ndim = 1;
 
     CHECK_RELEASED(self);
+    CHECK_RESTRICTED(self);
 
     if (!MV_C_CONTIGUOUS(self->flags)) {
         PyErr_SetString(PyExc_TypeError,
@@ -1452,6 +1492,7 @@ memoryview_toreadonly_impl(PyMemoryViewObject *self)
 /*[clinic end generated code: output=2c7e056f04c99e62 input=dc06d20f19ba236f]*/
 {
     CHECK_RELEASED(self);
+    CHECK_RESTRICTED(self);
     /* Even if self is already readonly, we still need to create a new
      * object for .release() to work correctly.
      */
@@ -1474,6 +1515,7 @@ memory_getbuf(PyMemoryViewObject *self, Py_buffer *view, int flags)
     int baseflags = self->flags;
 
     CHECK_RELEASED_INT(self);
+    CHECK_RESTRICTED_INT(self);
 
     /* start with complete information */
     *view = *base;
@@ -2535,6 +2577,7 @@ memory_subscript(PyMemoryViewObject *self, PyObject *key)
         return memory_item(self, index);
     }
     else if (PySlice_Check(key)) {
+        CHECK_RESTRICTED(self);
         PyMemoryViewObject *sliced;
 
         sliced = (PyMemoryViewObject *)mbuf_add_view(self->mbuf, view);
diff --git a/Objects/typeobject.c b/Objects/typeobject.c
index 456b10ee01d6..98fac276a873 100644
--- a/Objects/typeobject.c
+++ b/Objects/typeobject.c
@@ -6,7 +6,7 @@
 #include "pycore_symtable.h"      // _Py_Mangle()
 #include "pycore_dict.h"          // _PyDict_KeysSize()
 #include "pycore_initconfig.h"    // _PyStatus_OK()
-#include "pycore_memoryobject.h"  // PyMemoryView_FromObjectAndFlags()
+#include "pycore_memoryobject.h"  // _PyMemoryView_FromBufferProc()
 #include "pycore_moduleobject.h"  // _PyModule_GetDef()
 #include "pycore_object.h"        // _PyType_HasFeature()
 #include "pycore_long.h"          // _PyLong_IsNegative()
@@ -56,6 +56,11 @@ typedef struct PySlot_Offset {
     short slot_offset;
 } PySlot_Offset;
 
+static void
+slot_bf_releasebuffer(PyObject *self, Py_buffer *buffer);
+
+static void
+releasebuffer_call_python(PyObject *self, Py_buffer *buffer);
 
 static PyObject *
 slot_tp_new(PyTypeObject *type, PyObject *args, PyObject *kwds);
@@ -8078,7 +8083,8 @@ wrap_buffer(PyObject *self, PyObject *args, void *wrapped)
         return NULL;
     }
 
-    return PyMemoryView_FromObjectAndFlags(self, Py_SAFE_DOWNCAST(flags, Py_ssize_t, int));
+    return _PyMemoryView_FromBufferProc(self, Py_SAFE_DOWNCAST(flags, Py_ssize_t, int),
+                                        (getbufferproc)wrapped);
 }
 
 static PyObject *
@@ -8094,6 +8100,10 @@ wrap_releasebuffer(PyObject *self, PyObject *args, void *wrapped)
         return NULL;
     }
     PyMemoryViewObject *mview = (PyMemoryViewObject *)arg;
+    if (mview->view.obj == NULL) {
+        // Already released, ignore
+        Py_RETURN_NONE;
+    }
     if (mview->view.obj != self) {
         PyErr_SetString(PyExc_ValueError,
                         "memoryview's buffer is not this object");
@@ -8978,12 +8988,26 @@ bufferwrapper_releasebuf(PyObject *self, Py_buffer *view)
 {
     PyBufferWrapper *bw = (PyBufferWrapper *)self;
 
-    assert(PyMemoryView_Check(bw->mv));
-    Py_TYPE(bw->mv)->tp_as_buffer->bf_releasebuffer(bw->mv, view);
-    if (Py_TYPE(bw->obj)->tp_as_buffer != NULL
-        && Py_TYPE(bw->obj)->tp_as_buffer->bf_releasebuffer != NULL) {
-        Py_TYPE(bw->obj)->tp_as_buffer->bf_releasebuffer(bw->obj, view);
+    if (bw->mv == NULL || bw->obj == NULL) {
+        // Already released
+        return;
+    }
+
+    PyObject *mv = bw->mv;
+    PyObject *obj = bw->obj;
+
+    assert(PyMemoryView_Check(mv));
+    Py_TYPE(mv)->tp_as_buffer->bf_releasebuffer(mv, view);
+    // We only need to call bf_releasebuffer if it's a Python function. If it's a C
+    // bf_releasebuf, it will be called when the memoryview is released.
+    if (((PyMemoryViewObject *)mv)->view.obj != obj
+            && Py_TYPE(obj)->tp_as_buffer != NULL
+            && Py_TYPE(obj)->tp_as_buffer->bf_releasebuffer == slot_bf_releasebuffer) {
+        releasebuffer_call_python(obj, view);
     }
+
+    Py_CLEAR(bw->mv);
+    Py_CLEAR(bw->obj);
 }
 
 static PyBufferProcs bufferwrapper_as_buffer = {
@@ -9047,31 +9071,112 @@ slot_bf_getbuffer(PyObject *self, Py_buffer *buffer, int flags)
     return -1;
 }
 
+static int
+releasebuffer_maybe_call_super(PyObject *self, Py_buffer *buffer)
+{
+    PyTypeObject *self_type = Py_TYPE(self);
+    PyObject *mro = lookup_tp_mro(self_type);
+    if (mro == NULL) {
+        return -1;
+    }
+
+    assert(PyTuple_Check(mro));
+    Py_ssize_t n = PyTuple_GET_SIZE(mro);
+    Py_ssize_t i;
+
+    /* No need to check the last one: it's gonna be skipped anyway.  */
+    for (i = 0;  i < n -1; i++) {
+        if ((PyObject *)(self_type) == PyTuple_GET_ITEM(mro, i))
+            break;
+    }
+    i++;  /* skip self_type */
+    if (i >= n)
+        return -1;
+
+    releasebufferproc base_releasebuffer = NULL;
+    for (; i < n; i++) {
+        PyObject *obj = PyTuple_GET_ITEM(mro, i);
+        if (!PyType_Check(obj)) {
+            continue;
+        }
+        PyTypeObject *base_type = (PyTypeObject *)obj;
+        if (base_type->tp_as_buffer != NULL
+            && base_type->tp_as_buffer->bf_releasebuffer != NULL
+            && base_type->tp_as_buffer->bf_releasebuffer != slot_bf_releasebuffer) {
+            base_releasebuffer = base_type->tp_as_buffer->bf_releasebuffer;
+            break;
+        }
+    }
+
+    if (base_releasebuffer != NULL) {
+        base_releasebuffer(self, buffer);
+    }
+    return 0;
+}
+
 static void
-slot_bf_releasebuffer(PyObject *self, Py_buffer *buffer)
+releasebuffer_call_python(PyObject *self, Py_buffer *buffer)
 {
     PyObject *mv;
-    if (Py_TYPE(buffer->obj) == &_PyBufferWrapper_Type) {
+    bool is_buffer_wrapper = Py_TYPE(buffer->obj) == &_PyBufferWrapper_Type;
+    if (is_buffer_wrapper) {
         // Make sure we pass the same memoryview to
         // __release_buffer__() that __buffer__() returned.
-        mv = Py_NewRef(((PyBufferWrapper *)buffer->obj)->mv);
+        PyBufferWrapper *bw = (PyBufferWrapper *)buffer->obj;
+        if (bw->mv == NULL) {
+            return;
+        }
+        mv = Py_NewRef(bw->mv);
     }
     else {
+        // This means we are not dealing with a memoryview returned
+        // from a Python __buffer__ function.
         mv = PyMemoryView_FromBuffer(buffer);
         if (mv == NULL) {
             PyErr_WriteUnraisable(self);
             return;
         }
+        // Set the memoryview to restricted mode, which forbids
+        // users from saving any reference to the underlying buffer
+        // (e.g., by doing .cast()). This is necessary to ensure
+        // no Python code retains a reference to the to-be-released
+        // buffer.
+        ((PyMemoryViewObject *)mv)->flags |= _Py_MEMORYVIEW_RESTRICTED;
     }
     PyObject *stack[2] = {self, mv};
     PyObject *ret = vectorcall_method(&_Py_ID(__release_buffer__), stack, 2);
-    Py_DECREF(mv);
     if (ret == NULL) {
         PyErr_WriteUnraisable(self);
     }
     else {
         Py_DECREF(ret);
     }
+    if (!is_buffer_wrapper) {
+        PyObject_CallMethodNoArgs(mv, &_Py_ID(release));
+    }
+    Py_DECREF(mv);
+}
+
+/*
+ * bf_releasebuffer is very delicate, because we need to ensure that
+ * C bf_releasebuffer slots are called correctly (or we'll leak memory),
+ * but we cannot trust any __release_buffer__ implemented in Python to
+ * do so correctly. Therefore, if a base class has a C bf_releasebuffer
+ * slot, we call it directly here. That is safe because this function
+ * only gets called from C callers of the bf_releasebuffer slot. Python
+ * code that calls __release_buffer__ directly instead goes through
+ * wrap_releasebuffer(), which doesn't call the bf_releasebuffer slot
+ * directly but instead simply releases the associated memoryview.
+ */
+static void
+slot_bf_releasebuffer(PyObject *self, Py_buffer *buffer)
+{
+    releasebuffer_call_python(self, buffer);
+    if (releasebuffer_maybe_call_super(self, buffer) < 0) {
+        if (PyErr_Occurred()) {
+            PyErr_WriteUnraisable(self);
+        }
+    }
 }
 
 static PyObject *



More information about the Python-checkins mailing list