[Numpy-svn] r5906 - in trunk/numpy/core: code_generators src tests

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Oct 3 11:55:53 EDT 2008


Author: oliphant
Date: 2008-10-03 10:55:52 -0500 (Fri, 03 Oct 2008)
New Revision: 5906

Modified:
   trunk/numpy/core/code_generators/numpy_api_order.txt
   trunk/numpy/core/src/arrayobject.c
   trunk/numpy/core/src/multiarraymodule.c
   trunk/numpy/core/tests/test_multiarray.py
Log:
Fix ticket #925

Modified: trunk/numpy/core/code_generators/numpy_api_order.txt
===================================================================
--- trunk/numpy/core/code_generators/numpy_api_order.txt	2008-10-03 07:22:57 UTC (rev 5905)
+++ trunk/numpy/core/code_generators/numpy_api_order.txt	2008-10-03 15:55:52 UTC (rev 5906)
@@ -170,3 +170,4 @@
 PyArray_CheckAxis
 PyArray_OverflowMultiplyList
 PyArray_CompareString
+PyArray_MultiIterFromObjects

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2008-10-03 07:22:57 UTC (rev 5905)
+++ trunk/numpy/core/src/arrayobject.c	2008-10-03 15:55:52 UTC (rev 5906)
@@ -10790,6 +10790,75 @@
 /** END of Subscript Iterator **/
 
 
+/*
+  NUMPY_API
+  Get MultiIterator from array of Python objects and any additional
+
+  PyObject **mps -- array of PyObjects 
+  int n - number of PyObjects in the array
+  int nadd - number of additional arrays to include in the
+             iterator. 
+
+  Returns a multi-iterator object.
+ */
+static PyObject *
+PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
+{
+    va_list va;
+    PyArrayMultiIterObject *multi;
+    PyObject *current;
+    PyObject *arr;
+
+    int i, ntot, err=0;
+
+    ntot = n + nadd;
+    if (ntot < 2 || ntot > NPY_MAXARGS) {
+        PyErr_Format(PyExc_ValueError,
+                     "Need between 2 and (%d) "                 \
+                     "array objects (inclusive).", NPY_MAXARGS);
+        return NULL;
+    }
+
+    multi = _pya_malloc(sizeof(PyArrayMultiIterObject));
+    if (multi == NULL) return PyErr_NoMemory();
+    PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type);
+
+    for(i=0; i<ntot; i++) multi->iters[i] = NULL;
+    multi->numiter = ntot;
+    multi->index = 0;
+
+    va_start(va, nadd);
+    for(i=0; i<ntot; i++) {
+	if (i < n) {
+	    current = mps[i];
+	}
+	else {
+	    current = va_arg(va, PyObject *);
+	}
+        arr = PyArray_FROM_O(current);
+        if (arr==NULL) {
+            err=1; break;
+        }
+        else {
+            multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr);
+            Py_DECREF(arr);
+        }
+    }
+
+    va_end(va);
+
+    if (!err && PyArray_Broadcast(multi) < 0) err=1;
+
+    if (err) {
+        Py_DECREF(multi);
+        return NULL;
+    }
+
+    PyArray_MultiIter_RESET(multi);
+
+    return (PyObject *)multi;  
+}
+
 /*NUMPY_API
   Get MultiIterator,
 */

Modified: trunk/numpy/core/src/multiarraymodule.c
===================================================================
--- trunk/numpy/core/src/multiarraymodule.c	2008-10-03 07:22:57 UTC (rev 5905)
+++ trunk/numpy/core/src/multiarraymodule.c	2008-10-03 15:55:52 UTC (rev 5906)
@@ -2326,50 +2326,40 @@
 PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *ret,
                NPY_CLIPMODE clipmode)
 {
-    intp *sizes, offset;
     int n, elsize;
     intp i, m;
     char *ret_data;
     PyArrayObject **mps, *ap;
-    intp *self_data, mi;
+    PyArrayMultiIterObject *multi=NULL;
+    intp mi;
     int copyret=0;
     ap = NULL;
 
     /* Convert all inputs to arrays of a common type */
+    /* Also makes them C-contiguous */
     mps = PyArray_ConvertToCommonType(op, &n);
     if (mps == NULL) return NULL;
 
-    sizes = (intp *)_pya_malloc(n*sizeof(intp));
-    if (sizes == NULL) goto fail;
-
-    ap = (PyArrayObject *)PyArray_ContiguousFromAny((PyObject *)ip,
-                                                    PyArray_INTP,
-                                                    0, 0);
-    if (ap == NULL) goto fail;
-
-    /* Check the dimensions of the arrays */
     for(i=0; i<n; i++) {
         if (mps[i] == NULL) goto fail;
-        if (ap->nd < mps[i]->nd) {
-            PyErr_SetString(PyExc_ValueError,
-                            "too many dimensions");
-            goto fail;
-        }
-        if (!PyArray_CompareLists(ap->dimensions+(ap->nd-mps[i]->nd),
-                                  mps[i]->dimensions, mps[i]->nd)) {
-            PyErr_SetString(PyExc_ValueError,
-                            "array dimensions must agree");
-            goto fail;
-        }
-        sizes[i] = PyArray_NBYTES(mps[i]);
     }
 
+    ap = (PyArrayObject *)PyArray_FROM_OT((PyObject *)ip, NPY_INTP);
+
+    if (ap == NULL) goto fail;
+
+    /* Broadcast all arrays to each other, index array at the end. */ 
+    multi = (PyArrayMultiIterObject *)\
+	PyArray_MultiIterFromObjects((PyObject **)mps, n, 1, ap);
+    if (multi == NULL) goto fail;
+
+    /* Set-up return array */
     if (!ret) {
         Py_INCREF(mps[0]->descr);
         ret = (PyArrayObject *)PyArray_NewFromDescr(ap->ob_type,
                                                     mps[0]->descr,
-                                                    ap->nd,
-                                                    ap->dimensions,
+                                                    multi->nd,
+                                                    multi->dimensions,
                                                     NULL, NULL, 0,
                                                     (PyObject *)ap);
     }
@@ -2377,8 +2367,10 @@
         PyArrayObject *obj;
         int flags = NPY_CARRAY | NPY_UPDATEIFCOPY | NPY_FORCECAST;
 
-        if (PyArray_SIZE(ret) != PyArray_SIZE(ap)) {
-            PyErr_SetString(PyExc_TypeError,
+        if ((PyArray_NDIM(ret) != multi->nd) || 
+	    !PyArray_CompareLists(PyArray_DIMS(ret), multi->dimensions, 
+				  multi->nd)) {
+	    PyErr_SetString(PyExc_TypeError,
                             "invalid shape for output array.");
             ret = NULL;
             goto fail;
@@ -2399,12 +2391,10 @@
 
     if (ret == NULL) goto fail;
     elsize = ret->descr->elsize;
-    m = PyArray_SIZE(ret);
-    self_data = (intp *)ap->data;
     ret_data = ret->data;
 
-    for (i=0; i<m; i++) {
-        mi = *self_data;
+    while (PyArray_MultiIter_NOTDONE(multi)) {
+	mi = *((intp *)PyArray_MultiIter_DATA(multi, n));
         if (mi < 0 || mi >= n) {
             switch(clipmode) {
             case NPY_RAISE:
@@ -2426,17 +2416,16 @@
                 break;
             }
         }
-        offset = i*elsize;
-        if (offset >= sizes[mi]) {offset = offset % sizes[mi]; }
-        memmove(ret_data, mps[mi]->data+offset, elsize);
-        ret_data += elsize; self_data++;
+        memmove(ret_data, PyArray_MultiIter_DATA(multi, mi), elsize);
+        ret_data += elsize; 
+	PyArray_MultiIter_NEXT(multi);
     }
 
     PyArray_INCREF(ret);
+    Py_DECREF(multi);
     for(i=0; i<n; i++) Py_XDECREF(mps[i]);
     Py_DECREF(ap);
     PyDataMem_FREE(mps);
-    _pya_free(sizes);
     if (copyret) {
         PyObject *obj;
         obj = ret->base;
@@ -2447,10 +2436,10 @@
     return (PyObject *)ret;
 
  fail:
+    Py_XDECREF(multi);
     for(i=0; i<n; i++) Py_XDECREF(mps[i]);
     Py_XDECREF(ap);
     PyDataMem_FREE(mps);
-    _pya_free(sizes);
     PyArray_XDECREF_ERR(ret);
     return NULL;
 }

Modified: trunk/numpy/core/tests/test_multiarray.py
===================================================================
--- trunk/numpy/core/tests/test_multiarray.py	2008-10-03 07:22:57 UTC (rev 5905)
+++ trunk/numpy/core/tests/test_multiarray.py	2008-10-03 15:55:52 UTC (rev 5906)
@@ -946,6 +946,26 @@
         assert repr(A) == reprA
 
 
+class TestChoose(TestCase):
+    def setUp(self):
+        self.x = 2*ones((3,),dtype=int)
+        self.y = 3*ones((3,),dtype=int)
+        self.x2 = 2*ones((2,3), dtype=int)
+        self.y2 = 3*ones((2,3), dtype=int)        
+        self.ind = [0,0,1]
 
+    def test_basic(self):
+        A = np.choose(self.ind, (self.x, self.y))
+        assert_equal(A, [2,2,3])
+
+    def test_broadcast1(self):
+        A = np.choose(self.ind, (self.x2, self.y2))
+        assert_equal(A, [[2,2,3],[2,2,3]])
+    
+    def test_broadcast2(self):
+        A = np.choose(self.ind, (self.x, self.y2))
+        assert_equal(A, [[2,2,3],[2,2,3]])
+        
+
 if __name__ == "__main__":
     run_module_suite()




More information about the Numpy-svn mailing list