[Numpy-svn] r4156 - in trunk/numpy: core/blasdot core/src core/tests lib
numpy-svn at scipy.org
numpy-svn at scipy.org
Mon Oct 8 23:08:33 EDT 2007
Author: oliphant
Date: 2007-10-08 22:08:26 -0500 (Mon, 08 Oct 2007)
New Revision: 4156
Modified:
trunk/numpy/core/blasdot/_dotblas.c
trunk/numpy/core/src/arraytypes.inc.src
trunk/numpy/core/tests/test_regression.py
trunk/numpy/lib/function_base.py
Log:
Fix Ticket #588: problem with negative striding and fast blas implementation of dot
Modified: trunk/numpy/core/blasdot/_dotblas.c
===================================================================
--- trunk/numpy/core/blasdot/_dotblas.c 2007-10-04 19:40:50 UTC (rev 4155)
+++ trunk/numpy/core/blasdot/_dotblas.c 2007-10-09 03:08:26 UTC (rev 4156)
@@ -20,7 +20,8 @@
register int nb = strideb / sizeof(float);
if ((sizeof(float) * na == stridea) &&
- (sizeof(float) * nb == strideb))
+ (sizeof(float) * nb == strideb) &&
+ (na >= 0) && (nb >= 0))
*((float *)res) = cblas_sdot((int)n, (float *)a, na, (float *)b, nb);
else
@@ -35,7 +36,8 @@
register int nb = strideb / sizeof(double);
if ((sizeof(double) * na == stridea) &&
- (sizeof(double) * nb == strideb))
+ (sizeof(double) * nb == strideb) &&
+ (na >= 0) && (nb >= 0))
*((double *)res) = cblas_ddot((int)n, (double *)a, na, (double *)b, nb);
else
oldFunctions[PyArray_DOUBLE](a, stridea, b, strideb, res, n, tmp);
@@ -50,7 +52,8 @@
register int nb = strideb / sizeof(cfloat);
if ((sizeof(cfloat) * na == stridea) &&
- (sizeof(cfloat) * nb == strideb))
+ (sizeof(cfloat) * nb == strideb) &&
+ (na >= 0) && (nb >= 0))
cblas_cdotu_sub((int)n, (float *)a, na, (float *)b, nb, (float *)res);
else
oldFunctions[PyArray_CFLOAT](a, stridea, b, strideb, res, n, tmp);
@@ -64,7 +67,8 @@
register int nb = strideb / sizeof(cdouble);
if ((sizeof(cdouble) * na == stridea) &&
- (sizeof(cdouble) * nb == strideb))
+ (sizeof(cdouble) * nb == strideb) &&
+ (na >= 0) && (nb >= 0))
cblas_zdotu_sub((int)n, (double *)a, na, (double *)b, nb, (double *)res);
else
oldFunctions[PyArray_CDOUBLE](a, stridea, b, strideb, res, n, tmp);
@@ -172,6 +176,21 @@
}
+static int
+_bad_strides(PyArrayObject *ap)
+{
+ register int itemsize = PyArray_ITEMSIZE(ap);
+ register int i, N=PyArray_NDIM(ap);
+ register intp *strides = PyArray_STRIDES(ap);
+
+ for (i=0; i<N; i++) {
+ if ((strides[i] < 0) || (strides[i] % itemsize) != 0)
+ return 1;
+ }
+
+ return 0;
+}
+
static char doc_matrixproduct[] = "dot(a,b)\nReturns the dot product of a and b for arrays of floating point types.\nLike the generic numpy equivalent the product sum is over\nthe last dimension of a and the second-to-last dimension of b.\nNB: The first argument is not conjugated.";
static PyObject *
@@ -216,8 +235,10 @@
ap2 = (PyArrayObject *)PyArray_FromAny(op2, dtype, 0, 0, ALIGNED, NULL);
if (ap2 == NULL) goto fail;
+
if ((ap1->nd > 2) || (ap2->nd > 2)) {
- /* This function doesn't handle dimensions greater than 2 -- other
+ /* This function doesn't handle dimensions greater than 2
+ (or negative striding) -- other
than to ensure the dot function is altered
*/
if (!altered) {
@@ -235,13 +256,13 @@
return PyArray_Return(ret);
}
- if (!PyArray_ElementStrides((PyObject *)ap1)) {
+ if (_bad_strides(ap1)) {
op1 = PyArray_NewCopy(ap1, PyArray_ANYORDER);
Py_DECREF(ap1);
ap1 = (PyArrayObject *)op1;
if (ap1 == NULL) goto fail;
}
- if (!PyArray_ElementStrides((PyObject *)ap2)) {
+ if (_bad_strides(ap2)) {
op2 = PyArray_NewCopy(ap2, PyArray_ANYORDER);
Py_DECREF(ap2);
ap2 = (PyArrayObject *)op2;
Modified: trunk/numpy/core/src/arraytypes.inc.src
===================================================================
--- trunk/numpy/core/src/arraytypes.inc.src 2007-10-04 19:40:50 UTC (rev 4155)
+++ trunk/numpy/core/src/arraytypes.inc.src 2007-10-09 03:08:26 UTC (rev 4156)
@@ -397,6 +397,9 @@
return -1;
}
memcpy(ov, ptr, MIN(ap->descr->elsize,len));
+ /* If string lenth is smaller than room in array
+ Then fill the rest of the element size
+ with NULL */
if (ap->descr->elsize > len) {
memset(ov + len, 0, (ap->descr->elsize - len));
}
Modified: trunk/numpy/core/tests/test_regression.py
===================================================================
--- trunk/numpy/core/tests/test_regression.py 2007-10-04 19:40:50 UTC (rev 4155)
+++ trunk/numpy/core/tests/test_regression.py 2007-10-09 03:08:26 UTC (rev 4156)
@@ -725,6 +725,13 @@
"""Ticket #572"""
N.lib.place(1,1,1)
+ def check_dot_negative_stride(self, level=rlevel):
+ """Ticket #588"""
+ x = N.array([[1,5,25,125.,625]])
+ y = N.array([[20.],[160.],[640.],[1280.],[1024.]])
+ z = y[::-1].copy()
+ y2 = y[::-1]
+ assert_equal(N.dot(x,z),N.dot(x,y2))
if __name__ == "__main__":
NumpyTest().run()
Modified: trunk/numpy/lib/function_base.py
===================================================================
--- trunk/numpy/lib/function_base.py 2007-10-04 19:40:50 UTC (rev 4155)
+++ trunk/numpy/lib/function_base.py 2007-10-09 03:08:26 UTC (rev 4156)
@@ -1,7 +1,7 @@
__docformat__ = "restructuredtext en"
__all__ = ['logspace', 'linspace',
'select', 'piecewise', 'trim_zeros',
- 'copy', 'iterable', #'base_repr', 'binary_repr',
+ 'copy', 'iterable',
'diff', 'gradient', 'angle', 'unwrap', 'sort_complex', 'disp',
'unique', 'extract', 'place', 'nansum', 'nanmax', 'nanargmax',
'nanargmin', 'nanmin', 'vectorize', 'asarray_chkfinite', 'average',
More information about the Numpy-svn
mailing list