[pypy-commit] pypy fix-result-types: don't use find_binop_result_dtype() in W_Ufunc2.call()

rlamy noreply at buildbot.pypy.org
Tue May 19 03:59:22 CEST 2015


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77387:432c15e49a7e
Date: 2015-05-17 22:23 +0100
http://bitbucket.org/pypy/pypy/changeset/432c15e49a7e/

Log:	don't use find_binop_result_dtype() in W_Ufunc2.call()

diff --git a/pypy/module/micronumpy/test/test_dtypes.py b/pypy/module/micronumpy/test/test_dtypes.py
--- a/pypy/module/micronumpy/test/test_dtypes.py
+++ b/pypy/module/micronumpy/test/test_dtypes.py
@@ -288,7 +288,7 @@
             types += ['g', 'G']
         a = array([True], '?')
         for t in types:
-            assert (a + array([0], t)).dtype is dtype(t)
+            assert (a + array([0], t)).dtype == dtype(t)
 
     def test_binop_types(self):
         from numpy import array, dtype
@@ -312,7 +312,7 @@
         for d1, d2, dout in tests:
             # make a failed test print helpful info
             d3 = (array([1], d1) + array([1], d2)).dtype
-            assert (d1, d2) == (d1, d2) and d3 is dtype(dout)
+            assert (d1, d2) == (d1, d2) and d3 == dtype(dout)
 
     def test_add(self):
         import numpy as np
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -613,7 +613,7 @@
                 w_rdtype = w_ldtype
             elif w_lhs.is_scalar() and not w_rhs.is_scalar():
                 w_ldtype = w_rdtype
-        calc_dtype, res_dtype, func = self.find_specialization(space, w_ldtype, w_rdtype, out, casting)
+        calc_dtype, dt_out, func = self.find_specialization(space, w_ldtype, w_rdtype, out, casting)
         if (isinstance(w_lhs, W_GenericBox) and
                 isinstance(w_rhs, W_GenericBox) and out is None):
             return self.call_scalar(space, w_lhs, w_rhs, calc_dtype)
@@ -627,7 +627,7 @@
         new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)
         w_highpriority, out_subtype = array_priority(space, w_lhs, w_rhs)
         if out is None:
-            w_res = W_NDimArray.from_shape(space, new_shape, res_dtype,
+            w_res = W_NDimArray.from_shape(space, new_shape, dt_out,
                                            w_instance=out_subtype)
         else:
             w_res = out
@@ -648,26 +648,62 @@
         return w_val
 
     def find_specialization(self, space, l_dtype, r_dtype, out, casting):
-        calc_dtype = find_binop_result_dtype(space,
-            l_dtype, r_dtype,
-            promote_to_float=self.promote_to_float,
-            promote_bools=self.promote_bools)
-        if (self.int_only and (not (l_dtype.is_int() or l_dtype.is_object()) or
-                               not (r_dtype.is_int() or r_dtype.is_object()) or
-                               not (calc_dtype.is_int() or calc_dtype.is_object())) or
-                not self.allow_bool and (l_dtype.is_bool() or
+        if (not self.allow_bool and (l_dtype.is_bool() or
                                          r_dtype.is_bool()) or
                 not self.allow_complex and (l_dtype.is_complex() or
                                             r_dtype.is_complex())):
             raise oefmt(space.w_TypeError,
                 "ufunc '%s' not supported for the input types", self.name)
-        if out is not None:
-            calc_dtype = out.get_dtype()
+        dt_in, dt_out = self._calc_dtype(space, l_dtype, r_dtype, out, casting)
+        return dt_in, dt_out, self.func
+
+    def _calc_dtype(self, space, l_dtype, r_dtype, out=None, casting='unsafe'):
+        use_min_scalar = False
+        if l_dtype.is_object() or r_dtype.is_object():
+            return l_dtype, l_dtype
+        in_casting = safe_casting_mode(casting)
+        for dt_in, dt_out in self.allowed_types(space):
+            if use_min_scalar:
+                if not can_cast_array(space, w_arg, dt_in, in_casting):
+                    continue
+            else:
+                if not (can_cast_type(space, l_dtype, dt_in, in_casting) and
+                        can_cast_type(space, r_dtype, dt_in, in_casting)):
+                    continue
+            if out is not None:
+                res_dtype = out.get_dtype()
+                if not can_cast_type(space, dt_out, res_dtype, casting):
+                    continue
+            return dt_in, dt_out
+
+        else:
+            raise oefmt(space.w_TypeError,
+                "No loop matching the specified signature was found "
+                "for ufunc %s", self.name)
+
+    def allowed_types(self, space):
+        dtypes = []
+        cache = get_dtype_cache(space)
+        if not self.promote_bools and not self.promote_to_float:
+            dtypes.append((cache.w_booldtype, cache.w_booldtype))
+        if not self.promote_to_float:
+            for dt in cache.integer_dtypes:
+                dtypes.append((dt, dt))
+        if not self.int_only:
+            for dt in cache.float_dtypes:
+                dtypes.append((dt, dt))
+            for dt in cache.complex_dtypes:
+                if self.complex_to_float:
+                    if dt.num == NPY.CFLOAT:
+                        dt_out = get_dtype_cache(space).w_float32dtype
+                    else:
+                        dt_out = get_dtype_cache(space).w_float64dtype
+                    dtypes.append((dt, dt_out))
+                else:
+                    dtypes.append((dt, dt))
         if self.bool_result:
-            res_dtype = get_dtype_cache(space).w_booldtype
-        else:
-            res_dtype = calc_dtype
-        return calc_dtype, res_dtype, self.func
+            dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
+        return dtypes
 
 
 


More information about the pypy-commit mailing list