[Numpy-svn] r4566 - in trunk/numpy/core: src tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Mon Dec 10 23:48:40 EST 2007
Author: oliphant
Date: 2007-12-10 22:48:31 -0600 (Mon, 10 Dec 2007)
New Revision: 4566
Modified:
trunk/numpy/core/src/arraymethods.c
trunk/numpy/core/src/arraytypes.inc.src
trunk/numpy/core/src/multiarraymodule.c
trunk/numpy/core/tests/test_multiarray.py
Log:
Allow clip method to have either min or max passed in.
Modified: trunk/numpy/core/src/arraymethods.c
===================================================================
--- trunk/numpy/core/src/arraymethods.c 2007-12-11 02:12:30 UTC (rev 4565)
+++ trunk/numpy/core/src/arraymethods.c 2007-12-11 04:48:31 UTC (rev 4566)
@@ -1646,16 +1646,20 @@
static PyObject *
array_clip(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
- PyObject *min, *max;
+ PyObject *min=NULL, *max=NULL;
PyArrayObject *out=NULL;
static char *kwlist[] = {"min", "max", "out", NULL};
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO|O&", kwlist,
+ if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOO&", kwlist,
&min, &max,
PyArray_OutputConverter,
&out))
return NULL;
+ if (max == NULL && min == NULL) {
+ PyErr_SetString(PyExc_ValueError, "One of max or min must be given.");
+ return NULL;
+ }
return _ARET(PyArray_Clip(self, min, max, out));
}
Modified: trunk/numpy/core/src/arraytypes.inc.src
===================================================================
--- trunk/numpy/core/src/arraytypes.inc.src 2007-12-11 02:12:30 UTC (rev 4565)
+++ trunk/numpy/core/src/arraytypes.inc.src 2007-12-11 04:48:31 UTC (rev 4566)
@@ -2061,9 +2061,29 @@
register npy_intp i;
@type@ max_val, min_val;
- max_val = *max;
- min_val = *min;
+ if (max != NULL)
+ max_val = *max;
+ if (min != NULL)
+ min_val = *min;
+ if (max == NULL) {
+ for (i = 0; i < ni; i++) {
+ if (in[i] < min_val) {
+ out[i] = min_val;
+ }
+ }
+ return;
+ }
+
+ if (min == NULL) {
+ for (i = 0; i < ni; i++) {
+ if (in[i] > max_val) {
+ out[i] = max_val;
+ }
+ }
+ return;
+ }
+
for (i = 0; i < ni; i++) {
if (in[i] < min_val) {
out[i] = min_val;
@@ -2085,10 +2105,33 @@
{
register npy_intp i;
@type@ max_val, min_val;
-
+
min_val = *min;
max_val = *max;
+ if (max != NULL)
+ max_val = *max;
+ if (min != NULL)
+ min_val = *min;
+
+ if (max == NULL) {
+ for (i = 0; i < ni; i++) {
+ if (PyArray_CLT(in[i],min_val)) {
+ out[i] = min_val;
+ }
+ }
+ return;
+ }
+
+ if (min == NULL) {
+ for (i = 0; i < ni; i++) {
+ if (PyArray_CGT(in[i], max_val)) {
+ out[i] = max_val;
+ }
+ }
+ return;
+ }
+
for (i = 0; i < ni; i++) {
if (PyArray_CLT(in[i], min_val)) {
out[i] = min_val;
Modified: trunk/numpy/core/src/multiarraymodule.c
===================================================================
--- trunk/numpy/core/src/multiarraymodule.c 2007-12-11 02:12:30 UTC (rev 4565)
+++ trunk/numpy/core/src/multiarraymodule.c 2007-12-11 04:48:31 UTC (rev 4566)
@@ -1105,27 +1105,39 @@
PyObject *res1=NULL, *res2=NULL, *res3=NULL;
PyObject *two;
- two = PyInt_FromLong((long)2);
- res1 = PyArray_GenericBinaryFunction(self, max, n_ops.greater);
- res2 = PyArray_GenericBinaryFunction(self, min, n_ops.less);
- if ((res1 == NULL) || (res2 == NULL)) {
- Py_DECREF(two);
- Py_XDECREF(res1);
- Py_XDECREF(res2);
- return NULL;
+ if (max != NULL) {
+ res1 = PyArray_GenericBinaryFunction(self, max, n_ops.greater);
+ if (res1 == NULL) return NULL;
}
- res3 = PyNumber_Multiply(two, res1);
- Py_DECREF(two);
- Py_DECREF(res1);
- if (res3 == NULL) return NULL;
+ if (min != NULL) {
+ res2 = PyArray_GenericBinaryFunction(self, min, n_ops.less);
+ if (res2 == NULL) {Py_XDECREF(res1); return NULL;}
+ }
- selector = PyArray_EnsureAnyArray(PyNumber_Add(res2, res3));
- Py_DECREF(res2);
- Py_DECREF(res3);
- if (selector == NULL) return NULL;
+ if (max == NULL) {
+ selector = res2; /* Steal the reference */
+ newtup = Py_BuildValue("(OO)", (PyObject *)self, min);
+ }
+ else if (min == NULL) {
+ selector = res1; /* Steal the reference */
+ newtup = Py_BuildValue("(OO)", (PyObject *)self, max);
+ }
+ else {
+ two = PyInt_FromLong((long)2);
+ res3 = PyNumber_Multiply(two, res1);
+ Py_DECREF(two);
+ Py_DECREF(res1);
+ if (res3 == NULL) return NULL;
+ selector = PyArray_EnsureAnyArray(PyNumber_Add(res2, res3));
+ Py_DECREF(res2);
+ Py_DECREF(res3);
+ if (selector == NULL) return NULL;
- newtup = Py_BuildValue("(OOO)", (PyObject *)self, min, max);
+ newtup = Py_BuildValue("(OOO)", (PyObject *)self, min, max);
+ }
+
if (newtup == NULL) {Py_DECREF(selector); return NULL;}
+
ret = PyArray_Choose((PyAO *)selector, newtup, out, NPY_RAISE);
Py_DECREF(selector);
Py_DECREF(newtup);
@@ -1144,22 +1156,38 @@
PyArrayObject *mina=NULL;
PyArrayObject *newout=NULL, *newin=NULL;
PyArray_Descr *indescr, *newdescr;
+ char *max_data, *min_data;
PyObject *zero;
+ if ((max == NULL) && (min == NULL)) {
+ PyErr_SetString(PyExc_ValueError, "array_clip: must set either max "\
+ "or min");
+ return NULL;
+ }
+
func = self->descr->f->fastclip;
- if (func == NULL || !PyArray_CheckAnyScalar(min) ||
- !PyArray_CheckAnyScalar(max))
+ if (func == NULL || (min != NULL && !PyArray_CheckAnyScalar(min)) ||
+ (max != NULL && !PyArray_CheckAnyScalar(max)))
return _slow_array_clip(self, min, max, out);
/* Use the fast scalar clip function */
/* First we need to figure out the correct type */
- indescr = PyArray_DescrFromObject(min, NULL);
- if (indescr == NULL) return NULL;
- newdescr = PyArray_DescrFromObject(max, indescr);
- Py_DECREF(indescr);
+ indescr = NULL;
+ if (min != NULL) {
+ indescr = PyArray_DescrFromObject(min, NULL);
+ if (indescr == NULL) return NULL;
+ }
+ if (max != NULL) {
+ newdescr = PyArray_DescrFromObject(max, indescr);
+ Py_XDECREF(indescr);
+ if (newdescr == NULL) return NULL;
+ }
+ else {
+ newdescr = indescr; /* Steal the reference */
+ }
+
- if (newdescr == NULL) return NULL;
/* Use the scalar descriptor only if it is of a bigger
KIND than the input array (and then find the
type that matches both).
@@ -1184,9 +1212,15 @@
}
/* Convert max to an array */
- maxa = (NPY_AO *)PyArray_FromAny(max, indescr, 0, 0,
- NPY_DEFAULT, NULL);
- if (maxa == NULL) return NULL;
+ if (max != NULL) {
+ maxa = (NPY_AO *)PyArray_FromAny(max, indescr, 0, 0,
+ NPY_DEFAULT, NULL);
+ if (maxa == NULL) return NULL;
+ }
+ else {
+ /* Side-effect of PyArray_FromAny */
+ Py_DECREF(indescr);
+ }
/* If we are unsigned, then make sure min is not <0 */
@@ -1197,31 +1231,33 @@
for other data-types in which case they
are interpreted as their modular counterparts.
*/
- if (PyArray_ISUNSIGNED(self)) {
- int cmp;
- zero = PyInt_FromLong(0);
- cmp = PyObject_RichCompareBool(min, zero, Py_LT);
- if (cmp == -1) { Py_DECREF(zero); goto fail;}
- if (cmp == 1) {
- min = zero;
- }
- else {
- Py_DECREF(zero);
- Py_INCREF(min);
- }
+ if (min != NULL) {
+ if (PyArray_ISUNSIGNED(self)) {
+ int cmp;
+ zero = PyInt_FromLong(0);
+ cmp = PyObject_RichCompareBool(min, zero, Py_LT);
+ if (cmp == -1) { Py_DECREF(zero); goto fail;}
+ if (cmp == 1) {
+ min = zero;
+ }
+ else {
+ Py_DECREF(zero);
+ Py_INCREF(min);
+ }
+ }
+ else {
+ Py_INCREF(min);
+ }
+
+ /* Convert min to an array */
+ Py_INCREF(indescr);
+ mina = (NPY_AO *)PyArray_FromAny(min, indescr, 0, 0,
+ NPY_DEFAULT, NULL);
+ Py_DECREF(min);
+ if (mina == NULL) goto fail;
}
- else {
- Py_INCREF(min);
- }
+
- /* Convert min to an array */
- Py_INCREF(indescr);
- mina = (NPY_AO *)PyArray_FromAny(min, indescr, 0, 0,
- NPY_DEFAULT, NULL);
- Py_DECREF(min);
- if (mina == NULL) goto fail;
-
-
/* Check to see if input is single-segment, aligned,
and in native byteorder */
if (PyArray_ISONESEGMENT(self) && PyArray_CHKFLAGS(self, ALIGNED) &&
@@ -1311,12 +1347,18 @@
/* Now we can call the fast-clip function */
- func(newin->data, PyArray_SIZE(newin), mina->data, maxa->data,
+ min_data = max_data = NULL;
+ if (mina != NULL)
+ min_data = mina->data;
+ if (maxa != NULL)
+ max_data = maxa->data;
+
+ func(newin->data, PyArray_SIZE(newin), min_data, max_data,
newout->data);
/* Clean up temporary variables */
- Py_DECREF(mina);
- Py_DECREF(maxa);
+ Py_XDECREF(mina);
+ Py_XDECREF(maxa);
Py_DECREF(newin);
/* Copy back into out if out was not already a nice array. */
Py_DECREF(newout);
Modified: trunk/numpy/core/tests/test_multiarray.py
===================================================================
--- trunk/numpy/core/tests/test_multiarray.py 2007-12-11 02:12:30 UTC (rev 4565)
+++ trunk/numpy/core/tests/test_multiarray.py 2007-12-11 04:48:31 UTC (rev 4566)
@@ -422,6 +422,15 @@
y = rec['x'].clip(-0.3,0.5)
self._check_range(y,-0.3,0.5)
+ def check_max_or_min(self):
+ val = N.array([0,1,2,3,4,5,6,7])
+ x = val.clip(3)
+ assert N.all(x >= 3)
+ x = val.clip(min=3)
+ assert N.all(x >= 3)
+ x = val.clip(max=4)
+ assert N.all(x <= 4)
+
class TestPutmask(ParametricTestCase):
def tst_basic(self,x,T,mask,val):
N.putmask(x,mask,val)
More information about the Numpy-svn
mailing list