[pypy-commit] pypy default: provide ndarray.squeeze

bdkearns noreply at buildbot.pypy.org
Mon Nov 11 04:17:18 CET 2013


Author: Brian Kearns <bdkearns at gmail.com>
Branch: 
Changeset: r67943:00ec1be89de6
Date: 2013-11-10 22:14 -0500
http://bitbucket.org/pypy/pypy/changeset/00ec1be89de6/

Log:	provide ndarray.squeeze

diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -694,8 +694,16 @@
         return self.implementation.sort(space, w_axis, w_order)
 
     def descr_squeeze(self, space, w_axis=None):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            "squeeze not implemented yet"))
+        if not space.is_none(w_axis):
+            raise OperationError(space.w_NotImplementedError, space.wrap(
+                "axis unsupported for squeeze"))
+        cur_shape = self.get_shape()
+        new_shape = [s for s in cur_shape if s != 1]
+        if len(cur_shape) == len(new_shape):
+            return self
+        return wrap_impl(space, space.type(self), self,
+                         self.implementation.get_view(
+                             self, self.get_dtype(), new_shape))
 
     def descr_strides(self, space):
         raise OperationError(space.w_NotImplementedError, space.wrap(
@@ -705,7 +713,7 @@
         raise OperationError(space.w_NotImplementedError, space.wrap(
             "tofile not implemented yet"))
 
-    def descr_view(self, space, w_dtype=None, w_type=None) :
+    def descr_view(self, space, w_dtype=None, w_type=None):
         if not w_type and w_dtype:
             try:
                 if space.is_true(space.issubtype(w_dtype, space.gettypefor(W_NDimArray))):
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1723,9 +1723,13 @@
     def test_squeeze(self):
         import numpy as np
         a = np.array([1,2,3])
-        import sys
-        if '__pypy__' in sys.builtin_module_names:
-            raises(NotImplementedError, a.squeeze)
+        assert a.squeeze() is a
+        a = np.array([[1,2,3]])
+        b = a.squeeze()
+        assert b.shape == (3,)
+        assert (b == a).all()
+        b[1] = -1
+        assert a[0][1] == -1
 
     def test_swapaxes(self):
         from numpypy import array


More information about the pypy-commit mailing list