[pypy-commit] pypy numpy-refactor: fixes and tests

fijal noreply at buildbot.pypy.org
Tue Sep 11 14:30:13 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57267:2908d6fc2373
Date: 2012-09-11 14:29 +0200
http://bitbucket.org/pypy/pypy/changeset/2908d6fc2373/

Log:	fixes and tests

diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -75,6 +75,10 @@
     arr = convert_to_array(space, w_arr)
     x = convert_to_array(space, w_x)
     y = convert_to_array(space, w_y)
+    if x.is_scalar() and y.is_scalar() and arr.is_scalar():
+        if arr.get_dtype().itemtype.bool(arr.get_scalar_value()):
+            return x
+        return y
     dtype = interp_ufuncs.find_binop_result_dtype(space, x.get_dtype(),
                                                   y.get_dtype())
     shape = shape_agreement(space, arr.get_shape(), x)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -75,7 +75,14 @@
     arr_dtype = arr.get_dtype()
     x_iter = x.create_iter(shape)
     y_iter = y.create_iter(shape)
-    while not x_iter.done():
+    if x.is_scalar():
+        if y.is_scalar():
+            iter = arr_iter
+        else:
+            iter = y_iter
+    else:
+        iter = x_iter
+    while not iter.done():
         w_cond = arr_iter.getitem()
         if arr_dtype.itemtype.bool(w_cond):
             w_val = x_iter.getitem().convert_to(dtype)
diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -25,6 +25,11 @@
         from _numpypy import where, array
         raises(ValueError, "where([1, 2, 3], [3, 4, 5])")
         raises(ValueError, "where([1, 2, 3], [3, 4, 5], [6, 7])")
+        assert where(True, 1, 2) == array(1)
+        assert where(False, 1, 2) == array(2)
+        assert (where(True, [1, 2, 3], 2) == [1, 2, 3]).all()
+        assert (where(False, 1, [1, 2, 3]) == [1, 2, 3]).all()
+        assert (where([1, 2, 3], True, False) == [True, True, True]).all()
 
     #def test_where_1_arg(self):
     #    xxx


More information about the pypy-commit mailing list