[pypy-commit] pypy scalar-operations: simplify handling of np.array()'s ndmin parameter

rlamy noreply at buildbot.pypy.org
Sat Jun 28 20:30:57 CEST 2014


Author: Ronan Lamy <ronan.lamy at gmail.com>
Branch: scalar-operations
Changeset: r72271:4f27ccc29f40
Date: 2014-06-28 19:30 +0100
http://bitbucket.org/pypy/pypy/changeset/4f27ccc29f40/

Log:	simplify handling of np.array()'s ndmin parameter

diff --git a/pypy/module/micronumpy/ctors.py b/pypy/module/micronumpy/ctors.py
--- a/pypy/module/micronumpy/ctors.py
+++ b/pypy/module/micronumpy/ctors.py
@@ -41,6 +41,18 @@
 @unwrap_spec(ndmin=int, copy=bool, subok=bool)
 def array(space, w_object, w_dtype=None, copy=True, w_order=None, subok=False,
           ndmin=0):
+    w_res = _array(space, w_object, w_dtype, copy, w_order, subok)
+    shape = w_res.get_shape()
+    if len(shape) < ndmin:
+        shape = [1] * (ndmin - len(shape)) + shape
+        impl = w_res.implementation.set_shape(space, w_res, shape)
+        if w_res is w_object:
+            return W_NDimArray(impl)
+        else:
+            w_res.implementation = impl
+    return w_res
+
+def _array(space, w_object, w_dtype=None, copy=True, w_order=None, subok=False):
     from pypy.module.micronumpy import strides
 
     # for anything that isn't already an array, try __array__ method first
@@ -65,19 +77,10 @@
     # arrays with correct dtype
     if isinstance(w_object, W_NDimArray) and \
             (space.is_none(w_dtype) or w_object.get_dtype() is dtype):
-        shape = w_object.get_shape()
         if copy:
-            w_ret = w_object.descr_copy(space)
+            return w_object.descr_copy(space)
         else:
-            if ndmin <= len(shape):
-                return w_object
-            new_impl = w_object.implementation.set_shape(space, w_object, shape)
-            w_ret = W_NDimArray(new_impl)
-        if ndmin > len(shape):
-            shape = [1] * (ndmin - len(shape)) + shape
-            w_ret.implementation = w_ret.implementation.set_shape(space,
-                                                                  w_ret, shape)
-        return w_ret
+            return w_object
 
     # not an array or incorrect dtype
     shape, elems_w = strides.find_shape_and_elems(space, w_object, dtype)
@@ -89,8 +92,6 @@
             # promote S0 -> S1, U0 -> U1
             dtype = descriptor.variable_dtype(space, dtype.char + '1')
 
-    if ndmin > len(shape):
-        shape = [1] * (ndmin - len(shape)) + shape
     w_arr = W_NDimArray.from_shape(space, shape, dtype, order=order)
     if len(elems_w) == 1:
         w_arr.set_scalar_value(dtype.coerce(space, elems_w[0]))


More information about the pypy-commit mailing list