[pypy-commit] pypy numpy-refactor: fix concatenate

fijal noreply at buildbot.pypy.org
Wed Sep 5 22:08:50 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57161:9dbfeb3c0ff1
Date: 2012-09-05 22:04 +0200
http://bitbucket.org/pypy/pypy/changeset/9dbfeb3c0ff1/

Log:	fix concatenate

diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -29,7 +29,7 @@
         'fromstring': 'interp_support.fromstring',
         'flatiter': 'interp_flatiter.W_FlatIterator',
         'isna': 'interp_numarray.isna',
-        'concatenate': 'interp_numarray.concatenate',
+        'concatenate': 'interp_arrayops.concatenate',
         'repeat': 'interp_numarray.repeat',
         'where': 'interp_arrayops.where',
 
diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -1,7 +1,9 @@
 
 from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
-from pypy.module.micronumpy import loop
-from pypy.interpreter.error import OperationError
+from pypy.module.micronumpy import loop, interp_ufuncs
+from pypy.module.micronumpy.iter import Chunk, Chunks
+from pypy.interpreter.error import OperationError, operationerrfmt
+from pypy.interpreter.gateway import unwrap_spec
 
 def where(space, w_arr, w_x=None, w_y=None):
     """where(condition, [x, y])
@@ -78,3 +80,35 @@
     if w_arr.is_scalar():
         return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
     return w_arr.descr_dot(space, w_obj2)
+
+
+ at unwrap_spec(axis=int)
+def concatenate(space, w_args, axis=0):
+    args_w = space.listview(w_args)
+    if len(args_w) == 0:
+        raise OperationError(space.w_ValueError, space.wrap("need at least one array to concatenate"))
+    args_w = [convert_to_array(space, w_arg) for w_arg in args_w]
+    dtype = args_w[0].get_dtype()
+    shape = args_w[0].get_shape()[:]
+    if len(shape) <= axis:
+        raise operationerrfmt(space.w_IndexError, "axis %d out of bounds [0, %d)", axis, len(shape))
+    for arr in args_w[1:]:
+        dtype = interp_ufuncs.find_binop_result_dtype(space, dtype,
+                                                      arr.get_dtype())
+        if len(arr.get_shape()) <= axis:
+            raise operationerrfmt(space.w_IndexError, "axis %d out of bounds [0, %d)", axis, len(shape))
+        for i, axis_size in enumerate(arr.get_shape()):
+            if len(arr.get_shape()) != len(shape) or (i != axis and axis_size != shape[i]):
+                raise OperationError(space.w_ValueError, space.wrap(
+                    "all the input arrays must have same number of dimensions"))
+            elif i == axis:
+                shape[i] += axis_size
+    res = W_NDimArray.from_shape(shape, dtype, 'C')
+    chunks = [Chunk(0, i, 1, i) for i in shape]
+    axis_start = 0
+    for arr in args_w:
+        chunks[axis] = Chunk(axis_start, axis_start + arr.get_shape()[axis], 1,
+                             arr.get_shape()[axis])
+        Chunks(chunks).apply(res.implementation).implementation.setslice(space, arr)
+        axis_start += arr.get_shape()[axis]
+    return res
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -475,9 +475,6 @@
 def isna(space):
     pass
 
-def concatenate(space):
-    pass
-
 def repeat(space):
     pass
 


More information about the pypy-commit mailing list