[pypy-commit] pypy default: (snus, alex) Added the comparison functions to micronumpy. This is mostly the work from the numpy-comparisons branch, refactored by me.

alex_gaynor noreply at buildbot.pypy.org
Mon Sep 5 19:03:13 CEST 2011


Author: Alex Gaynor <alex.gaynor at gmail.com>
Branch: 
Changeset: r47087:618b0bba96a2
Date: 2011-09-05 10:02 -0700
http://bitbucket.org/pypy/pypy/changeset/618b0bba96a2/

Log:	(snus, alex) Added the comparison functions to micronumpy. This is
	mostly the work from the numpy-comparisons branch, refactored by me.

diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -26,13 +26,19 @@
         ("copysign", "copysign"),
         ("cos", "cos"),
         ("divide", "divide"),
+        ("equal", "equal"),
         ("exp", "exp"),
         ("fabs", "fabs"),
         ("floor", "floor"),
+        ("greater", "greater"),
+        ("greater_equal", "greater_equal"),
+        ("less", "less"),
+        ("less_equal", "less_equal"),
         ("maximum", "maximum"),
         ("minimum", "minimum"),
         ("multiply", "multiply"),
         ("negative", "negative"),
+        ("not_equal", "not_equal"),
         ("reciprocal", "reciprocal"),
         ("sign", "sign"),
         ("sin", "sin"),
diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py
--- a/pypy/module/micronumpy/interp_dtype.py
+++ b/pypy/module/micronumpy/interp_dtype.py
@@ -129,6 +129,16 @@
         ))
     return impl
 
+def raw_binop(func):
+    # Returns the result unwrapped.
+    @functools.wraps(func)
+    def impl(self, v1, v2):
+        return func(self,
+            self.for_computation(self.unbox(v1)),
+            self.for_computation(self.unbox(v2))
+        )
+    return impl
+
 def unaryop(func):
     @functools.wraps(func)
     def impl(self, v):
@@ -170,8 +180,24 @@
 
     def bool(self, v):
         return bool(self.for_computation(self.unbox(v)))
+    @raw_binop
+    def eq(self, v1, v2):
+        return v1 == v2
+    @raw_binop
     def ne(self, v1, v2):
-        return self.for_computation(self.unbox(v1)) != self.for_computation(self.unbox(v2))
+        return v1 != v2
+    @raw_binop
+    def lt(self, v1, v2):
+        return v1 < v2
+    @raw_binop
+    def le(self, v1, v2):
+        return v1 <= v2
+    @raw_binop
+    def gt(self, v1, v2):
+        return v1 > v2
+    @raw_binop
+    def ge(self, v1, v2):
+        return v1 >= v2
 
 
 class FloatArithmeticDtype(ArithmaticTypeMixin):
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -74,6 +74,13 @@
     descr_pow = _binop_impl("power")
     descr_mod = _binop_impl("mod")
 
+    descr_eq = _binop_impl("equal")
+    descr_ne = _binop_impl("not_equal")
+    descr_lt = _binop_impl("less")
+    descr_le = _binop_impl("less_equal")
+    descr_gt = _binop_impl("greater")
+    descr_ge = _binop_impl("greater_equal")
+
     def _binop_right_impl(ufunc_name):
         def impl(self, space, w_other):
             w_other = scalar_w(space,
@@ -404,10 +411,11 @@
     """
     Intermediate class for performing binary operations.
     """
-    def __init__(self, signature, res_dtype, left, right):
+    def __init__(self, signature, calc_dtype, res_dtype, left, right):
         VirtualArray.__init__(self, signature, res_dtype)
         self.left = left
         self.right = right
+        self.calc_dtype = calc_dtype
 
     def _del_sources(self):
         self.left = None
@@ -421,14 +429,14 @@
         return self.right.find_size()
 
     def _eval(self, i):
-        lhs = self.left.eval(i).convert_to(self.res_dtype)
-        rhs = self.right.eval(i).convert_to(self.res_dtype)
+        lhs = self.left.eval(i).convert_to(self.calc_dtype)
+        rhs = self.right.eval(i).convert_to(self.calc_dtype)
 
         sig = jit.promote(self.signature)
         assert isinstance(sig, signature.Signature)
         call_sig = sig.components[0]
         assert isinstance(call_sig, signature.Call2)
-        return call_sig.func(self.res_dtype, lhs, rhs)
+        return call_sig.func(self.calc_dtype, lhs, rhs)
 
 class ViewArray(BaseArray):
     """
@@ -573,18 +581,28 @@
     __pos__ = interp2app(BaseArray.descr_pos),
     __neg__ = interp2app(BaseArray.descr_neg),
     __abs__ = interp2app(BaseArray.descr_abs),
+
     __add__ = interp2app(BaseArray.descr_add),
     __sub__ = interp2app(BaseArray.descr_sub),
     __mul__ = interp2app(BaseArray.descr_mul),
     __div__ = interp2app(BaseArray.descr_div),
     __pow__ = interp2app(BaseArray.descr_pow),
     __mod__ = interp2app(BaseArray.descr_mod),
+
     __radd__ = interp2app(BaseArray.descr_radd),
     __rsub__ = interp2app(BaseArray.descr_rsub),
     __rmul__ = interp2app(BaseArray.descr_rmul),
     __rdiv__ = interp2app(BaseArray.descr_rdiv),
     __rpow__ = interp2app(BaseArray.descr_rpow),
     __rmod__ = interp2app(BaseArray.descr_rmod),
+
+    __eq__ = interp2app(BaseArray.descr_eq),
+    __ne__ = interp2app(BaseArray.descr_ne),
+    __lt__ = interp2app(BaseArray.descr_lt),
+    __le__ = interp2app(BaseArray.descr_le),
+    __gt__ = interp2app(BaseArray.descr_gt),
+    __ge__ = interp2app(BaseArray.descr_ge),
+
     __repr__ = interp2app(BaseArray.descr_repr),
     __str__ = interp2app(BaseArray.descr_str),
 
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -113,10 +113,11 @@
     argcount = 2
 
     def __init__(self, func, name, promote_to_float=False, promote_bools=False,
-        identity=None):
+        identity=None, comparison_func=False):
 
         W_Ufunc.__init__(self, name, promote_to_float, promote_bools, identity)
         self.func = func
+        self.comparison_func = comparison_func
         self.signature = signature.Call2(func)
         self.reduce_signature = signature.BaseSignature()
 
@@ -127,18 +128,25 @@
         [w_lhs, w_rhs] = args_w
         w_lhs = convert_to_array(space, w_lhs)
         w_rhs = convert_to_array(space, w_rhs)
-        res_dtype = find_binop_result_dtype(space,
+        calc_dtype = find_binop_result_dtype(space,
             w_lhs.find_dtype(), w_rhs.find_dtype(),
             promote_to_float=self.promote_to_float,
             promote_bools=self.promote_bools,
         )
+        if self.comparison_func:
+            res_dtype = space.fromcache(interp_dtype.W_BoolDtype)
+        else:
+            res_dtype = calc_dtype
         if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar):
-            return self.func(res_dtype, w_lhs.value, w_rhs.value).wrap(space)
+            return self.func(calc_dtype,
+                w_lhs.value.convert_to(calc_dtype),
+                w_rhs.value.convert_to(calc_dtype)
+            ).wrap(space)
 
         new_sig = signature.Signature.find_sig([
             self.signature, w_lhs.signature, w_rhs.signature
         ])
-        w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs)
+        w_res = Call2(new_sig, calc_dtype, res_dtype, w_lhs, w_rhs)
         w_lhs.add_invalidates(w_res)
         w_rhs.add_invalidates(w_res)
         return w_res
@@ -209,13 +217,16 @@
     return space.fromcache(interp_dtype.W_Float64Dtype)
 
 
-def ufunc_dtype_caller(ufunc_name, op_name, argcount):
+def ufunc_dtype_caller(space, ufunc_name, op_name, argcount, comparison_func):
     if argcount == 1:
         def impl(res_dtype, value):
             return getattr(res_dtype, op_name)(value)
     elif argcount == 2:
         def impl(res_dtype, lvalue, rvalue):
-            return getattr(res_dtype, op_name)(lvalue, rvalue)
+            res = getattr(res_dtype, op_name)(lvalue, rvalue)
+            if comparison_func:
+                res = space.fromcache(interp_dtype.W_BoolDtype).box(res)
+            return res
     return func_with_new_name(impl, ufunc_name)
 
 class UfuncState(object):
@@ -229,6 +240,13 @@
             ("mod", "mod", 2, {"promote_bools": True}),
             ("power", "pow", 2, {"promote_bools": True}),
 
+            ("equal", "eq", 2, {"comparison_func": True}),
+            ("not_equal", "ne", 2, {"comparison_func": True}),
+            ("less", "lt", 2, {"comparison_func": True}),
+            ("less_equal", "le", 2, {"comparison_func": True}),
+            ("greater", "gt", 2, {"comparison_func": True}),
+            ("greater_equal", "ge", 2, {"comparison_func": True}),
+
             ("maximum", "max", 2),
             ("minimum", "min", 2),
 
@@ -262,7 +280,9 @@
             identity = space.fromcache(interp_dtype.W_Int64Dtype).adapt_val(identity)
         extra_kwargs["identity"] = identity
 
-        func = ufunc_dtype_caller(ufunc_name, op_name, argcount)
+        func = ufunc_dtype_caller(space, ufunc_name, op_name, argcount,
+            comparison_func=extra_kwargs.get("comparison_func", False)
+        )
         if argcount == 1:
             ufunc = W_Ufunc1(func, ufunc_name, **extra_kwargs)
         elif argcount == 2:
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -557,6 +557,26 @@
         assert array([1.2, 5]).dtype is dtype(float)
         assert array([]).dtype is dtype(float)
 
+    def test_comparison(self):
+        import operator
+        from numpy import array, dtype
+
+        a = array(range(5))
+        b = array(range(5), float)
+        for func in [
+            operator.eq, operator.ne, operator.lt, operator.le, operator.gt,
+            operator.ge
+        ]:
+            c = func(a, 3)
+            assert c.dtype is dtype(bool)
+            for i in xrange(5):
+                assert c[i] == func(a[i], 3)
+
+            c = func(b, 3)
+            assert c.dtype is dtype(bool)
+            for i in xrange(5):
+                assert c[i] == func(b[i], 3)
+
 
 class AppTestSupport(object):
     def setup_class(cls):
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -310,4 +310,30 @@
         assert add.reduce([1, 2, 3]) == 6
         assert maximum.reduce([1]) == 1
         assert maximum.reduce([1, 2, 3]) == 3
-        raises(ValueError, maximum.reduce, [])
\ No newline at end of file
+        raises(ValueError, maximum.reduce, [])
+
+    def test_comparisons(self):
+        import operator
+        from numpy import equal, not_equal, less, less_equal, greater, greater_equal
+
+        for ufunc, func in [
+            (equal, operator.eq),
+            (not_equal, operator.ne),
+            (less, operator.lt),
+            (less_equal, operator.le),
+            (greater, operator.gt),
+            (greater_equal, operator.ge),
+        ]:
+            for a, b in [
+                (3, 3),
+                (3, 4),
+                (4, 3),
+                (3.0, 3.0),
+                (3.0, 3.5),
+                (3.5, 3.0),
+                (3.0, 3),
+                (3, 3.0),
+                (3.5, 3),
+                (3, 3.5),
+            ]:
+                assert ufunc(a, b) is func(a, b)
\ No newline at end of file


More information about the pypy-commit mailing list