[Numpy-discussion] dtype subarray comparison
Mark Wiebe
mwwiebe at gmail.com
Thu Oct 21 15:50:11 EDT 2010
>
> <snip>
The other issue of allowing broadcasting in sub-arrays --- it does not
seem very useful to me. Unlike arrays, the dimensions of sub-arrays
cannot be manipulated easily, and so many use-cases of broadcasting just
disappear.
I implemented the a/b shape checking in the structured array comparison
code, and it breaks 6 unit tests. The errors it produces are like the
following:
======================================================================
FAIL: test_multiarray.TestNewBufferProtocol.test_roundtrip
----------------------------------------------------------------------
Traceback (most recent call last):
File "/usr/lib/python2.6/site-packages/nose/case.py", line 186, in runTest
self.test(*self.arg)
File
"/home/mwiebe/installtest/lib64/python2.6/site-packages/numpy/core/tests/test_multiarray.py",
line 1635, in test_roundtrip
self._check_roundtrip(x)
File
"/home/mwiebe/installtest/lib64/python2.6/site-packages/numpy/core/tests/test_multiarray.py",
line 1595, in _check_roundtrip
assert_array_equal(obj, y)
File
"/home/mwiebe/installtest/lib64/python2.6/site-packages/numpy/testing/utils.py",
line 686, in assert_array_equal
verbose=verbose, header='Arrays are not equal')
File
"/home/mwiebe/installtest/lib64/python2.6/site-packages/numpy/testing/utils.py",
line 618, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Arrays are not equal
(mismatch 100.0%)
x: array(([[1, 2], [3, 4]],),
dtype=[('a', '<i8', (2, 2))])
y: array([([[1, 2], [3, 4]],)],
dtype=[('a', '<i8', (2, 2))])
Here's what this test does:
>>> import numpy as np
>>> from numpy.core.multiarray import memorysimpleview as memoryview
>>> obj = np.array(([[1, 2], [3, 4]],), dtype=[('a', '<i8', (2, 2))])
>>> x = memoryview(obj)
>>> y = np.asarray(x)
>>> obj == y
False
>>> y['a'].shape
(1, 2, 2)
>>> obj['a'].shape
(2, 2)
This happens because y.shape is (1,) and obj.shape is (), then the dtype's
(2,2) shape is being appended to both when extracting the field 'a'.
Cheers,
Mark
p.s.: Here's the small code addition, which I can add to the branch if
desired:
diff --git a/numpy/core/src/multiarray/arrayobject.c
b/numpy/core/src/multiarray/a
index 6e5bd9a..f12012c 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -844,6 +844,14 @@ _void_compare(PyArrayObject *self, PyArrayObject
*other, int
Py_DECREF(a);
return NULL;
}
+ if (PyArray_NDIM(a) != PyArray_NDIM(b) ||
+ !PyArray_CompareLists(PyArray_DIMS(a), PyArray_DIMS(b),
PyArr
+ Py_XDECREF(res);
+ Py_DECREF(a);
+ Py_DECREF(b);
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
temp = array_richcompare((PyArrayObject *)a,b,cmp_op);
Py_DECREF(a);
Py_DECREF(b);
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/numpy-discussion/attachments/20101021/1ec0960c/attachment.html>
More information about the NumPy-Discussion
mailing list