[pypy-commit] pypy numpy-refactor: some revamp of where, not passing tests yet

fijal noreply at buildbot.pypy.org
Sat Sep 1 18:25:40 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57069:cf64611cbb94
Date: 2012-09-01 18:25 +0200
http://bitbucket.org/pypy/pypy/changeset/cf64611cbb94/

Log:	some revamp of where, not passing tests yet

diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -1,8 +1,9 @@
 
-from pypy.module.micronumpy.support import convert_to_array
+from pypy.module.micronumpy.support import convert_to_array, create_array
 from pypy.module.micronumpy import loop
+from pypy.interpreter.error import OperationError
 
-def where(space, w_arr, w_x, w_y):
+def where(space, w_arr, w_x=None, w_y=None):
     """where(condition, [x, y])
 
     Return elements, either from `x` or `y`, depending on `condition`.
@@ -62,7 +63,12 @@
     
     NOTE: support for not passing x and y is unsupported
     """
+    if space.is_w(w_x, space.w_None) or space.is_w(w_y, space.w_None):
+        raise OperationError(space.w_NotImplementedError, space.wrap(
+            "1-arg where unsupported right now"))
     arr = convert_to_array(space, w_arr)
     x = convert_to_array(space, w_x)
     y = convert_to_array(space, w_y)
-    return loop.where(space, arr, x, y)
+    dtype = arr.get_dtype()
+    out = create_array(arr.get_shape(), dtype)
+    return loop.where(out, arr, x, y, dtype)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -57,3 +57,20 @@
     while not arr_iter.done():
         arr_iter.setitem(box)
         arr_iter.next()
+
+def where(out, arr, x, y, dtype):
+    out_iter = out.create_iter()
+    arr_iter = arr.create_iter()
+    x_iter = x.create_iter()
+    y_iter = y.create_iter()
+    while not arr_iter.done():
+        w_cond = arr_iter.getitem()
+        if dtype.itemtype.bool(w_cond):
+            w_val = x_iter.getitem().convert_to(dtype)
+        else:
+            w_val = y_iter.getitem().convert_to(dtype)
+        out_iter.setitem(w_val)
+        arr_iter.next()
+        x_iter.next()
+        y_iter.next()
+    return out
diff --git a/pypy/module/micronumpy/support.py b/pypy/module/micronumpy/support.py
--- a/pypy/module/micronumpy/support.py
+++ b/pypy/module/micronumpy/support.py
@@ -22,3 +22,10 @@
         # If it's a scalar
         dtype = interp_ufuncs.find_dtype_for_scalar(space, w_obj)
         return scalar_w(space, dtype, w_obj)
+
+def create_array(shape, dtype):
+    """ Convinient shortcut to avoid circular imports
+    """
+    from pypy.module.micronumpy.interp_numarray import W_NDimArray
+    
+    return W_NDimArray(shape, dtype)
diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -6,8 +6,18 @@
         from _numpypy import where, ones, zeros, array
         a = [1, 2, 3, 0, -3]
         a = where(array(a) > 0, ones(5), zeros(5))
+        print a
         assert (a == [1, 1, 1, 0, 0]).all()
 
+    def test_where_differing_dtypes(self):
+        xxx
+
+    def test_where_errors(self):
+        xxx
+
+    def test_where_1_arg(self):
+        xxx
+
     def test_where_invalidates(self):
         from _numpypy import where, ones, zeros, array
         a = array([1, 2, 3, 0, -3])
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -35,17 +35,17 @@
         return self.space.newtuple(args_w)
 
     def test_strides_f(self):
-        a = W_NDimArray([10, 5, 3], MockDtype(), order='F')
+        a = W_NDimArray([10, 5, 3], MockDtype(), order='F').implementation
         assert a.strides == [1, 10, 50]
         assert a.backstrides == [9, 40, 100]
 
     def test_strides_c(self):
-        a = W_NDimArray([10, 5, 3], MockDtype(), order='C')
+        a = W_NDimArray([10, 5, 3], MockDtype(), order='C').implementation
         assert a.strides == [15, 3, 1]
         assert a.backstrides == [135, 12, 2]
 
     def test_create_slice_f(self):
-        a = W_NDimArray([10, 5, 3], MockDtype(), order='F')
+        a = W_NDimArray([10, 5, 3], MockDtype(), order='F').implementation
         s = create_slice(a, [Chunk(3, 0, 0, 1)])
         assert s.start == 3
         assert s.strides == [10, 50]
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -4,6 +4,7 @@
 
 from pypy.interpreter.error import OperationError
 from pypy.module.micronumpy import interp_boxes
+from pypy.module.micronumpy.support import create_array
 from pypy.objspace.std.floatobject import float2string
 from pypy.rlib import rfloat, clibffi
 from pypy.rlib.rawstorage import (alloc_raw_storage, raw_storage_setitem,
@@ -931,8 +932,6 @@
 
     @jit.unroll_safe
     def coerce(self, space, dtype, w_item):
-        from pypy.module.micronumpy.interp_numarray import W_NDimArray
-
         if isinstance(w_item, interp_boxes.W_VoidBox):
             return w_item
         # we treat every sequence as sequence, no special support
@@ -946,7 +945,7 @@
         items_w = space.fixedview(w_item)
         # XXX optimize it out one day, but for now we just allocate an
         #     array
-        arr = W_NDimArray([1], dtype)
+        arr = create_array([1], dtype)
         for i in range(len(items_w)):
             subdtype = dtype.fields[dtype.fieldnames[i]][1]
             ofs, itemtype = self.offsets_and_fields[i]


More information about the pypy-commit mailing list