[Scipy-svn] r6791 - trunk/scipy/lib/blas/tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Sep 12 17:19:54 EDT 2010


Author: warren.weckesser
Date: 2010-09-12 16:19:54 -0500 (Sun, 12 Sep 2010)
New Revision: 6791

Modified:
   trunk/scipy/lib/blas/tests/test_blas.py
   trunk/scipy/lib/blas/tests/test_fblas.py
Log:
TST: lib: Don't use 'import *'. Don't use plain 'assert'.

Modified: trunk/scipy/lib/blas/tests/test_blas.py
===================================================================
--- trunk/scipy/lib/blas/tests/test_blas.py	2010-09-12 21:18:42 UTC (rev 6790)
+++ trunk/scipy/lib/blas/tests/test_blas.py	2010-09-12 21:19:54 UTC (rev 6791)
@@ -13,7 +13,8 @@
 import math
 
 from numpy import array
-from numpy.testing import *
+from numpy.testing import assert_equal, assert_almost_equal, \
+        assert_array_almost_equal, TestCase, run_module_suite
 from scipy.lib.blas import fblas
 from scipy.lib.blas import cblas
 from scipy.lib.blas import get_blas_funcs

Modified: trunk/scipy/lib/blas/tests/test_fblas.py
===================================================================
--- trunk/scipy/lib/blas/tests/test_fblas.py	2010-09-12 21:18:42 UTC (rev 6790)
+++ trunk/scipy/lib/blas/tests/test_fblas.py	2010-09-12 21:19:54 UTC (rev 6791)
@@ -8,7 +8,8 @@
 
 from numpy import zeros, transpose, newaxis, shape, float32, float64, \
                   complex64, complex128, arange, array, common_type, conjugate
-from numpy.testing import *
+from numpy.testing import assert_equal, assert_array_almost_equal, \
+        run_module_suite, TestCase
 from scipy.lib.blas import fblas
 
 #decimal accuracy to require between Python and LAPACK/BLAS calculations
@@ -22,7 +23,7 @@
         b = b[:,newaxis]
     else:
         b_is_vector = False
-    assert a.shape[1] == b.shape[0]
+    assert_equal(a.shape[1], b.shape[0])
     c = zeros((a.shape[0], b.shape[1]), common_type(a, b))
     for i in xrange(a.shape[0]):
         for j in xrange(b.shape[1]):
@@ -38,19 +39,23 @@
 ### Test blas ?axpy
 
 class BaseAxpy(object):
+
     # Mixin class to test dtypes
+
     def test_default_a(self):
         x = arange(3.,dtype=self.dtype)
         y = arange(3.,dtype=x.dtype)
         real_y = x*1.+y
         self.blas_func(x,y)
         assert_array_almost_equal(real_y,y)
+
     def test_simple(self):
         x = arange(3.,dtype=self.dtype)
         y = arange(3.,dtype=x.dtype)
         real_y = x*3.+y
         self.blas_func(x,y,a=3.)
         assert_array_almost_equal(real_y,y)
+
     def test_x_stride(self):
         x = arange(6.,dtype=self.dtype)
         y = zeros(3,x.dtype)
@@ -58,18 +63,21 @@
         real_y = x[::2]*3.+y
         self.blas_func(x,y,a=3.,n=3,incx=2)
         assert_array_almost_equal(real_y,y)
+
     def test_y_stride(self):
         x = arange(3.,dtype=self.dtype)
         y = zeros(6,x.dtype)
         real_y = x*3.+y[::2]
         self.blas_func(x,y,a=3.,n=3,incy=2)
         assert_array_almost_equal(real_y,y[::2])
+
     def test_x_and_y_stride(self):
         x = arange(12.,dtype=self.dtype)
         y = zeros(6,x.dtype)
         real_y = x[::4]*3.+y[::2]
         self.blas_func(x,y,a=3.,n=3,incx=4,incy=2)
         assert_array_almost_equal(real_y,y[::2])
+
     def test_x_bad_size(self):
         x = arange(12.,dtype=self.dtype)
         y = zeros(6,x.dtype)
@@ -79,6 +87,7 @@
             return
         # should catch error and never get here
         assert(0)
+
     def test_y_bad_size(self):
         x = arange(12.,dtype=complex64)
         y = zeros(6,x.dtype)
@@ -95,15 +104,18 @@
         dtype = float32
 except AttributeError:
     class TestSaxpy: pass
+
 class TestDaxpy(TestCase, BaseAxpy):
     blas_func = fblas.daxpy
     dtype = float64
+
 try:
     class TestCaxpy(TestCase, BaseAxpy):
         blas_func = fblas.caxpy
         dtype = complex64
 except AttributeError:
     class TestCaxpy: pass
+
 class TestZaxpy(TestCase, BaseAxpy):
     blas_func = fblas.zaxpy
     dtype = complex128
@@ -113,18 +125,22 @@
 ### Test blas ?scal
 
 class BaseScal(object):
+
     # Mixin class for testing particular dtypes
+
     def test_simple(self):
         x = arange(3.,dtype=self.dtype)
         real_x = x*3.
         self.blas_func(3.,x)
         assert_array_almost_equal(real_x,x)
+
     def test_x_stride(self):
         x = arange(6.,dtype=self.dtype)
         real_x = x.copy()
         real_x[::2] = x[::2]*array(3.,self.dtype)
         self.blas_func(3.,x,n=3,incx=2)
         assert_array_almost_equal(real_x,x)
+
     def test_x_bad_size(self):
         x = arange(12.,dtype=self.dtype)
         try:
@@ -133,21 +149,25 @@
             return
         # should catch error and never get here
         assert(0)
+
 try:
     class TestSscal(TestCase, BaseScal):
         blas_func = fblas.sscal
         dtype = float32
 except AttributeError:
     class TestSscal: pass
+
 class TestDscal(TestCase, BaseScal):
     blas_func = fblas.dscal
     dtype = float64
+
 try:
     class TestCscal(TestCase, BaseScal):
         blas_func = fblas.cscal
         dtype = complex64
 except AttributeError:
     class TestCscal: pass
+
 class TestZscal(TestCase, BaseScal):
     blas_func = fblas.zscal
     dtype = complex128
@@ -159,27 +179,33 @@
 ### Test blas ?copy
 
 class BaseCopy(object):
+
     # Mixin class for testing dtypes
+
     def test_simple(self):
         x = arange(3.,dtype=self.dtype)
         y = zeros(shape(x),x.dtype)
         self.blas_func(x,y)
         assert_array_almost_equal(x,y)
+
     def test_x_stride(self):
         x = arange(6.,dtype=self.dtype)
         y = zeros(3,x.dtype)
         self.blas_func(x,y,n=3,incx=2)
         assert_array_almost_equal(x[::2],y)
+
     def test_y_stride(self):
         x = arange(3.,dtype=self.dtype)
         y = zeros(6,x.dtype)
         self.blas_func(x,y,n=3,incy=2)
         assert_array_almost_equal(x,y[::2])
+
     def test_x_and_y_stride(self):
         x = arange(12.,dtype=self.dtype)
         y = zeros(6,x.dtype)
         self.blas_func(x,y,n=3,incx=4,incy=2)
         assert_array_almost_equal(x[::4],y[::2])
+
     def test_x_bad_size(self):
         x = arange(12.,dtype=self.dtype)
         y = zeros(6,x.dtype)
@@ -189,6 +215,7 @@
             return
         # should catch error and never get here
         assert(0)
+
     def test_y_bad_size(self):
         x = arange(12.,dtype=complex64)
         y = zeros(6,x.dtype)
@@ -211,15 +238,18 @@
         dtype = float32
 except AttributeError:
     class TestScopy: pass
+
 class TestDcopy(TestCase, BaseCopy):
     blas_func = fblas.dcopy
     dtype = float64
+
 try:
     class TestCcopy(TestCase, BaseCopy):
         blas_func = fblas.ccopy
         dtype = complex64
 except AttributeError:
     class TestCcopy: pass
+
 class TestZcopy(TestCase, BaseCopy):
     blas_func = fblas.zcopy
     dtype = complex128
@@ -229,7 +259,9 @@
 ### Test blas ?swap
 
 class BaseSwap(object):
+
     # Mixin class to implement test objects
+
     def test_simple(self):
         x = arange(3.,dtype=self.dtype)
         y = zeros(shape(x),x.dtype)
@@ -238,6 +270,7 @@
         self.blas_func(x,y)
         assert_array_almost_equal(desired_x,x)
         assert_array_almost_equal(desired_y,y)
+
     def test_x_stride(self):
         x = arange(6.,dtype=self.dtype)
         y = zeros(3,x.dtype)
@@ -246,6 +279,7 @@
         self.blas_func(x,y,n=3,incx=2)
         assert_array_almost_equal(desired_x,x[::2])
         assert_array_almost_equal(desired_y,y)
+
     def test_y_stride(self):
         x = arange(3.,dtype=self.dtype)
         y = zeros(6,x.dtype)
@@ -263,6 +297,7 @@
         self.blas_func(x,y,n=3,incx=4,incy=2)
         assert_array_almost_equal(desired_x,x[::4])
         assert_array_almost_equal(desired_y,y[::2])
+
     def test_x_bad_size(self):
         x = arange(12.,dtype=self.dtype)
         y = zeros(6,x.dtype)
@@ -272,6 +307,7 @@
             return
         # should catch error and never get here
         assert(0)
+
     def test_y_bad_size(self):
         x = arange(12.,dtype=complex64)
         y = zeros(6,x.dtype)
@@ -288,15 +324,18 @@
         dtype = float32
 except AttributeError:
     class TestSswap: pass
+
 class TestDswap(TestCase, BaseSwap):
     blas_func = fblas.dswap
     dtype = float64
+
 try:
     class TestCswap(TestCase, BaseSwap):
         blas_func = fblas.cswap
         dtype = complex64
 except AttributeError:
     class TestCswap: pass
+
 class TestZswap(TestCase, BaseSwap):
     blas_func = fblas.zswap
     dtype = complex128
@@ -306,7 +345,9 @@
 ### This will be a mess to test all cases.
 
 class BaseGemv(object):
+
     # Mixin class to test dtypes
+
     def get_data(self,x_stride=1,y_stride=1):
         mult = array(1, dtype = self.dtype)
         if self.dtype in [complex64, complex128]:
@@ -318,36 +359,43 @@
         x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult
         y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult
         return alpha,beta,a,x,y
+
     def test_simple(self):
         alpha,beta,a,x,y = self.get_data()
         desired_y = alpha*matrixmultiply(a,x)+beta*y
         y = self.blas_func(alpha,a,x,beta,y)
         assert_array_almost_equal(desired_y,y)
+
     def test_default_beta_y(self):
         alpha,beta,a,x,y = self.get_data()
         desired_y = matrixmultiply(a,x)
         y = self.blas_func(1,a,x)
         assert_array_almost_equal(desired_y,y)
+
     def test_simple_transpose(self):
         alpha,beta,a,x,y = self.get_data()
         desired_y = alpha*matrixmultiply(transpose(a),x)+beta*y
         y = self.blas_func(alpha,a,x,beta,y,trans=1)
         assert_array_almost_equal(desired_y,y)
+
     def test_simple_transpose_conj(self):
         alpha,beta,a,x,y = self.get_data()
         desired_y = alpha*matrixmultiply(transpose(conjugate(a)),x)+beta*y
         y = self.blas_func(alpha,a,x,beta,y,trans=2)
         assert_array_almost_equal(desired_y,y)
+
     def test_x_stride(self):
         alpha,beta,a,x,y = self.get_data(x_stride=2)
         desired_y = alpha*matrixmultiply(a,x[::2])+beta*y
         y = self.blas_func(alpha,a,x,beta,y,incx=2)
         assert_array_almost_equal(desired_y,y)
+
     def test_x_stride_transpose(self):
         alpha,beta,a,x,y = self.get_data(x_stride=2)
         desired_y = alpha*matrixmultiply(transpose(a),x[::2])+beta*y
         y = self.blas_func(alpha,a,x,beta,y,trans=1,incx=2)
         assert_array_almost_equal(desired_y,y)
+
     def test_x_stride_assert(self):
         # What is the use of this test?
         alpha,beta,a,x,y = self.get_data(x_stride=2)
@@ -361,18 +409,21 @@
             assert(0)
         except:
             pass
+
     def test_y_stride(self):
         alpha,beta,a,x,y = self.get_data(y_stride=2)
         desired_y = y.copy()
         desired_y[::2] = alpha*matrixmultiply(a,x)+beta*y[::2]
         y = self.blas_func(alpha,a,x,beta,y,incy=2)
         assert_array_almost_equal(desired_y,y)
+
     def test_y_stride_transpose(self):
         alpha,beta,a,x,y = self.get_data(y_stride=2)
         desired_y = y.copy()
         desired_y[::2] = alpha*matrixmultiply(transpose(a),x)+beta*y[::2]
         y = self.blas_func(alpha,a,x,beta,y,trans=1,incy=2)
         assert_array_almost_equal(desired_y,y)
+
     def test_y_stride_assert(self):
         # What is the use of this test?
         alpha,beta,a,x,y = self.get_data(y_stride=2)
@@ -393,15 +444,18 @@
         dtype = float32
 except AttributeError:
     class TestSgemv: pass
+
 class TestDgemv(TestCase, BaseGemv):
     blas_func = fblas.dgemv
     dtype = float64
+
 try:
     class TestCgemv(TestCase, BaseGemv):
         blas_func = fblas.cgemv
         dtype = complex64
 except AttributeError:
     class TestCgemv: pass
+
 class TestZgemv(TestCase, BaseGemv):
     blas_func = fblas.zgemv
     dtype = complex128
@@ -412,6 +466,7 @@
 ### This will be a mess to test all cases.
 
 class BaseGer(TestCase):
+
     def get_data(self,x_stride=1,y_stride=1):
         from numpy.random import normal
         alpha = array(1., dtype = self.dtype)
@@ -419,17 +474,20 @@
         x = arange(shape(a)[0]*x_stride,dtype=self.dtype)
         y = arange(shape(a)[1]*y_stride,dtype=self.dtype)
         return alpha,a,x,y
+
     def test_simple(self):
         alpha,a,x,y = self.get_data()
         # tranpose takes care of Fortran vs. C(and Python) memory layout
         desired_a = alpha*transpose(x[:,newaxis]*y) + a
         self.blas_func(x,y,a)
         assert_array_almost_equal(desired_a,a)
+
     def test_x_stride(self):
         alpha,a,x,y = self.get_data(x_stride=2)
         desired_a = alpha*transpose(x[::2,newaxis]*y) + a
         self.blas_func(x,y,a,incx=2)
         assert_array_almost_equal(desired_a,a)
+
     def test_x_stride_assert(self):
         alpha,a,x,y = self.get_data(x_stride=2)
         try:
@@ -437,6 +495,7 @@
             assert(0)
         except:
             pass
+
     def test_y_stride(self):
         alpha,a,x,y = self.get_data(y_stride=2)
         desired_a = alpha*transpose(x[:,newaxis]*y[::2]) + a
@@ -454,6 +513,7 @@
 class TestSger(BaseGer):
     blas_func = fblas.sger
     dtype = float32
+
 class TestDger(BaseGer):
     blas_func = fblas.dger
     dtype = float64
@@ -464,6 +524,7 @@
 
 """
 class BaseGerComplex(BaseGer):
+
     def get_data(self,x_stride=1,y_stride=1):
         from numpy.random import normal
         alpha = array(1+1j, dtype = self.dtype)
@@ -474,6 +535,7 @@
         y = normal(0.,1.,shape(a)[1]*y_stride).astype(self.dtype)
         y = y + y * array(1j, dtype = self.dtype)
         return alpha,a,x,y
+
     def test_simple(self):
         alpha,a,x,y = self.get_data()
         # tranpose takes care of Fortran vs. C(and Python) memory layout
@@ -496,25 +558,34 @@
     #    assert_array_almost_equal(desired_a,a)
 
 class TestCgeru(BaseGerComplex):
+
     blas_func = fblas.cgeru
     dtype = complex64
+
     def transform(self,x):
         return x
+
 class TestZgeru(BaseGerComplex):
+
     blas_func = fblas.zgeru
     dtype = complex128
+
     def transform(self,x):
         return x
 
 class TestCgerc(BaseGerComplex):
+
     blas_func = fblas.cgerc
     dtype = complex64
+
     def transform(self,x):
         return conjugate(x)
 
 class TestZgerc(BaseGerComplex):
+
     blas_func = fblas.zgerc
     dtype = complex128
+
     def transform(self,x):
         return conjugate(x)
 """




More information about the Scipy-svn mailing list