[pypy-commit] pypy numpy-refactor: fixes for where

fijal noreply at buildbot.pypy.org
Tue Sep 11 14:30:12 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57266:1efcfcf768c7
Date: 2012-09-11 14:25 +0200
http://bitbucket.org/pypy/pypy/changeset/1efcfcf768c7/

Log:	fixes for where

diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -2,6 +2,7 @@
 from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
 from pypy.module.micronumpy import loop, interp_ufuncs
 from pypy.module.micronumpy.iter import Chunk, Chunks
+from pypy.module.micronumpy.strides import shape_agreement
 from pypy.interpreter.error import OperationError, operationerrfmt
 from pypy.interpreter.gateway import unwrap_spec
 
@@ -65,15 +66,21 @@
     
     NOTE: support for not passing x and y is unsupported
     """
-    if space.is_w(w_x, space.w_None) or space.is_w(w_y, space.w_None):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            "1-arg where unsupported right now"))
+    if space.is_w(w_y, space.w_None):
+        if space.is_w(w_x, space.w_None):
+            raise OperationError(space.w_NotImplementedError, space.wrap(
+                "1-arg where unsupported right now"))
+        raise OperationError(space.w_ValueError, space.wrap(
+            "Where should be called with either 1 or 3 arguments"))
     arr = convert_to_array(space, w_arr)
     x = convert_to_array(space, w_x)
     y = convert_to_array(space, w_y)
-    dtype = arr.get_dtype()
-    out = W_NDimArray.from_shape(arr.get_shape(), dtype)
-    return loop.where(out, arr, x, y, dtype)
+    dtype = interp_ufuncs.find_binop_result_dtype(space, x.get_dtype(),
+                                                  y.get_dtype())
+    shape = shape_agreement(space, arr.get_shape(), x)
+    shape = shape_agreement(space, shape, y)
+    out = W_NDimArray.from_shape(shape, dtype)
+    return loop.where(out, shape, arr, x, y, dtype)
 
 def dot(space, w_obj1, w_obj2):
     w_arr = convert_to_array(space, w_obj1)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -69,18 +69,20 @@
         arr_iter.setitem(box)
         arr_iter.next()
 
-def where(out, arr, x, y, dtype):
-    out_iter = out.create_iter()
-    arr_iter = arr.create_iter()
-    x_iter = x.create_iter()
-    y_iter = y.create_iter()
-    while not arr_iter.done():
+def where(out, shape, arr, x, y, dtype):
+    out_iter = out.create_iter(shape)
+    arr_iter = arr.create_iter(shape)
+    arr_dtype = arr.get_dtype()
+    x_iter = x.create_iter(shape)
+    y_iter = y.create_iter(shape)
+    while not x_iter.done():
         w_cond = arr_iter.getitem()
-        if dtype.itemtype.bool(w_cond):
+        if arr_dtype.itemtype.bool(w_cond):
             w_val = x_iter.getitem().convert_to(dtype)
         else:
             w_val = y_iter.getitem().convert_to(dtype)
         out_iter.setitem(w_val)
+        out_iter.next()
         arr_iter.next()
         x_iter.next()
         y_iter.next()
diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -6,17 +6,28 @@
         from _numpypy import where, ones, zeros, array
         a = [1, 2, 3, 0, -3]
         a = where(array(a) > 0, ones(5), zeros(5))
-        print a
         assert (a == [1, 1, 1, 0, 0]).all()
 
     def test_where_differing_dtypes(self):
-        xxx
+        from _numpypy import array, ones, zeros, where
+        a = [1, 2, 3, 0, -3]
+        a = where(array(a) > 0, ones(5, dtype=int), zeros(5, dtype=float))
+        assert (a == [1, 1, 1, 0, 0]).all()
+
+    def test_where_broadcast(self):
+        from _numpypy import array, where
+        a = where(array([[1, 2, 3], [4, 5, 6]]) > 3, [1, 1, 1], 2)
+        assert (a == [[2, 2, 2], [1, 1, 1]]).all()
+        a = where(True, [1, 1, 1], 2)
+        assert (a == [1, 1, 1]).all()
 
     def test_where_errors(self):
-        xxx
+        from _numpypy import where, array
+        raises(ValueError, "where([1, 2, 3], [3, 4, 5])")
+        raises(ValueError, "where([1, 2, 3], [3, 4, 5], [6, 7])")
 
-    def test_where_1_arg(self):
-        xxx
+    #def test_where_1_arg(self):
+    #    xxx
 
     def test_where_invalidates(self):
         from _numpypy import where, ones, zeros, array


More information about the pypy-commit mailing list