[pypy-commit] pypy numpy-refactor: getitem/setitem_filter

fijal noreply at buildbot.pypy.org
Thu Sep 6 16:59:20 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: numpy-refactor
Changeset: r57179:f8c83b52da47
Date: 2012-09-06 16:58 +0200
http://bitbucket.org/pypy/pypy/changeset/f8c83b52da47/

Log:	getitem/setitem_filter

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -23,6 +23,9 @@
     def getitem(self):
         return self.array.getitem(self.offset)
 
+    def getitem_bool(self):
+        return self.dtype.getitem_bool(self.array, self.offset)
+
     def next(self):
         self.offset += self.skip
 
@@ -120,8 +123,6 @@
         return self._done
 
 def int_w(space, w_obj):
-    # a special version that respects both __index__ and __int__
-    # XXX add __index__ support
     try:
         return space.int_w(space.index(w_obj))
     except OperationError:
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -51,6 +51,17 @@
     def descr_get_nbytes(self, space):
         return space.wrap(self.get_size() * self.get_dtype().itemtype.get_element_size())
 
+    def getitem_filter(self, space, arr):
+        if arr.get_size() > self.get_size():
+            raise OperationError(space.w_IndexError,
+                                 space.wrap("index out of range for array"))
+        size = loop.count_all_true(arr)
+        res = W_NDimArray.from_shape([size], self.get_dtype())
+        return loop.getitem_filter(res, self, arr)
+
+    def setitem_filter(self, space, idx, val):
+        loop.setitem_filter(self, idx, val)
+
     def descr_getitem(self, space, w_idx):
         if (isinstance(w_idx, W_NDimArray) and w_idx.get_shape() == self.get_shape() and
             w_idx.get_dtype().is_bool_type()):
@@ -61,10 +72,10 @@
             raise OperationError(space.w_IndexError, space.wrap("wrong index"))
 
     def descr_setitem(self, space, w_idx, w_value):
-        if (isinstance(w_idx, W_NDimArray) and w_idx.shape == self.shape and
+        if (isinstance(w_idx, W_NDimArray) and w_idx.get_shape() == self.get_shape() and
             w_idx.get_dtype().is_bool_type()):
             return self.setitem_filter(space, w_idx,
-                                       support.convert_to_array(space, w_value))
+                                       convert_to_array(space, w_value))
         self.implementation.descr_setitem(space, w_idx, w_value)
 
     def descr_len(self, space):
@@ -96,7 +107,9 @@
         s.append('])')
         return s.build()
 
-    def create_iter(self, shape):
+    def create_iter(self, shape=None):
+        if shape is None:
+            shape = self.get_shape()
         return self.implementation.create_iter(shape)
 
     def create_axis_iter(self, shape, dim):
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -155,3 +155,34 @@
         righti.next()
         lefti.next()
     return result
+
+def count_all_true(arr):
+    s = 0
+    iter = arr.create_iter()
+    while not iter.done():
+        s += iter.getitem_bool()
+        iter.next()
+    return s
+
+def getitem_filter(res, arr, index):
+    res_iter = res.create_iter()
+    index_iter = index.create_iter()
+    arr_iter = arr.create_iter()
+    while not index_iter.done():
+        if index_iter.getitem_bool():
+            res_iter.setitem(arr_iter.getitem())
+            res_iter.next()
+        index_iter.next()
+        arr_iter.next()
+    return res
+
+def setitem_filter(arr, index, value):
+    arr_iter = arr.create_iter()
+    index_iter = index.create_iter()
+    value_iter = value.create_iter(arr.get_shape())
+    while not arr_iter.done():
+        if index_iter.getitem_bool():
+            arr_iter.setitem(value_iter.getitem())
+        arr_iter.next()
+        index_iter.next()
+        value_iter.next()


More information about the pypy-commit mailing list