[Numpy-svn] r5906 - in trunk/numpy/core: code_generators src tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Fri Oct 3 11:55:53 EDT 2008
Author: oliphant
Date: 2008-10-03 10:55:52 -0500 (Fri, 03 Oct 2008)
New Revision: 5906
Modified:
trunk/numpy/core/code_generators/numpy_api_order.txt
trunk/numpy/core/src/arrayobject.c
trunk/numpy/core/src/multiarraymodule.c
trunk/numpy/core/tests/test_multiarray.py
Log:
Fix ticket #925
Modified: trunk/numpy/core/code_generators/numpy_api_order.txt
===================================================================
--- trunk/numpy/core/code_generators/numpy_api_order.txt 2008-10-03 07:22:57 UTC (rev 5905)
+++ trunk/numpy/core/code_generators/numpy_api_order.txt 2008-10-03 15:55:52 UTC (rev 5906)
@@ -170,3 +170,4 @@
PyArray_CheckAxis
PyArray_OverflowMultiplyList
PyArray_CompareString
+PyArray_MultiIterFromObjects
Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c 2008-10-03 07:22:57 UTC (rev 5905)
+++ trunk/numpy/core/src/arrayobject.c 2008-10-03 15:55:52 UTC (rev 5906)
@@ -10790,6 +10790,75 @@
/** END of Subscript Iterator **/
+/*
+ NUMPY_API
+ Get MultiIterator from array of Python objects and any additional
+
+ PyObject **mps -- array of PyObjects
+ int n - number of PyObjects in the array
+ int nadd - number of additional arrays to include in the
+ iterator.
+
+ Returns a multi-iterator object.
+ */
+static PyObject *
+PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
+{
+ va_list va;
+ PyArrayMultiIterObject *multi;
+ PyObject *current;
+ PyObject *arr;
+
+ int i, ntot, err=0;
+
+ ntot = n + nadd;
+ if (ntot < 2 || ntot > NPY_MAXARGS) {
+ PyErr_Format(PyExc_ValueError,
+ "Need between 2 and (%d) " \
+ "array objects (inclusive).", NPY_MAXARGS);
+ return NULL;
+ }
+
+ multi = _pya_malloc(sizeof(PyArrayMultiIterObject));
+ if (multi == NULL) return PyErr_NoMemory();
+ PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type);
+
+ for(i=0; i<ntot; i++) multi->iters[i] = NULL;
+ multi->numiter = ntot;
+ multi->index = 0;
+
+ va_start(va, nadd);
+ for(i=0; i<ntot; i++) {
+ if (i < n) {
+ current = mps[i];
+ }
+ else {
+ current = va_arg(va, PyObject *);
+ }
+ arr = PyArray_FROM_O(current);
+ if (arr==NULL) {
+ err=1; break;
+ }
+ else {
+ multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr);
+ Py_DECREF(arr);
+ }
+ }
+
+ va_end(va);
+
+ if (!err && PyArray_Broadcast(multi) < 0) err=1;
+
+ if (err) {
+ Py_DECREF(multi);
+ return NULL;
+ }
+
+ PyArray_MultiIter_RESET(multi);
+
+ return (PyObject *)multi;
+}
+
/*NUMPY_API
Get MultiIterator,
*/
Modified: trunk/numpy/core/src/multiarraymodule.c
===================================================================
--- trunk/numpy/core/src/multiarraymodule.c 2008-10-03 07:22:57 UTC (rev 5905)
+++ trunk/numpy/core/src/multiarraymodule.c 2008-10-03 15:55:52 UTC (rev 5906)
@@ -2326,50 +2326,40 @@
PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret,
NPY_CLIPMODE clipmode)
{
- intp *sizes, offset;
int n, elsize;
intp i, m;
char *ret_data;
PyArrayObject **mps, *ap;
- intp *self_data, mi;
+ PyArrayMultiIterObject *multi=NULL;
+ intp mi;
int copyret=0;
ap = NULL;
/* Convert all inputs to arrays of a common type */
+ /* Also makes them C-contiguous */
mps = PyArray_ConvertToCommonType(op, &n);
if (mps == NULL) return NULL;
- sizes = (intp *)_pya_malloc(n*sizeof(intp));
- if (sizes == NULL) goto fail;
-
- ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ip,
- PyArray_INTP,
- 0, 0);
- if (ap == NULL) goto fail;
-
- /* Check the dimensions of the arrays */
for(i=0; i<n; i++) {
if (mps[i] == NULL) goto fail;
- if (ap->nd < mps[i]->nd) {
- PyErr_SetString(PyExc_ValueError,
- "too many dimensions");
- goto fail;
- }
- if (!PyArray_CompareLists(ap->dimensions+(ap->nd-mps[i]->nd),
- mps[i]->dimensions, mps[i]->nd)) {
- PyErr_SetString(PyExc_ValueError,
- "array dimensions must agree");
- goto fail;
- }
- sizes[i] = PyArray_NBYTES(mps[i]);
}
+ ap = (PyArrayObject *)PyArray_FROM_OT((PyObject *)ip, NPY_INTP);
+
+ if (ap == NULL) goto fail;
+
+ /* Broadcast all arrays to each other, index array at the end. */
+ multi = (PyArrayMultiIterObject *)\
+ PyArray_MultiIterFromObjects((PyObject **)mps, n, 1, ap);
+ if (multi == NULL) goto fail;
+
+ /* Set-up return array */
if (!ret) {
Py_INCREF(mps[0]->descr);
ret = (PyArrayObject *)PyArray_NewFromDescr(ap->ob_type,
mps[0]->descr,
- ap->nd,
- ap->dimensions,
+ multi->nd,
+ multi->dimensions,
NULL, NULL, 0,
(PyObject *)ap);
}
@@ -2377,8 +2367,10 @@
PyArrayObject *obj;
int flags = NPY_CARRAY | NPY_UPDATEIFCOPY | NPY_FORCECAST;
- if (PyArray_SIZE(ret) != PyArray_SIZE(ap)) {
- PyErr_SetString(PyExc_TypeError,
+ if ((PyArray_NDIM(ret) != multi->nd) ||
+ !PyArray_CompareLists(PyArray_DIMS(ret), multi->dimensions,
+ multi->nd)) {
+ PyErr_SetString(PyExc_TypeError,
"invalid shape for output array.");
ret = NULL;
goto fail;
@@ -2399,12 +2391,10 @@
if (ret == NULL) goto fail;
elsize = ret->descr->elsize;
- m = PyArray_SIZE(ret);
- self_data = (intp *)ap->data;
ret_data = ret->data;
- for (i=0; i<m; i++) {
- mi = *self_data;
+ while (PyArray_MultiIter_NOTDONE(multi)) {
+ mi = *((intp *)PyArray_MultiIter_DATA(multi, n));
if (mi < 0 || mi >= n) {
switch(clipmode) {
case NPY_RAISE:
@@ -2426,17 +2416,16 @@
break;
}
}
- offset = i*elsize;
- if (offset >= sizes[mi]) {offset = offset % sizes[mi]; }
- memmove(ret_data, mps[mi]->data+offset, elsize);
- ret_data += elsize; self_data++;
+ memmove(ret_data, PyArray_MultiIter_DATA(multi, mi), elsize);
+ ret_data += elsize;
+ PyArray_MultiIter_NEXT(multi);
}
PyArray_INCREF(ret);
+ Py_DECREF(multi);
for(i=0; i<n; i++) Py_XDECREF(mps[i]);
Py_DECREF(ap);
PyDataMem_FREE(mps);
- _pya_free(sizes);
if (copyret) {
PyObject *obj;
obj = ret->base;
@@ -2447,10 +2436,10 @@
return (PyObject *)ret;
fail:
+ Py_XDECREF(multi);
for(i=0; i<n; i++) Py_XDECREF(mps[i]);
Py_XDECREF(ap);
PyDataMem_FREE(mps);
- _pya_free(sizes);
PyArray_XDECREF_ERR(ret);
return NULL;
}
Modified: trunk/numpy/core/tests/test_multiarray.py
===================================================================
--- trunk/numpy/core/tests/test_multiarray.py 2008-10-03 07:22:57 UTC (rev 5905)
+++ trunk/numpy/core/tests/test_multiarray.py 2008-10-03 15:55:52 UTC (rev 5906)
@@ -946,6 +946,26 @@
assert repr(A) == reprA
+class TestChoose(TestCase):
+ def setUp(self):
+ self.x = 2*ones((3,),dtype=int)
+ self.y = 3*ones((3,),dtype=int)
+ self.x2 = 2*ones((2,3), dtype=int)
+ self.y2 = 3*ones((2,3), dtype=int)
+ self.ind = [0,0,1]
+ def test_basic(self):
+ A = np.choose(self.ind, (self.x, self.y))
+ assert_equal(A, [2,2,3])
+
+ def test_broadcast1(self):
+ A = np.choose(self.ind, (self.x2, self.y2))
+ assert_equal(A, [[2,2,3],[2,2,3]])
+
+ def test_broadcast2(self):
+ A = np.choose(self.ind, (self.x, self.y2))
+ assert_equal(A, [[2,2,3],[2,2,3]])
+
+
if __name__ == "__main__":
run_module_suite()
More information about the Numpy-svn
mailing list