[pypy-commit] pypy ndarray-round: start to implement round(), need to add it to types.py

mattip noreply at buildbot.pypy.org
Mon Jun 24 23:56:09 CEST 2013


Author: Matti Picus <matti.picus at gmail.com>
Branch: ndarray-round
Changeset: r64970:03aaf579cabc
Date: 2013-06-25 00:54 +0300
http://bitbucket.org/pypy/pypy/changeset/03aaf579cabc/

Log:	start to implement round(), need to add it to types.py

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -558,9 +558,17 @@
         raise OperationError(space.w_NotImplementedError, space.wrap(
             "resize not implemented yet"))
 
-    def descr_round(self, space, w_decimals=0, w_out=None):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            "round not implemented yet"))
+    @unwrap_spec(decimals=int)
+    def descr_round(self, space, decimals=0, w_out=None):
+        if space.is_none(w_out):
+            w_out = None
+        elif not isinstance(w_out, W_NDimArray):
+            raise OperationError(space.w_TypeError, space.wrap(
+                "return arrays must be of ArrayType"))
+        out = interp_dtype.dtype_agreement(space, [self], self.get_shape(),
+                                           w_out)
+        loop.round(space, self, self.get_shape(), decimals, out)
+        return out
 
     def descr_searchsorted(self, space, w_v, w_side='left'):
         raise OperationError(space.w_NotImplementedError, space.wrap(
@@ -975,6 +983,7 @@
     byteswap = interp2app(W_NDimArray.descr_byteswap),
     choose   = interp2app(W_NDimArray.descr_choose),
     clip     = interp2app(W_NDimArray.descr_clip),
+    round    = interp2app(W_NDimArray.descr_round),
     data     = GetSetProperty(W_NDimArray.descr_get_data),
     diagonal = interp2app(W_NDimArray.descr_diagonal),
 
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -12,6 +12,7 @@
 from pypy.module.micronumpy.iter import PureShapeIterator
 from pypy.module.micronumpy import constants
 from pypy.module.micronumpy.support import int_w
+from rpython.rlib.rfloat import round_double
 
 call2_driver = jit.JitDriver(name='numpy_call2',
                              greens = ['shapelen', 'func', 'calc_dtype',
@@ -173,7 +174,7 @@
         iter = x_iter
     shapelen = len(shape)
     while not iter.done():
-        where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype, 
+        where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
                                         arr_dtype=arr_dtype)
         w_cond = arr_iter.getitem()
         if arr_dtype.itemtype.bool(w_cond):
@@ -188,7 +189,7 @@
     return out
 
 axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
-                                    greens=['shapelen', 
+                                    greens=['shapelen',
                                             'func', 'dtype',
                                             'identity'],
                                     reds='auto')
@@ -228,7 +229,7 @@
     arg_driver = jit.JitDriver(name='numpy_' + op_name,
                                greens = ['shapelen', 'dtype'],
                                reds = 'auto')
-    
+
     def argmin_argmax(arr):
         result = 0
         idx = 1
@@ -265,7 +266,7 @@
      result.shape == [3, 5, 2, 4]
      broadcast shape should be [3, 5, 2, 7, 4]
      result should skip dims 3 which is len(result_shape) - 1
-        (note that if right is 1d, result should 
+        (note that if right is 1d, result should
                   skip len(result_shape))
      left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
           except where it==(right.ndims-2)
@@ -283,9 +284,9 @@
     righti = right.create_dot_iter(broadcast_shape, right_skip)
     while not outi.done():
         dot_driver.jit_merge_point(dtype=dtype)
-        lval = lefti.getitem().convert_to(dtype) 
-        rval = righti.getitem().convert_to(dtype) 
-        outval = outi.getitem().convert_to(dtype) 
+        lval = lefti.getitem().convert_to(dtype)
+        rval = righti.getitem().convert_to(dtype)
+        outval = outi.getitem().convert_to(dtype)
         v = dtype.itemtype.mul(lval, rval)
         value = dtype.itemtype.add(v, outval).convert_to(dtype)
         outi.setitem(value)
@@ -355,7 +356,7 @@
         setitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
-                                             ) 
+                                             )
         if index_iter.getitem_bool():
             arr_iter.setitem(value_iter.getitem())
             value_iter.next()
@@ -572,6 +573,23 @@
         out_iter.next()
         min_iter.next()
 
+round_driver = jit.JitDriver(greens = ['shapelen', 'dtype'],
+                                    reds = 'auto')
+
+def round(space, arr, shape, decimals, out):
+    arr_iter = arr.create_iter(shape)
+    dtype = out.get_dtype()
+    shapelen = len(shape)
+    out_iter = out.create_iter(shape)
+    while not arr_iter.done():
+        round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
+        #w_v = dtype.itemtype.round(arr_iter.getitem().convert_to(dtype),
+        #             decimals)
+        w_v = arr_iter.getitem().convert_to(dtype)
+        out_iter.setitem(w_v)
+        arr_iter.next()
+        out_iter.next()
+
 diagonal_simple_driver = jit.JitDriver(greens = ['axis1', 'axis2'],
                                        reds = 'auto')
 
@@ -613,4 +631,4 @@
         out_iter.setitem(arr.getitem_index(space, indexes))
         iter.next()
         out_iter.next()
-       
+
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -334,7 +334,7 @@
         assert all([math.copysign(1, f(abs(float("nan")))) == 1 for f in floor, ceil, trunc])
         assert all([math.copysign(1, f(-abs(float("nan")))) == -1 for f in floor, ceil, trunc])
 
-    def test_around(self):
+    def test_round(self):
         from numpypy import array
         ninf, inf = float("-inf"), float("inf")
         a = array([ninf, -1.4, -1.5, -1.0, 0.0, 1.0, 1.4, 0.5, inf])


More information about the pypy-commit mailing list