[Numpy-svn] r3355 - trunk/numpy/core/src
numpy-svn at scipy.org
numpy-svn at scipy.org
Wed Oct 18 02:38:31 EDT 2006
Author: oliphant
Date: 2006-10-18 01:38:26 -0500 (Wed, 18 Oct 2006)
New Revision: 3355
Modified:
trunk/numpy/core/src/ufuncobject.c
Log:
Add ability to select specific 1-d loop to use when calling a ufunc, either by the output type or an exact signature. An error occurs if a 1-d loop is not found to match the specification.
Modified: trunk/numpy/core/src/ufuncobject.c
===================================================================
--- trunk/numpy/core/src/ufuncobject.c 2006-10-18 04:44:28 UTC (rev 3354)
+++ trunk/numpy/core/src/ufuncobject.c 2006-10-18 06:38:26 UTC (rev 3355)
@@ -657,6 +657,142 @@
return -1;
}
+/* if only one type is specified then it is the "first" output data-type
+ and the first signature matching this output data-type is returned.
+
+ if a tuple of types is specified then an exact match to the signature
+ is searched and it much match exactly or an error occurs
+*/
+static int
+extract_specified_loop(PyUFuncObject *self, int *arg_types,
+ PyUFuncGenericFunction *function, void **data,
+ PyObject *type_tup, int userdef)
+{
+ Py_ssize_t n=1;
+ int *rtypenums;
+ static char msg[] = "loop written to specified type(s) not found";
+ PyArray_Descr *dtype;
+ int nargs;
+ int i, j;
+
+ nargs = self->nargs;
+
+ if (PyTuple_Check(type_tup)) {
+ n = PyTuple_GET_SIZE(type_tup);
+ if (n != 1 && n != nargs) {
+ PyErr_Format(PyExc_ValueError,
+ "a type-tuple must be specified " \
+ "of length 1 or %d for %s", nargs,
+ self->name ? self->name : "(unknown)");
+ return -1;
+ }
+ }
+ else {
+ n = 1;
+ }
+ rtypenums = (int *)_pya_malloc(n*sizeof(int));
+ if (rtypenums==NULL) {
+ PyErr_NoMemory();
+ return -1;
+ }
+
+ if (PyTuple_Check(type_tup)) {
+ for (i=0; i<n; i++) {
+ if (PyArray_DescrConverter(PyTuple_GET_ITEM \
+ (type_tup, i),
+ &dtype) == NPY_FAIL)
+ goto fail;
+ rtypenums[i] = dtype->type_num;
+ Py_DECREF(dtype);
+ }
+ }
+ else {
+ if (PyArray_DescrConverter(type_tup, &dtype) == NPY_FAIL) {
+ goto fail;
+ }
+ rtypenums[0] = dtype->type_num;
+ Py_DECREF(dtype);
+ }
+
+ if (userdef > 0) { /* search in the user-defined functions */
+ PyObject *key, *obj;
+ PyUFunc_Loop1d *funcdata;
+ obj = NULL;
+ key = PyInt_FromLong((long) userdef);
+ if (key == NULL) goto fail;
+ obj = PyDict_GetItem(self->userloops, key);
+ Py_DECREF(key);
+ if (obj == NULL) {
+ PyErr_SetString(PyExc_TypeError,
+ "user-defined type used in ufunc" \
+ " with no registered loops");
+ goto fail;
+ }
+ /* extract the correct function
+ data and argtypes
+ */
+ funcdata = (PyUFunc_Loop1d *)PyCObject_AsVoidPtr(obj);
+ while (funcdata != NULL) {
+ if (n != 1) {
+ for (i=0; i<nargs; i++) {
+ if (rtypenums[i] != funcdata->arg_types[i])
+ break;
+ }
+ }
+ else if (rtypenums[0] == funcdata->arg_types[self->nin]) {
+ i = nargs;
+ }
+ if (i == nargs) {
+ *function = funcdata->func;
+ *data = funcdata->data;
+ for (i=0; i<nargs; i++) {
+ arg_types[i] = funcdata->arg_types[i];
+ }
+ Py_DECREF(obj);
+ goto finish;
+ }
+ funcdata = funcdata->next;
+ }
+ Py_DECREF(obj);
+ PyErr_SetString(PyExc_TypeError, msg);
+ goto fail;
+ }
+
+ /* look for match in self->functions */
+
+ for (j=0; j<self->ntypes; j++) {
+ if (n != 1) {
+ for (i=0; i<nargs; i++) {
+ if (rtypenums[i] != self->types[j*nargs + i])
+ break;
+ }
+ }
+ else if (rtypenums[0] == self->types[j*nargs+self->nin]) {
+ i = nargs;
+ }
+ if (i == nargs) {
+ *function = self->functions[j];
+ *data = self->data[j];
+ for (i=0; i<nargs; i++) {
+ arg_types[i] = self->types[j*nargs+i];
+ }
+ goto finish;
+ }
+ }
+ PyErr_SetString(PyExc_TypeError, msg);
+
+ fail:
+ _pya_free(rtypenums);
+ return -1;
+
+ finish:
+ _pya_free(rtypenums);
+ return 0;
+
+
+}
+
+
/* Called to determine coercion
Can change arg_types.
*/
@@ -665,18 +801,12 @@
select_types(PyUFuncObject *self, int *arg_types,
PyUFuncGenericFunction *function, void **data,
PyArray_SCALARKIND *scalars,
- PyArray_Descr *dtype)
+ PyObject *typetup)
{
int i, j;
char start_type;
int userdef=-1;
- /*
- if (dtype != NULL)
- return extract_specified_loop(self, arg_types, function, data,
- scalars, dtype)
- */
-
if (self->userloops) {
for (i=0; i<self->nin; i++) {
if (PyTypeNum_ISUSERDEF(arg_types[i])) {
@@ -685,7 +815,11 @@
}
}
}
-
+
+ if (typetup != NULL)
+ return extract_specified_loop(self, arg_types, function, data,
+ typetup, userdef);
+
if (userdef > 0) {
PyObject *key, *obj;
int ret;
@@ -753,7 +887,8 @@
static int
-_extract_pyvals(PyObject *ref, char *name, int *bufsize, int *errmask, PyObject **errobj)
+_extract_pyvals(PyObject *ref, char *name, int *bufsize,
+ int *errmask, PyObject **errobj)
{
PyObject *retval;
@@ -908,7 +1043,7 @@
static int
construct_arrays(PyUFuncLoopObject *loop, PyObject *args, PyArrayObject **mps,
- PyArray_Descr *dtype)
+ PyObject *typetup)
{
int nargs, i, maxsize;
int arg_types[NPY_MAXARGS];
@@ -976,7 +1111,7 @@
/* Select an appropriate function for these argument types. */
if (select_types(loop->ufunc, arg_types, &(loop->function),
- &(loop->funcdata), scalars, dtype) == -1)
+ &(loop->funcdata), scalars, typetup) == -1)
return -1;
/* FAIL with NotImplemented if the other object has
@@ -1376,7 +1511,7 @@
{
PyUFuncLoopObject *loop;
int i;
- PyArray_Descr *dtype=NULL;
+ PyObject *typetup=NULL;
PyObject *extobj=NULL;
char *name;
@@ -1416,8 +1551,7 @@
extobj = value;
}
else if (strncmp(PyString_AS_STRING(key), "dtype", 5) == 0) {
- if (PyArray_DescrConverter2(value, &dtype) == PY_FAIL)
- goto fail;
+ typetup = value;
}
else {
PyErr_Format(PyExc_TypeError, "'%s' is an invalid keyword",
@@ -1440,7 +1574,7 @@
}
/* Setup the arrays */
- if (construct_arrays(loop, args, mps, dtype) < 0) goto fail;
+ if (construct_arrays(loop, args, mps, typetup) < 0) goto fail;
PyUFunc_clearfperr();
@@ -3134,10 +3268,10 @@
/* Find the location of the matching signature */
for (i=0; i<func->ntypes; i++) {
for (j=0; j<func->nargs; j++) {
- if (signature[j] == func->types[i*func->nargs+j])
+ if (signature[j] != func->types[i*func->nargs+j])
break;
}
- if (j >= func->nargs) continue;
+ if (j < func->nargs) continue;
if (oldfunc != NULL) {
*oldfunc = func->functions[i];
@@ -3147,7 +3281,6 @@
break;
}
return res;
-
}
/*UFUNC_API*/
More information about the Numpy-svn
mailing list