[Numpy-svn] r2740 - trunk/numpy/core/src

numpy-svn at scipy.org numpy-svn at scipy.org
Thu Jul 6 04:01:30 EDT 2006


Author: oliphant
Date: 2006-07-06 03:01:27 -0500 (Thu, 06 Jul 2006)
New Revision: 2740

Modified:
   trunk/numpy/core/src/arrayobject.c
Log:
Add support for == and != comparison of void-type arrays.

Modified: trunk/numpy/core/src/arrayobject.c
===================================================================
--- trunk/numpy/core/src/arrayobject.c	2006-07-06 07:06:44 UTC (rev 2739)
+++ trunk/numpy/core/src/arrayobject.c	2006-07-06 08:01:27 UTC (rev 2740)
@@ -4123,11 +4123,11 @@
 				      NULL);
 	if (result == NULL) goto finish;
 
-	if (self->descr->type_num == PyArray_STRING) {
-		val = _compare_strings(result, mit, cmp_op, _mystrncmp);
+	if (self->descr->type_num == PyArray_UNICODE) {
+		val = _compare_strings(result, mit, cmp_op, _myunincmp);
 	}
 	else {
-		val = _compare_strings(result, mit, cmp_op, _myunincmp);
+		val = _compare_strings(result, mit, cmp_op, _mystrncmp);
 	}
 	
 	if (val < 0) {Py_DECREF(result); result = NULL;}
@@ -4137,10 +4137,64 @@
         return result;
 }
 
-/* What do we do about VOID type arrays?
+/* VOID-type arrays can only be compared equal and not-equal
+    in which case the fields are all compared by extracting the fields
+    and testing one at a time...
+    equality testing is performed using logical_ands on all the fields.
+    in-equality testing is performed using logical_ors on all the fields.
+
+    VOID-type arrays without fields are compared for equality by comparing their 
+    memory at each location directly (using string-code).
  */
 
 static PyObject *
+_void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op) 
+{
+	if (!(cmp_op == Py_EQ || cmp_op == Py_NE)) {
+		PyErr_SetString(PyExc_ValueError, "Void-arrays can only" \
+				"be compared for equality.");
+		return NULL;
+	}
+	if (PyArray_HASFIELDS(self)) {
+		PyObject *res=NULL, *temp, *a, *b;
+                PyObject *key, *value, *temp2;
+		PyObject *op, *op2;
+                int pos=0;
+		op = (cmp_op == Py_EQ ? n_ops.equal : n_ops.not_equal);
+		op2 = (cmp_op == Py_EQ ? n_ops.logical_and : n_ops.logical_or);
+                while (PyDict_Next(self->descr->fields, &pos, &key, &value)) {
+			if (!PyString_Check(key)) continue;
+			a = array_subscript(self, key);
+			if (a==NULL) {Py_XDECREF(res); return NULL;}
+			b = array_subscript(other, key);
+			if (b==NULL) {Py_XDECREF(res); Py_DECREF(a); return NULL;}
+			temp = PyObject_CallFunction(op, "OO", a, b);
+			Py_DECREF(a);
+			Py_DECREF(b);
+			if (temp == NULL) {Py_XDECREF(res); return NULL;}
+			if (res == NULL) {
+				res = temp;
+			}
+			else {
+				temp2 = PyObject_CallFunction(op2, "OO", res, temp);
+				Py_DECREF(temp);
+				Py_DECREF(res);
+				if (temp2 == NULL) return NULL;
+				res = temp2;
+			}
+                }
+		if (res == NULL && !PyErr_Occurred()) {
+			PyErr_SetString(PyExc_ValueError, "No fields found.");
+		}
+                return res;
+	}
+	else { /* compare as a string */
+		/* assumes self and other have same descr->type */
+		return _strings_richcompare(self, other, cmp_op);
+	}
+}
+
+static PyObject *
 array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
 {
         PyObject *array_other, *result = NULL;
@@ -4190,6 +4244,29 @@
                         result = PyArray_GenericBinaryFunction(self,
 							       array_other,
 							       n_ops.equal);
+			if ((result == Py_NotImplemented) && 
+			    (self->descr->type_num == PyArray_VOID)) {
+				int _res;
+				_res = PyObject_RichCompareBool	\
+					((PyObject *)self->descr, 
+					 (PyObject *)\
+					 PyArray_DESCR(array_other),
+					 Py_EQ);
+				if (_res < 0) {
+					Py_DECREF(result);
+					Py_DECREF(array_other);
+					return NULL;
+				}
+				if (_res) {
+					Py_DECREF(result);
+					result = _void_compare\
+						(self, 
+						 (PyArrayObject *)array_other,
+						 cmp_op);
+					Py_DECREF(array_other);
+				}
+				return result;
+			}
                         /* If the comparison results in NULL, then the
 			   two array objects can not be compared together so
 			   return zero
@@ -4233,6 +4310,30 @@
 			result = PyArray_GenericBinaryFunction(self,
 							       array_other,
 							       n_ops.not_equal);
+			if ((result == Py_NotImplemented) && 
+			    (self->descr->type_num == PyArray_VOID)) {
+				int _res;
+				_res = PyObject_RichCompareBool\
+					((PyObject *)self->descr, 
+					 (PyObject *)\
+					 PyArray_DESCR(array_other),
+					 Py_EQ);
+				if (_res < 0) {
+					Py_DECREF(result);
+					Py_DECREF(array_other);
+					return NULL;
+				}
+				if (_res) {
+					Py_DECREF(result);
+					result = _void_compare\
+						(self, 
+						 (PyArrayObject *)array_other, 
+						 cmp_op);
+					Py_DECREF(array_other);
+				}
+				return result;
+			}
+
 			Py_DECREF(array_other);
                         if (result == NULL) {
                                 PyErr_Clear();
@@ -4257,6 +4358,7 @@
                 if (self->descr->type_num == PyArray_OBJECT) return result;
                 array_other = PyArray_FromObject(other,PyArray_NOTYPE, 0, 0);
                 if (PyArray_ISSTRING(self) && PyArray_ISSTRING(array_other)) {
+			Py_DECREF(result);
                         result = _strings_richcompare(self, (PyArrayObject *)
 						      array_other, cmp_op);
                 }




More information about the Numpy-svn mailing list