[Numpy-svn] r2726 - trunk/numpy/core/src

numpy-svn at scipy.org numpy-svn at scipy.org
Sun Jul 2 17:44:27 EDT 2006


Author: oliphant
Date: 2006-07-02 16:44:24 -0500 (Sun, 02 Jul 2006)
New Revision: 2726

Modified:
   trunk/numpy/core/src/arrayobject.c
Log:
Fixed missing case in casting call.

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2006-07-02 17:19:26 UTC (rev 2725)
+++ trunk/numpy/core/src/arrayobject.c	2006-07-02 21:44:24 UTC (rev 2726)
@@ -897,7 +897,13 @@
 	multi = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, dest, src);
 	if (multi == NULL) return -1;
 	maxaxis = PyArray_RemoveLargest(multi);
-	if (maxaxis < 0) return -1;
+	if (maxaxis < 0) { /* copy 1 0-d array to another */
+		PyArray_XDECREF(dest);
+		memcpy(dest->data, src->data, elsize);
+		if (swap) byte_swap_vector(dest->data, 1, elsize);
+		PyArray_INCREF(dest);
+		return 0;
+	}
 	maxdim = multi->dimensions[maxaxis];
 
 	PyArray_XDECREF(dest);
@@ -6866,7 +6872,7 @@
 {
 	int delsize, selsize, maxaxis, i, N;
 	PyArrayMultiIterObject *multi;
-	intp maxdim;
+	intp maxdim, ostrides, istrides;
 	char *buffers[2];
 	PyArray_CopySwapNFunc *ocopyfunc, *icopyfunc;	
 	char *obptr;
@@ -6875,10 +6881,22 @@
 	selsize = PyArray_ITEMSIZE(in);
 	multi = (PyArrayMultiIterObject *)PyArray_MultiIterNew(2, out, in);
 	if (multi == NULL) return -1;
+	icopyfunc = in->descr->f->copyswapn;
+	ocopyfunc = out->descr->f->copyswapn;
 	maxaxis = PyArray_RemoveLargest(multi);
-	if (maxaxis < 0) return -1;
-	maxdim = multi->dimensions[maxaxis];
-	N = (int) (MIN(maxdim, PyArray_BUFSIZE));
+	if (maxaxis < 0) { /* cast 1 0-d array to another */
+		N = 1;
+		maxdim = 1;
+		ostrides = delsize;
+		istrides = selsize;
+	}
+	else {
+		maxdim = multi->dimensions[maxaxis];
+		N = (int) (MIN(maxdim, PyArray_BUFSIZE));
+		ostrides = multi->iters[0]->strides[maxaxis];
+		istrides = multi->iters[1]->strides[maxaxis];
+
+	}
 	buffers[0] = _pya_malloc(N*delsize);
 	if (buffers[0] == NULL) {
 		PyErr_NoMemory();
@@ -6886,7 +6904,7 @@
 	}
 	buffers[1] = _pya_malloc(N*selsize);
 	if (buffers[1] == NULL) {
-		free(buffers[0]);
+		_pya_free(buffers[0]);
 		PyErr_NoMemory();
 		return -1;
 	}
@@ -6895,15 +6913,12 @@
 	if (in->descr->hasobject) 
 		memset(buffers[1], 0, N*selsize);
 
-	icopyfunc = in->descr->f->copyswapn;
-	ocopyfunc = out->descr->f->copyswapn;
-
 	while(multi->index < multi->size) {
 		_strided_buffered_cast(multi->iters[0]->dataptr,
-				       multi->iters[0]->strides[maxaxis],
+				       ostrides,
 				       delsize, oswap, ocopyfunc,
 				       multi->iters[1]->dataptr, 
-				       multi->iters[1]->strides[maxaxis],
+				       istrides,
 				       selsize, iswap, icopyfunc,
 				       maxdim, buffers, N,
 				       castfunc, out, in);




More information about the Numpy-svn mailing list