[pypy-commit] pypy fix-result-types: precompute W_Ufunc1.allowed_types()
rlamy
noreply at buildbot.pypy.org
Tue May 26 21:55:30 CEST 2015
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77598:3c16f8ae6898
Date: 2015-05-26 20:05 +0100
http://bitbucket.org/pypy/pypy/changeset/3c16f8ae6898/
Log: precompute W_Ufunc1.allowed_types()
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -1,5 +1,5 @@
from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest
-from pypy.module.micronumpy.ufuncs import W_UfuncGeneric, W_Ufunc1
+from pypy.module.micronumpy.ufuncs import W_UfuncGeneric, unary_ufunc
from pypy.module.micronumpy.support import _parse_signature
from pypy.module.micronumpy.descriptor import get_dtype_cache
from pypy.module.micronumpy.base import W_NDimArray
@@ -58,16 +58,16 @@
dt_bool = get_dtype_cache(space).w_booldtype
dt_float16 = get_dtype_cache(space).w_float16dtype
dt_int32 = get_dtype_cache(space).w_int32dtype
- ufunc = W_Ufunc1(None, 'x', int_only=True)
+ ufunc = unary_ufunc(space, None, 'x', int_only=True)
assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_bool, dt_bool)
- assert ufunc.allowed_types(space) # XXX: shouldn't contain too much stuff
+ assert ufunc.dtypes # XXX: shouldn't contain too much stuff
- ufunc = W_Ufunc1(None, 'x', promote_to_float=True)
+ ufunc = unary_ufunc(space, None, 'x', promote_to_float=True)
assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_float16, dt_float16)
assert ufunc._calc_dtype(space, dt_bool, casting='same_kind') == (dt_float16, dt_float16)
raises(OperationError, ufunc._calc_dtype, space, dt_bool, casting='no')
- ufunc = W_Ufunc1(None, 'x')
+ ufunc = unary_ufunc(space, None, 'x')
assert ufunc._calc_dtype(space, dt_int32, out=None) == (dt_int32, dt_int32)
class AppTestUfuncs(BaseNumpyAppTest):
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -428,7 +428,7 @@
return casting
class W_Ufunc1(W_Ufunc):
- _immutable_fields_ = ["func", "bool_result"]
+ _immutable_fields_ = ["func", "bool_result", "dtypes"]
nin = 1
nout = 1
nargs = 2
@@ -495,7 +495,7 @@
if arg_dtype.is_object():
return arg_dtype, arg_dtype
in_casting = safe_casting_mode(casting)
- for dt_in, dt_out in self.allowed_types(space):
+ for dt_in, dt_out in self.dtypes:
if use_min_scalar:
if not can_cast_array(space, w_arg, dt_in, in_casting):
continue
@@ -512,30 +512,6 @@
raise oefmt(space.w_TypeError,
"ufunc '%s' not supported for the input types", self.name)
- def allowed_types(self, space):
- dtypes = []
- cache = get_dtype_cache(space)
- if not self.promote_bools and not self.promote_to_float:
- dtypes.append((cache.w_booldtype, cache.w_booldtype))
- if not self.promote_to_float:
- for dt in cache.integer_dtypes:
- dtypes.append((dt, dt))
- if not self.int_only:
- for dt in cache.float_dtypes:
- dtypes.append((dt, dt))
- for dt in cache.complex_dtypes:
- if self.complex_to_float:
- if dt.num == NPY.CFLOAT:
- dt_out = get_dtype_cache(space).w_float32dtype
- else:
- dt_out = get_dtype_cache(space).w_float64dtype
- dtypes.append((dt, dt_out))
- else:
- dtypes.append((dt, dt))
- if self.bool_result:
- dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
- return dtypes
-
class W_Ufunc2(W_Ufunc):
_immutable_fields_ = ["func", "bool_result", "done_func", "simple_binary"]
@@ -1332,11 +1308,40 @@
bool_result=extra_kwargs.get("bool_result", False),
)
if nin == 1:
- ufunc = W_Ufunc1(func, ufunc_name, **extra_kwargs)
+ ufunc = unary_ufunc(space, func, ufunc_name, **extra_kwargs)
elif nin == 2:
ufunc = W_Ufunc2(func, ufunc_name, **extra_kwargs)
setattr(self, ufunc_name, ufunc)
+def unary_ufunc(space, func, ufunc_name, **kwargs):
+ ufunc = W_Ufunc1(func, ufunc_name, **kwargs)
+ ufunc.dtypes = _ufunc1_dtypes(ufunc, space)
+ return ufunc
+
+def _ufunc1_dtypes(ufunc, space):
+ dtypes = []
+ cache = get_dtype_cache(space)
+ if not ufunc.promote_bools and not ufunc.promote_to_float:
+ dtypes.append((cache.w_booldtype, cache.w_booldtype))
+ if not ufunc.promote_to_float:
+ for dt in cache.integer_dtypes:
+ dtypes.append((dt, dt))
+ if not ufunc.int_only:
+ for dt in cache.float_dtypes:
+ dtypes.append((dt, dt))
+ for dt in cache.complex_dtypes:
+ if ufunc.complex_to_float:
+ if dt.num == NPY.CFLOAT:
+ dt_out = get_dtype_cache(space).w_float32dtype
+ else:
+ dt_out = get_dtype_cache(space).w_float64dtype
+ dtypes.append((dt, dt_out))
+ else:
+ dtypes.append((dt, dt))
+ if ufunc.bool_result:
+ dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
+ return dtypes
+
def get(space):
return space.fromcache(UfuncState)
More information about the pypy-commit
mailing list