[Scipy-svn] r5042 - trunk/scipy/lib/lapack/tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Nov 10 09:55:51 EST 2008


Author: cdavid
Date: 2008-11-10 08:55:31 -0600 (Mon, 10 Nov 2008)
New Revision: 5042

Removed:
   trunk/scipy/lib/lapack/tests/esv_tests.py
   trunk/scipy/lib/lapack/tests/gesv_tests.py
Modified:
   trunk/scipy/lib/lapack/tests/common.py
   trunk/scipy/lib/lapack/tests/test_lapack.py
Log:
Finish refactoring of lapack.lib tests.

Modified: trunk/scipy/lib/lapack/tests/common.py
===================================================================
--- trunk/scipy/lib/lapack/tests/common.py	2008-11-10 14:55:08 UTC (rev 5041)
+++ trunk/scipy/lib/lapack/tests/common.py	2008-11-10 14:55:31 UTC (rev 5042)
@@ -9,7 +9,13 @@
          'ssyev' : np.float32, 
          'dsyev': np.float,
          'ssyevr' : np.float32,
-         'dsyevr' : np.float}
+         'dsyevr' : np.float,
+         'ssyevr' : np.float32,
+         'dsyevr' : np.float,
+         'sgehrd' : np.float32,
+         'dgehrd' : np.float,
+         'sgebal' : np.float32, 
+         'dgebal': np.float}
 
 # Test FLAPACK if not empty
 if hasattr(flapack, 'empty_module'):
@@ -31,7 +37,11 @@
                      'ssyev' : flapack.ssyev,
                      'dsyev': flapack.dsyev,
                      'ssyevr' : flapack.ssyevr,
-                     'dsyevr' : flapack.dsyevr}
+                     'dsyevr' : flapack.dsyevr,
+                     'sgehrd' : flapack.sgehrd,
+                     'dgehrd' : flapack.dgehrd,
+                     'sgebal' : flapack.sgebal,
+                     'dgebal': flapack.dgebal}
 else:
     FUNCS_FLAPACK = None
 
@@ -43,7 +53,11 @@
                      'ssyev' : clapack.ssyev,
                      'dsyev': clapack.dsyev,
                      'ssyevr' : clapack.ssyevr,
-                     'dsyevr' : clapack.dsyevr}
+                     'dsyevr' : clapack.dsyevr,
+                     'sgehrd' : flapack.sgehrd,
+                     'dgehrd' : flapack.dgehrd,
+                     'sgebal' : clapack.sgebal,
+                     'dgebal': clapack.dgebal}
 else:
     FUNCS_CLAPACK = None
 

Deleted: trunk/scipy/lib/lapack/tests/esv_tests.py
===================================================================
--- trunk/scipy/lib/lapack/tests/esv_tests.py	2008-11-10 14:55:08 UTC (rev 5041)
+++ trunk/scipy/lib/lapack/tests/esv_tests.py	2008-11-10 14:55:31 UTC (rev 5042)
@@ -1,116 +0,0 @@
-import numpy as np
-from numpy.testing import *
-from numpy import dot
-
-class _test_ev(object):
-
-    def check_syev(self,sym='sy',suffix=''):
-        a = [[1,2,3],[2,2,3],[3,3,6]]
-        exact_w = [-0.6699243371851365,0.4876938861533345,9.182230451031804]
-        f = getattr(self.lapack,sym+'ev'+suffix)
-        w,v,info=f(a)
-        assert not info,`info`
-        assert_array_almost_equal(w,exact_w)
-        for i in range(3):
-            assert_array_almost_equal(dot(a,v[:,i]),w[i]*v[:,i])
-
-    def check_syevd(self): self.check_syev(suffix='d')
-
-    #def check_heev(self): self.check_syev(sym='he')
-
-    #def check_heevd(self): self.check_syev(sym='he',suffix='d')
-
-##    def check_heev_complex(self,suffix=''):
-##        a= [[1,2-2j,3+7j],[2+2j,2,3],[3-7j,3,5]]
-##        exact_w=[-6.305141710654834,2.797880950890922,11.50726075976392]
-##        f = getattr(self.lapack,'heev'+suffix)
-##        w,v,info=f(a)
-##        assert not info,`info`
-##        assert_array_almost_equal(w,exact_w)
-##        for i in range(3):
-##            assert_array_almost_equal(dot(a,v[:,i]),w[i]*v[:,i],self.decimal)
-
-    #def check_heevd_complex(self): self.check_heev_complex(suffix='d')
-
-    def check_syevr(self,sym='sy'):
-        a = [[1,2,3],[2,2,3],[3,3,6]]
-        if self.lapack.prefix == 's':
-            exact_dtype = np.float32
-        else:
-            exact_dtype = np.float
-        exact_w = np.array([-0.6699243371851365, 0.4876938861533345,
-                            9.182230451031804], exact_dtype)
-        f = getattr(self.lapack,sym+'evr')
-        w,v,info = f(a)
-        assert not info,`info`
-        assert_array_almost_equal(w,exact_w, decimal=self.decimal)
-        for i in range(3):
-            assert_array_almost_equal(dot(a,v[:,i]),w[i]*v[:,i], decimal=self.decimal)
-
-##    def check_heevr_complex(self):
-##        a= [[1,2-2j,3+7j],[2+2j,2,3],[3-7j,3,5]]
-##        exact_w=[-6.305141710654834,2.797880950890922,11.50726075976392]
-##        f = self.lapack.heevr
-##        w,v,info = f(a)
-##        assert not info,`info`
-##        assert_array_almost_equal(w,exact_w)
-##        for i in range(3):
-##            assert_array_almost_equal(dot(a,v[:,i]),w[i]*v[:,i])
-
-##    def check_heevr(self): self.check_syevr(sym='he')
-
-    def check_syevr_irange(self,sym='sy',irange=[0,2]):
-        a = [[1,2,3],[2,2,3],[3,3,6]]
-        if self.lapack.prefix == 's':
-            exact_dtype = np.float32
-        else:
-            exact_dtype = np.float
-        exact_w = np.array([-0.6699243371851365, 0.4876938861533345,
-                            9.182230451031804], exact_dtype)
-        f = getattr(self.lapack,sym+'evr')
-        w,v,info = f(a,irange=irange)
-        assert not info,`info`
-        rslice = slice(irange[0],irange[1]+1)
-        m = irange[1] - irange[0] + 1
-        assert_equal(len(w),m)
-        assert_array_almost_equal(w,exact_w[rslice], decimal=self.decimal)
-        for i in range(m):
-            assert_array_almost_equal(dot(a,v[:,i]),w[i]*v[:,i], decimal=self.decimal)
-
-    def check_syevr_irange_low(self): self.check_syevr_irange(irange=[0,1])
-
-    def check_syevr_irange_mid(self): self.check_syevr_irange(irange=[1,1])
-
-    def check_syevr_irange_high(self): self.check_syevr_irange(irange=[1,2])
-
-##    def check_heevr_irange(self): self.check_syevr_irange(sym='he')
-
-##    def check_heevr_irange_low(self): self.check_syevr_irange(sym='he',irange=[0,1])
-
-##    def check_heevr_irange_high(self): self.check_syevr_irange(sym='he',irange=[1,2])
-
-    def check_syevr_vrange(self,sym='sy',vrange=None):
-        a = [[1,2,3],[2,2,3],[3,3,6]]
-        exact_w = [-0.6699243371851365,0.4876938861533345,9.182230451031804]
-        if vrange is None:
-            vrange = [-1,10]
-        ew = [value for value in exact_w if vrange[0]<value<=vrange[1]]
-        f = getattr(self.lapack,sym+'evr')
-        w,v,info = f(a,vrange=vrange)
-        assert not info,`info`
-        assert_array_almost_equal(w,ew)
-        m = len(w)
-        for i in range(m):
-            assert_array_almost_equal(dot(a,v[:,i]),w[i]*v[:,i])
-
-    def check_syevr_vrange_low(self): self.check_syevr_vrange(vrange=[-1,1])
-
-    def check_syevr_vrange_mid(self): self.check_syevr_vrange(vrange=[0,1])
-
-    def check_syevr_vrange_high(self): self.check_syevr_vrange(vrange=[1,10])
-
-##    def check_heevr_vrange(self): self.check_syevr_vrange(sym='he')
-
-##    def check_heevr_vrange_low(self): self.check_syevr_vrange(sym='he',vrange=[-1,1])
-
-##    def check_heevr_vrange_high(self): self.check_syevr_vrange(sym='he',vrange=[1,10])

Deleted: trunk/scipy/lib/lapack/tests/gesv_tests.py
===================================================================
--- trunk/scipy/lib/lapack/tests/gesv_tests.py	2008-11-10 14:55:08 UTC (rev 5041)
+++ trunk/scipy/lib/lapack/tests/gesv_tests.py	2008-11-10 14:55:31 UTC (rev 5042)
@@ -1,43 +0,0 @@
-
-from numpy.testing import *
-from numpy import dot
-
-class _test_gev(object):
-
-    def check_sygv(self,sym='sy',suffix='',itype=1):
-        a = [[1,2,3],[2,2,3],[3,3,6]]
-        b = [[10,-1,1],[-1,8,-2],[1,-2,6]]
-        f = getattr(self.lapack,sym+'gv'+suffix)
-        w,v,info=f(a,b,itype=itype)
-        assert not info,`info`
-        for i in range(3):
-            if itype==1:
-                assert_array_almost_equal(dot(a,v[:,i]),w[i]*dot(b,v[:,i]),self.decimal)
-            elif itype==2:
-                assert_array_almost_equal(dot(a,dot(b,v[:,i])),w[i]*v[:,i],self.decimal)
-            elif itype==3:
-                assert_array_almost_equal(dot(b,dot(a,v[:,i])),w[i]*v[:,i],self.decimal-1)
-            else:
-                raise ValueError,`itype`
-
-    def check_sygv_2(self): self.check_sygv(itype=2)
-
-    def check_sygv_3(self): self.check_sygv(itype=3)
-
-##    def check_hegv(self): self.check_sygv(sym='he')
-
-##    def check_hegv_2(self): self.check_sygv(sym='he',itype=2)
-
-##    def check_hegv_3(self): self.check_sygv(sym='he',itype=3)
-
-    def check_sygvd(self): self.check_sygv(suffix='d')
-
-    def check_sygvd_2(self): self.check_sygv(suffix='d',itype=2)
-
-    def check_sygvd_3(self): self.check_sygv(suffix='d',itype=3)
-
-##    def check_hegvd(self): self.check_sygv(sym='he',suffix='d')
-
-##    def check_hegvd_2(self): self.check_sygv(sym='he',suffix='d',itype=2)
-
-##    def check_hegvd_3(self): self.check_sygv(sym='he',suffix='d',itype=3)

Modified: trunk/scipy/lib/lapack/tests/test_lapack.py
===================================================================
--- trunk/scipy/lib/lapack/tests/test_lapack.py	2008-11-10 14:55:08 UTC (rev 5041)
+++ trunk/scipy/lib/lapack/tests/test_lapack.py	2008-11-10 14:55:31 UTC (rev 5042)
@@ -2,136 +2,83 @@
 #
 # Created by: Pearu Peterson, September 2002
 #
-'''
-This file adapted for nose tests 1/1/08
-
-Note that the conversion is not very complete.
-
-This and the included files deliberately use "check_" as the test
-method names.  There are no subclasses of TestCase.  Thus nose will
-pick up nothing but the final test_all_lapack generator function.
-This does the work of collecting the test methods and checking if they
-can be run (see the isrunnable method).
-'''
-
+import numpy as np
 from numpy.testing import *
-from numpy import ones
 
-from scipy.lib.lapack import flapack, clapack
+from common import FUNCS_TP, FUNCS_CLAPACK, FUNCS_FLAPACK, FLAPACK_IS_EMPTY, \
+                   CLAPACK_IS_EMPTY
 
-#sys.path.insert(0, os.path.split(__file__))
-from gesv_tests import _test_gev
-from esv_tests import _test_ev
-#del sys.path[0]
+class TestLapack(TestCase):
+    def _test_gebal_base(self, func, lang):
+        tp = FUNCS_TP[func]
 
-#class _test_ev: pass
+        a = np.array([[1,2,3],[4,5,6],[7,8,9]]).astype(tp)
+        a1 = np.array([[1,0,0,3e-4],
+                       [4,0,0,2e-3],
+                       [7,1,0,0],
+                       [0,1,0,0]]).astype(tp)
 
-class _TestLapack( _test_ev,
-                   _test_gev):
+        if lang == 'C':
+            f = FUNCS_CLAPACK[func]
+        elif lang == 'F':
+            f = FUNCS_FLAPACK[func]
+        else:
+            raise ValueError("Lang %s ??" % lang)
 
-    def check_gebal(self):
-        a = [[1,2,3],[4,5,6],[7,8,9]]
-        a1 = [[1,0,0,3e-4],
-              [4,0,0,2e-3],
-              [7,1,0,0],
-              [0,1,0,0]]
-        f = self.lapack.gebal
+        ba, lo, hi, pivscale, info = f(a)
+        assert not info, `info`
+        assert_array_almost_equal(ba, a)
+        assert_equal((lo,hi), (0, len(a[0])-1))
+        assert_array_almost_equal(pivscale, np.ones(len(a)))
 
-        ba,lo,hi,pivscale,info = f(a)
-        assert not info,`info`
-        assert_array_almost_equal(ba,a)
-        assert_equal((lo,hi),(0,len(a[0])-1))
-        assert_array_almost_equal(pivscale,ones(len(a)))
+        ba, lo, hi, pivscale, info = f(a1,permute=1,scale=1)
+        assert not info, `info`
 
-        ba,lo,hi,pivscale,info = f(a1,permute=1,scale=1)
-        assert not info,`info`
+    def _test_gehrd_base(self, func, lang):
+        tp = FUNCS_TP[func]
 
-    def check_gehrd(self):
-        a = [[-149, -50,-154],
+        a = np.array([[-149, -50,-154],
              [ 537, 180, 546],
-             [ -27,  -9, -25]]
-        f = self.lapack.gehrd
-        ht,tau,info = f(a)
+             [ -27,  -9, -25]]).astype(tp)
+        
+        if lang == 'C':
+            f = FUNCS_CLAPACK[func]
+        elif lang == 'F':
+            f = FUNCS_FLAPACK[func]
+        else:
+            raise ValueError("Lang %s ??" % lang)
+
+        ht, tau, info = f(a)
         assert not info,`info`
 
-    def isrunnable(self,mthname):
-        ''' Return True if required routines for check method present in module '''
-        l = mthname.split('_')
-        if len(l)>1 and l[0]=='check':
-            return hasattr(self.lapack,l[1])
-        return 2
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_sgebal(self):
+        self._test_gebal_base('sgebal', 'F')
+                
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip flapack test")
+    def test_dgebal(self):
+        self._test_gebal_base('dgebal', 'F')
+                
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip clapack test")
+    def test_sgehrd(self):
+        self._test_gehrd_base('sgehrd', 'F')
+                
+    @dec.skipif(FLAPACK_IS_EMPTY, "Flapack empty, skip clapack test")
+    def test_dgehrd(self):
+        self._test_gehrd_base('dgehrd', 'F')
 
-class PrefixWrapper(object):
-    def __init__(self,module,prefix):
-        self.module = module
-        self.prefix = prefix
-        self.__doc__ = module.__doc__
-
-    def __getattr__(self, name):
-        class A: pass
-        a = getattr(self.module,self.prefix+name,getattr(self.module,name,A()))
-        if isinstance(a,A):
-            raise AttributeError,'%s has no attribute %r' % (self.module,name)
-        return a
-
-if hasattr(flapack,'empty_module'):
-    print """
-****************************************************************
-WARNING: flapack module is empty
------------
-See scipy/INSTALL.txt for troubleshooting.
-****************************************************************
-"""
-else:
-    class TestFlapackDouble(_TestLapack):
-        lapack = PrefixWrapper(flapack,'d')
-        decimal = 12
-    class TestFlapackFloat(_TestLapack):
-        lapack = PrefixWrapper(flapack,'s')
-        decimal = 5
-    class TestFlapackComplex(_TestLapack):
-        lapack = PrefixWrapper(flapack,'c')
-        decimal = 5
-    class TestFlapackDoubleComplex(_TestLapack):
-        lapack = PrefixWrapper(flapack,'z')
-        decimal = 12
-
-if hasattr(clapack,'empty_module') or clapack is flapack:
-    print """
-****************************************************************
-WARNING: clapack module is empty
------------
-See scipy/INSTALL.txt for troubleshooting.
-Notes:
-* If atlas library is not found by numpy/distutils/system_info.py,
-  then scipy uses flapack instead of clapack.
-****************************************************************
-"""
-else:
-    class TestClapackDouble(_TestLapack):
-        lapack = PrefixWrapper(clapack,'d')
-        decimal = 12
-    class TestClapackFloat(_TestLapack):
-        lapack = PrefixWrapper(clapack,'s')
-        decimal = 5
-    class TestClapackComplex(_TestLapack):
-        lapack = PrefixWrapper(clapack,'c')
-        decimal = 5
-    class TestClapackDoubleComplex(_TestLapack):
-        lapack = PrefixWrapper(clapack,'z')
-        decimal = 12
-
-# Collect test classes and methods with generator
-# This is a moderate hack replicating some obscure numpy testing
-# functionality for use with nose
-
-def test_all_lapack():
-    methods = []
-    for name, value in globals().items():
-        if not (name.startswith('Test')
-                and issubclass(value, _TestLapack)):
-            continue
-        o = value()
-        methods += [getattr(o, n) for n in dir(o) if o.isrunnable(n) is True]
-    for method in methods:
-        yield (method, )
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_sgebal(self):
+        self._test_gebal_base('sgebal', 'C')
+                
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_dgebal(self):
+        self._test_gebal_base('dgebal', 'C')
+                
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_sgehrd(self):
+        self._test_gehrd_base('sgehrd', 'C')
+                
+    @dec.skipif(CLAPACK_IS_EMPTY, "Clapack empty, skip clapack test")
+    def test_clapack_dgehrd(self):
+        self._test_gehrd_base('dgehrd', 'C')




More information about the Scipy-svn mailing list