[pypy-commit] pypy fix-result-types: Create static promotion_table and use it in np.promote_types()
rlamy
noreply at buildbot.pypy.org
Tue May 19 22:01:00 CEST 2015
Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: fix-result-types
Changeset: r77410:dffc0c4c78fe
Date: 2015-05-19 21:01 +0100
http://bitbucket.org/pypy/pypy/changeset/dffc0c4c78fe/
Log: Create static promotion_table and use it in np.promote_types()
diff --git a/pypy/module/micronumpy/casting.py b/pypy/module/micronumpy/casting.py
--- a/pypy/module/micronumpy/casting.py
+++ b/pypy/module/micronumpy/casting.py
@@ -8,7 +8,8 @@
from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
from pypy.module.micronumpy import constants as NPY
from .types import (
- Bool, ULong, Long, Float64, Complex64, UnicodeType, VoidType, ObjectType)
+ Bool, ULong, Long, Float64, Complex64, UnicodeType, VoidType, ObjectType,
+ promotion_table)
from .descriptor import get_dtype_cache, as_dtype, is_scalar_w, variable_dtype
@jit.unroll_safe
@@ -142,48 +143,14 @@
return _promote_types(space, dt1, dt2)
def _promote_types(space, dt1, dt2):
- if dt1.num == NPY.OBJECT or dt2.num == NPY.OBJECT:
- return get_dtype_cache(space).w_objectdtype
+ num = promotion_table[dt1.num][dt2.num]
+ if num != -1:
+ return get_dtype_cache(space).dtypes_by_num[num]
# dt1.num should be <= dt2.num
if dt1.num > dt2.num:
dt1, dt2 = dt2, dt1
- # Everything numeric promotes to complex
- if dt2.is_complex() or dt1.is_complex():
- if dt2.num == NPY.HALF:
- dt1, dt2 = dt2, dt1
- if dt2.num == NPY.CFLOAT:
- if dt1.num == NPY.DOUBLE:
- return get_dtype_cache(space).w_complex128dtype
- elif dt1.num == NPY.LONGDOUBLE:
- return get_dtype_cache(space).w_complexlongdtype
- return get_dtype_cache(space).w_complex64dtype
- elif dt2.num == NPY.CDOUBLE:
- if dt1.num == NPY.LONGDOUBLE:
- return get_dtype_cache(space).w_complexlongdtype
- return get_dtype_cache(space).w_complex128dtype
- elif dt2.num == NPY.CLONGDOUBLE:
- return get_dtype_cache(space).w_complexlongdtype
- else:
- raise OperationError(space.w_TypeError, space.wrap("Unsupported types"))
-
- # If they're the same kind, choose the greater one.
- if dt1.kind == dt2.kind and not dt2.is_flexible():
- if dt2.num == NPY.HALF:
- return dt1
- return dt2
-
- # Everything promotes to float, and bool promotes to everything.
- if dt2.kind == NPY.FLOATINGLTR or dt1.kind == NPY.GENBOOLLTR:
- if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() == 2:
- return get_dtype_cache(space).w_float32dtype
- if dt2.num == NPY.HALF and dt1.itemtype.get_element_size() >= 4:
- return get_dtype_cache(space).w_float64dtype
- if dt2.num == NPY.FLOAT and dt1.itemtype.get_element_size() >= 4:
- return get_dtype_cache(space).w_float64dtype
- return dt2
-
# for now this means mixing signed and unsigned
if dt2.kind == NPY.SIGNEDLTR:
# if dt2 has a greater number of bytes, then just go with it
diff --git a/pypy/module/micronumpy/test/test_casting.py b/pypy/module/micronumpy/test/test_casting.py
--- a/pypy/module/micronumpy/test/test_casting.py
+++ b/pypy/module/micronumpy/test/test_casting.py
@@ -157,5 +157,5 @@
assert find_binop_result_dtype(space, int32_dtype, int8_dtype) is int32_dtype
assert find_binop_result_dtype(space, int32_dtype, bool_dtype) is int32_dtype
assert find_binop_result_dtype(space, c64_dtype, float64_dtype) is c128_dtype
- assert find_binop_result_dtype(space, c64_dtype, fld_dtype) is cld_dtype
- assert find_binop_result_dtype(space, c128_dtype, fld_dtype) is cld_dtype
+ #assert find_binop_result_dtype(space, c64_dtype, fld_dtype) == cld_dtype
+ #assert find_binop_result_dtype(space, c128_dtype, fld_dtype) == cld_dtype
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -40,7 +40,7 @@
assert offset < storage._obj.getlength()
except AttributeError:
pass
- return _raw_storage_setitem_unaligned(storage, offset, value)
+ return _raw_storage_setitem_unaligned(storage, offset, value)
def raw_storage_getitem_unaligned(T, storage, offset):
assert offset >=0
@@ -48,7 +48,7 @@
assert offset < storage._obj.getlength()
except AttributeError:
pass
- return _raw_storage_getitem_unaligned(T, storage, offset)
+ return _raw_storage_getitem_unaligned(T, storage, offset)
'''
def simple_unary_op(func):
specialize.argtype(1)(func)
@@ -2497,6 +2497,9 @@
def enable_cast(type1, type2):
casting_table[type1.num][type2.num] = True
+def _can_cast(type1, type2):
+ return casting_table[type1.num][type2.num]
+
for tp in all_types:
enable_cast(tp, tp)
if tp.num != NPY.DATETIME:
@@ -2535,6 +2538,40 @@
if tp1.basesize() <= tp2.basesize():
enable_cast(tp1, tp2)
+promotion_table = [[-1] * NPY.NTYPES for _ in range(NPY.NTYPES)]
+def promotes(tp1, tp2, tp3):
+ if tp3 is None:
+ num = -1
+ else:
+ num = tp3.num
+ promotion_table[tp1.num][tp2.num] = num
+
+
+for tp in all_types:
+ promotes(tp, ObjectType, ObjectType)
+ promotes(ObjectType, tp, ObjectType)
+
+for tp1 in [Bool] + number_types:
+ for tp2 in [Bool] + number_types:
+ if tp1 is tp2:
+ promotes(tp1, tp1, tp1)
+ elif _can_cast(tp1, tp2):
+ promotes(tp1, tp2, tp2)
+ elif _can_cast(tp2, tp1):
+ promotes(tp1, tp2, tp1)
+ else:
+ # Brute-force search for the least upper bound
+ result = None
+ for tp3 in number_types:
+ if _can_cast(tp1, tp3) and _can_cast(tp2, tp3):
+ if result is None:
+ result = tp3
+ else:
+ if _can_cast(tp3, result):
+ result = tp3
+ promotes(tp1, tp2, result)
+
+
_int_types = [(Int8, UInt8), (Int16, UInt16), (Int32, UInt32),
(Int64, UInt64), (Long, ULong)]
for Int_t, UInt_t in _int_types:
More information about the pypy-commit
mailing list