[pypy-commit] pypy missing-ndarray-attributes: Basic sorting - not RPython yet. argsort for now

fijal noreply at buildbot.pypy.org
Sun Sep 30 00:15:00 CEST 2012


Author: Maciej Fijalkowski <fijall at gmail.com>
Branch: missing-ndarray-attributes
Changeset: r57679:0017e12c21c0
Date: 2012-09-30 00:09 +0200
http://bitbucket.org/pypy/pypy/changeset/0017e12c21c0/

Log:	Basic sorting - not RPython yet. argsort for now

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -10,6 +10,7 @@
 from pypy.rpython.lltypesystem import rffi, lltype
 from pypy.rlib import jit
 from pypy.rlib.rawstorage import free_raw_storage
+from pypy.module.micronumpy.arrayimpl.sort import sort_array
 
 class ConcreteArrayIterator(base.BaseArrayIterator):
     def __init__(self, array):
@@ -372,6 +373,9 @@
     def get_storage_as_int(self, space):
         return rffi.cast(lltype.Signed, self.storage)
 
+    def get_storage(self):
+        return self.storage
+
 class ConcreteArray(BaseConcreteArray):
     def __init__(self, shape, dtype, order, strides, backstrides):
         self.shape = shape
@@ -400,6 +404,9 @@
                                                     self.order)
         return SliceArray(0, strides, backstrides, new_shape, self)
 
+    def argsort(self, space):
+        return sort_array(self, space)
+
 class SliceArray(BaseConcreteArray):
     def __init__(self, start, strides, backstrides, shape, parent, dtype=None):
         self.strides = strides
diff --git a/pypy/module/micronumpy/arrayimpl/scalar.py b/pypy/module/micronumpy/arrayimpl/scalar.py
--- a/pypy/module/micronumpy/arrayimpl/scalar.py
+++ b/pypy/module/micronumpy/arrayimpl/scalar.py
@@ -95,3 +95,6 @@
     def get_storage_as_int(self, space):
         raise OperationError(space.w_ValueError,
                              space.wrap("scalars have no address"))
+
+    def argsort(self, space):
+        return space.wrap(0)
diff --git a/pypy/module/micronumpy/arrayimpl/sort.py b/pypy/module/micronumpy/arrayimpl/sort.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/arrayimpl/sort.py
@@ -0,0 +1,60 @@
+
+from pypy.rpython.lltypesystem import rffi
+from pypy.rlib.listsort import make_timsort_class
+from pypy.rlib.objectmodel import specialize
+from pypy.rlib.rawstorage import raw_storage_getitem, raw_storage_setitem
+from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy import interp_dtype
+
+ at specialize.memo()
+def make_sort_classes(space, TP):
+    class ArgArrayRepresentation(object):
+        def __init__(self, itemsize, size, values, indexes):
+            self.itemsize = itemsize
+            self.size = size
+            self.values = values
+            self.indexes = indexes
+
+        def getitem(self, item):
+            idx = item * self.itemsize
+            return (raw_storage_getitem(TP, self.values, idx),
+                    raw_storage_getitem(TP, self.indexes, idx))
+
+        def setitem(self, idx, item):
+            idx *= self.itemsize
+            raw_storage_setitem(self.values, idx, rffi.cast(TP, item[0]))
+            raw_storage_setitem(self.indexes, idx, item[1])
+
+    def arg_getitem(lst, item):
+        return lst.getitem(item)
+
+    def arg_setitem(lst, item, value):
+        lst.setitem(item, value)
+
+    def arg_length(lst):
+        return lst.size
+
+    def arg_getitem_slice(lst, start, stop):
+        xxx
+
+    def arg_lt(a, b):
+        return a[0] < b[0]
+
+    ArgSort = make_timsort_class(arg_getitem, arg_setitem, arg_length,
+                                 arg_getitem_slice, arg_lt)
+
+    return ArgArrayRepresentation, ArgSort
+
+def sort_array(arr, space):
+    itemsize = arr.dtype.itemtype.get_element_size()
+    # create array of indexes
+    dtype = interp_dtype.get_dtype_cache(space).w_longdtype
+    indexes = W_NDimArray.from_shape([arr.get_size()], dtype)
+    storage = indexes.implementation.get_storage()
+    for i in range(arr.get_size()):
+        raw_storage_setitem(storage, i * itemsize, i)
+    Repr, Sort = make_sort_classes(space, arr.dtype.itemtype.T)
+    r = Repr(itemsize, arr.get_size(), arr.get_storage(),
+             indexes.implementation.get_storage())
+    Sort(r).sort()
+    return indexes
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -362,9 +362,11 @@
                                                        space.w_False]))
         return w_d
 
-    def descr_argsort(self, space, w_axis=-1, w_kind='quicksort', w_order=None):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            "argsort not implemented yet"))
+    def descr_argsort(self, space, w_axis=None, w_kind=None, w_order=None):
+        # happily ignore the kind
+        # create a contiguous copy of the array
+        contig = self.descr_copy(space)
+        return contig.implementation.argsort(space)
 
     def descr_astype(self, space, w_type):
         raise OperationError(space.w_NotImplementedError, space.wrap(
@@ -765,6 +767,8 @@
     flat = GetSetProperty(W_NDimArray.descr_get_flatiter),
     item = interp2app(W_NDimArray.descr_item),
 
+    argsort = interp2app(W_NDimArray.descr_argsort),
+
     __array_interface__ = GetSetProperty(W_NDimArray.descr_array_iface),
 )
 
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -1549,6 +1549,16 @@
         assert (b == [20, 1, 21, 3, 4]).all() 
         raises(ValueError, "array([1, 2])[array([True, False, True])] = [1, 2, 3]")
 
+    def test_argsort(self):
+        from _numpypy import array, arange
+        a = array([6, 4, 1, 3, 8, 3])
+        assert array(2.0).argsort() == 0
+        res = a.argsort()
+        assert (res == [2, 3, 5, 1, 0, 4]).all()
+        assert (a == [6, 4, 1, 3, 8, 3]).all() # not modified
+        a = arange(100)
+        assert (a.argsort() == a).all()
+
 class AppTestMultiDim(BaseNumpyAppTest):
     def test_init(self):
         import _numpypy
diff --git a/pypy/rlib/listsort.py b/pypy/rlib/listsort.py
--- a/pypy/rlib/listsort.py
+++ b/pypy/rlib/listsort.py
@@ -7,20 +7,28 @@
 ## ------------------------------------------------------------------------
 ##         Adapted from CPython, original code and algorithms by Tim Peters
 
-def list_getitem(list, item):
-    return list[item]
+def make_timsort_class(getitem=None, setitem=None, length=None,
+                       getitem_slice=None, lt=None):
 
-def list_setitem(list, item, value):
-    list[item] = value
+    if getitem is None:
+        def getitem(list, item):
+            return list[item]
 
-def list_length(list):
-    return len(list)
+    if setitem is None:
+        def setitem(list, item, value):
+            list[item] = value
 
-def list_getitem_slice(list, start, stop):
-    return list[start:stop]
+    if length is None:
+        def length(list):
+            return len(list)
 
-def make_timsort_class(getitem=list_getitem, setitem=list_setitem,
-                       length=list_length, getitem_slice=list_getitem_slice):
+    if getitem_slice is None:
+        def getitem_slice(list, start, stop):
+            return list[start:stop]
+
+    if lt is None:
+        def lt(a, b):
+            return a < b
 
     class TimSort(object):
         """TimSort(list).sort()
@@ -38,7 +46,7 @@
             setitem(self.list, item, val)
 
         def lt(self, a, b):
-            return a < b
+            return lt(a, b)
 
         def le(self, a, b):
             return not self.lt(b, a)   # always use self.lt() as the primitive
@@ -57,14 +65,14 @@
                 # set l to where list[start] belongs
                 l = a.base
                 r = start
-                pivot = a.list[r]
+                pivot = a.getitem(r)
                 # Invariants:
                 # pivot >= all in [base, l).
                 # pivot  < all in [r, start).
                 # The second is vacuously true at the start.
                 while l < r:
                     p = l + ((r - l) >> 1)
-                    if self.lt(pivot, a.list[p]):
+                    if self.lt(pivot, a.getitem(p)):
                         r = p
                     else:
                         l = p+1
@@ -75,8 +83,8 @@
                 # first slot after them -- that's why this sort is stable.
                 # Slide over to make room.
                 for p in xrange(start, l, -1):
-                    a.list[p] = a.list[p-1]
-                a.list[l] = pivot
+                    a.setitem(p, a.getitem(p-1))
+                a.setitem(l, pivot)
 
         # Compute the length of the run in the slice "a".
         # "A run" is the longest ascending sequence, with
@@ -100,17 +108,17 @@
                 descending = False
             else:
                 n = 2
-                if self.lt(a.list[a.base + 1], a.list[a.base]):
+                if self.lt(a.getitem(a.base + 1), a.getitem(a.base)):
                     descending = True
                     for p in xrange(a.base + 2, a.base + a.len):
-                        if self.lt(a.list[p], a.list[p-1]):
+                        if self.lt(a.getitem(p), a.getitem(p-1)):
                             n += 1
                         else:
                             break
                 else:
                     descending = False
                     for p in xrange(a.base + 2, a.base + a.len):
-                        if self.lt(a.list[p], a.list[p-1]):
+                        if self.lt(a.getitem(p), a.getitem(p-1)):
                             break
                         else:
                             n += 1
@@ -143,13 +151,13 @@
             p = a.base + hint
             lastofs = 0
             ofs = 1
-            if lower(a.list[p], key):
+            if lower(a.getitem(p), key):
                 # a[hint] < key -- gallop right, until
                 #     a[hint + lastofs] < key <= a[hint + ofs]
 
                 maxofs = a.len - hint     # a[a.len-1] is highest
                 while ofs < maxofs:
-                    if lower(a.list[p + ofs], key):
+                    if lower(a.getitem(p + ofs), key):
                         lastofs = ofs
                         try:
                             ofs = ovfcheck(ofs << 1)
@@ -171,7 +179,7 @@
                 #     a[hint - ofs] < key <= a[hint - lastofs]
                 maxofs = hint + 1   # a[0] is lowest
                 while ofs < maxofs:
-                    if lower(a.list[p - ofs], key):
+                    if lower(a.getitem(p - ofs), key):
                         break
                     else:
                         # key <= a[hint - ofs]
@@ -196,7 +204,7 @@
             lastofs += 1
             while lastofs < ofs:
                 m = lastofs + ((ofs - lastofs) >> 1)
-                if lower(a.list[a.base + m], key):
+                if lower(a.getitem(a.base + m), key):
                     lastofs = m+1   # a[m] < key
                 else:
                     ofs = m         # key <= a[m]
@@ -263,7 +271,7 @@
                     # Do the straightforward thing until (if ever) one run
                     # appears to win consistently.
                     while True:
-                        if self.lt(b.list[b.base], a.list[a.base]):
+                        if self.lt(b.getitem(b.base), a.getitem(a.base)):
                             self.setitem(dest, b.popleft())
                             dest += 1
                             if b.len == 0:
@@ -292,7 +300,7 @@
                         min_gallop -= min_gallop > 1
                         self.min_gallop = min_gallop
 
-                        acount = self.gallop(b.list[b.base], a, hint=0,
+                        acount = self.gallop(b.getitem(b.base), a, hint=0,
                                              rightmost=True)
                         for p in xrange(a.base, a.base + acount):
                             self.setitem(dest, a.getitem(p))
@@ -309,7 +317,7 @@
                         if b.len == 0:
                             return
 
-                        bcount = self.gallop(a.list[a.base], b, hint=0,
+                        bcount = self.gallop(a.getitem(a.base), b, hint=0,
                                              rightmost=False)
                         for p in xrange(b.base, b.base + bcount):
                             self.setitem(dest, b.getitem(p))
@@ -366,8 +374,8 @@
                     # Do the straightforward thing until (if ever) one run
                     # appears to win consistently.
                     while True:
-                        nexta = a.list[a.base + a.len - 1]
-                        nextb = b.list[b.base + b.len - 1]
+                        nexta = a.getitem(a.base + a.len - 1)
+                        nextb = b.getitem(b.base + b.len - 1)
                         if self.lt(nextb, nexta):
                             dest -= 1
                             self.setitem(dest, nexta)
@@ -399,7 +407,7 @@
                         min_gallop -= min_gallop > 1
                         self.min_gallop = min_gallop
 
-                        nextb = b.list[b.base + b.len - 1]
+                        nextb = b.getitem(b.base + b.len - 1)
                         k = self.gallop(nextb, a, hint=a.len-1, rightmost=True)
                         acount = a.len - k
                         for p in xrange(a.base + a.len - 1, a.base + k - 1, -1):
@@ -414,7 +422,7 @@
                         if b.len == 1:
                             return
 
-                        nexta = a.list[a.base + a.len - 1]
+                        nexta = a.getitem(a.base + a.len - 1)
                         k = self.gallop(nexta, b, hint=b.len-1, rightmost=False)
                         bcount = b.len - k
                         for p in xrange(b.base + b.len - 1, b.base + k - 1, -1):
@@ -463,14 +471,14 @@
 
             # Where does b start in a?  Elements in a before that can be
             # ignored (already in place).
-            k = self.gallop(b.list[b.base], a, hint=0, rightmost=True)
+            k = self.gallop(b.getitem(b.base), a, hint=0, rightmost=True)
             a.advance(k)
             if a.len == 0:
                 return
 
             # Where does a end in b?  Elements in b after that can be
             # ignored (already in place).
-            b.len = self.gallop(a.list[a.base+a.len-1], b, hint=b.len-1,
+            b.len = self.gallop(a.getitem(a.base+a.len-1), b, hint=b.len-1,
                                 rightmost=False)
             if b.len == 0:
                 return
@@ -589,6 +597,9 @@
         def getitem(self, item):
             return getitem(self.list, item)
 
+        def setitem(self, item, value):
+            setitem(self.list, item, value)
+
         def popleft(self):
             result = getitem(self.list, self.base)
             self.base += 1


More information about the pypy-commit mailing list