[Numpy-svn] r3355 - trunk/numpy/core/src

numpy-svn at scipy.org numpy-svn at scipy.org
Wed Oct 18 02:38:31 EDT 2006


Author: oliphant
Date: 2006-10-18 01:38:26 -0500 (Wed, 18 Oct 2006)
New Revision: 3355

Modified:
   trunk/numpy/core/src/ufuncobject.c
Log:
Add ability to select specific 1-d loop to use when calling a ufunc, either by the output type or an exact signature.  An error occurs if a 1-d loop is not found to match the specification.

Modified: trunk/numpy/core/src/ufuncobject.c
===================================================================
--- trunk/numpy/core/src/ufuncobject.c	2006-10-18 04:44:28 UTC (rev 3354)
+++ trunk/numpy/core/src/ufuncobject.c	2006-10-18 06:38:26 UTC (rev 3355)
@@ -657,6 +657,142 @@
         return -1;
 }
 
+/* if only one type is specified then it is the "first" output data-type
+   and the first signature matching this output data-type is returned. 
+   
+   if a tuple of types is specified then an exact match to the signature
+   is searched and it much match exactly or an error occurs
+*/
+static int
+extract_specified_loop(PyUFuncObject *self, int *arg_types, 
+                       PyUFuncGenericFunction *function, void **data,
+                       PyObject *type_tup, int userdef)
+{
+        Py_ssize_t n=1;
+        int *rtypenums;
+        static char msg[] = "loop written to specified type(s) not found";
+        PyArray_Descr *dtype;
+        int nargs;
+        int i, j;
+
+        nargs = self->nargs;
+ 
+        if (PyTuple_Check(type_tup)) {
+                n = PyTuple_GET_SIZE(type_tup);
+                if (n != 1 && n != nargs) {
+                        PyErr_Format(PyExc_ValueError, 
+                                     "a type-tuple must be specified "  \
+                                     "of length 1 or %d for %s", nargs, 
+                                     self->name ? self->name : "(unknown)");
+                        return -1;
+                }
+        }
+        else {
+                n = 1;
+        }
+        rtypenums = (int *)_pya_malloc(n*sizeof(int));
+        if (rtypenums==NULL) {
+                PyErr_NoMemory();
+                return -1;
+        }
+        
+        if (PyTuple_Check(type_tup)) {
+                for (i=0; i<n; i++) {
+                        if (PyArray_DescrConverter(PyTuple_GET_ITEM     \
+                                                   (type_tup, i),
+                                                   &dtype) == NPY_FAIL) 
+                                goto fail;
+                        rtypenums[i] = dtype->type_num;
+                        Py_DECREF(dtype);
+                }
+        }
+        else {
+                if (PyArray_DescrConverter(type_tup, &dtype) == NPY_FAIL) {
+                        goto fail;
+                }
+                rtypenums[0] = dtype->type_num;
+                Py_DECREF(dtype);
+        }
+
+        if (userdef > 0) { /* search in the user-defined functions */
+		PyObject *key, *obj;
+                PyUFunc_Loop1d *funcdata;
+		obj = NULL;
+		key = PyInt_FromLong((long) userdef);
+		if (key == NULL) goto fail;
+		obj = PyDict_GetItem(self->userloops, key);
+		Py_DECREF(key);
+		if (obj == NULL) {
+			PyErr_SetString(PyExc_TypeError,
+					"user-defined type used in ufunc" \
+					" with no registered loops");
+                        goto fail;
+		}
+                /* extract the correct function
+                   data and argtypes
+                */
+                funcdata = (PyUFunc_Loop1d *)PyCObject_AsVoidPtr(obj);
+                while (funcdata != NULL) {
+                        if (n != 1) {
+                                for (i=0; i<nargs; i++) {
+                                        if (rtypenums[i] != funcdata->arg_types[i])
+                                                break;
+                                }
+                        }
+                        else if (rtypenums[0] == funcdata->arg_types[self->nin]) {
+                                i = nargs;
+                        }
+                        if (i == nargs) {
+                                *function = funcdata->func;
+                                *data = funcdata->data;
+                                for (i=0; i<nargs; i++) {
+                                        arg_types[i] = funcdata->arg_types[i];
+                                }
+                                Py_DECREF(obj);
+                                goto finish;
+                        }
+                        funcdata = funcdata->next;
+                }
+                Py_DECREF(obj);
+                PyErr_SetString(PyExc_TypeError, msg);
+                goto fail;
+        }
+
+        /* look for match in self->functions */
+
+        for (j=0; j<self->ntypes; j++) {
+                if (n != 1) {
+                        for (i=0; i<nargs; i++) {
+                                if (rtypenums[i] != self->types[j*nargs + i])
+                                        break;
+                        }
+                }
+                else if (rtypenums[0] == self->types[j*nargs+self->nin]) {
+                        i = nargs;
+                }
+                if (i == nargs) {
+                        *function = self->functions[j];
+                        *data = self->data[j];
+                        for (i=0; i<nargs; i++) {
+                                arg_types[i] = self->types[j*nargs+i];
+                        }
+                        goto finish;
+                }
+        }
+        PyErr_SetString(PyExc_TypeError, msg);
+
+ fail:
+        _pya_free(rtypenums);
+        return -1;
+
+ finish:
+        _pya_free(rtypenums);
+        return 0;
+
+                
+}
+
+
 /* Called to determine coercion
    Can change arg_types. 
  */
@@ -665,18 +801,12 @@
 select_types(PyUFuncObject *self, int *arg_types,
              PyUFuncGenericFunction *function, void **data,
 	     PyArray_SCALARKIND *scalars,
-             PyArray_Descr *dtype)
+             PyObject *typetup)
 {
 	int i, j;
 	char start_type;
 	int userdef=-1;
 
-        /*
-          if (dtype != NULL)
-          return extract_specified_loop(self, arg_types, function, data,
-          scalars, dtype)
-        */
-
 	if (self->userloops) {
 		for (i=0; i<self->nin; i++) {
 			if (PyTypeNum_ISUSERDEF(arg_types[i])) {
@@ -685,7 +815,11 @@
 			}
 		}
 	}
-
+        
+        if (typetup != NULL)
+                return extract_specified_loop(self, arg_types, function, data,
+                                              typetup, userdef);
+                        
 	if (userdef > 0) {
 		PyObject *key, *obj;
                 int ret;
@@ -753,7 +887,8 @@
 
 
 static int
-_extract_pyvals(PyObject *ref, char *name, int *bufsize, int *errmask, PyObject **errobj)
+_extract_pyvals(PyObject *ref, char *name, int *bufsize, 
+                int *errmask, PyObject **errobj)
 {
         PyObject *retval;
 
@@ -908,7 +1043,7 @@
 
 static int
 construct_arrays(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps,
-                 PyArray_Descr *dtype)
+                 PyObject *typetup)
 {
         int nargs, i, maxsize;
         int arg_types[NPY_MAXARGS];
@@ -976,7 +1111,7 @@
 
         /* Select an appropriate function for these argument types. */
         if (select_types(loop->ufunc, arg_types, &(loop->function),
-                         &(loop->funcdata), scalars, dtype) == -1)
+                         &(loop->funcdata), scalars, typetup) == -1)
 		return -1;
 
         /* FAIL with NotImplemented if the other object has
@@ -1376,7 +1511,7 @@
 {
 	PyUFuncLoopObject *loop;
 	int i;
-        PyArray_Descr *dtype=NULL;
+        PyObject *typetup=NULL;
         PyObject *extobj=NULL;
         char *name;
 
@@ -1416,8 +1551,7 @@
                                 extobj = value;
                         }
                         else if (strncmp(PyString_AS_STRING(key), "dtype", 5) == 0) {
-                                if (PyArray_DescrConverter2(value, &dtype) == PY_FAIL)
-                                        goto fail;
+                                typetup = value;
                         }
                         else {
                                 PyErr_Format(PyExc_TypeError, "'%s' is an invalid keyword",
@@ -1440,7 +1574,7 @@
         }
         
 	/* Setup the arrays */
-	if (construct_arrays(loop, args, mps, dtype) < 0) goto fail;
+	if (construct_arrays(loop, args, mps, typetup) < 0) goto fail;
 
 	PyUFunc_clearfperr();
 
@@ -3134,10 +3268,10 @@
 	/* Find the location of the matching signature */
 	for (i=0; i<func->ntypes; i++) {
 		for (j=0; j<func->nargs; j++) {
-			if (signature[j] == func->types[i*func->nargs+j])
+			if (signature[j] != func->types[i*func->nargs+j])
 				break;
 		}
-		if (j >= func->nargs) continue;
+		if (j < func->nargs) continue;
 		
 		if (oldfunc != NULL) {
 			*oldfunc = func->functions[i];
@@ -3147,7 +3281,6 @@
                 break;
 	}
 	return res;
-
 }
 
 /*UFUNC_API*/




More information about the Numpy-svn mailing list