[pypy-commit] pypy default: merge numpy-concatenate and move it to interp-level

fijal noreply at buildbot.pypy.org
Thu Jan 26 16:22:19 CET 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: 
Changeset: r51800:8d9c1bd7e3be
Date: 2012-01-26 17:21 +0200
http://bitbucket.org/pypy/pypy/changeset/8d9c1bd7e3be/

Log:	merge numpy-concatenate and move it to interp-level

diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -28,6 +28,7 @@
         'fromstring': 'interp_support.fromstring',
         'flatiter': 'interp_numarray.W_FlatIterator',
         'isna': 'interp_numarray.isna',
+        'concatenate': 'interp_numarray.concatenate',
 
         'True_': 'types.Bool.True',
         'False_': 'types.Bool.False',
diff --git a/pypy/module/micronumpy/app_numpy.py b/pypy/module/micronumpy/app_numpy.py
--- a/pypy/module/micronumpy/app_numpy.py
+++ b/pypy/module/micronumpy/app_numpy.py
@@ -59,7 +59,7 @@
     if not hasattr(a, "max"):
         a = _numpypy.array(a)
     return a.max(axis)
-
+    
 def arange(start, stop=None, step=1, dtype=None):
     '''arange([start], stop[, step], dtype=None)
     Generate values in the half-interval [start, stop).
diff --git a/pypy/module/micronumpy/interp_iter.py b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -17,6 +17,10 @@
         if self.step != 0:
             shape.append(self.lgt)
 
+    def __repr__(self):
+        return 'Chunk(%d, %d, %d, %d)' % (self.start, self.stop, self.step,
+                                          self.lgt)
+
 class BaseTransform(object):
     pass
 
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -1,6 +1,6 @@
 from pypy.interpreter.baseobjspace import Wrappable
 from pypy.interpreter.error import OperationError, operationerrfmt
-from pypy.interpreter.gateway import interp2app, NoneNotWrapped
+from pypy.interpreter.gateway import interp2app, NoneNotWrapped, unwrap_spec
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.module.micronumpy import interp_ufuncs, interp_dtype, signature,\
      interp_boxes
@@ -1341,6 +1341,42 @@
         return convert_to_array(space, w_obj2).descr_dot(space, w_arr)
     return w_arr.descr_dot(space, w_obj2)
 
+ at unwrap_spec(axis=int)
+def concatenate(space, w_args, axis=0):
+    args_w = space.listview(w_args)
+    if len(args_w) == 0:
+        raise OperationError(space.w_ValueError, space.wrap("concatenation of zero-length sequences is impossible"))
+    args_w = [convert_to_array(space, w_arg) for w_arg in args_w]
+    dtype = args_w[0].find_dtype()
+    shape = args_w[0].shape[:]
+    if len(shape) <= axis:
+        raise OperationError(space.w_ValueError,
+                             space.wrap("bad axis argument"))
+    for arr in args_w[1:]:
+        dtype = interp_ufuncs.find_binop_result_dtype(space, dtype,
+                                                      arr.find_dtype())
+        if len(arr.shape) <= axis:
+            raise OperationError(space.w_ValueError,
+                                 space.wrap("bad axis argument"))
+        for i, axis_size in enumerate(arr.shape):
+            if len(arr.shape) != len(shape) or (i != axis and axis_size != shape[i]):
+                raise OperationError(space.w_ValueError, space.wrap(
+                    "array dimensions must agree except for axis being concatenated"))
+            elif i == axis:
+                shape[i] += axis_size
+    size = 1
+    for elem in shape:
+        size *= elem
+    res = W_NDimArray(size, shape, dtype, 'C')
+    chunks = [Chunk(0, i, 1, i) for i in shape]
+    axis_start = 0
+    for arr in args_w:
+        chunks[axis] = Chunk(axis_start, axis_start + arr.shape[axis], 1,
+                             arr.shape[axis])
+        res.create_slice(chunks).setslice(space, arr)
+        axis_start += arr.shape[axis]
+    return res
+
 BaseArray.typedef = TypeDef(
     'ndarray',
     __module__ = "numpypy",
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1039,6 +1039,40 @@
         #assert (a.var(0) == [8, 8]).all()
         #assert (a.var(1) == [.25] * 5).all()
 
+    def test_concatenate(self):
+        from numpypy import array, concatenate, dtype
+        a1 = array([0,1,2])
+        a2 = array([3,4,5])
+        a = concatenate((a1, a2))
+        assert len(a) == 6
+        assert (a == [0,1,2,3,4,5]).all()
+        assert a.dtype is dtype(int)
+        b1 = array([[1, 2], [3, 4]])
+        b2 = array([[5, 6]])
+        b = concatenate((b1, b2), axis=0)
+        assert (b == [[1, 2],[3, 4],[5, 6]]).all()
+        c = concatenate((b1, b2.T), axis=1)
+        assert (c == [[1, 2, 5],[3, 4, 6]]).all()
+        d = concatenate(([0],[1]))
+        assert (d == [0,1]).all()
+        e1 = array([[0,1],[2,3]])
+        e = concatenate(e1)
+        assert (e == [0,1,2,3]).all()
+        f1 = array([0,1])
+        f = concatenate((f1, [2], f1, [7]))
+        assert (f == [0,1,2,0,1,7]).all()
+        
+        bad_axis = raises(ValueError, concatenate, (a1,a2), axis=1)
+        assert str(bad_axis.value) == "bad axis argument"
+        
+        concat_zero = raises(ValueError, concatenate, ())
+        assert str(concat_zero.value) == \
+            "concatenation of zero-length sequences is impossible"
+        
+        dims_disagree = raises(ValueError, concatenate, (a1, b1), axis=0)
+        assert str(dims_disagree.value) == \
+            "array dimensions must agree except for axis being concatenated"
+
     def test_std(self):
         from _numpypy import array
         a = array(range(10))
@@ -1062,7 +1096,6 @@
         a = array([[1, 2], [3, 4]])
         assert (a.T.flatten() == [1, 3, 2, 4]).all()
 
-
 class AppTestMultiDim(BaseNumpyAppTest):
     def test_init(self):
         import _numpypy


More information about the pypy-commit mailing list