[pypy-commit] pypy fix-result-types: don't use find_binop_result_dtype() in W_Ufunc2.call()
rlamy
noreply at buildbot.pypy.org
Tue May 19 03:59:22 CEST 2015
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77387:432c15e49a7e
Date: 2015-05-17 22:23 +0100
http://bitbucket.org/pypy/pypy/changeset/432c15e49a7e/
Log: don't use find_binop_result_dtype() in W_Ufunc2.call()
diff --git a/pypy/module/micronumpy/test/test_dtypes.py b/pypy/module/micronumpy/test/test_dtypes.py
--- a/pypy/module/micronumpy/test/test_dtypes.py
+++ b/pypy/module/micronumpy/test/test_dtypes.py
@@ -288,7 +288,7 @@
types += ['g', 'G']
a = array([True], '?')
for t in types:
- assert (a + array([0], t)).dtype is dtype(t)
+ assert (a + array([0], t)).dtype == dtype(t)
def test_binop_types(self):
from numpy import array, dtype
@@ -312,7 +312,7 @@
for d1, d2, dout in tests:
# make a failed test print helpful info
d3 = (array([1], d1) + array([1], d2)).dtype
- assert (d1, d2) == (d1, d2) and d3 is dtype(dout)
+ assert (d1, d2) == (d1, d2) and d3 == dtype(dout)
def test_add(self):
import numpy as np
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -613,7 +613,7 @@
w_rdtype = w_ldtype
elif w_lhs.is_scalar() and not w_rhs.is_scalar():
w_ldtype = w_rdtype
- calc_dtype, res_dtype, func = self.find_specialization(space, w_ldtype, w_rdtype, out, casting)
+ calc_dtype, dt_out, func = self.find_specialization(space, w_ldtype, w_rdtype, out, casting)
if (isinstance(w_lhs, W_GenericBox) and
isinstance(w_rhs, W_GenericBox) and out is None):
return self.call_scalar(space, w_lhs, w_rhs, calc_dtype)
@@ -627,7 +627,7 @@
new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)
w_highpriority, out_subtype = array_priority(space, w_lhs, w_rhs)
if out is None:
- w_res = W_NDimArray.from_shape(space, new_shape, res_dtype,
+ w_res = W_NDimArray.from_shape(space, new_shape, dt_out,
w_instance=out_subtype)
else:
w_res = out
@@ -648,26 +648,62 @@
return w_val
def find_specialization(self, space, l_dtype, r_dtype, out, casting):
- calc_dtype = find_binop_result_dtype(space,
- l_dtype, r_dtype,
- promote_to_float=self.promote_to_float,
- promote_bools=self.promote_bools)
- if (self.int_only and (not (l_dtype.is_int() or l_dtype.is_object()) or
- not (r_dtype.is_int() or r_dtype.is_object()) or
- not (calc_dtype.is_int() or calc_dtype.is_object())) or
- not self.allow_bool and (l_dtype.is_bool() or
+ if (not self.allow_bool and (l_dtype.is_bool() or
r_dtype.is_bool()) or
not self.allow_complex and (l_dtype.is_complex() or
r_dtype.is_complex())):
raise oefmt(space.w_TypeError,
"ufunc '%s' not supported for the input types", self.name)
- if out is not None:
- calc_dtype = out.get_dtype()
+ dt_in, dt_out = self._calc_dtype(space, l_dtype, r_dtype, out, casting)
+ return dt_in, dt_out, self.func
+
+ def _calc_dtype(self, space, l_dtype, r_dtype, out=None, casting='unsafe'):
+ use_min_scalar = False
+ if l_dtype.is_object() or r_dtype.is_object():
+ return l_dtype, l_dtype
+ in_casting = safe_casting_mode(casting)
+ for dt_in, dt_out in self.allowed_types(space):
+ if use_min_scalar:
+ if not can_cast_array(space, w_arg, dt_in, in_casting):
+ continue
+ else:
+ if not (can_cast_type(space, l_dtype, dt_in, in_casting) and
+ can_cast_type(space, r_dtype, dt_in, in_casting)):
+ continue
+ if out is not None:
+ res_dtype = out.get_dtype()
+ if not can_cast_type(space, dt_out, res_dtype, casting):
+ continue
+ return dt_in, dt_out
+
+ else:
+ raise oefmt(space.w_TypeError,
+ "No loop matching the specified signature was found "
+ "for ufunc %s", self.name)
+
+ def allowed_types(self, space):
+ dtypes = []
+ cache = get_dtype_cache(space)
+ if not self.promote_bools and not self.promote_to_float:
+ dtypes.append((cache.w_booldtype, cache.w_booldtype))
+ if not self.promote_to_float:
+ for dt in cache.integer_dtypes:
+ dtypes.append((dt, dt))
+ if not self.int_only:
+ for dt in cache.float_dtypes:
+ dtypes.append((dt, dt))
+ for dt in cache.complex_dtypes:
+ if self.complex_to_float:
+ if dt.num == NPY.CFLOAT:
+ dt_out = get_dtype_cache(space).w_float32dtype
+ else:
+ dt_out = get_dtype_cache(space).w_float64dtype
+ dtypes.append((dt, dt_out))
+ else:
+ dtypes.append((dt, dt))
if self.bool_result:
- res_dtype = get_dtype_cache(space).w_booldtype
- else:
- res_dtype = calc_dtype
- return calc_dtype, res_dtype, self.func
+ dtypes = [(dt_in, cache.w_booldtype) for dt_in, _ in dtypes]
+ return dtypes
More information about the pypy-commit
mailing list