[pypy-commit] pypy fix-result-types: precompute W_Ufunc1.allowed_types()

rlamy noreply at buildbot.pypy.org
Tue May 26 21:55:30 CEST 2015


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77598:3c16f8ae6898
Date: 2015-05-26 20:05 +0100
http://bitbucket.org/pypy/pypy/changeset/3c16f8ae6898/

Log:	precompute W_Ufunc1.allowed_types()

diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -1,5 +1,5 @@
 from pypy.module.micronumpy.test.test_base import BaseNumpyAppTest
-from pypy.module.micronumpy.ufuncs import W_UfuncGeneric, W_Ufunc1
+from pypy.module.micronumpy.ufuncs import W_UfuncGeneric, unary_ufunc
 from pypy.module.micronumpy.support import _parse_signature
 from pypy.module.micronumpy.descriptor import get_dtype_cache
 from pypy.module.micronumpy.base import W_NDimArray
@@ -58,16 +58,16 @@
         dt_bool = get_dtype_cache(space).w_booldtype
         dt_float16 = get_dtype_cache(space).w_float16dtype
         dt_int32 = get_dtype_cache(space).w_int32dtype
-        ufunc = W_Ufunc1(None, 'x', int_only=True)
+        ufunc = unary_ufunc(space, None, 'x', int_only=True)
         assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_bool, dt_bool)
-        assert ufunc.allowed_types(space)  # XXX: shouldn't contain too much stuff
+        assert ufunc.dtypes  # XXX: shouldn't contain too much stuff
 
-        ufunc = W_Ufunc1(None, 'x', promote_to_float=True)
+        ufunc = unary_ufunc(space, None, 'x', promote_to_float=True)
         assert ufunc._calc_dtype(space, dt_bool, out=None) == (dt_float16, dt_float16)
         assert ufunc._calc_dtype(space, dt_bool, casting='same_kind') == (dt_float16, dt_float16)
         raises(OperationError, ufunc._calc_dtype, space, dt_bool, casting='no')
 
-        ufunc = W_Ufunc1(None, 'x')
+        ufunc = unary_ufunc(space, None, 'x')
         assert ufunc._calc_dtype(space, dt_int32, out=None) == (dt_int32, dt_int32)
 
 class AppTestUfuncs(BaseNumpyAppTest):
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -428,7 +428,7 @@
         return casting
 
 class W_Ufunc1(W_Ufunc):
-    _immutable_fields_ = ["func", "bool_result"]
+    _immutable_fields_ = ["func", "bool_result", "dtypes"]
     nin = 1
     nout = 1
     nargs = 2
@@ -495,7 +495,7 @@
         if arg_dtype.is_object():
             return arg_dtype, arg_dtype
         in_casting = safe_casting_mode(casting)
-        for dt_in, dt_out in self.allowed_types(space):
+        for dt_in, dt_out in self.dtypes:
             if use_min_scalar:
                 if not can_cast_array(space, w_arg, dt_in, in_casting):
                     continue
@@ -512,30 +512,6 @@
             raise oefmt(space.w_TypeError,
                 "ufunc '%s' not supported for the input types", self.name)
 
-    def allowed_types(self, space):
-        dtypes = []
-        cache = get_dtype_cache(space)
-        if not self.promote_bools and not self.promote_to_float:
-            dtypes.append((cache.w_booldtype, cache.w_booldtype))
-        if not self.promote_to_float:
-            for dt in cache.integer_dtypes:
-                dtypes.append((dt, dt))
-        if not self.int_only:
-            for dt in cache.float_dtypes:
-                dtypes.append((dt, dt))
-            for dt in cache.complex_dtypes:
-                if self.complex_to_float:
-                    if dt.num == NPY.CFLOAT:
-                        dt_out = get_dtype_cache(space).w_float32dtype
-                    else:
-                        dt_out = get_dtype_cache(space).w_float64dtype
-                    dtypes.append((dt, dt_out))
-                else:
-                    dtypes.append((dt, dt))
-        if self.bool_result:
-            dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
-        return dtypes
-
 
 class W_Ufunc2(W_Ufunc):
     _immutable_fields_ = ["func", "bool_result", "done_func", "simple_binary"]
@@ -1332,11 +1308,40 @@
             bool_result=extra_kwargs.get("bool_result", False),
         )
         if nin == 1:
-            ufunc = W_Ufunc1(func, ufunc_name, **extra_kwargs)
+            ufunc = unary_ufunc(space, func, ufunc_name, **extra_kwargs)
         elif nin == 2:
             ufunc = W_Ufunc2(func, ufunc_name, **extra_kwargs)
         setattr(self, ufunc_name, ufunc)
 
+def unary_ufunc(space, func, ufunc_name, **kwargs):
+    ufunc = W_Ufunc1(func, ufunc_name, **kwargs)
+    ufunc.dtypes = _ufunc1_dtypes(ufunc, space)
+    return ufunc
+
+def _ufunc1_dtypes(ufunc, space):
+    dtypes = []
+    cache = get_dtype_cache(space)
+    if not ufunc.promote_bools and not ufunc.promote_to_float:
+        dtypes.append((cache.w_booldtype, cache.w_booldtype))
+    if not ufunc.promote_to_float:
+        for dt in cache.integer_dtypes:
+            dtypes.append((dt, dt))
+    if not ufunc.int_only:
+        for dt in cache.float_dtypes:
+            dtypes.append((dt, dt))
+        for dt in cache.complex_dtypes:
+            if ufunc.complex_to_float:
+                if dt.num == NPY.CFLOAT:
+                    dt_out = get_dtype_cache(space).w_float32dtype
+                else:
+                    dt_out = get_dtype_cache(space).w_float64dtype
+                dtypes.append((dt, dt_out))
+            else:
+                dtypes.append((dt, dt))
+    if ufunc.bool_result:
+        dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
+    return dtypes
+
 
 def get(space):
     return space.fromcache(UfuncState)


More information about the pypy-commit mailing list