[Numpy-svn] r3142 - in trunk/numpy/core: include/numpy src

numpy-svn at scipy.org numpy-svn at scipy.org
Mon Sep 11 19:22:57 EDT 2006


Author: oliphant
Date: 2006-09-11 18:22:54 -0500 (Mon, 11 Sep 2006)
New Revision: 3142

Modified:
   trunk/numpy/core/include/numpy/ufuncobject.h
   trunk/numpy/core/src/multiarraymodule.c
   trunk/numpy/core/src/ufuncobject.c
Log:
Improve the getting and setting of ufunc loops for user-defined types. 

Modified: trunk/numpy/core/include/numpy/ufuncobject.h
===================================================================
--- trunk/numpy/core/include/numpy/ufuncobject.h	2006-09-11 22:32:39 UTC (rev 3141)
+++ trunk/numpy/core/include/numpy/ufuncobject.h	2006-09-11 23:22:54 UTC (rev 3142)
@@ -187,7 +187,17 @@
         PyObject *callable;
 } PyUFunc_PyFuncData;
 
+/* A linked-list of function information for
+   user-defined 1-d loops.
+ */
+typedef struct _loop1d_info {
+        PyUFuncGenericFunction func;
+        void *data;
+        int *arg_types;
+        struct _loop1d_info *next;
+} PyUFunc_Loop1d;
 
+
 #include "__ufunc_api.h"
 
 #define UFUNC_PYVALS_NAME "UFUNC_PYVALS"

Modified: trunk/numpy/core/src/multiarraymodule.c
===================================================================
--- trunk/numpy/core/src/multiarraymodule.c	2006-09-11 22:32:39 UTC (rev 3141)
+++ trunk/numpy/core/src/multiarraymodule.c	2006-09-11 23:22:54 UTC (rev 3142)
@@ -1731,7 +1731,6 @@
 }
 
 
-
 /*OBJECT_API*/
 static PyArrayObject **
 PyArray_ConvertToCommonType(PyObject *op, int *retn)

Modified: trunk/numpy/core/src/ufuncobject.c
===================================================================
--- trunk/numpy/core/src/ufuncobject.c	2006-09-11 22:32:39 UTC (rev 3141)
+++ trunk/numpy/core/src/ufuncobject.c	2006-09-11 23:22:54 UTC (rev 3142)
@@ -616,7 +616,49 @@
         }
 }
 
+static char *_types_msg =  "function not supported for these types, "   \
+        "and can't coerce safely to supported types";
+
+/* Called for non-NULL user-defined functions. 
+   The object should be a CObject pointing to a linked-list of functions
+   storing the function, data, and signature of all user-defined functions. 
+   There must be a match with the input argument types or an error 
+   will occur. 
+ */
+static int
+_find_matching_userloop(PyObject *obj, int *arg_types, 
+                        PyArray_SCALARKIND *scalars,
+                        PyUFuncGenericFunction *function, void **data, 
+                        int nargs)
+{
+        PyUFunc_Loop1d *funcdata;
+        int i;
+        funcdata = (PyUFunc_Loop1d *)PyCObject_AsVoidPtr(obj);
+        while (funcdata != NULL) {
+                for (i=0; i<nargs; i++) {
+			if (!PyArray_CanCoerceScalar(arg_types[i],
+						     funcdata->arg_types[i],
+						     scalars[i])) 
+                                break;
+                }
+                if (i==nargs) { /* match found */
+                        *function = funcdata->func;
+                        *data = funcdata->data;
+                        /* Make sure actual arg_types supported
+                           by the loop are used */
+                        for (i=0; i<nargs; i++) {
+                                arg_types[i] = funcdata->arg_types[i];
+                        }
+                        return 0;
+                }
+                funcdata = funcdata->next;
+        } 
+        PyErr_SetString(PyExc_TypeError, _types_msg);
+        return -1;
+}
+
 /* Called to determine coercion
+   Can change arg_types. 
  */
 
 static int
@@ -639,8 +681,7 @@
 
 	if (userdef > 0) {
 		PyObject *key, *obj;
-		int *this_types=NULL;
-
+                int ret;
 		obj = NULL;
 		key = PyInt_FromLong((long) userdef);
 		if (key == NULL) return -1;
@@ -652,37 +693,13 @@
 					" with no registered loops");
 			return -1;
 		}
-		if PyTuple_Check(obj) {
-			PyObject *item;
-			*function = (PyUFuncGenericFunction)		\
-				PyCObject_AsVoidPtr(PyTuple_GET_ITEM(obj,0));
-			item = PyTuple_GET_ITEM(obj, 2);
-			if (PyCObject_Check(item)) {
-				*data = PyCObject_AsVoidPtr(item);
-			}
-			item = PyTuple_GET_ITEM(obj, 1);
-			if (PyCObject_Check(item)) {
-					this_types = PyCObject_AsVoidPtr(item);
-			}
-		}
-		else {
-			*function = (PyUFuncGenericFunction)		\
-				PyCObject_AsVoidPtr(obj);
-			*data = NULL;
-		}
-
-		if (this_types == NULL) {
-			for (i=1; i<self->nargs; i++) {
-				arg_types[i] = userdef;
-			}
-		}
-		else {
-			for (i=1; i<self->nargs; i++) {
-				arg_types[i] = this_types[i];
-			}
-		}
-		Py_DECREF(obj);
-			return 0;
+                /* extract the correct function
+                   data and argtypes
+                */
+                ret = _find_matching_userloop(obj, arg_types, scalars,
+                                              function, data, self->nargs);
+                Py_DECREF(obj);
+                return ret;
 	}
 
 	start_type = arg_types[0];
@@ -707,9 +724,7 @@
 		if (j == self->nin) break;
 	}
 	if(i>=self->ntypes) {
-		PyErr_SetString(PyExc_TypeError,
-				"function not supported for these types, "\
-				"and can't coerce safely to supported types");
+		PyErr_SetString(PyExc_TypeError, _types_msg);
 		return -1;
 	}
 	for(j=0; j<self->nargs; j++)
@@ -876,7 +891,7 @@
 #undef _GETATTR_
 
 static int
-construct_matrices(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps)
+construct_arrays(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps)
 {
         int nargs, i, maxsize;
         int arg_types[NPY_MAXARGS];
@@ -963,7 +978,8 @@
                 }
         }
 
-	/* Create copies for some of the arrays if appropriate */
+	/* Create copies for some of the arrays if they are small
+           enough and not already contiguous */
 	if (_create_copies(loop, arg_types, mps) < 0) return -1;
 
 	/* Create Iterators for the Inputs */
@@ -1344,8 +1360,8 @@
 				&(loop->bufsize), &(loop->errormask),
 				&(loop->errobj)) < 0) goto fail;
 
-	/* Setup the matrices */
-	if (construct_matrices(loop, args, mps) < 0) goto fail;
+	/* Setup the arrays */
+	if (construct_arrays(loop, args, mps) < 0) goto fail;
 
 	PyUFunc_clearfperr();
 
@@ -1421,7 +1437,7 @@
 
 /* This generic function is called with the ufunc object, the arguments to it,
    and an array of (pointers to) PyArrayObjects which are NULL.  The
-   arguments are parsed and placed in mps in construct_loop (construct_matrices)
+   arguments are parsed and placed in mps in construct_loop (construct_arrays)
 */
 
 /*UFUNC_API*/
@@ -3087,6 +3103,45 @@
 	return (PyObject *)self;
 }
 
+typedef struct {
+        PyObject_HEAD
+        void *c_obj;
+} _simple_cobj;
+
+#define _SETCPTR(cobj, val) ((_simple_cobj *)(cobj))->c_obj = (val)
+
+/* return 1 if arg1 > arg2, 0 if arg1 == arg2, and -1 if arg1 < arg2
+ */
+static int 
+cmp_arg_types(int *arg1, int *arg2, int n)
+{
+        while (n--) {
+                if (*arg1 > *arg2)
+                        return 1;
+                else if (*arg1 < *arg2)
+                        return -1;
+                arg1++; arg2++;
+        }
+        return 0;
+}
+
+/* This frees the linked-list structure 
+   when the CObject is destroyed (removed
+   from the internal dictionary)
+*/
+static void
+_loop1d_list_free(void *ptr)
+{
+        PyUFunc_Loop1d *funcdata;
+        if (ptr == NULL) return;
+        funcdata = (PyUFunc_Loop1d *)ptr;
+        if (funcdata == NULL) return;
+        _pya_free(funcdata->arg_types);
+        _loop1d_list_free(funcdata->next);
+        _pya_free(funcdata);
+}
+
+
 /*UFUNC_API*/
 static int
 PyUFunc_RegisterLoopForType(PyUFuncObject *ufunc,
@@ -3096,9 +3151,12 @@
 			    void *data)
 {
 	PyArray_Descr *descr;
+        PyUFunc_Loop1d *funcdata;
     	PyObject *key, *cobj;
-	int ret;
+	int i;
+        int *newtypes=NULL;
 
+
 	descr=PyArray_DescrFromType(usertype);
 	if ((usertype < PyArray_USERDEF) || (descr==NULL)) {
 		PyErr_SetString(PyExc_TypeError,
@@ -3112,49 +3170,90 @@
 	}
 	key = PyInt_FromLong((long) usertype);
 	if (key == NULL) return -1;
-	cobj = PyCObject_FromVoidPtr((void *)function, NULL);
-	if (cobj == NULL) {Py_DECREF(key); return -1;}
-	if (data == NULL && arg_types == NULL) {
-		ret = PyDict_SetItem(ufunc->userloops, key, cobj);
-		Py_DECREF(cobj);
-		Py_DECREF(key);
-		return ret;
-	}
-	else {
-		PyObject *cobj2, *cobj3, *tmp;
-		if (arg_types == NULL) {
-			cobj2 = Py_None;
-			Py_INCREF(cobj2);
-		}
-		else {
-			cobj2 = PyCObject_FromVoidPtr((void *)arg_types, NULL);
-			if (cobj2 == NULL) {
-				Py_DECREF(cobj);
-				Py_DECREF(key);
-				return -1;
-			}
-		}
-		if (data == NULL) {
-			cobj3 = Py_None;
-			Py_INCREF(cobj3);
-		}
-		else {
-			cobj3 = PyCObject_FromVoidPtr(data, NULL);
-			if (cobj3 == NULL) {
-				Py_DECREF(cobj2);
-				Py_DECREF(cobj);
-				Py_DECREF(key);
-				return -1;
-			}
-		}
-		tmp=Py_BuildValue("NNN", cobj, cobj2, cobj3);
-		ret = PyDict_SetItem(ufunc->userloops, key, tmp);
-		Py_DECREF(tmp);
-		Py_DECREF(key);
-		return ret;
-	}
+        funcdata = _pya_malloc(sizeof(PyUFunc_Loop1d));
+        if (funcdata == NULL) goto fail;
+        newtypes = _pya_malloc(sizeof(int)*ufunc->nargs);
+        if (newtypes == NULL) goto fail;
+        if (arg_types != NULL) {
+                for (i=0; i<ufunc->nargs; i++) {
+                        newtypes[i] = arg_types[i];
+                }
+        }
+        else {
+                for (i=0; i<ufunc->nargs; i++) {
+                        newtypes[i] = usertype;
+                }
+        }
+
+        funcdata->func = function;
+        funcdata->arg_types = newtypes;
+        funcdata->data = data;
+        funcdata->next = NULL;
+
+        /* Get entry for this user-defined type*/
+        cobj = PyDict_GetItem(ufunc->userloops, key);
+
+        /* If it's not there, then make one and return. */
+        if (cobj == NULL) {
+                cobj = PyCObject_FromVoidPtr((void *)function, 
+                                             _loop1d_list_free);
+                if (cobj == NULL) goto fail;
+                PyDict_SetItem(ufunc->userloops, key, cobj);
+                Py_DECREF(cobj);
+                Py_DECREF(key);
+                return 0;
+        }
+        else {
+                PyUFunc_Loop1d *current, *prev=NULL;
+                int cmp;
+                /* There is already at least 1 loop. Place this one in 
+                   lexicographic order.  If the next one signature
+                   is exactly like this one, then just replace.
+                   Otherwise insert. 
+                */
+                current = (PyUFunc_Loop1d *)PyCObject_AsVoidPtr(cobj);
+                while (current != NULL) {
+                        cmp = cmp_arg_types(current->arg_types, newtypes, 
+                                            ufunc->nargs);
+                        if (cmp >= 0) break;
+                        prev = current;
+                        current = current->next;
+                }
+                if (cmp == 0) { /* just replace it with new function */
+                        current->func = function;
+                        current->data = data;
+                        _pya_free(newtypes);
+                        _pya_free(funcdata);
+                }
+                else { /* insert it before the current one 
+                          by hacking the internals of cobject to 
+                          replace the function pointer --- 
+                          can't use API because destructor is set. 
+                       */
+                        funcdata->next = current;
+                        if (prev == NULL) { /* place this at front */
+                                _SETCPTR(cobj, funcdata);
+                        }
+                        else {
+                                prev->next = funcdata;
+                        }
+                }
+        }
+        Py_DECREF(key);
+        return 0;
+        
+
+ fail:
+        Py_DECREF(key);
+        _pya_free(funcdata);
+        _pya_free(newtypes);
+        if (!PyErr_Occurred()) PyErr_NoMemory();
+        return -1;
 }
 
+#undef _SETCPTR
+
+
 static void
 ufunc_dealloc(PyUFuncObject *self)
 {




More information about the Numpy-svn mailing list