[pypy-commit] pypy default: (snus, alex) Added the comparison functions to micronumpy. This is mostly the work from the numpy-comparisons branch, refactored by me.
alex_gaynor
noreply at buildbot.pypy.org
Mon Sep 5 19:03:13 CEST 2011
Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch:
Changeset: r47087:618b0bba96a2
Date: 2011-09-05 10:02 -0700
http://bitbucket.org/pypy/pypy/changeset/618b0bba96a2/
Log: (snus, alex) Added the comparison functions to micronumpy. This is
mostly the work from the numpy-comparisons branch, refactored by me.
diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -26,13 +26,19 @@
("copysign", "copysign"),
("cos", "cos"),
("divide", "divide"),
+ ("equal", "equal"),
("exp", "exp"),
("fabs", "fabs"),
("floor", "floor"),
+ ("greater", "greater"),
+ ("greater_equal", "greater_equal"),
+ ("less", "less"),
+ ("less_equal", "less_equal"),
("maximum", "maximum"),
("minimum", "minimum"),
("multiply", "multiply"),
("negative", "negative"),
+ ("not_equal", "not_equal"),
("reciprocal", "reciprocal"),
("sign", "sign"),
("sin", "sin"),
diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -129,6 +129,16 @@
))
return impl
+def raw_binop(func):
+ # Returns the result unwrapped.
+ @functools.wraps(func)
+ def impl(self, v1, v2):
+ return func(self,
+ self.for_computation(self.unbox(v1)),
+ self.for_computation(self.unbox(v2))
+ )
+ return impl
+
def unaryop(func):
@functools.wraps(func)
def impl(self, v):
@@ -170,8 +180,24 @@
def bool(self, v):
return bool(self.for_computation(self.unbox(v)))
+ @raw_binop
+ def eq(self, v1, v2):
+ return v1 == v2
+ @raw_binop
def ne(self, v1, v2):
- return self.for_computation(self.unbox(v1)) != self.for_computation(self.unbox(v2))
+ return v1 != v2
+ @raw_binop
+ def lt(self, v1, v2):
+ return v1 < v2
+ @raw_binop
+ def le(self, v1, v2):
+ return v1 <= v2
+ @raw_binop
+ def gt(self, v1, v2):
+ return v1 > v2
+ @raw_binop
+ def ge(self, v1, v2):
+ return v1 >= v2
class FloatArithmeticDtype(ArithmaticTypeMixin):
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -74,6 +74,13 @@
descr_pow = _binop_impl("power")
descr_mod = _binop_impl("mod")
+ descr_eq = _binop_impl("equal")
+ descr_ne = _binop_impl("not_equal")
+ descr_lt = _binop_impl("less")
+ descr_le = _binop_impl("less_equal")
+ descr_gt = _binop_impl("greater")
+ descr_ge = _binop_impl("greater_equal")
+
def _binop_right_impl(ufunc_name):
def impl(self, space, w_other):
w_other = scalar_w(space,
@@ -404,10 +411,11 @@
"""
Intermediate class for performing binary operations.
"""
- def __init__(self, signature, res_dtype, left, right):
+ def __init__(self, signature, calc_dtype, res_dtype, left, right):
VirtualArray.__init__(self, signature, res_dtype)
self.left = left
self.right = right
+ self.calc_dtype = calc_dtype
def _del_sources(self):
self.left = None
@@ -421,14 +429,14 @@
return self.right.find_size()
def _eval(self, i):
- lhs = self.left.eval(i).convert_to(self.res_dtype)
- rhs = self.right.eval(i).convert_to(self.res_dtype)
+ lhs = self.left.eval(i).convert_to(self.calc_dtype)
+ rhs = self.right.eval(i).convert_to(self.calc_dtype)
sig = jit.promote(self.signature)
assert isinstance(sig, signature.Signature)
call_sig = sig.components[0]
assert isinstance(call_sig, signature.Call2)
- return call_sig.func(self.res_dtype, lhs, rhs)
+ return call_sig.func(self.calc_dtype, lhs, rhs)
class ViewArray(BaseArray):
"""
@@ -573,18 +581,28 @@
__pos__ = interp2app(BaseArray.descr_pos),
__neg__ = interp2app(BaseArray.descr_neg),
__abs__ = interp2app(BaseArray.descr_abs),
+
__add__ = interp2app(BaseArray.descr_add),
__sub__ = interp2app(BaseArray.descr_sub),
__mul__ = interp2app(BaseArray.descr_mul),
__div__ = interp2app(BaseArray.descr_div),
__pow__ = interp2app(BaseArray.descr_pow),
__mod__ = interp2app(BaseArray.descr_mod),
+
__radd__ = interp2app(BaseArray.descr_radd),
__rsub__ = interp2app(BaseArray.descr_rsub),
__rmul__ = interp2app(BaseArray.descr_rmul),
__rdiv__ = interp2app(BaseArray.descr_rdiv),
__rpow__ = interp2app(BaseArray.descr_rpow),
__rmod__ = interp2app(BaseArray.descr_rmod),
+
+ __eq__ = interp2app(BaseArray.descr_eq),
+ __ne__ = interp2app(BaseArray.descr_ne),
+ __lt__ = interp2app(BaseArray.descr_lt),
+ __le__ = interp2app(BaseArray.descr_le),
+ __gt__ = interp2app(BaseArray.descr_gt),
+ __ge__ = interp2app(BaseArray.descr_ge),
+
__repr__ = interp2app(BaseArray.descr_repr),
__str__ = interp2app(BaseArray.descr_str),
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -113,10 +113,11 @@
argcount = 2
def __init__(self, func, name, promote_to_float=False, promote_bools=False,
- identity=None):
+ identity=None, comparison_func=False):
W_Ufunc.__init__(self, name, promote_to_float, promote_bools, identity)
self.func = func
+ self.comparison_func = comparison_func
self.signature = signature.Call2(func)
self.reduce_signature = signature.BaseSignature()
@@ -127,18 +128,25 @@
[w_lhs, w_rhs] = args_w
w_lhs = convert_to_array(space, w_lhs)
w_rhs = convert_to_array(space, w_rhs)
- res_dtype = find_binop_result_dtype(space,
+ calc_dtype = find_binop_result_dtype(space,
w_lhs.find_dtype(), w_rhs.find_dtype(),
promote_to_float=self.promote_to_float,
promote_bools=self.promote_bools,
)
+ if self.comparison_func:
+ res_dtype = space.fromcache(interp_dtype.W_BoolDtype)
+ else:
+ res_dtype = calc_dtype
if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar):
- return self.func(res_dtype, w_lhs.value, w_rhs.value).wrap(space)
+ return self.func(calc_dtype,
+ w_lhs.value.convert_to(calc_dtype),
+ w_rhs.value.convert_to(calc_dtype)
+ ).wrap(space)
new_sig = signature.Signature.find_sig([
self.signature, w_lhs.signature, w_rhs.signature
])
- w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs)
+ w_res = Call2(new_sig, calc_dtype, res_dtype, w_lhs, w_rhs)
w_lhs.add_invalidates(w_res)
w_rhs.add_invalidates(w_res)
return w_res
@@ -209,13 +217,16 @@
return space.fromcache(interp_dtype.W_Float64Dtype)
-def ufunc_dtype_caller(ufunc_name, op_name, argcount):
+def ufunc_dtype_caller(space, ufunc_name, op_name, argcount, comparison_func):
if argcount == 1:
def impl(res_dtype, value):
return getattr(res_dtype, op_name)(value)
elif argcount == 2:
def impl(res_dtype, lvalue, rvalue):
- return getattr(res_dtype, op_name)(lvalue, rvalue)
+ res = getattr(res_dtype, op_name)(lvalue, rvalue)
+ if comparison_func:
+ res = space.fromcache(interp_dtype.W_BoolDtype).box(res)
+ return res
return func_with_new_name(impl, ufunc_name)
class UfuncState(object):
@@ -229,6 +240,13 @@
("mod", "mod", 2, {"promote_bools": True}),
("power", "pow", 2, {"promote_bools": True}),
+ ("equal", "eq", 2, {"comparison_func": True}),
+ ("not_equal", "ne", 2, {"comparison_func": True}),
+ ("less", "lt", 2, {"comparison_func": True}),
+ ("less_equal", "le", 2, {"comparison_func": True}),
+ ("greater", "gt", 2, {"comparison_func": True}),
+ ("greater_equal", "ge", 2, {"comparison_func": True}),
+
("maximum", "max", 2),
("minimum", "min", 2),
@@ -262,7 +280,9 @@
identity = space.fromcache(interp_dtype.W_Int64Dtype).adapt_val(identity)
extra_kwargs["identity"] = identity
- func = ufunc_dtype_caller(ufunc_name, op_name, argcount)
+ func = ufunc_dtype_caller(space, ufunc_name, op_name, argcount,
+ comparison_func=extra_kwargs.get("comparison_func", False)
+ )
if argcount == 1:
ufunc = W_Ufunc1(func, ufunc_name, **extra_kwargs)
elif argcount == 2:
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -557,6 +557,26 @@
assert array([1.2, 5]).dtype is dtype(float)
assert array([]).dtype is dtype(float)
+ def test_comparison(self):
+ import operator
+ from numpy import array, dtype
+
+ a = array(range(5))
+ b = array(range(5), float)
+ for func in [
+ operator.eq, operator.ne, operator.lt, operator.le, operator.gt,
+ operator.ge
+ ]:
+ c = func(a, 3)
+ assert c.dtype is dtype(bool)
+ for i in xrange(5):
+ assert c[i] == func(a[i], 3)
+
+ c = func(b, 3)
+ assert c.dtype is dtype(bool)
+ for i in xrange(5):
+ assert c[i] == func(b[i], 3)
+
class AppTestSupport(object):
def setup_class(cls):
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -310,4 +310,30 @@
assert add.reduce([1, 2, 3]) == 6
assert maximum.reduce([1]) == 1
assert maximum.reduce([1, 2, 3]) == 3
- raises(ValueError, maximum.reduce, [])
\ No newline at end of file
+ raises(ValueError, maximum.reduce, [])
+
+ def test_comparisons(self):
+ import operator
+ from numpy import equal, not_equal, less, less_equal, greater, greater_equal
+
+ for ufunc, func in [
+ (equal, operator.eq),
+ (not_equal, operator.ne),
+ (less, operator.lt),
+ (less_equal, operator.le),
+ (greater, operator.gt),
+ (greater_equal, operator.ge),
+ ]:
+ for a, b in [
+ (3, 3),
+ (3, 4),
+ (4, 3),
+ (3.0, 3.0),
+ (3.0, 3.5),
+ (3.5, 3.0),
+ (3.0, 3),
+ (3, 3.0),
+ (3.5, 3),
+ (3, 3.5),
+ ]:
+ assert ufunc(a, b) is func(a, b)
\ No newline at end of file
More information about the pypy-commit
mailing list