[pypy-commit] pypy numpypy-axisops: translation fixes

mattip noreply at buildbot.pypy.org
Sun Dec 25 01:02:08 CET 2011


Author: mattip
Branch: numpypy-axisops
Changeset: r50847:f06e38ca0d00
Date: 2011-12-25 01:58 +0200
http://bitbucket.org/pypy/pypy/changeset/f06e38ca0d00/

Log:	translation fixes

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -277,10 +277,11 @@
     descr_rdiv = _binop_right_impl("divide")
     descr_rpow = _binop_right_impl("power")
     descr_rmod = _binop_right_impl("mod")
-
+    
     def _reduce_ufunc_impl(ufunc_name):
-        def impl(self, space, args_w):
-            return getattr(interp_ufuncs.get(space), ufunc_name).reduce(space, self, True, args_w)
+        def impl(self, space, w_dim=None):
+            return getattr(interp_ufuncs.get(space), ufunc_name).reduce(space,
+                                                       self, True, w_dim)
         return func_with_new_name(impl, "reduce_%s_impl" % ufunc_name)
 
     descr_sum = _reduce_ufunc_impl("add")
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -47,19 +47,16 @@
         return self.call(space, __args__.arguments_w)
 
     def descr_reduce(self, space, w_obj):
-        return self.reduce(space, w_obj, False)
+        return self.reduce(space, w_obj, False, space.wrap(-1))
 
-    def reduce(self, space, w_obj, multidim, args_w):
+    def reduce(self, space, w_obj, multidim, w_dim):
         from pypy.module.micronumpy.interp_numarray import convert_to_array, Scalar
         if self.argcount != 2:
             raise OperationError(space.w_ValueError, space.wrap("reduce only "
                 "supported for binary functions"))
         dim = -1
-        if multidim and len(args_w)>0:
-            dim = space.int_w(args_w[0])
-        if len(args_w)>1:
-            raise OperationError(space.w_TypeError, space.wrap(
-                 self.name + " recieved extra arguments"))
+        if not space.is_w(w_dim, space.w_None):
+            dim = space.int_w(w_dim)
         assert isinstance(self, W_Ufunc2)
         obj = convert_to_array(space, w_obj)
         if isinstance(obj, Scalar):
@@ -72,6 +69,7 @@
             promote_to_largest=True
         )
         shapelen = len(obj.shape)
+        #TODO: if dim>=0 return a ArraySignature?
         sig = find_sig(ReduceSignature(self.func, self.name, dtype,
                                        ScalarSignature(dtype),
                                        obj.create_sig(obj.shape)), obj)
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -718,7 +718,7 @@
 
         a = array([True] * 5, bool)
         assert a.sum() == 5
-        
+
         raises(TypeError, 'a.sum(2, 3)')
         a = arange(15).reshape(5, 3)
         assert (a.sum(0) == [30, 35, 40]).all()
@@ -730,19 +730,19 @@
         a = identity(0)
         assert len(a) == 0
         assert a.dtype == dtype('float64')
-        assert a.shape == (0,0)
+        assert a.shape == (0, 0)
         b = identity(1, dtype=int32)
         assert len(b) == 1
         assert b[0][0] == 1
-        assert b.shape == (1,1)
+        assert b.shape == (1, 1)
         assert b.dtype == dtype('int32')
         c = identity(2)
-        assert c.shape == (2,2)
-        assert (c == [[1,0],[0,1]]).all()
+        assert c.shape == (2, 2)
+        assert (c == [[1, 0], [0, 1]]).all()
         d = identity(3, dtype='int32')
-        assert d.shape == (3,3)
+        assert d.shape == (3, 3)
         assert d.dtype == dtype('int32')
-        assert (d == [[1,0,0],[0,1,0],[0,0,1]]).all()
+        assert (d == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]).all()
 
     def test_prod(self):
         from numpypy import array
@@ -950,13 +950,13 @@
 
     def test_tolist_view(self):
         from numpypy import array
-        a = array([[1,2],[3,4]])
+        a = array([[1, 2], [3, 4]])
         assert (a + a).tolist() == [[2, 4], [6, 8]]
 
     def test_tolist_slice(self):
         from numpypy import array
         a = array([[17.1, 27.2], [40.3, 50.3]])
-        assert a[:,0].tolist() == [17.1, 40.3]
+        assert a[:, 0].tolist() == [17.1, 40.3]
         assert a[0].tolist() == [17.1, 27.2]
 
 
@@ -1086,11 +1086,11 @@
         from numpypy import zeros, ones
         a = zeros((3, 3))
         b = ones((3, 3))
-        a[:,1:3] = b[:,1:3]
+        a[:, 1:3] = b[:, 1:3]
         assert (a == [[0, 1, 1], [0, 1, 1], [0, 1, 1]]).all()
         a = zeros((3, 3))
         b = ones((3, 3))
-        a[:,::2] = b[:,::2]
+        a[:, ::2] = b[:, ::2]
         assert (a == [[1, 0, 1], [1, 0, 1], [1, 0, 1]]).all()
 
     def test_broadcast_ufunc(self):
@@ -1271,17 +1271,17 @@
         assert g[1] == 2
         assert g[2] == 3
         h = fromstring("1, , 2, 3", dtype=uint8, sep=",")
-        assert (h == [1,0,2,3]).all()
+        assert (h == [1, 0, 2, 3]).all()
         i = fromstring("1    2 3", dtype=uint8, sep=" ")
-        assert (i == [1,2,3]).all()
+        assert (i == [1, 2, 3]).all()
         j = fromstring("1\t\t\t\t2\t3", dtype=uint8, sep="\t")
-        assert (j == [1,2,3]).all()
+        assert (j == [1, 2, 3]).all()
         k = fromstring("1,x,2,3", dtype=uint8, sep=",")
-        assert (k == [1,0]).all()
+        assert (k == [1, 0]).all()
         l = fromstring("1,x,2,3", dtype='float32', sep=",")
-        assert (l == [1.0,-1.0]).all()
+        assert (l == [1.0, -1.0]).all()
         m = fromstring("1,,2,3", sep=",")
-        assert (m == [1.0,-1.0,2.0,3.0]).all()
+        assert (m == [1.0, -1.0, 2.0, 3.0]).all()
         n = fromstring("3.4 2.0 3.8 2.2", dtype=int32, sep=" ")
         assert (n == [3]).all()
         o = fromstring("1.0 2f.0f 3.8 2.2", dtype=float32, sep=" ")
@@ -1329,7 +1329,6 @@
         j = fromstring(self.ulongval, dtype='L')
         assert j[0] == 12
 
-
     def test_fromstring_invalid(self):
         from numpypy import fromstring, uint16, uint8, int32
         #default dtype is 64-bit float, so 3 bytes should fail


More information about the pypy-commit mailing list