[Scipy-svn] r5039 - trunk/scipy/lib/lapack/tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Nov 10 09:54:33 EST 2008


Author: cdavid
Date: 2008-11-10 08:54:25 -0600 (Mon, 10 Nov 2008)
New Revision: 5039

Added:
   trunk/scipy/lib/lapack/tests/test_esv.py
Log:
Refactor syev* tests in scipy.lib.lapack.

Added: trunk/scipy/lib/lapack/tests/test_esv.py
===================================================================
--- trunk/scipy/lib/lapack/tests/test_esv.py	2008-11-10 13:20:36 UTC (rev 5038)
+++ trunk/scipy/lib/lapack/tests/test_esv.py	2008-11-10 14:54:25 UTC (rev 5039)
@@ -0,0 +1,162 @@
+import numpy as np
+from numpy.testing import TestCase, assert_array_almost_equal, dec, \
+                          assert_equal
+
+from scipy.lib.lapack import flapack, clapack
+
+SYEV_ARG = np.array([[1,2,3],[2,2,3],[3,3,6]])
+SYEV_REF = np.array([-0.6699243371851365, 0.4876938861533345,
+                     9.182230451031804])
+
+FUNCS_TP = {'ssyev' : np.float32, 
+         'dsyev': np.float,
+         'ssyevr' : np.float32,
+         'dsyevr' : np.float}
+
+# Test FLAPACK if not empty
+if hasattr(flapack, 'empty_module'):
+    FLAPACK_IS_EMPTY = True
+else:
+    FLAPACK_IS_EMPTY = False
+
+# Test CLAPACK if not empty and not the same as clapack
+if hasattr(clapack, 'empty_module') or (clapack == flapack):
+    CLAPACK_IS_EMPTY = True
+else:
+    CLAPACK_IS_EMPTY = False
+
+if not FLAPACK_IS_EMPTY:
+    FUNCS_FLAPACK = {'ssyev' : flapack.ssyev,
+                     'dsyev': flapack.dsyev,
+                     'ssyevr' : flapack.ssyevr,
+                     'dsyevr' : flapack.dsyevr}
+
+if not CLAPACK_IS_EMPTY:
+    FUNCS_CLAPACK = {'ssyev' : clapack.ssyev,
+                     'dsyev': clapack.dsyev,
+                     'ssyevr' : clapack.ssyevr,
+                     'dsyevr' : clapack.dsyevr}
+
+PREC = {np.float32: 5, np.float: 12}
+
+class TestEsv(TestCase):
+    def _test_base(self, func, lang):
+        tp = FUNCS_TP[func]
+        a = SYEV_ARG.astype(tp)
+        if lang == 'C':
+            f = FUNCS_CLAPACK[func]
+        elif lang == 'F':
+            f = FUNCS_FLAPACK[func]
+        else:
+            raise ValueError("Lang %s ??" % lang)
+
+        w, v, info = f(a)
+
+        assert not info, `info`
+        assert_array_almost_equal(w, SYEV_REF, decimal=PREC[tp])
+        for i in range(3):
+            assert_array_almost_equal(np.dot(a,v[:,i]), w[i]*v[:,i], 
+                                      decimal=PREC[tp])
+
+    def _test_base_irange(self, func, irange, lang):
+        tp = FUNCS_TP[func]
+        a = SYEV_ARG.astype(tp)
+        if lang == 'C':
+            f = FUNCS_CLAPACK[func]
+        elif lang == 'F':
+            f = FUNCS_FLAPACK[func]
+        else:
+            raise ValueError("Lang %s ??" % lang)
+
+        w, v, info = f(a, irange=irange)
+        rslice = slice(irange[0], irange[1]+1)
+        m = irange[1] - irange[0] + 1
+        assert not info, `info`
+
+        assert_equal(len(w),m)
+        assert_array_almost_equal(w, SYEV_REF[rslice], decimal=PREC[tp])
+
+        for i in range(m):
+            assert_array_almost_equal(np.dot(a,v[:,i]), w[i]*v[:,i], 
+                                      decimal=PREC[tp])
+
+    def _test_base_vrange(self, func, vrange, lang):
+        tp = FUNCS_TP[func]
+        a = SYEV_ARG.astype(tp)
+        ew = [value for value in SYEV_REF if vrange[0] < value <= vrange[1]]
+
+        if lang == 'C':
+            f = FUNCS_CLAPACK[func]
+        elif lang == 'F':
+            f = FUNCS_FLAPACK[func]
+        else:
+            raise ValueError("Lang %s ??" % lang)
+
+        w, v, info = f(a, vrange=vrange)
+        assert not info, `info`
+
+        assert_array_almost_equal(w, ew, decimal=PREC[tp])
+
+        for i in range(len(w)):
+            assert_array_almost_equal(np.dot(a,v[:,i]), w[i]*v[:,i], 
+                                      decimal=PREC[tp])
+
+    def _test_syevr_ranges(self, func, lang):
+        for irange in ([0, 2], [0, 1], [1, 1], [1, 2]):
+            self._test_base_irange(func, irange, lang)
+
+        for vrange in ([-1, 10], [-1, 1], [0, 1], [1, 10]):
+            self._test_base_vrange(func, vrange, lang)
+
+    # Flapack tests
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_ssyev(self):
+        self._test_base('ssyev', 'F')
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_dsyev(self):
+        self._test_base('dsyev', 'F')
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_ssyevr(self):
+        self._test_base('ssyevr', 'F')
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_dsyevr(self):
+        self._test_base('dsyevr', 'F')
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_ssyevr_ranges(self):
+        self._test_syevr_ranges('ssyevr', 'F')
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_dsyevr_ranges(self):
+        self._test_syevr_ranges('dsyevr', 'F')
+
+    # Clapack tests
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_ssyev(self):
+        self._test_base('ssyev', 'C')
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_dsyev(self):
+        self._test_base('dsyev', 'C')
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_ssyevr(self):
+        self._test_base('ssyevr', 'C')
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_dsyevr(self):
+        self._test_base('dsyevr', 'C')
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_ssyevr_ranges(self):
+        self._test_syevr_ranges('ssyevr', 'C')
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_dsyevr_ranges(self):
+        self._test_syevr_ranges('dsyevr', 'C')
+
+if __name__=="__main__":
+    run_module_suite()




More information about the Scipy-svn mailing list