[pypy-commit] pypy PyBuffer: Create BufferViewND to avoid mutating the memoryview in _cast_to_ND()
rlamy
pypy.commits at gmail.com
Tue Mar 28 10:55:01 EDT 2017
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: PyBuffer
Changeset: r90843:dbc12a5f29c5
Date: 2017-03-28 15:54 +0100
http://bitbucket.org/pypy/pypy/changeset/dbc12a5f29c5/
Log: Create BufferViewND to avoid mutating the memoryview in
_cast_to_ND()
diff --git a/pypy/objspace/std/memoryobject.py b/pypy/objspace/std/memoryobject.py
--- a/pypy/objspace/std/memoryobject.py
+++ b/pypy/objspace/std/memoryobject.py
@@ -547,11 +547,11 @@
origfmt = self.getformat()
newbuf = self._cast_to_1D(space, buf, origfmt, fmt)
- mv = W_MemoryView(newbuf, newbuf.getformat(), newbuf.getitemsize())
if w_shape:
fview = space.fixedview(w_shape)
shape = [space.int_w(w_obj) for w_obj in fview]
- mv._cast_to_ND(space, shape, ndim)
+ newbuf = self._cast_to_ND(space, newbuf, shape, ndim)
+ mv = W_MemoryView(newbuf, newbuf.getformat(), newbuf.getitemsize())
return mv
def _init_flags(self):
@@ -634,19 +634,16 @@
return None
- def _cast_to_ND(self, space, shape, ndim):
- buf = self.buf
+ def _cast_to_ND(self, space, buf, shape, ndim):
length = itemsize = buf.getitemsize()
for i in range(ndim):
length *= shape[i]
- if length != self.buf.getlength():
+ if length != buf.getlength():
raise oefmt(space.w_TypeError,
"memoryview: product(shape) * itemsize != buffer size")
- self.ndim = ndim
- self.shape = shape
- self.strides = self._strides_from_shape(shape, itemsize)
- self._init_flags()
+ strides = self._strides_from_shape(shape, itemsize)
+ return BufferViewND(buf, ndim, shape, strides)
@staticmethod
def _strides_from_shape(shape, itemsize):
@@ -813,3 +810,55 @@
def getstrides(self):
return self.strides
+
+class BufferViewND(Buffer):
+ def __init__(self, parent, ndim, shape, strides):
+ assert parent.getndim() == 1
+ assert len(shape) == len(strides) == ndim
+ self.parent = parent
+ self.readonly = parent.readonly
+ self.ndim = ndim
+ self.shape = shape
+ self.strides = strides
+
+ def getlength(self):
+ return self.parent.getlength()
+
+ def as_str(self):
+ return self.parent.as_str()
+
+ def as_str_and_offset_maybe(self):
+ return self.parent.as_str_and_offset_maybe()
+
+ def as_binary(self):
+ return self.parent.as_binary()
+
+ def getitem(self, index):
+ return self.parent.getitem(index)
+
+ def setitem(self, index, char):
+ self.parent.setitem(index, char)
+
+ def getslice(self, start, stop, step, size):
+ return self.parent.getslice(start, stop, step, size)
+
+ def setslice(self, start, string):
+ self.parent.setslice(start, string)
+
+ def get_raw_address(self):
+ return self.parent.get_raw_address()
+
+ def getformat(self):
+ return self.parent.getformat()
+
+ def getitemsize(self):
+ return self.parent.getitemsize()
+
+ def getndim(self):
+ return self.ndim
+
+ def getshape(self):
+ return self.shape
+
+ def getstrides(self):
+ return self.strides
More information about the pypy-commit
mailing list