[Scipy-svn] r5046 - trunk/scipy/linalg/tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Mon Nov 10 10:49:11 EST 2008
Author: tzito
Date: 2008-11-10 09:48:31 -0600 (Mon, 10 Nov 2008)
New Revision: 5046
Modified:
trunk/scipy/linalg/tests/test_decomp.py
Log:
Added tests for linalg.eigh.
Added functions in linalg/tests/test_decomp.py that should
better go somewhere else:
assert_dtype_equal -> numpy.testing
hermitian, symrand, random_rot -> scipy.linalg
Modified: trunk/scipy/linalg/tests/test_decomp.py
===================================================================
--- trunk/scipy/linalg/tests/test_decomp.py 2008-11-10 15:46:35 UTC (rev 5045)
+++ trunk/scipy/linalg/tests/test_decomp.py 2008-11-10 15:48:31 UTC (rev 5046)
@@ -18,17 +18,97 @@
from scipy.linalg import eig,eigvals,lu,svd,svdvals,cholesky,qr, \
schur,rsf2csf, lu_solve,lu_factor,solve,diagsvd,hessenberg,rq, \
- eig_banded, eigvals_banded
+ eig_banded, eigvals_banded, eigh
from scipy.linalg.flapack import dgbtrf, dgbtrs, zgbtrf, zgbtrs, \
dsbev, dsbevd, dsbevx, zhbevd, zhbevx
from numpy import array, transpose, sometrue, diag, ones, linalg, \
argsort, zeros, arange, float32, complex64, dot, conj, identity, \
ravel, sqrt, iscomplex, shape, sort, conjugate, bmat, sign, \
- asarray, matrix, isfinite, all
+ asarray, matrix, isfinite, all, ndarray, outer, eye, dtype
-from numpy.random import rand
+from numpy.random import rand, normal
+# digit precision to use in asserts for different types
+DIGITS = {'d':12, 'D':12, 'f':6, 'F':6}
+
+# matrix dimension in tests
+DIM = 5
+
+# XXX: This function should be available through numpy.testing
+def assert_dtype_equal(act, des):
+ if isinstance(act, ndarray):
+ act = act.dtype
+ else:
+ act = dtype(act)
+
+ if isinstance(des, ndarray):
+ des = des.dtype
+ else:
+ des = dtype(des)
+
+ assert act == des, 'dtype mismatch: "%s" (should be "%s") '%(act, des)
+
+# XXX: This function should not be defined here, but somewhere in
+# scipy.linalg namespace
+def hermitian(x):
+ """Return the Hermitian, i.e. conjugate transpose, of x."""
+ return x.T.conj()
+
+# XXX: This function should not be defined here, but somewhere in
+# scipy.linalg namespace
+def symrand(dim_or_eigv, dtype="d"):
+ """Return a random symmetric (Hermitian) matrix.
+
+ If 'dim_or_eigv' is an integer N, return a NxN matrix, with eigenvalues
+ uniformly distributed on (0.1,1].
+
+ If 'dim_or_eigv' is 1-D real array 'a', return a matrix whose
+ eigenvalues are sort(a).
+ """
+ if isinstance(dim_or_eigv, int):
+ dim = dim_or_eigv
+ d = (rand(dim)*0.9)+0.1
+ elif (isinstance(dim_or_eigv, ndarray) and
+ len(dim_or_eigv.shape) == 1):
+ dim = dim_or_eigv.shape[0]
+ d = dim_or_eigv
+ else:
+ raise TypeError("input type not supported.")
+
+ v = random_rot(dim, dtype=dtype)
+ h = dot(dot(hermitian(v), diag(d)), v)
+ # to avoid roundoff errors, symmetrize the matrix (again)
+ return (0.5*(hermitian(h)+h)).astype(dtype)
+
+# XXX: This function should not be defined here, but somewhere in
+# scipy.linalg namespace
+def random_rot(dim, dtype='d'):
+ """Return a random rotation matrix, drawn from the Haar distribution
+ (the only uniform distribution on SO(n)).
+ The algorithm is described in the paper
+ Stewart, G.W., 'The efficient generation of random orthogonal
+ matrices with an application to condition estimators', SIAM Journal
+ on Numerical Analysis, 17(3), pp. 403-409, 1980.
+ For more information see
+ http://en.wikipedia.org/wiki/Orthogonal_matrix#Randomization"""
+ H = eye(dim, dtype=dtype)
+ D = ones((dim, ), dtype=dtype)
+ for n in range(1, dim):
+ x = normal(size=(dim-n+1, )).astype(dtype)
+ D[n-1] = sign(x[0])
+ x[0] -= D[n-1]*sqrt((x*x).sum())
+ # Householder transformation
+
+ Hx = eye(dim-n+1, dtype=dtype) - 2.*outer(x, x)/(x*x).sum()
+ mat = eye(dim, dtype=dtype)
+ mat[n-1:,n-1:] = Hx
+ H = dot(H, mat)
+ # Fix the last sign such that the determinant is 1
+ D[-1] = -D.prod()
+ H = (D*H.T).T
+ return H
+
def random(size):
return rand(*size)
@@ -413,9 +493,37 @@
y_lin = linalg.solve(self.comp_mat, self.bc)
assert_array_almost_equal(y, y_lin)
+class TestEigH(TestCase):
+ def eigenproblem_standard(self, dim, dtype, overwrite, lower):
+ """Solve a standard eigenvalue problem."""
+ a = symrand(dim, dtype)
+ if overwrite:
+ a_c = a.copy()
+ else:
+ a_c = a
+ w, z = eigh(a, lower=lower, overwrite_a = overwrite)
+ assert_dtype_equal(z.dtype, dtype)
+ w = w.astype(dtype)
+ diag_ = diag(dot(hermitian(z), dot(a_c, z))).real
+ assert_array_almost_equal(diag_, w, DIGITS[dtype])
+ def test_eigh_real_standard(self):
+ self.eigenproblem_standard(DIM, 'd', False, False)
+ self.eigenproblem_standard(DIM, 'd', False, True)
+ self.eigenproblem_standard(DIM, 'd', True, True)
+ self.eigenproblem_standard(DIM, 'f', False, False)
+ self.eigenproblem_standard(DIM, 'f', False, True)
+ self.eigenproblem_standard(DIM, 'f', True, True)
+
+ def test_eigh_complex_standard(self):
+ self.eigenproblem_standard(DIM, 'D', False, False)
+ self.eigenproblem_standard(DIM, 'D', False, True)
+ self.eigenproblem_standard(DIM, 'D', True, True)
+ self.eigenproblem_standard(DIM, 'F', False, False)
+ self.eigenproblem_standard(DIM, 'F', False, True)
+ self.eigenproblem_standard(DIM, 'F', True, True)
-
+
class TestLU(TestCase):
def __init__(self, *args, **kw):
More information about the Scipy-svn
mailing list