[pypy-commit] pypy numpy-refactor: fix concatenate
fijal
noreply at buildbot.pypy.org
Wed Sep 5 22:08:50 CEST 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57161:9dbfeb3c0ff1
Date: 2012-09-05 22:04 +0200
http://bitbucket.org/pypy/pypy/changeset/9dbfeb3c0ff1/
Log: fix concatenate
diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -29,7 +29,7 @@
'fromstring': 'interp_support.fromstring',
'flatiter': 'interp_flatiter.W_FlatIterator',
'isna': 'interp_numarray.isna',
- 'concatenate': 'interp_numarray.concatenate',
+ 'concatenate': 'interp_arrayops.concatenate',
'repeat': 'interp_numarray.repeat',
'where': 'interp_arrayops.where',
diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -1,7 +1,9 @@
from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
-from pypy.module.micronumpy import loop
-from pypy.interpreter.error import OperationError
+from pypy.module.micronumpy import loop, interp_ufuncs
+from pypy.module.micronumpy.iter import Chunk, Chunks
+from pypy.interpreter.error import OperationError, operationerrfmt
+from pypy.interpreter.gateway import unwrap_spec
def where(space, w_arr, w_x=None, w_y=None):
"""where(condition, [x, y])
@@ -78,3 +80,35 @@
if w_arr.is_scalar():
return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
return w_arr.descr_dot(space, w_obj2)
+
+
+ at unwrap_spec(axis=int)
+def concatenate(space, w_args, axis=0):
+ args_w = space.listview(w_args)
+ if len(args_w) == 0:
+ raise OperationError(space.w_ValueError, space.wrap("need at least one array to concatenate"))
+ args_w = [convert_to_array(space, w_arg) for w_arg in args_w]
+ dtype = args_w[0].get_dtype()
+ shape = args_w[0].get_shape()[:]
+ if len(shape) <= axis:
+ raise operationerrfmt(space.w_IndexError, "axis %d out of bounds [0, %d)", axis, len(shape))
+ for arr in args_w[1:]:
+ dtype = interp_ufuncs.find_binop_result_dtype(space, dtype,
+ arr.get_dtype())
+ if len(arr.get_shape()) <= axis:
+ raise operationerrfmt(space.w_IndexError, "axis %d out of bounds [0, %d)", axis, len(shape))
+ for i, axis_size in enumerate(arr.get_shape()):
+ if len(arr.get_shape()) != len(shape) or (i != axis and axis_size != shape[i]):
+ raise OperationError(space.w_ValueError, space.wrap(
+ "all the input arrays must have same number of dimensions"))
+ elif i == axis:
+ shape[i] += axis_size
+ res = W_NDimArray.from_shape(shape, dtype, 'C')
+ chunks = [Chunk(0, i, 1, i) for i in shape]
+ axis_start = 0
+ for arr in args_w:
+ chunks[axis] = Chunk(axis_start, axis_start + arr.get_shape()[axis], 1,
+ arr.get_shape()[axis])
+ Chunks(chunks).apply(res.implementation).implementation.setslice(space, arr)
+ axis_start += arr.get_shape()[axis]
+ return res
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -475,9 +475,6 @@
def isna(space):
pass
-def concatenate(space):
- pass
-
def repeat(space):
pass
More information about the pypy-commit
mailing list