[pypy-commit] pypy numpy-refactor: some revamp of where, not passing tests yet
fijal
noreply at buildbot.pypy.org
Sat Sep 1 18:25:40 CEST 2012
Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57069:cf64611cbb94
Date: 2012-09-01 18:25 +0200
http://bitbucket.org/pypy/pypy/changeset/cf64611cbb94/
Log: some revamp of where, not passing tests yet
diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -1,8 +1,9 @@
-from pypy.module.micronumpy.support import convert_to_array
+from pypy.module.micronumpy.support import convert_to_array, create_array
from pypy.module.micronumpy import loop
+from pypy.interpreter.error import OperationError
-def where(space, w_arr, w_x, w_y):
+def where(space, w_arr, w_x=None, w_y=None):
"""where(condition, [x, y])
Return elements, either from `x` or `y`, depending on `condition`.
@@ -62,7 +63,12 @@
NOTE: support for not passing x and y is unsupported
"""
+ if space.is_w(w_x, space.w_None) or space.is_w(w_y, space.w_None):
+ raise OperationError(space.w_NotImplementedError, space.wrap(
+ "1-arg where unsupported right now"))
arr = convert_to_array(space, w_arr)
x = convert_to_array(space, w_x)
y = convert_to_array(space, w_y)
- return loop.where(space, arr, x, y)
+ dtype = arr.get_dtype()
+ out = create_array(arr.get_shape(), dtype)
+ return loop.where(out, arr, x, y, dtype)
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -57,3 +57,20 @@
while not arr_iter.done():
arr_iter.setitem(box)
arr_iter.next()
+
+def where(out, arr, x, y, dtype):
+ out_iter = out.create_iter()
+ arr_iter = arr.create_iter()
+ x_iter = x.create_iter()
+ y_iter = y.create_iter()
+ while not arr_iter.done():
+ w_cond = arr_iter.getitem()
+ if dtype.itemtype.bool(w_cond):
+ w_val = x_iter.getitem().convert_to(dtype)
+ else:
+ w_val = y_iter.getitem().convert_to(dtype)
+ out_iter.setitem(w_val)
+ arr_iter.next()
+ x_iter.next()
+ y_iter.next()
+ return out
diff --git a/pypy/module/micronumpy/support.py b/pypy/module/micronumpy/support.py
--- a/pypy/module/micronumpy/support.py
+++ b/pypy/module/micronumpy/support.py
@@ -22,3 +22,10 @@
# If it's a scalar
dtype = interp_ufuncs.find_dtype_for_scalar(space, w_obj)
return scalar_w(space, dtype, w_obj)
+
+def create_array(shape, dtype):
+ """ Convinient shortcut to avoid circular imports
+ """
+ from pypy.module.micronumpy.interp_numarray import W_NDimArray
+
+ return W_NDimArray(shape, dtype)
diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -6,8 +6,18 @@
from _numpypy import where, ones, zeros, array
a = [1, 2, 3, 0, -3]
a = where(array(a) > 0, ones(5), zeros(5))
+ print a
assert (a == [1, 1, 1, 0, 0]).all()
+ def test_where_differing_dtypes(self):
+ xxx
+
+ def test_where_errors(self):
+ xxx
+
+ def test_where_1_arg(self):
+ xxx
+
def test_where_invalidates(self):
from _numpypy import where, ones, zeros, array
a = array([1, 2, 3, 0, -3])
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -35,17 +35,17 @@
return self.space.newtuple(args_w)
def test_strides_f(self):
- a = W_NDimArray([10, 5, 3], MockDtype(), order='F')
+ a = W_NDimArray([10, 5, 3], MockDtype(), order='F').implementation
assert a.strides == [1, 10, 50]
assert a.backstrides == [9, 40, 100]
def test_strides_c(self):
- a = W_NDimArray([10, 5, 3], MockDtype(), order='C')
+ a = W_NDimArray([10, 5, 3], MockDtype(), order='C').implementation
assert a.strides == [15, 3, 1]
assert a.backstrides == [135, 12, 2]
def test_create_slice_f(self):
- a = W_NDimArray([10, 5, 3], MockDtype(), order='F')
+ a = W_NDimArray([10, 5, 3], MockDtype(), order='F').implementation
s = create_slice(a, [Chunk(3, 0, 0, 1)])
assert s.start == 3
assert s.strides == [10, 50]
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -4,6 +4,7 @@
from pypy.interpreter.error import OperationError
from pypy.module.micronumpy import interp_boxes
+from pypy.module.micronumpy.support import create_array
from pypy.objspace.std.floatobject import float2string
from pypy.rlib import rfloat, clibffi
from pypy.rlib.rawstorage import (alloc_raw_storage, raw_storage_setitem,
@@ -931,8 +932,6 @@
@jit.unroll_safe
def coerce(self, space, dtype, w_item):
- from pypy.module.micronumpy.interp_numarray import W_NDimArray
-
if isinstance(w_item, interp_boxes.W_VoidBox):
return w_item
# we treat every sequence as sequence, no special support
@@ -946,7 +945,7 @@
items_w = space.fixedview(w_item)
# XXX optimize it out one day, but for now we just allocate an
# array
- arr = W_NDimArray([1], dtype)
+ arr = create_array([1], dtype)
for i in range(len(items_w)):
subdtype = dtype.fields[dtype.fieldnames[i]][1]
ofs, itemtype = self.offsets_and_fields[i]
More information about the pypy-commit
mailing list