[pypy-commit] pypy default: provide ndarray.squeeze
bdkearns
noreply at buildbot.pypy.org
Mon Nov 11 04:17:18 CET 2013
Author: Brian Kearns <bdkearns at gmail.com>
Branch:
Changeset: r67943:00ec1be89de6
Date: 2013-11-10 22:14 -0500
http://bitbucket.org/pypy/pypy/changeset/00ec1be89de6/
Log: provide ndarray.squeeze
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -694,8 +694,16 @@
return self.implementation.sort(space, w_axis, w_order)
def descr_squeeze(self, space, w_axis=None):
- raise OperationError(space.w_NotImplementedError, space.wrap(
- "squeeze not implemented yet"))
+ if not space.is_none(w_axis):
+ raise OperationError(space.w_NotImplementedError, space.wrap(
+ "axis unsupported for squeeze"))
+ cur_shape = self.get_shape()
+ new_shape = [s for s in cur_shape if s != 1]
+ if len(cur_shape) == len(new_shape):
+ return self
+ return wrap_impl(space, space.type(self), self,
+ self.implementation.get_view(
+ self, self.get_dtype(), new_shape))
def descr_strides(self, space):
raise OperationError(space.w_NotImplementedError, space.wrap(
@@ -705,7 +713,7 @@
raise OperationError(space.w_NotImplementedError, space.wrap(
"tofile not implemented yet"))
- def descr_view(self, space, w_dtype=None, w_type=None) :
+ def descr_view(self, space, w_dtype=None, w_type=None):
if not w_type and w_dtype:
try:
if space.is_true(space.issubtype(w_dtype, space.gettypefor(W_NDimArray))):
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1723,9 +1723,13 @@
def test_squeeze(self):
import numpy as np
a = np.array([1,2,3])
- import sys
- if '__pypy__' in sys.builtin_module_names:
- raises(NotImplementedError, a.squeeze)
+ assert a.squeeze() is a
+ a = np.array([[1,2,3]])
+ b = a.squeeze()
+ assert b.shape == (3,)
+ assert (b == a).all()
+ b[1] = -1
+ assert a[0][1] == -1
def test_swapaxes(self):
from numpypy import array
More information about the pypy-commit
mailing list