[Numpy-svn] r5418 - in branches/1.1.x/numpy/core: . tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Tue Jul 15 04:11:45 EDT 2008
Author: rkern
Date: 2008-07-15 03:11:43 -0500 (Tue, 15 Jul 2008)
New Revision: 5418
Modified:
branches/1.1.x/numpy/core/numeric.py
branches/1.1.x/numpy/core/tests/test_numeric.py
Log:
Backport r5357.
Modified: branches/1.1.x/numpy/core/numeric.py
===================================================================
--- branches/1.1.x/numpy/core/numeric.py 2008-07-15 08:07:53 UTC (rev 5417)
+++ branches/1.1.x/numpy/core/numeric.py 2008-07-15 08:11:43 UTC (rev 5418)
@@ -137,7 +137,7 @@
def asanyarray(a, dtype=None, order=None):
"""Returns a as an array, but will pass subclasses through.
"""
- return array(a, dtype, copy=False, order=order, subok=1)
+ return array(a, dtype, copy=False, order=order, subok=True)
def ascontiguousarray(a, dtype=None):
"""Return 'a' as an array contiguous in memory (C order).
@@ -182,9 +182,9 @@
return asanyarray(a, dtype=dtype)
if 'ENSUREARRAY' in requirements or 'E' in requirements:
- subok = 0
+ subok = False
else:
- subok = 1
+ subok = True
arr = array(a, dtype=dtype, copy=False, subok=subok)
@@ -344,12 +344,12 @@
nda = len(a.shape)
bs = b.shape
ndb = len(b.shape)
- equal = 1
- if (na != nb): equal = 0
+ equal = True
+ if (na != nb): equal = False
else:
for k in xrange(na):
if as_[axes_a[k]] != bs[axes_b[k]]:
- equal = 0
+ equal = False
break
if axes_a[k] < 0:
axes_a[k] += nda
@@ -394,10 +394,10 @@
a = asanyarray(a)
if axis is None:
n = a.size
- reshape=1
+ reshape = True
else:
n = a.shape[axis]
- reshape=0
+ reshape = False
shift %= n
indexes = concatenate((arange(n-shift,n),arange(n-shift)))
res = a.take(indexes, axis)
@@ -732,10 +732,10 @@
try:
a1, a2 = asarray(a1), asarray(a2)
except:
- return 0
+ return False
if a1.shape != a2.shape:
- return 0
- return logical_and.reduce(equal(a1,a2).ravel())
+ return False
+ return bool(logical_and.reduce(equal(a1,a2).ravel()))
def array_equiv(a1, a2):
"""Returns True if a1 and a2 are shape consistent
@@ -745,11 +745,11 @@
try:
a1, a2 = asarray(a1), asarray(a2)
except:
- return 0
+ return False
try:
- return logical_and.reduce(equal(a1,a2).ravel())
+ return bool(logical_and.reduce(equal(a1,a2).ravel()))
except ValueError:
- return 0
+ return False
_errdict = {"ignore":ERR_IGNORE,
Modified: branches/1.1.x/numpy/core/tests/test_numeric.py
===================================================================
--- branches/1.1.x/numpy/core/tests/test_numeric.py 2008-07-15 08:07:53 UTC (rev 5417)
+++ branches/1.1.x/numpy/core/tests/test_numeric.py 2008-07-15 08:11:43 UTC (rev 5418)
@@ -252,6 +252,52 @@
assert_equal(binary_repr(-1), '-1')
assert_equal(binary_repr(-1, width=8), '11111111')
+class TestArrayComparisons(NumpyTestCase):
+ def test_array_equal(self):
+ res = array_equal(array([1,2]), array([1,2]))
+ assert res
+ assert type(res) is bool
+ res = array_equal(array([1,2]), array([1,2,3]))
+ assert not res
+ assert type(res) is bool
+ res = array_equal(array([1,2]), array([3,4]))
+ assert not res
+ assert type(res) is bool
+ res = array_equal(array([1,2]), array([1,3]))
+ assert not res
+ assert type(res) is bool
+
+ def test_array_equiv(self):
+ res = array_equiv(array([1,2]), array([1,2]))
+ assert res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([1,2,3]))
+ assert not res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([3,4]))
+ assert not res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([1,3]))
+ assert not res
+ assert type(res) is bool
+
+ res = array_equiv(array([1,1]), array([1]))
+ assert res
+ assert type(res) is bool
+ res = array_equiv(array([1,1]), array([[1],[1]]))
+ assert res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([2]))
+ assert not res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([[1],[2]]))
+ assert not res
+ assert type(res) is bool
+ res = array_equiv(array([1,2]), array([[1,2,3],[4,5,6],[7,8,9]]))
+ assert not res
+ assert type(res) is bool
+
+
def assert_array_strict_equal(x, y):
assert_array_equal(x, y)
# Check flags
More information about the Numpy-svn
mailing list