[Scipy-svn] r7116 - in trunk/scipy/linalg: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Jan 31 16:52:58 EST 2011


Author: ptvirtan
Date: 2011-01-31 15:52:58 -0600 (Mon, 31 Jan 2011)
New Revision: 7116

Added:
   trunk/scipy/linalg/_testutils.py
Modified:
   trunk/scipy/linalg/basic.py
   trunk/scipy/linalg/decomp.py
   trunk/scipy/linalg/decomp_cholesky.py
   trunk/scipy/linalg/decomp_lu.py
   trunk/scipy/linalg/decomp_qr.py
   trunk/scipy/linalg/decomp_schur.py
   trunk/scipy/linalg/decomp_svd.py
   trunk/scipy/linalg/misc.py
   trunk/scipy/linalg/tests/test_basic.py
   trunk/scipy/linalg/tests/test_decomp.py
   trunk/scipy/linalg/tests/test_decomp_cholesky.py
Log:
BUG: linalg: more robust data ovewrite behavior

Some routines in linalg used an invalid way to determine if data can be
overwritten, which fails for non-ndarray objects providing an array
interface but no __array__ method.

Also add new tests checking the data overwrite behavior.

Added: trunk/scipy/linalg/_testutils.py
===================================================================
--- trunk/scipy/linalg/_testutils.py	                        (rev 0)
+++ trunk/scipy/linalg/_testutils.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -0,0 +1,57 @@
+import numpy as np
+
+class _FakeMatrix(object):
+    def __init__(self, data):
+        self._data = data
+        self.__array_interface__ = data.__array_interface__
+
+class _FakeMatrix2(object):
+    def __init__(self, data):
+        self._data = data
+    def __array__(self):
+        return self._data
+
+def _get_array(shape, dtype):
+    """
+    Get a test array of given shape and data type.
+    Returned NxN matrices are posdef, and 2xN are banded-posdef.
+
+    """
+    if len(shape) == 2 and shape[0] == 2:
+        # yield a banded positive definite one
+        x = np.zeros(shape, dtype=dtype)
+        x[0,1:] = -1
+        x[1] = 2
+        return x
+    elif len(shape) == 2 and shape[0] == shape[1]:
+        # always yield a positive definite matrix
+        x = np.zeros(shape, dtype=dtype)
+        j = np.arange(shape[0])
+        x[j,j] = 2
+        x[j[:-1],j[:-1]+1] = -1
+        x[j[:-1]+1,j[:-1]] = -1
+        return x
+    else:
+        np.random.seed(1234)
+        return np.random.randn(*shape).astype(dtype)
+
+def _id(x):
+    return x
+
+def assert_no_overwrite(call, shapes, dtypes=None):
+    """
+    Test that a call does not overwrite its input arguments
+    """
+
+    if dtypes is None:
+        dtypes = [np.float32, np.float64, np.complex64, np.complex128]
+
+    for dtype in dtypes:
+        for order in ["C", "F"]:
+            for faker in [_id, _FakeMatrix, _FakeMatrix2]:
+                orig_inputs = [_get_array(s, dtype) for s in shapes]
+                inputs = [faker(x.copy(order)) for x in orig_inputs]
+                call(*inputs)
+                msg = "call modified inputs [%r, %r]" % (dtype, faker)
+                for a, b in zip(inputs, orig_inputs):
+                    np.testing.assert_equal(a, b, err_msg=msg)

Modified: trunk/scipy/linalg/basic.py
===================================================================
--- trunk/scipy/linalg/basic.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/basic.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -12,7 +12,7 @@
 
 from flinalg import get_flinalg_funcs
 from lapack import get_lapack_funcs
-from misc import LinAlgError
+from misc import LinAlgError, _datacopied
 from scipy.linalg import calc_lwork
 import decomp_svd
 
@@ -49,8 +49,8 @@
         raise ValueError('expected square matrix')
     if a1.shape[0] != b1.shape[0]:
         raise ValueError('incompatible dimensions')
-    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
-    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+    overwrite_a = overwrite_a or _datacopied(a1, a)
+    overwrite_b = overwrite_b or _datacopied(b1, b)
     if debug:
         print 'solve:overwrite_a=',overwrite_a
         print 'solve:overwrite_b=',overwrite_b
@@ -117,7 +117,7 @@
         raise ValueError('expected square matrix')
     if a1.shape[0] != b1.shape[0]:
         raise ValueError('incompatible dimensions')
-    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+    overwrite_b = overwrite_b or _datacopied(b1, b)
     if debug:
         print 'solve:overwrite_b=',overwrite_b
     trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans)
@@ -174,7 +174,7 @@
         raise ValueError("invalid values for the number of lower and upper diagonals:"
                 " l+u+1 (%d) does not equal ab.shape[0] (%d)" % (l+u+1, ab.shape[0]))
 
-    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+    overwrite_b = overwrite_b or _datacopied(b1, b)
 
     gbsv, = get_lapack_funcs(('gbsv',), (a1, b1))
     a2 = zeros((2*l+u+1, a1.shape[1]), dtype=gbsv.dtype)
@@ -285,7 +285,7 @@
     a1 = asarray_chkfinite(a)
     if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
         raise ValueError('expected square matrix')
-    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
+    overwrite_a = overwrite_a or _datacopied(a1, a)
     #XXX: I found no advantage or disadvantage of using finv.
 ##     finv, = get_flinalg_funcs(('inv',),(a1,))
 ##     if finv is not None:
@@ -350,7 +350,7 @@
     a1 = asarray_chkfinite(a)
     if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
         raise ValueError('expected square matrix')
-    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
+    overwrite_a = overwrite_a or _datacopied(a1, a)
     fdet, = get_flinalg_funcs(('det',), (a1,))
     a_det, info = fdet(a1, overwrite_a=overwrite_a)
     if info < 0:
@@ -426,8 +426,8 @@
         else:
             b2[:m,0] = b1
         b1 = b2
-    overwrite_a = overwrite_a or (a1 is not a and not hasattr(a,'__array__'))
-    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+    overwrite_a = overwrite_a or _datacopied(a1, a)
+    overwrite_b = overwrite_b or _datacopied(b1, b)
     if gelss.module_name[:7] == 'flapack':
         lwork = calc_lwork.gelss(gelss.prefix, m, n, nrhs)[1]
         v, x, s, rank, info = gelss(a1, b1, cond=cond, lwork=lwork,

Modified: trunk/scipy/linalg/decomp.py
===================================================================
--- trunk/scipy/linalg/decomp.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -21,7 +21,7 @@
 
 # Local imports
 from scipy.linalg import calc_lwork
-from misc import LinAlgError, _datanotshared
+from misc import LinAlgError, _datacopied
 from lapack import get_lapack_funcs
 from blas import get_blas_funcs
 
@@ -43,7 +43,7 @@
 
 def _geneig(a1, b, left, right, overwrite_a, overwrite_b):
     b1 = asarray(b)
-    overwrite_b = overwrite_b or _datanotshared(b1, b)
+    overwrite_b = overwrite_b or _datacopied(b1, b)
     if len(b1.shape) != 2 or b1.shape[0] != b1.shape[1]:
         raise ValueError('expected square matrix')
     ggev, = get_lapack_funcs(('ggev',), (a1, b1))
@@ -135,7 +135,7 @@
     a1 = asarray_chkfinite(a)
     if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
         raise ValueError('expected square matrix')
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
     if b is not None:
         b = asarray_chkfinite(b)
         if b.shape != a1.shape:
@@ -265,14 +265,14 @@
     a1 = asarray_chkfinite(a)
     if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
         raise ValueError('expected square matrix')
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
     if iscomplexobj(a1):
         cplx = True
     else:
         cplx = False
     if b is not None:
         b1 = asarray_chkfinite(b)
-        overwrite_b = overwrite_b or _datanotshared(b1, b)
+        overwrite_b = overwrite_b or _datacopied(b1, b)
         if len(b1.shape) != 2 or b1.shape[0] != b1.shape[1]:
             raise ValueError('expected square matrix')
 
@@ -455,7 +455,7 @@
     """
     if eigvals_only or overwrite_a_band:
         a1 = asarray_chkfinite(a_band)
-        overwrite_a_band = overwrite_a_band or (_datanotshared(a1, a_band))
+        overwrite_a_band = overwrite_a_band or (_datacopied(a1, a_band))
     else:
         a1 = array(a_band)
         if issubclass(a1.dtype.type, inexact) and not isfinite(a1).all():
@@ -734,9 +734,9 @@
     a1 = asarray(a)
     if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
         raise ValueError('expected square matrix')
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
     gehrd,gebal = get_lapack_funcs(('gehrd','gebal'), (a1,))
-    ba, lo, hi, pivscale, info = gebal(a, permute=1, overwrite_a=overwrite_a)
+    ba, lo, hi, pivscale, info = gebal(a1, permute=1, overwrite_a=overwrite_a)
     if info < 0:
         raise ValueError('illegal value in %d-th argument of internal gebal '
                                                     '(hessenberg)' % -info)

Modified: trunk/scipy/linalg/decomp_cholesky.py
===================================================================
--- trunk/scipy/linalg/decomp_cholesky.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_cholesky.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -3,7 +3,7 @@
 from numpy import asarray_chkfinite
 
 # Local imports
-from misc import LinAlgError, _datanotshared
+from misc import LinAlgError, _datacopied
 from lapack import get_lapack_funcs
 
 __all__ = ['cholesky', 'cho_factor', 'cho_solve', 'cholesky_banded',
@@ -17,7 +17,7 @@
     if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]:
         raise ValueError('expected square matrix')
 
-    overwrite_a = overwrite_a or _datanotshared(a1, a)
+    overwrite_a = overwrite_a or _datacopied(a1, a)
     potrf, = get_lapack_funcs(('potrf',), (a1,))
     c, info = potrf(a1, lower=lower, overwrite_a=overwrite_a, clean=clean)
     if info > 0:
@@ -104,7 +104,7 @@
 
     See also
     --------
-    cho_solve : Solve a linear set equations using the Cholesky factorization 
+    cho_solve : Solve a linear set equations using the Cholesky factorization
                 of a matrix.
 
     """
@@ -140,7 +140,7 @@
     if c.shape[1] != b1.shape[0]:
         raise ValueError("incompatible dimensions.")
 
-    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b,'__array__'))
+    overwrite_b = overwrite_b or _datacopied(b1, b)
 
     potrs, = get_lapack_funcs(('potrs',), (c, b1))
     x, info = potrs(c, b1, lower=lower, overwrite_b=overwrite_b)
@@ -208,7 +208,7 @@
     b : array
         Right-hand side
     overwrite_b : bool
-        If True, the function will overwrite the values in `b`.    
+        If True, the function will overwrite the values in `b`.
 
     Returns
     -------
@@ -221,7 +221,7 @@
 
     Notes
     -----
-    
+
     .. versionadded:: 0.8.0
 
     """

Modified: trunk/scipy/linalg/decomp_lu.py
===================================================================
--- trunk/scipy/linalg/decomp_lu.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_lu.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -5,7 +5,7 @@
 from numpy import asarray, asarray_chkfinite
 
 # Local imports
-from misc import _datanotshared
+from misc import _datacopied
 from lapack import get_lapack_funcs
 from flinalg import get_flinalg_funcs
 
@@ -48,9 +48,9 @@
     a1 = asarray(a)
     if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
         raise ValueError('expected square matrix')
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
     getrf, = get_lapack_funcs(('getrf',), (a1,))
-    lu, piv, info = getrf(a, overwrite_a=overwrite_a)
+    lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
     if info < 0:
         raise ValueError('illegal value in %d-th argument of '
                                 'internal getrf (lu_factor)' % -info)
@@ -91,7 +91,7 @@
 
     """
     b1 = asarray_chkfinite(b)
-    overwrite_b = overwrite_b or (b1 is not b and not hasattr(b, '__array__'))
+    overwrite_b = overwrite_b or _datacopied(b1, b)
     if lu.shape[0] != b1.shape[0]:
         raise ValueError("incompatible dimensions.")
 
@@ -148,7 +148,7 @@
     a1 = asarray_chkfinite(a)
     if len(a1.shape) != 2:
         raise ValueError('expected matrix')
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
     flu, = get_flinalg_funcs(('lu',), (a1,))
     p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
     if info < 0:

Modified: trunk/scipy/linalg/decomp_qr.py
===================================================================
--- trunk/scipy/linalg/decomp_qr.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_qr.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -7,7 +7,7 @@
 import special_matrices
 from blas import get_blas_funcs
 from lapack import get_lapack_funcs, find_best_lapack_type
-from misc import _datanotshared
+from misc import _datacopied
 
 
 def qr(a, overwrite_a=False, lwork=None, mode='full'):
@@ -77,7 +77,7 @@
     if len(a1.shape) != 2:
         raise ValueError("expected 2D array")
     M, N = a1.shape
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
 
     geqrf, = get_lapack_funcs(('geqrf',), (a1,))
     if lwork is None or lwork == -1:
@@ -157,7 +157,7 @@
     if len(a1.shape) != 2:
         raise ValueError('expected matrix')
     M,N = a1.shape
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
     geqrf, = get_lapack_funcs(('geqrf',), (a1,))
     if lwork is None or lwork == -1:
         # get optimal work array
@@ -235,7 +235,7 @@
     if len(a1.shape) != 2:
         raise ValueError('expected matrix')
     M, N = a1.shape
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
 
     gerqf, = get_lapack_funcs(('gerqf',), (a1,))
     if lwork is None or lwork == -1:

Modified: trunk/scipy/linalg/decomp_schur.py
===================================================================
--- trunk/scipy/linalg/decomp_schur.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_schur.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -5,7 +5,7 @@
 
 # Local imports.
 import misc
-from misc import LinAlgError, _datanotshared
+from misc import LinAlgError, _datacopied
 from lapack import get_lapack_funcs
 from decomp import eigvals
 
@@ -63,13 +63,13 @@
         else:
             a1 = a1.astype('F')
             typ = 'F'
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
     gees, = get_lapack_funcs(('gees',), (a1,))
     if lwork is None or lwork == -1:
         # get optimal work array
-        result = gees(lambda x: None, a, lwork=-1)
+        result = gees(lambda x: None, a1, lwork=-1)
         lwork = result[-2][0].real.astype(numpy.int)
-    result = gees(lambda x: None, a, lwork=lwork, overwrite_a=overwrite_a)
+    result = gees(lambda x: None, a1, lwork=lwork, overwrite_a=overwrite_a)
     info = result[-1]
     if info < 0:
         raise ValueError('illegal value in %d-th argument of internal gees'

Modified: trunk/scipy/linalg/decomp_svd.py
===================================================================
--- trunk/scipy/linalg/decomp_svd.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/decomp_svd.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -5,7 +5,7 @@
 from scipy.linalg import calc_lwork
 
 # Local imports.
-from misc import LinAlgError, _datanotshared
+from misc import LinAlgError, _datacopied
 from lapack import get_lapack_funcs
 
 
@@ -73,7 +73,7 @@
     if len(a1.shape) != 2:
         raise ValueError('expected matrix')
     m,n = a1.shape
-    overwrite_a = overwrite_a or (_datanotshared(a1, a))
+    overwrite_a = overwrite_a or (_datacopied(a1, a))
     gesdd, = get_lapack_funcs(('gesdd',), (a1,))
     if gesdd.module_name[:7] == 'flapack':
         lwork = calc_lwork.gesdd(gesdd.prefix, m, n, compute_uv)[1]

Modified: trunk/scipy/linalg/misc.py
===================================================================
--- trunk/scipy/linalg/misc.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/misc.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -9,13 +9,14 @@
     return np.linalg.norm(np.asarray_chkfinite(a), ord=ord)
 norm.__doc__ = np.linalg.norm.__doc__
 
+def _datacopied(arr, original):
+    """
+    Strict check for `arr` not sharing any data with `original`,
+    under the assumption that arr = asarray(original)
 
-def _datanotshared(a1,a):
-    if a1 is a:
+    """
+    if arr is original:
         return False
-    else:
-        #try comparing data pointers
-        try:
-            return a1.__array_interface__['data'][0] != a.__array_interface__['data'][0]
-        except:
-            return True
\ No newline at end of file
+    if not isinstance(original, np.ndarray) and hasattr(original, '__array__'):
+        return False
+    return arr.base is None

Modified: trunk/scipy/linalg/tests/test_basic.py
===================================================================
--- trunk/scipy/linalg/tests/test_basic.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/tests/test_basic.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -29,6 +29,7 @@
 from scipy.linalg import solve, inv, det, lstsq, pinv, pinv2, norm,\
         solve_banded, solveh_banded, solve_triangular
 
+from scipy.linalg._testutils import assert_no_overwrite
 
 def random(size):
     return rand(*size)
@@ -561,5 +562,26 @@
         assert_equal(norm([1,0,3], 0), 2)
         assert_equal(norm([1,2,3], 0), 3)
 
+class TestOverwrite(object):
+    def test_solve(self):
+        assert_no_overwrite(solve, [(3,3), (3,)])
+    def test_solve_triangular(self):
+        assert_no_overwrite(solve_triangular, [(3,3), (3,)])
+    def test_solve_banded(self):
+        assert_no_overwrite(lambda ab, b: solve_banded((2,1), ab, b),
+                            [(4,6), (6,)])
+    def test_solveh_banded(self):
+        assert_no_overwrite(solveh_banded, [(2,6), (6,)])
+    def test_inv(self):
+        assert_no_overwrite(inv, [(3,3)])
+    def test_det(self):
+        assert_no_overwrite(det, [(3,3)])
+    def test_lstsq(self):
+        assert_no_overwrite(lstsq, [(3,2), (3,)])
+    def test_pinv(self):
+        assert_no_overwrite(pinv, [(3,3)])
+    def test_pinv2(self):
+        assert_no_overwrite(pinv2, [(3,3)])
+
 if __name__ == "__main__":
     run_module_suite()

Modified: trunk/scipy/linalg/tests/test_decomp.py
===================================================================
--- trunk/scipy/linalg/tests/test_decomp.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/tests/test_decomp.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -20,7 +20,7 @@
 
 from scipy.linalg import eig, eigvals, lu, svd, svdvals, cholesky, qr, \
      schur, rsf2csf, lu_solve, lu_factor, solve, diagsvd, hessenberg, rq, \
-     eig_banded, eigvals_banded, eigh
+     eig_banded, eigvals_banded, eigh, eigvalsh
 from scipy.linalg.flapack import dgbtrf, dgbtrs, zgbtrf, zgbtrs, \
      dsbev, dsbevd, dsbevx, zhbevd, zhbevx
 
@@ -32,6 +32,8 @@
 
 from numpy.random import rand, normal, seed
 
+from scipy.linalg._testutils import assert_no_overwrite
+
 # digit precision to use in asserts for different types
 DIGITS = {'d':11, 'D':11, 'f':4, 'F':4}
 
@@ -1101,24 +1103,36 @@
 
 
 
-class TestDataNotShared(TestCase):
+class TestDatacopied(TestCase):
 
-    def test_datanotshared(self):
-        from scipy.linalg.decomp import _datanotshared
+    def test_datacopied(self):
+        from scipy.linalg.decomp import _datacopied
 
         M = matrix([[0,1],[2,3]])
         A = asarray(M)
         L = M.tolist()
         M2 = M.copy()
 
-        assert_equal(_datanotshared(M,M),False)
-        assert_equal(_datanotshared(M,A),False)
+        class Fake1:
+            def __array__(self):
+                return A
 
-        assert_equal(_datanotshared(M,L),True)
-        assert_equal(_datanotshared(M,M2),True)
-        assert_equal(_datanotshared(A,M2),True)
+        class Fake2:
+            __array_interface__ = A.__array_interface__
 
+        F1 = Fake1()
+        F2 = Fake2()
 
+        AF1 = asarray(F1)
+        AF2 = asarray(F2)
+
+        for item, status in [(M, False), (A, False), (L, True),
+                             (M2, False), (F1, False), (F2, False)]:
+            arr = asarray(item)
+            assert_equal(_datacopied(arr, item), status,
+                         err_msg=repr(item))
+
+
 def test_aligned_mem_float():
     """Check linalg works with non-aligned memory"""
     # Allocate 402 bytes of memory (allocated on boundary)
@@ -1207,5 +1221,45 @@
 # not properly tested
 # cholesky, rsf2csf, lu_solve, solve, eig_banded, eigvals_banded, eigh, diagsvd
 
+
+class TestOverwrite(object):
+    def test_eig(self):
+        assert_no_overwrite(eig, [(3,3)])
+        assert_no_overwrite(eig, [(3,3), (3,3)])
+    def test_eigh(self):
+        assert_no_overwrite(eigh, [(3,3)])
+        assert_no_overwrite(eigh, [(3,3), (3,3)])
+    def test_eig_banded(self):
+        assert_no_overwrite(eig_banded, [(3,2)])
+    def test_eigvals(self):
+        assert_no_overwrite(eigvals, [(3,3)])
+    def test_eigvalsh(self):
+        assert_no_overwrite(eigvalsh, [(3,3)])
+    def test_eigvals_banded(self):
+        assert_no_overwrite(eigvals_banded, [(3,2)])
+    def test_hessenberg(self):
+        assert_no_overwrite(hessenberg, [(3,3)])
+    def test_lu_factor(self):
+        assert_no_overwrite(lu_factor, [(3,3)])
+    def test_lu_solve(self):
+        x = np.array([[1,2,3], [4,5,6], [7,8,8]])
+        xlu = lu_factor(x)
+        assert_no_overwrite(lambda b: lu_solve(xlu, b), [(3,)])
+    def test_lu(self):
+        assert_no_overwrite(lu, [(3,3)])
+    def test_qr(self):
+        assert_no_overwrite(qr, [(3,3)])
+    def test_rq(self):
+        assert_no_overwrite(rq, [(3,3)])
+    def test_schur(self):
+        assert_no_overwrite(schur, [(3,3)])
+    def test_schur_complex(self):
+        assert_no_overwrite(lambda a: schur(a, 'complex'), [(3,3)],
+                            dtypes=[np.float32, np.float64])
+    def test_svd(self):
+        assert_no_overwrite(svd, [(3,3)])
+    def test_svdvals(self):
+        assert_no_overwrite(svdvals, [(3,3)])
+
 if __name__ == "__main__":
     run_module_suite()

Modified: trunk/scipy/linalg/tests/test_decomp_cholesky.py
===================================================================
--- trunk/scipy/linalg/tests/test_decomp_cholesky.py	2011-01-31 21:52:21 UTC (rev 7115)
+++ trunk/scipy/linalg/tests/test_decomp_cholesky.py	2011-01-31 21:52:58 UTC (rev 7116)
@@ -4,8 +4,10 @@
 
 from numpy import array, transpose, dot, conjugate, zeros_like
 from numpy.random import rand
-from scipy.linalg import cholesky, cholesky_banded, cho_solve_banded
+from scipy.linalg import cholesky, cholesky_banded, cho_solve_banded, \
+     cho_factor, cho_solve
 
+from scipy.linalg._testutils import assert_no_overwrite
 
 def random(size):
     return rand(*size)
@@ -138,3 +140,20 @@
         b = array([0.0, 0.5j, 3.8j, 3.8])
         x = cho_solve_banded((c, True), b)
         assert_array_almost_equal(x, [0.0, 0.0, 1.0j, 1.0])
+
+class TestOverwrite(object):
+    def test_cholesky(self):
+        assert_no_overwrite(cholesky, [(3,3)])
+    def test_cho_factor(self):
+        assert_no_overwrite(cho_factor, [(3,3)])
+    def test_cho_solve(self):
+        x = array([[2,-1,0], [-1,2,-1], [0,-1,2]])
+        xcho = cho_factor(x)
+        assert_no_overwrite(lambda b: cho_solve(xcho, b), [(3,)])
+    def test_cholesky_banded(self):
+        assert_no_overwrite(cholesky_banded, [(2,3)])
+    def test_cho_solve_banded(self):
+        x = array([[0, -1, -1], [2, 2, 2]])
+        xcho = cholesky_banded(x)
+        assert_no_overwrite(lambda b: cho_solve_banded((xcho, False), b),
+                            [(3,)])




More information about the Scipy-svn mailing list