[Scipy-svn] r6803 - in trunk/scipy: lib/blas/tests linalg/tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Sun Sep 12 18:11:24 EDT 2010
Author: warren.weckesser
Date: 2010-09-12 17:11:24 -0500 (Sun, 12 Sep 2010)
New Revision: 6803
Modified:
trunk/scipy/lib/blas/tests/test_fblas.py
trunk/scipy/linalg/tests/test_fblas.py
Log:
TST: Remove plain asserts from two more files.
Modified: trunk/scipy/lib/blas/tests/test_fblas.py
===================================================================
--- trunk/scipy/lib/blas/tests/test_fblas.py 2010-09-12 21:46:08 UTC (rev 6802)
+++ trunk/scipy/lib/blas/tests/test_fblas.py 2010-09-12 22:11:24 UTC (rev 6803)
@@ -86,7 +86,7 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
def test_y_bad_size(self):
x = arange(12.,dtype=complex64)
@@ -96,7 +96,7 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
try:
class TestSaxpy(TestCase, BaseAxpy):
@@ -148,7 +148,7 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
try:
class TestSscal(TestCase, BaseScal):
@@ -214,7 +214,7 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
def test_y_bad_size(self):
x = arange(12.,dtype=complex64)
@@ -224,7 +224,8 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
+
#def test_y_bad_type(self):
## Hmmm. Should this work? What should be the output.
# x = arange(3.,dtype=self.dtype)
@@ -306,7 +307,7 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
def test_y_bad_size(self):
x = arange(12.,dtype=complex64)
@@ -316,7 +317,7 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
try:
class TestSswap(TestCase, BaseSwap):
@@ -401,12 +402,12 @@
alpha,beta,a,x,y = self.get_data(x_stride=2)
try:
y = self.blas_func(1,a,x,1,y,trans=0,incx=3)
- assert(0)
+ assert_(0)
except:
pass
try:
y = self.blas_func(1,a,x,1,y,trans=1,incx=3)
- assert(0)
+ assert_(0)
except:
pass
@@ -429,12 +430,12 @@
alpha,beta,a,x,y = self.get_data(y_stride=2)
try:
y = self.blas_func(1,a,x,1,y,trans=0,incy=3)
- assert(0)
+ assert_(0)
except:
pass
try:
y = self.blas_func(1,a,x,1,y,trans=1,incy=3)
- assert(0)
+ assert_(0)
except:
pass
Modified: trunk/scipy/linalg/tests/test_fblas.py
===================================================================
--- trunk/scipy/linalg/tests/test_fblas.py 2010-09-12 21:46:08 UTC (rev 6802)
+++ trunk/scipy/linalg/tests/test_fblas.py 2010-09-12 22:11:24 UTC (rev 6803)
@@ -42,18 +42,21 @@
class BaseAxpy(object):
''' Mixin class for axpy tests '''
+
def test_default_a(self):
x = arange(3.,dtype=self.dtype)
y = arange(3.,dtype=x.dtype)
real_y = x*1.+y
self.blas_func(x,y)
assert_array_equal(real_y,y)
+
def test_simple(self):
x = arange(3.,dtype=self.dtype)
y = arange(3.,dtype=x.dtype)
real_y = x*3.+y
self.blas_func(x,y,a=3.)
assert_array_equal(real_y,y)
+
def test_x_stride(self):
x = arange(6.,dtype=self.dtype)
y = zeros(3,x.dtype)
@@ -61,18 +64,21 @@
real_y = x[::2]*3.+y
self.blas_func(x,y,a=3.,n=3,incx=2)
assert_array_equal(real_y,y)
+
def test_y_stride(self):
x = arange(3.,dtype=self.dtype)
y = zeros(6,x.dtype)
real_y = x*3.+y[::2]
self.blas_func(x,y,a=3.,n=3,incy=2)
assert_array_equal(real_y,y[::2])
+
def test_x_and_y_stride(self):
x = arange(12.,dtype=self.dtype)
y = zeros(6,x.dtype)
real_y = x[::4]*3.+y[::2]
self.blas_func(x,y,a=3.,n=3,incx=4,incy=2)
assert_array_equal(real_y,y[::2])
+
def test_x_bad_size(self):
x = arange(12.,dtype=self.dtype)
y = zeros(6,x.dtype)
@@ -81,7 +87,8 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
+
def test_y_bad_size(self):
x = arange(12.,dtype=complex64)
y = zeros(6,x.dtype)
@@ -90,7 +97,7 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
try:
class TestSaxpy(TestCase, BaseAxpy):
@@ -98,15 +105,18 @@
dtype = float32
except AttributeError:
class TestSaxpy: pass
+
class TestDaxpy(TestCase, BaseAxpy):
blas_func = fblas.daxpy
dtype = float64
+
try:
class TestCaxpy(TestCase, BaseAxpy):
blas_func = fblas.caxpy
dtype = complex64
except AttributeError:
class TestCaxpy: pass
+
class TestZaxpy(TestCase, BaseAxpy):
blas_func = fblas.zaxpy
dtype = complex128
@@ -117,17 +127,20 @@
class BaseScal(object):
''' Mixin class for scal testing '''
+
def test_simple(self):
x = arange(3.,dtype=self.dtype)
real_x = x*3.
self.blas_func(3.,x)
assert_array_equal(real_x,x)
+
def test_x_stride(self):
x = arange(6.,dtype=self.dtype)
real_x = x.copy()
real_x[::2] = x[::2]*array(3.,self.dtype)
self.blas_func(3.,x,n=3,incx=2)
assert_array_equal(real_x,x)
+
def test_x_bad_size(self):
x = arange(12.,dtype=self.dtype)
try:
@@ -135,54 +148,61 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
+
try:
class TestSscal(TestCase, BaseScal):
blas_func = fblas.sscal
dtype = float32
except AttributeError:
class TestSscal: pass
+
class TestDscal(TestCase, BaseScal):
blas_func = fblas.dscal
dtype = float64
+
try:
class TestCscal(TestCase, BaseScal):
blas_func = fblas.cscal
dtype = complex64
except AttributeError:
class TestCscal: pass
+
class TestZscal(TestCase, BaseScal):
blas_func = fblas.zscal
dtype = complex128
-
-
##################################################
### Test blas ?copy
class BaseCopy(object):
''' Mixin class for copy testing '''
+
def test_simple(self):
x = arange(3.,dtype=self.dtype)
y = zeros(shape(x),x.dtype)
self.blas_func(x,y)
assert_array_equal(x,y)
+
def test_x_stride(self):
x = arange(6.,dtype=self.dtype)
y = zeros(3,x.dtype)
self.blas_func(x,y,n=3,incx=2)
assert_array_equal(x[::2],y)
+
def test_y_stride(self):
x = arange(3.,dtype=self.dtype)
y = zeros(6,x.dtype)
self.blas_func(x,y,n=3,incy=2)
assert_array_equal(x,y[::2])
+
def test_x_and_y_stride(self):
x = arange(12.,dtype=self.dtype)
y = zeros(6,x.dtype)
self.blas_func(x,y,n=3,incx=4,incy=2)
assert_array_equal(x[::4],y[::2])
+
def test_x_bad_size(self):
x = arange(12.,dtype=self.dtype)
y = zeros(6,x.dtype)
@@ -191,7 +211,8 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
+
def test_y_bad_size(self):
x = arange(12.,dtype=complex64)
y = zeros(6,x.dtype)
@@ -200,7 +221,8 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
+
#def test_y_bad_type(self):
## Hmmm. Should this work? What should be the output.
# x = arange(3.,dtype=self.dtype)
@@ -214,15 +236,18 @@
dtype = float32
except AttributeError:
class TestScopy: pass
+
class TestDcopy(TestCase, BaseCopy):
blas_func = fblas.dcopy
dtype = float64
+
try:
class TestCcopy(TestCase, BaseCopy):
blas_func = fblas.ccopy
dtype = complex64
except AttributeError:
class TestCcopy: pass
+
class TestZcopy(TestCase, BaseCopy):
blas_func = fblas.zcopy
dtype = complex128
@@ -233,6 +258,7 @@
class BaseSwap(object):
''' Mixin class for swap tests '''
+
def test_simple(self):
x = arange(3.,dtype=self.dtype)
y = zeros(shape(x),x.dtype)
@@ -241,6 +267,7 @@
self.blas_func(x,y)
assert_array_equal(desired_x,x)
assert_array_equal(desired_y,y)
+
def test_x_stride(self):
x = arange(6.,dtype=self.dtype)
y = zeros(3,x.dtype)
@@ -249,6 +276,7 @@
self.blas_func(x,y,n=3,incx=2)
assert_array_equal(desired_x,x[::2])
assert_array_equal(desired_y,y)
+
def test_y_stride(self):
x = arange(3.,dtype=self.dtype)
y = zeros(6,x.dtype)
@@ -266,6 +294,7 @@
self.blas_func(x,y,n=3,incx=4,incy=2)
assert_array_equal(desired_x,x[::4])
assert_array_equal(desired_y,y[::2])
+
def test_x_bad_size(self):
x = arange(12.,dtype=self.dtype)
y = zeros(6,x.dtype)
@@ -274,7 +303,8 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
+
def test_y_bad_size(self):
x = arange(12.,dtype=complex64)
y = zeros(6,x.dtype)
@@ -283,7 +313,7 @@
except: # what kind of error should be caught?
return
# should catch error and never get here
- assert(0)
+ assert_(0)
try:
class TestSswap(TestCase, BaseSwap):
@@ -291,15 +321,18 @@
dtype = float32
except AttributeError:
class TestSswap: pass
+
class TestDswap(TestCase, BaseSwap):
blas_func = fblas.dswap
dtype = float64
+
try:
class TestCswap(TestCase, BaseSwap):
blas_func = fblas.cswap
dtype = complex64
except AttributeError:
class TestCswap: pass
+
class TestZswap(TestCase, BaseSwap):
blas_func = fblas.zswap
dtype = complex128
@@ -310,6 +343,7 @@
class BaseGemv(object):
''' Mixin class for gemv tests '''
+
def get_data(self,x_stride=1,y_stride=1):
mult = array(1, dtype = self.dtype)
if self.dtype in [complex64, complex128]:
@@ -321,72 +355,82 @@
x = arange(shape(a)[0]*x_stride,dtype=self.dtype) * mult
y = arange(shape(a)[1]*y_stride,dtype=self.dtype) * mult
return alpha,beta,a,x,y
+
def test_simple(self):
alpha,beta,a,x,y = self.get_data()
desired_y = alpha*matrixmultiply(a,x)+beta*y
y = self.blas_func(alpha,a,x,beta,y)
assert_array_almost_equal(desired_y,y)
+
def test_default_beta_y(self):
alpha,beta,a,x,y = self.get_data()
desired_y = matrixmultiply(a,x)
y = self.blas_func(1,a,x)
assert_array_almost_equal(desired_y,y)
+
def test_simple_transpose(self):
alpha,beta,a,x,y = self.get_data()
desired_y = alpha*matrixmultiply(transpose(a),x)+beta*y
y = self.blas_func(alpha,a,x,beta,y,trans=1)
assert_array_almost_equal(desired_y,y)
+
def test_simple_transpose_conj(self):
alpha,beta,a,x,y = self.get_data()
desired_y = alpha*matrixmultiply(transpose(conjugate(a)),x)+beta*y
y = self.blas_func(alpha,a,x,beta,y,trans=2)
assert_array_almost_equal(desired_y,y)
+
def test_x_stride(self):
alpha,beta,a,x,y = self.get_data(x_stride=2)
desired_y = alpha*matrixmultiply(a,x[::2])+beta*y
y = self.blas_func(alpha,a,x,beta,y,incx=2)
assert_array_almost_equal(desired_y,y)
+
def test_x_stride_transpose(self):
alpha,beta,a,x,y = self.get_data(x_stride=2)
desired_y = alpha*matrixmultiply(transpose(a),x[::2])+beta*y
y = self.blas_func(alpha,a,x,beta,y,trans=1,incx=2)
assert_array_almost_equal(desired_y,y)
+
def test_x_stride_assert(self):
# What is the use of this test?
alpha,beta,a,x,y = self.get_data(x_stride=2)
try:
y = self.blas_func(1,a,x,1,y,trans=0,incx=3)
- assert(0)
+ assert_(0)
except:
pass
try:
y = self.blas_func(1,a,x,1,y,trans=1,incx=3)
- assert(0)
+ assert_(0)
except:
pass
+
def test_y_stride(self):
alpha,beta,a,x,y = self.get_data(y_stride=2)
desired_y = y.copy()
desired_y[::2] = alpha*matrixmultiply(a,x)+beta*y[::2]
y = self.blas_func(alpha,a,x,beta,y,incy=2)
assert_array_almost_equal(desired_y,y)
+
def test_y_stride_transpose(self):
alpha,beta,a,x,y = self.get_data(y_stride=2)
desired_y = y.copy()
desired_y[::2] = alpha*matrixmultiply(transpose(a),x)+beta*y[::2]
y = self.blas_func(alpha,a,x,beta,y,trans=1,incy=2)
assert_array_almost_equal(desired_y,y)
+
def test_y_stride_assert(self):
# What is the use of this test?
alpha,beta,a,x,y = self.get_data(y_stride=2)
try:
y = self.blas_func(1,a,x,1,y,trans=0,incy=3)
- assert(0)
+ assert_(0)
except:
pass
try:
y = self.blas_func(1,a,x,1,y,trans=1,incy=3)
- assert(0)
+ assert_(0)
except:
pass
@@ -396,15 +440,18 @@
dtype = float32
except AttributeError:
class TestSgemv: pass
+
class TestDgemv(TestCase, BaseGemv):
blas_func = fblas.dgemv
dtype = float64
+
try:
class TestCgemv(TestCase, BaseGemv):
blas_func = fblas.cgemv
dtype = complex64
except AttributeError:
class TestCgemv: pass
+
class TestZgemv(TestCase, BaseGemv):
blas_func = fblas.zgemv
dtype = complex128
More information about the Scipy-svn
mailing list