[pypy-commit] pypy numpy-refactor: fixes and tests
fijal
noreply at buildbot.pypy.org
Tue Sep 11 14:30:13 CEST 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57267:2908d6fc2373
Date: 2012-09-11 14:29 +0200
http://bitbucket.org/pypy/pypy/changeset/2908d6fc2373/
Log: fixes and tests
diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -75,6 +75,10 @@
arr = convert_to_array(space, w_arr)
x = convert_to_array(space, w_x)
y = convert_to_array(space, w_y)
+ if x.is_scalar() and y.is_scalar() and arr.is_scalar():
+ if arr.get_dtype().itemtype.bool(arr.get_scalar_value()):
+ return x
+ return y
dtype = interp_ufuncs.find_binop_result_dtype(space, x.get_dtype(),
y.get_dtype())
shape = shape_agreement(space, arr.get_shape(), x)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -75,7 +75,14 @@
arr_dtype = arr.get_dtype()
x_iter = x.create_iter(shape)
y_iter = y.create_iter(shape)
- while not x_iter.done():
+ if x.is_scalar():
+ if y.is_scalar():
+ iter = arr_iter
+ else:
+ iter = y_iter
+ else:
+ iter = x_iter
+ while not iter.done():
w_cond = arr_iter.getitem()
if arr_dtype.itemtype.bool(w_cond):
w_val = x_iter.getitem().convert_to(dtype)
diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -25,6 +25,11 @@
from _numpypy import where, array
raises(ValueError, "where([1, 2, 3], [3, 4, 5])")
raises(ValueError, "where([1, 2, 3], [3, 4, 5], [6, 7])")
+ assert where(True, 1, 2) == array(1)
+ assert where(False, 1, 2) == array(2)
+ assert (where(True, [1, 2, 3], 2) == [1, 2, 3]).all()
+ assert (where(False, 1, [1, 2, 3]) == [1, 2, 3]).all()
+ assert (where([1, 2, 3], True, False) == [True, True, True]).all()
#def test_where_1_arg(self):
# xxx
More information about the pypy-commit
mailing list