[pypy-commit] pypy ndarray-round: implement ndarray.round, add failing tests for scalar.round

mattip noreply at buildbot.pypy.org
Tue Jun 25 22:20:22 CEST 2013


Author: Matti Picus <matti.picus at gmail.com>
Branch: ndarray-round
Changeset: r64985:c6d01b824d87
Date: 2013-06-25 21:48 +0300
http://bitbucket.org/pypy/pypy/changeset/c6d01b824d87/

Log:	implement ndarray.round, add failing tests for scalar.round

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -561,13 +561,23 @@
     @unwrap_spec(decimals=int)
     def descr_round(self, space, decimals=0, w_out=None):
         if space.is_none(w_out):
-            w_out = None
+            if self.get_dtype().is_bool_type():
+                #numpy promotes bool.round() to float16. Go figure.
+                w_out = W_NDimArray.from_shape(self.get_shape(),
+                       interp_dtype.get_dtype_cache(space).w_float16dtype)
+            else:
+                w_out = None
         elif not isinstance(w_out, W_NDimArray):
             raise OperationError(space.w_TypeError, space.wrap(
                 "return arrays must be of ArrayType"))
         out = interp_dtype.dtype_agreement(space, [self], self.get_shape(),
                                            w_out)
-        loop.round(space, self, self.get_shape(), decimals, out)
+        if out.get_dtype().is_bool_type() and self.get_dtype().is_bool_type():
+            calc_dtype = interp_dtype.get_dtype_cache(space).w_longdtype
+        else:
+            calc_dtype = out.get_dtype()
+
+        loop.round(space, self, calc_dtype, self.get_shape(), decimals, out)
         return out
 
     def descr_searchsorted(self, space, w_v, w_side='left'):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -12,7 +12,6 @@
 from pypy.module.micronumpy.iter import PureShapeIterator
 from pypy.module.micronumpy import constants
 from pypy.module.micronumpy.support import int_w
-from rpython.rlib.rfloat import round_double
 
 call2_driver = jit.JitDriver(name='numpy_call2',
                              greens = ['shapelen', 'func', 'calc_dtype',
@@ -576,16 +575,14 @@
 round_driver = jit.JitDriver(greens = ['shapelen', 'dtype'],
                                     reds = 'auto')
 
-def round(space, arr, shape, decimals, out):
+def round(space, arr, dtype, shape, decimals, out):
     arr_iter = arr.create_iter(shape)
-    dtype = out.get_dtype()
     shapelen = len(shape)
     out_iter = out.create_iter(shape)
     while not arr_iter.done():
         round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        #w_v = dtype.itemtype.round(arr_iter.getitem().convert_to(dtype),
-        #             decimals)
-        w_v = arr_iter.getitem().convert_to(dtype)
+        w_v = dtype.itemtype.round(arr_iter.getitem().convert_to(dtype),
+                     decimals)
         out_iter.setitem(w_v)
         arr_iter.next()
         out_iter.next()
diff --git a/pypy/module/micronumpy/test/test_scalar.py b/pypy/module/micronumpy/test/test_scalar.py
--- a/pypy/module/micronumpy/test/test_scalar.py
+++ b/pypy/module/micronumpy/test/test_scalar.py
@@ -21,3 +21,19 @@
 
         a = zeros(3)
         assert loads(dumps(sum(a))) == sum(a)
+
+    def test_round(self):
+        from numpypy import int32, float64, complex128, bool
+        i = int32(1337)
+        f = float64(13.37)
+        c = complex128(13 + 37.j)
+        b = bool(0)
+        assert i.round(decimals=-2) == 1300
+        assert i.round(decimals=1) == 1337
+        assert c.round() == c
+        assert f.round() == 13.
+        assert f.round(decimals=-1) == 10.
+        assert f.round(decimals=1) == 13.4
+        exc = raises(AttributeError, 'b.round()')
+        assert exc.value[0] == "'bool' object has no attribute 'round'"
+
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -335,10 +335,18 @@
         assert all([math.copysign(1, f(-abs(float("nan")))) == -1 for f in floor, ceil, trunc])
 
     def test_round(self):
-        from numpypy import array
+        from numpypy import array, dtype
         ninf, inf = float("-inf"), float("inf")
         a = array([ninf, -1.4, -1.5, -1.0, 0.0, 1.0, 1.4, 0.5, inf])
         assert ([ninf, -1.0, -2.0, -1.0, 0.0, 1.0, 1.0, 0.0, inf] == a.round()).all()
+        i = array([-1000, -100, -1, 0, 1, 111, 1111, 11111], dtype=int)
+        assert (i == i.round()).all()
+        assert (i.round(decimals=4) == i).all()
+        assert (i.round(decimals=-4) == [0, 0, 0, 0, 0, 0, 0, 10000]).all()
+        b = array([True, False], dtype=bool)
+        bround = b.round()
+        assert (bround == [1., 0.]).all()
+        assert bround.dtype is dtype('float16')
 
 
     def test_copysign(self):
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -37,7 +37,7 @@
         return self.box(
             func(
                 self,
-                self.for_computation(raw)
+                self.for_computation(raw),
             )
         )
     return dispatcher
@@ -521,6 +521,20 @@
             return v
         return 0
 
+    @specialize.argtype(1)
+    def round(self, v, decimals=0):
+        raw = self.unbox(v)
+        if decimals < 0:
+            factor = int(10 ** -decimals)
+            #int does floor division, we want toward zero
+            if raw < 0:
+                ans = - (-raw / factor * factor)
+            else:
+                ans = raw / factor * factor
+        else:
+            ans = raw
+        return self.box(ans)
+
     @raw_unary_op
     def signbit(self, v):
         return v < 0
@@ -798,6 +812,16 @@
     def ceil(self, v):
         return math.ceil(v)
 
+    @specialize.argtype(1)
+    def round(self, v, decimals=0):
+        raw = self.unbox(v)
+        if rfloat.isinf(raw):
+            return v
+        elif rfloat.isnan(raw):
+            return v
+        ans = rfloat.round_double(raw, decimals, half_even=True)
+        return self.box(ans)
+
     @simple_unary_op
     def trunc(self, v):
         if v < 0:


More information about the pypy-commit mailing list