[pypy-commit] pypy numpy-back-to-applevel: implement isnan/isinf

fijal noreply at buildbot.pypy.org
Tue Jan 24 19:21:31 CET 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-back-to-applevel
Changeset: r51735:200be74121ce
Date: 2012-01-24 20:21 +0200
http://bitbucket.org/pypy/pypy/changeset/200be74121ce/

Log:	implement isnan/isinf

diff --git a/lib_pypy/numpypy/core/arrayprint.py b/lib_pypy/numpypy/core/arrayprint.py
--- a/lib_pypy/numpypy/core/arrayprint.py
+++ b/lib_pypy/numpypy/core/arrayprint.py
@@ -14,9 +14,9 @@
 
 import sys
 import _numpypy as _nt
-from _numpypy import maximum, minimum, absolute, not_equal #, isnan, isinf
+from _numpypy import maximum, minimum, absolute, not_equal, isinf, isnan
 #from _numpypy import format_longfloat, datetime_as_string, datetime_data, isna
-from fromnumeric import ravel
+from .fromnumeric import ravel
 
 
 def product(x, y): return x*y
diff --git a/lib_pypy/numpypy/core/fromnumeric.py b/lib_pypy/numpypy/core/fromnumeric.py
--- a/lib_pypy/numpypy/core/fromnumeric.py
+++ b/lib_pypy/numpypy/core/fromnumeric.py
@@ -30,7 +30,7 @@
            'rank', 'size', 'around', 'round_', 'mean', 'std', 'var', 'squeeze',
            'amax', 'amin',
           ]
-          
+
 def take(a, indices, axis=None, out=None, mode='raise'):
     """
     Take elements from an array along an axis.
diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -93,6 +93,8 @@
         ("tan", "tan"),
         ('bitwise_and', 'bitwise_and'),
         ('bitwise_or', 'bitwise_or'),
+        ('isnan', 'isnan'),
+        ('isinf', 'isinf'),
     ]:
         interpleveldefs[exposed] = "interp_ufuncs.get(space).%s" % impl
 
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -735,11 +735,12 @@
 
 
 class Call1(VirtualArray):
-    def __init__(self, ufunc, name, shape, res_dtype, values):
+    def __init__(self, ufunc, name, shape, calc_dtype, res_dtype, values):
         VirtualArray.__init__(self, name, shape, res_dtype)
         self.values = values
         self.size = values.size
         self.ufunc = ufunc
+        self.calc_dtype = calc_dtype
 
     def _del_sources(self):
         self.values = None
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -248,10 +248,11 @@
     _immutable_fields_ = ["func", "name"]
 
     def __init__(self, func, name, promote_to_float=False, promote_bools=False,
-        identity=None):
+        identity=None, bool_result=False):
 
         W_Ufunc.__init__(self, name, promote_to_float, promote_bools, identity)
         self.func = func
+        self.bool_result = bool_result
 
     def call(self, space, args_w):
         from pypy.module.micronumpy.interp_numarray import (Call1,
@@ -259,15 +260,19 @@
 
         [w_obj] = args_w
         w_obj = convert_to_array(space, w_obj)
-        res_dtype = find_unaryop_result_dtype(space,
-            w_obj.find_dtype(),
-            promote_to_float=self.promote_to_float,
-            promote_bools=self.promote_bools,
-        )
+        calc_dtype = find_unaryop_result_dtype(space,
+                                  w_obj.find_dtype(),
+                                  promote_to_float=self.promote_to_float,
+                                  promote_bools=self.promote_bools)
+        if self.bool_result:
+            res_dtype = interp_dtype.get_dtype_cache(space).w_booldtype
+        else:
+            res_dtype = calc_dtype
         if isinstance(w_obj, Scalar):
-            return self.func(res_dtype, w_obj.value.convert_to(res_dtype))
+            return self.func(calc_dtype, w_obj.value.convert_to(calc_dtype))
 
-        w_res = Call1(self.func, self.name, w_obj.shape, res_dtype, w_obj)
+        w_res = Call1(self.func, self.name, w_obj.shape, calc_dtype, res_dtype,
+                      w_obj)
         w_obj.add_invalidates(w_res)
         return w_res
 
@@ -433,12 +438,16 @@
     return interp_dtype.get_dtype_cache(space).w_float64dtype
 
 
-def ufunc_dtype_caller(space, ufunc_name, op_name, argcount, comparison_func):
+def ufunc_dtype_caller(space, ufunc_name, op_name, argcount, comparison_func,
+                       bool_result):
+    dtype_cache = interp_dtype.get_dtype_cache(space)
     if argcount == 1:
         def impl(res_dtype, value):
-            return getattr(res_dtype.itemtype, op_name)(value)
+            res = getattr(res_dtype.itemtype, op_name)(value)
+            if bool_result:
+                return dtype_cache.w_booldtype.box(res)
+            return res
     elif argcount == 2:
-        dtype_cache = interp_dtype.get_dtype_cache(space)
         def impl(res_dtype, lvalue, rvalue):
             res = getattr(res_dtype.itemtype, op_name)(lvalue, rvalue)
             if comparison_func:
@@ -468,6 +477,8 @@
             ("less_equal", "le", 2, {"comparison_func": True}),
             ("greater", "gt", 2, {"comparison_func": True}),
             ("greater_equal", "ge", 2, {"comparison_func": True}),
+            ("isnan", "isnan", 1, {"bool_result": True}),
+            ("isinf", "isinf", 1, {"bool_result": True}),
 
             ("maximum", "max", 2),
             ("minimum", "min", 2),
@@ -510,6 +521,7 @@
 
         func = ufunc_dtype_caller(space, ufunc_name, op_name, argcount,
             comparison_func=extra_kwargs.get("comparison_func", False),
+            bool_result=extra_kwargs.get("bool_result", False),
         )
         if argcount == 1:
             ufunc = W_Ufunc1(func, ufunc_name, **extra_kwargs)
diff --git a/pypy/module/micronumpy/signature.py b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -293,8 +293,8 @@
     def eval(self, frame, arr):
         from pypy.module.micronumpy.interp_numarray import Call1
         assert isinstance(arr, Call1)
-        v = self.child.eval(frame, arr.values).convert_to(arr.res_dtype)
-        return self.unfunc(arr.res_dtype, v)
+        v = self.child.eval(frame, arr.values).convert_to(arr.calc_dtype)
+        return self.unfunc(arr.calc_dtype, v)
 
 class Call2(Signature):
     _immutable_fields_ = ['binfunc', 'name', 'calc_dtype', 'left', 'right']
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -413,3 +413,17 @@
     def test_true_divide(self):
         from _numpypy import arange, array, true_divide
         assert (true_divide(arange(3), array([2, 2, 2])) == array([0, 0.5, 1])).all()
+
+    def test_isnan_isinf(self):
+        from _numpypy import isnan, isinf, float64, array
+        assert isnan(float('nan'))
+        assert isnan(float64(float('nan')))
+        assert not isnan(3)
+        assert isinf(float('inf'))
+        assert not isnan(3.5)
+        assert not isinf(3.5)
+        assert not isnan(float('inf'))
+        assert not isinf(float('nan'))
+        assert (isnan(array([0.2, float('inf'), float('nan')])) == [False, False, True]).all()
+        assert (isinf(array([0.2, float('inf'), float('nan')])) == [False, True, False]).all()
+        assert isinf(array([0.2])).dtype.kind == 'b'
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -23,6 +23,16 @@
         )
     return dispatcher
 
+def raw_unary_op(func):
+    specialize.argtype(1)(func)
+    @functools.wraps(func)
+    def dispatcher(self, v):
+        return func(
+            self,
+            self.for_computation(self.unbox(v))
+        )
+    return dispatcher
+
 def simple_binary_op(func):
     specialize.argtype(1, 2)(func)
     @functools.wraps(func)
@@ -137,6 +147,14 @@
     def abs(self, v):
         return abs(v)
 
+    @raw_unary_op
+    def isnan(self, v):
+        return False
+
+    @raw_unary_op
+    def isinf(self, v):
+        return False
+
     @raw_binary_op
     def eq(self, v1, v2):
         return v1 == v2
@@ -448,6 +466,14 @@
         except ValueError:
             return rfloat.NAN
 
+    @raw_unary_op
+    def isnan(self, v):
+        return rfloat.isnan(v)
+
+    @raw_unary_op
+    def isinf(self, v):
+        return rfloat.isinf(v)
+
 
 class Float32(BaseType, Float):
     T = rffi.FLOAT


More information about the pypy-commit mailing list