[Scipy-svn] r5500 - trunk/scipy/fftpack/tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Jan 19 03:33:26 EST 2009


Author: cdavid
Date: 2009-01-19 02:33:22 -0600 (Mon, 19 Jan 2009)
New Revision: 5500

Modified:
   trunk/scipy/fftpack/tests/test_real_transforms.py
Log:
Adapt DCT tests to new API.

Modified: trunk/scipy/fftpack/tests/test_real_transforms.py
===================================================================
--- trunk/scipy/fftpack/tests/test_real_transforms.py	2009-01-19 08:33:05 UTC (rev 5499)
+++ trunk/scipy/fftpack/tests/test_real_transforms.py	2009-01-19 08:33:22 UTC (rev 5500)
@@ -5,7 +5,7 @@
 from numpy.fft import fft as numfft
 from numpy.testing import assert_array_almost_equal, TestCase
 
-from scipy.fftpack.realtransforms import dct1, dct2, dct3
+from scipy.fftpack.realtransforms import dct, idct
 
 # Matlab reference data
 MDATA = np.load(join(dirname(__file__), 'test.npz'))
@@ -21,8 +21,6 @@
 FFTWDATA_SINGLE = np.load(join(dirname(__file__), 'fftw_single_ref.npz'))
 FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes']
 
-TYPE2DCT = {1: dct1, 2: dct2, 3: dct3}
-
 def fftw_ref(type, size, dt):
     x = np.linspace(0, size-1, size).astype(dt)
     if dt == np.double:
@@ -39,12 +37,11 @@
         self.rdt = None
         self.dec = 14
         self.type = None
-        self.func = None
 
     def test_definition(self):
         for i in FFTWDATA_SIZES:
             x, yr = fftw_ref(self.type, i, self.rdt)
-            y = self.func(x)
+            y = dct(x, type=self.type)
             self.failUnless(y.dtype == self.rdt,
                     "Output dtype is %s, expected %s" % (y.dtype, self.rdt))
             # XXX: we divide by np.max(y) because the tests fail otherwise. We
@@ -58,14 +55,16 @@
         nt = 2
         for i in [7, 8, 9, 16, 32, 64]:
             x = np.random.randn(nt, i)
-            y = self.func(x)
+            y = dct(x, type=self.type)
             for j in range(nt):
-                assert_array_almost_equal(y[j], self.func(x[j]), decimal=self.dec)
+                assert_array_almost_equal(y[j], dct(x[j], type=self.type),
+                        decimal=self.dec)
 
             x = x.T
-            y = self.func(x, axis=0)
+            y = dct(x, axis=0, type=self.type)
             for j in range(nt):
-                assert_array_almost_equal(y[:,j], self.func(x[:,j]), decimal=self.dec)
+                assert_array_almost_equal(y[:,j], dct(x[:,j], type=self.type),
+                        decimal=self.dec)
 
 class _TestDCTIIBase(_TestDCTBase):
     def test_definition_matlab(self):
@@ -73,7 +72,7 @@
         for i in range(len(X)):
             x = np.array(X[i], dtype=self.rdt)
             yr = Y[i]
-            y = dct2(x, norm="ortho")
+            y = dct(x, norm="ortho", type=2)
             self.failUnless(y.dtype == self.rdt,
                     "Output dtype is %s, expected %s" % (y.dtype, self.rdt))
             assert_array_almost_equal(y, yr, decimal=self.dec)
@@ -83,8 +82,8 @@
         """Test orthornomal mode."""
         for i in range(len(X)):
             x = np.array(X[i], dtype=self.rdt)
-            y = dct2(x, norm='ortho')
-            xi = dct3(y, norm="ortho")
+            y = dct(x, norm='ortho', type=2)
+            xi = dct(y, norm="ortho", type=3)
             self.failUnless(xi.dtype == self.rdt,
                     "Output dtype is %s, expected %s" % (xi.dtype, self.rdt))
             assert_array_almost_equal(xi, x, decimal=self.dec)
@@ -94,42 +93,96 @@
         self.rdt = np.double
         self.dec = 10
         self.type = 1
-        self.func = TYPE2DCT[self.type]
 
 class TestDCTIFloat(_TestDCTBase):
     def setUp(self):
         self.rdt = np.float32
         self.dec = 5
         self.type = 1
-        self.func = TYPE2DCT[self.type]
 
 class TestDCTIIDouble(_TestDCTIIBase):
     def setUp(self):
         self.rdt = np.double
         self.dec = 10
         self.type = 2
-        self.func = TYPE2DCT[self.type]
 
 class TestDCTIIFloat(_TestDCTIIBase):
     def setUp(self):
         self.rdt = np.float32
         self.dec = 5
         self.type = 2
-        self.func = TYPE2DCT[self.type]
 
 class TestDCTIIIDouble(_TestDCTIIIBase):
     def setUp(self):
         self.rdt = np.double
         self.dec = 14
         self.type = 3
-        self.func = TYPE2DCT[self.type]
 
 class TestDCTIIIFloat(_TestDCTIIIBase):
     def setUp(self):
         self.rdt = np.float32
         self.dec = 5
         self.type = 3
-        self.func = TYPE2DCT[self.type]
 
+class _TestIDCTBase(TestCase):
+    def setUp(self):
+        self.rdt = None
+        self.dec = 14
+        self.type = None
+
+    def test_definition(self):
+        for i in FFTWDATA_SIZES:
+            xr, yr = fftw_ref(self.type, i, self.rdt)
+            y = dct(xr, type=self.type)
+            x = idct(yr, type=self.type)
+            if self.type == 1:
+                x /= 2 * (i-1)
+            else:
+                x /= 2 * i
+            self.failUnless(x.dtype == self.rdt,
+                    "Output dtype is %s, expected %s" % (x.dtype, self.rdt))
+            # XXX: we divide by np.max(y) because the tests fail otherwise. We
+            # should really use something like assert_array_approx_equal. The
+            # difference is due to fftw using a better algorithm w.r.t error
+            # propagation compared to the ones from fftpack. 
+            assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec, 
+                    err_msg="Size %d failed" % i)
+
+class TestIDCTIDouble(_TestIDCTBase):
+    def setUp(self):
+        self.rdt = np.double
+        self.dec = 10
+        self.type = 1
+
+class TestIDCTIFloat(_TestIDCTBase):
+    def setUp(self):
+        self.rdt = np.float32
+        self.dec = 4
+        self.type = 1
+
+class TestIDCTIIDouble(_TestIDCTBase):
+    def setUp(self):
+        self.rdt = np.double
+        self.dec = 10
+        self.type = 2
+
+class TestIDCTIIFloat(_TestIDCTBase):
+    def setUp(self):
+        self.rdt = np.float32
+        self.dec = 5
+        self.type = 2
+
+class TestIDCTIIIDouble(_TestIDCTBase):
+    def setUp(self):
+        self.rdt = np.double
+        self.dec = 14
+        self.type = 3
+
+class TestIDCTIIIFloat(_TestIDCTBase):
+    def setUp(self):
+        self.rdt = np.float32
+        self.dec = 5
+        self.type = 3
+
 if __name__ == "__main__":
     np.testing.run_module_suite()




More information about the Scipy-svn mailing list