[Scipy-svn] r5046 - trunk/scipy/linalg/tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Nov 10 10:49:11 EST 2008


Author: tzito
Date: 2008-11-10 09:48:31 -0600 (Mon, 10 Nov 2008)
New Revision: 5046

Modified:
   trunk/scipy/linalg/tests/test_decomp.py
Log:
Added tests for linalg.eigh.
Added functions in linalg/tests/test_decomp.py that should
better go somewhere else:
assert_dtype_equal -> numpy.testing
hermitian, symrand, random_rot -> scipy.linalg


Modified: trunk/scipy/linalg/tests/test_decomp.py
===================================================================
--- trunk/scipy/linalg/tests/test_decomp.py	2008-11-10 15:46:35 UTC (rev 5045)
+++ trunk/scipy/linalg/tests/test_decomp.py	2008-11-10 15:48:31 UTC (rev 5046)
@@ -18,17 +18,97 @@
 
 from scipy.linalg import eig,eigvals,lu,svd,svdvals,cholesky,qr, \
      schur,rsf2csf, lu_solve,lu_factor,solve,diagsvd,hessenberg,rq, \
-     eig_banded, eigvals_banded
+     eig_banded, eigvals_banded, eigh
 from scipy.linalg.flapack import dgbtrf, dgbtrs, zgbtrf, zgbtrs, \
      dsbev, dsbevd, dsbevx, zhbevd, zhbevx
 
 from numpy import array, transpose, sometrue, diag, ones, linalg, \
      argsort, zeros, arange, float32, complex64, dot, conj, identity, \
      ravel, sqrt, iscomplex, shape, sort, conjugate, bmat, sign, \
-     asarray, matrix, isfinite, all
+     asarray, matrix, isfinite, all, ndarray, outer, eye, dtype
 
-from numpy.random import rand
+from numpy.random import rand, normal
 
+# digit precision to use in asserts for different types
+DIGITS = {'d':12, 'D':12, 'f':6, 'F':6}
+
+# matrix dimension in tests
+DIM = 5
+
+# XXX: This function should be available through numpy.testing
+def assert_dtype_equal(act, des):
+    if isinstance(act, ndarray):
+        act = act.dtype
+    else:
+        act = dtype(act)
+        
+    if isinstance(des, ndarray):
+        des = des.dtype
+    else:
+        des = dtype(des)
+        
+    assert act == des, 'dtype mismatch: "%s" (should be "%s") '%(act, des)
+
+# XXX: This function should not be defined here, but somewhere in
+#      scipy.linalg namespace
+def hermitian(x):
+    """Return the Hermitian, i.e. conjugate transpose, of x."""
+    return x.T.conj()
+
+# XXX: This function should not be defined here, but somewhere in
+#      scipy.linalg namespace
+def symrand(dim_or_eigv, dtype="d"):
+    """Return a random symmetric (Hermitian) matrix.
+    
+    If 'dim_or_eigv' is an integer N, return a NxN matrix, with eigenvalues
+        uniformly distributed on (0.1,1].
+        
+    If 'dim_or_eigv' is  1-D real array 'a', return a matrix whose
+                      eigenvalues are sort(a).
+    """
+    if isinstance(dim_or_eigv, int):
+        dim = dim_or_eigv
+        d = (rand(dim)*0.9)+0.1
+    elif (isinstance(dim_or_eigv, ndarray) and
+          len(dim_or_eigv.shape) == 1):
+        dim = dim_or_eigv.shape[0]
+        d = dim_or_eigv
+    else:
+        raise TypeError("input type not supported.")
+    
+    v = random_rot(dim, dtype=dtype)
+    h = dot(dot(hermitian(v), diag(d)), v)
+    # to avoid roundoff errors, symmetrize the matrix (again)
+    return (0.5*(hermitian(h)+h)).astype(dtype)
+
+# XXX: This function should not be defined here, but somewhere in
+#      scipy.linalg namespace
+def random_rot(dim, dtype='d'):
+    """Return a random rotation matrix, drawn from the Haar distribution
+    (the only uniform distribution on SO(n)).
+    The algorithm is described in the paper
+    Stewart, G.W., 'The efficient generation of random orthogonal
+    matrices with an application to condition estimators', SIAM Journal
+    on Numerical Analysis, 17(3), pp. 403-409, 1980.
+    For more information see
+    http://en.wikipedia.org/wiki/Orthogonal_matrix#Randomization"""
+    H = eye(dim, dtype=dtype)
+    D = ones((dim, ), dtype=dtype)
+    for n in range(1, dim):
+        x = normal(size=(dim-n+1, )).astype(dtype)
+        D[n-1] = sign(x[0])
+        x[0] -= D[n-1]*sqrt((x*x).sum())
+        # Householder transformation
+        
+        Hx = eye(dim-n+1, dtype=dtype) - 2.*outer(x, x)/(x*x).sum()
+        mat = eye(dim, dtype=dtype)
+        mat[n-1:,n-1:] = Hx
+        H = dot(H, mat)
+    # Fix the last sign such that the determinant is 1
+    D[-1] = -D.prod()
+    H = (D*H.T).T
+    return H
+
 def random(size):
     return rand(*size)
 
@@ -413,9 +493,37 @@
         y_lin = linalg.solve(self.comp_mat, self.bc)
         assert_array_almost_equal(y, y_lin)
 
+class TestEigH(TestCase):
+    def eigenproblem_standard(self, dim, dtype, overwrite, lower):
+        """Solve a standard eigenvalue problem."""
+        a = symrand(dim, dtype)
+        if overwrite:
+            a_c = a.copy()
+        else:
+            a_c = a
+        w, z = eigh(a, lower=lower, overwrite_a = overwrite)
+        assert_dtype_equal(z.dtype, dtype)
+        w = w.astype(dtype)
+        diag_ = diag(dot(hermitian(z), dot(a_c, z))).real
+        assert_array_almost_equal(diag_, w, DIGITS[dtype])
 
+    def test_eigh_real_standard(self):
+        self.eigenproblem_standard(DIM, 'd', False, False)
+        self.eigenproblem_standard(DIM, 'd', False, True)
+        self.eigenproblem_standard(DIM, 'd', True, True)
+        self.eigenproblem_standard(DIM, 'f', False, False)
+        self.eigenproblem_standard(DIM, 'f', False, True)
+        self.eigenproblem_standard(DIM, 'f', True, True)
+    
+    def test_eigh_complex_standard(self):
+        self.eigenproblem_standard(DIM, 'D', False, False)
+        self.eigenproblem_standard(DIM, 'D', False, True)
+        self.eigenproblem_standard(DIM, 'D', True, True)
+        self.eigenproblem_standard(DIM, 'F', False, False)
+        self.eigenproblem_standard(DIM, 'F', False, True)
+        self.eigenproblem_standard(DIM, 'F', True, True)
 
-
+    
 class TestLU(TestCase):
 
     def __init__(self, *args, **kw):




More information about the Scipy-svn mailing list