[Scipy-svn] r5040 - trunk/scipy/lib/lapack/tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Nov 10 09:54:57 EST 2008


Author: cdavid
Date: 2008-11-10 08:54:44 -0600 (Mon, 10 Nov 2008)
New Revision: 5040

Added:
   trunk/scipy/lib/lapack/tests/test_gesv.py
Log:
refactor gesv tests.

Added: trunk/scipy/lib/lapack/tests/test_gesv.py
===================================================================
--- trunk/scipy/lib/lapack/tests/test_gesv.py	2008-11-10 14:54:25 UTC (rev 5039)
+++ trunk/scipy/lib/lapack/tests/test_gesv.py	2008-11-10 14:54:44 UTC (rev 5040)
@@ -0,0 +1,119 @@
+import numpy as np
+from numpy.testing import TestCase, assert_array_almost_equal, dec, \
+                          assert_equal
+
+from scipy.lib.lapack import flapack, clapack
+
+A = np.array([[1,2,3],[2,2,3],[3,3,6]])
+B = np.array([[10,-1,1],[-1,8,-2],[1,-2,6]])
+
+FUNCS_TP = {'ssygv' : np.float32, 
+         'dsygv': np.float,
+         'ssygvd' : np.float32,
+         'dsygvd' : np.float}
+
+# Test FLAPACK if not empty
+if hasattr(flapack, 'empty_module'):
+    FLAPACK_IS_EMPTY = True
+else:
+    FLAPACK_IS_EMPTY = False
+
+# Test CLAPACK if not empty and not the same as clapack
+if hasattr(clapack, 'empty_module') or (clapack == flapack):
+    CLAPACK_IS_EMPTY = True
+else:
+    CLAPACK_IS_EMPTY = False
+
+if not FLAPACK_IS_EMPTY:
+    FUNCS_FLAPACK = {'ssygv' : flapack.ssygv,
+                     'dsygv': flapack.dsygv,
+                     'ssygvd' : flapack.ssygvd,
+                     'dsygvd' : flapack.dsygvd}
+
+if not CLAPACK_IS_EMPTY:
+    FUNCS_CLAPACK = {'ssygv' : clapack.ssygv,
+                     'dsygv': clapack.dsygv,
+                     'ssygvd' : clapack.ssygvd,
+                     'dsygvd' : clapack.dsygvd}
+
+
+PREC = {np.float32: 5, np.float: 12}
+
+class TestSygv(TestCase):
+    def _test_base(self, func, lang, itype):
+        tp = FUNCS_TP[func]
+        a = A.astype(tp)
+        b = B.astype(tp)
+        if lang == 'C':
+            f = FUNCS_CLAPACK[func]
+        elif lang == 'F':
+            f = FUNCS_FLAPACK[func]
+        else:
+            raise ValueError("Lang %s ??" % lang)
+
+        w, v, info = f(a, b, itype=itype)
+
+        assert not info, `info`
+        for i in range(3):
+            if itype == 1:
+                assert_array_almost_equal(np.dot(a,v[:,i]), w[i]*np.dot(b,v[:,i]),
+                                          decimal=PREC[tp])
+            elif itype == 2:
+                assert_array_almost_equal(np.dot(a,np.dot(b,v[:,i])), w[i]*v[:,i],
+                                          decimal=PREC[tp])
+            elif itype == 3:
+                assert_array_almost_equal(np.dot(b,np.dot(a,v[:,i])), 
+                                          w[i]*v[:,i], decimal=PREC[tp] - 1)
+            else:
+                raise ValueError, `itype`
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_ssygv_1(self):
+        self._test_base('ssygv', 'F', 1)
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_ssygv_2(self):
+        self._test_base('ssygv', 'F', 2)
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_ssygv_3(self):
+        self._test_base('ssygv', 'F', 3)
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_dsygv_1(self):
+        self._test_base('dsygv', 'F', 1)
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_dsygv_2(self):
+        self._test_base('dsygv', 'F', 2)
+
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_dsygv_3(self):
+        self._test_base('dsygv', 'F', 3)
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip flapack test")
+    def test_clapack_ssygv_1(self):
+        self._test_base('ssygv', 'C', 1)
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip flapack test")
+    def test_clapack_ssygv_2(self):
+        self._test_base('ssygv', 'C', 2)
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip flapack test")
+    def test_clapack_ssygv_3(self):
+        self._test_base('ssygv', 'C', 3)
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip flapack test")
+    def test_clapack_dsygv_1(self):
+        self._test_base('dsygv', 'C', 1)
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip flapack test")
+    def test_clapack_dsygv_2(self):
+        self._test_base('dsygv', 'C', 2)
+
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip flapack test")
+    def test_clapack_dsygv_3(self):
+        self._test_base('dsygv', 'C', 3)
+
+if __name__=="__main__":
+    run_module_suite()




More information about the Scipy-svn mailing list