[Scipy-svn] r7113 - branches/0.9.x/scipy/fftpack/tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Jan 31 16:18:57 EST 2011


Author: ptvirtan
Date: 2011-01-31 15:18:57 -0600 (Mon, 31 Jan 2011)
New Revision: 7113

Modified:
   branches/0.9.x/scipy/fftpack/tests/test_basic.py
   branches/0.9.x/scipy/fftpack/tests/test_pseudo_diffs.py
   branches/0.9.x/scipy/fftpack/tests/test_real_transforms.py
Log:
TST: fftpack: add tests checking fft routine overwrite behavior

(backport of r7111)

Modified: branches/0.9.x/scipy/fftpack/tests/test_basic.py
===================================================================
--- branches/0.9.x/scipy/fftpack/tests/test_basic.py	2011-01-31 21:18:43 UTC (rev 7112)
+++ branches/0.9.x/scipy/fftpack/tests/test_basic.py	2011-01-31 21:18:57 UTC (rev 7113)
@@ -652,5 +652,109 @@
             except ValueError:
                 pass
 
+
+
+class TestOverwrite(object):
+    """
+    Check input overwrite behavior of the FFT functions
+    """
+
+    real_dtypes = [np.float32, np.float64]
+    dtypes = real_dtypes + [np.complex64, np.complex128]
+
+    def _check(self, x, routine, fftsize, axis):
+        x2 = x.copy()
+        y = routine(x2, fftsize, axis)
+
+        sig = "%s(%s%r, %r, axis=%r)" % (routine.__name__, x.dtype, x.shape,
+                                         fftsize, axis)
+        assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
+
+    def _check_1d(self, routine, dtype, shape, axis):
+        np.random.seed(1234)
+        if np.issubdtype(dtype, np.complexfloating):
+            data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
+        else:
+            data = np.random.randn(*shape)
+        data = data.astype(dtype)
+
+        for fftsize in [8, 16, 32]:
+            self._check(data, routine, fftsize, axis)
+
+    def test_fft(self):
+        for dtype in self.dtypes:
+            self._check_1d(fft, dtype, (16,), -1)
+            self._check_1d(fft, dtype, (16, 2), 0)
+            self._check_1d(fft, dtype, (2, 16), 1)
+
+    def test_ifft(self):
+        for dtype in self.dtypes:
+            self._check_1d(ifft, dtype, (16,), -1)
+            self._check_1d(ifft, dtype, (16, 2), 0)
+            self._check_1d(ifft, dtype, (2, 16), 1)
+
+    def test_rfft(self):
+        for dtype in self.real_dtypes:
+            self._check_1d(rfft, dtype, (16,), -1)
+            self._check_1d(rfft, dtype, (16, 2), 0)
+            self._check_1d(rfft, dtype, (2, 16), 1)
+
+    def test_irfft(self):
+        for dtype in self.real_dtypes:
+            self._check_1d(irfft, dtype, (16,), -1)
+            self._check_1d(irfft, dtype, (16, 2), 0)
+            self._check_1d(irfft, dtype, (2, 16), 1)
+
+    def _check_nd_one(self, routine, dtype, shape, axes):
+        np.random.seed(1234)
+        if np.issubdtype(dtype, np.complexfloating):
+            data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
+        else:
+            data = np.random.randn(*shape)
+        data = data.astype(dtype)
+
+        def fftshape_iter(shp):
+            if len(shp) <= 0:
+                yield ()
+            else:
+                for j in (shp[0]//2, shp[0], shp[0]*2):
+                    for rest in fftshape_iter(shp[1:]):
+                        yield (j,) + rest
+
+        if axes is None:
+            part_shape = shape
+        else:
+            part_shape = tuple(np.take(shape, axes))
+
+        for fftshape in fftshape_iter(part_shape):
+            self._check(data, routine, fftshape, axes)
+            if data.ndim > 1:
+                # check fortran order: it never overwrites
+                self._check(data.T, routine, fftshape, axes)
+
+    def _check_nd(self, routine, dtype):
+        self._check_nd_one(routine, dtype, (16,), None)
+        self._check_nd_one(routine, dtype, (16,), (0,))
+        self._check_nd_one(routine, dtype, (16, 2), (0,))
+        self._check_nd_one(routine, dtype, (2, 16), (1,))
+        self._check_nd_one(routine, dtype, (8, 16), None)
+        self._check_nd_one(routine, dtype, (8, 16), (0, 1))
+        self._check_nd_one(routine, dtype, (8, 16, 2), (0, 1))
+        self._check_nd_one(routine, dtype, (8, 16, 2), (1, 2))
+        self._check_nd_one(routine, dtype, (8, 16, 2), (0,))
+        self._check_nd_one(routine, dtype, (8, 16, 2), (1,))
+        self._check_nd_one(routine, dtype, (8, 16, 2), (2,))
+        self._check_nd_one(routine, dtype, (8, 16, 2), None)
+        self._check_nd_one(routine, dtype, (8, 16, 2), (0,1,2))
+
+    def test_fftn(self):
+        for dtype in self.dtypes:
+            self._check_nd(fftn, dtype)
+
+    def test_ifftn(self):
+        for dtype in self.dtypes:
+            self._check_nd(ifftn, dtype)
+
+
 if __name__ == "__main__":
     run_module_suite()

Modified: branches/0.9.x/scipy/fftpack/tests/test_pseudo_diffs.py
===================================================================
--- branches/0.9.x/scipy/fftpack/tests/test_pseudo_diffs.py	2011-01-31 21:18:43 UTC (rev 7112)
+++ branches/0.9.x/scipy/fftpack/tests/test_pseudo_diffs.py	2011-01-31 21:18:57 UTC (rev 7113)
@@ -13,8 +13,10 @@
 
 from numpy.testing import *
 from scipy.fftpack import diff, fft, ifft, tilbert, itilbert, hilbert, \
-                          ihilbert, shift, fftfreq
+                          ihilbert, shift, fftfreq, cs_diff, sc_diff, \
+                          ss_diff, cc_diff
 
+import numpy as np
 from numpy import arange, sin, cos, pi, exp, tanh, sum, sign
 
 def random(size):
@@ -312,5 +314,68 @@
             assert_array_almost_equal(shift(sin(x),pi/2),cos(x))
 
 
+class TestOverwrite(object):
+    """
+    Check input overwrite behavior
+    """
+
+    real_dtypes = [np.float32, np.float64]
+    dtypes = real_dtypes + [np.complex64, np.complex128]
+
+    def _check(self, x, routine, *args, **kwargs):
+        x2 = x.copy()
+        y = routine(x2, *args, **kwargs)
+        sig = routine.__name__
+        if args:
+            sig += repr(args)
+        if kwargs:
+            sig += repr(kwargs)
+        assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
+
+    def _check_1d(self, routine, dtype, shape, *args, **kwargs):
+        np.random.seed(1234)
+        if np.issubdtype(dtype, np.complexfloating):
+            data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
+        else:
+            data = np.random.randn(*shape)
+        data = data.astype(dtype)
+        self._check(data, routine, *args, **kwargs)
+
+    def test_diff(self):
+        for dtype in self.dtypes:
+            self._check_1d(diff, dtype, (16,))
+
+    def test_tilbert(self):
+        for dtype in self.dtypes:
+            self._check_1d(tilbert, dtype, (16,), 1.6)
+
+    def test_itilbert(self):
+        for dtype in self.dtypes:
+            self._check_1d(itilbert, dtype, (16,), 1.6)
+
+    def test_hilbert(self):
+        for dtype in self.dtypes:
+            self._check_1d(hilbert, dtype, (16,))
+
+    def test_cs_diff(self):
+        for dtype in self.dtypes:
+            self._check_1d(cs_diff, dtype, (16,), 1.0, 4.0)
+
+    def test_sc_diff(self):
+        for dtype in self.dtypes:
+            self._check_1d(sc_diff, dtype, (16,), 1.0, 4.0)
+
+    def test_ss_diff(self):
+        for dtype in self.dtypes:
+            self._check_1d(ss_diff, dtype, (16,), 1.0, 4.0)
+
+    def test_cc_diff(self):
+        for dtype in self.dtypes:
+            self._check_1d(cc_diff, dtype, (16,), 1.0, 4.0)
+
+    def test_shift(self):
+        for dtype in self.dtypes:
+            self._check_1d(shift, dtype, (16,), 1.0)
+
 if __name__ == "__main__":
     run_module_suite()

Modified: branches/0.9.x/scipy/fftpack/tests/test_real_transforms.py
===================================================================
--- branches/0.9.x/scipy/fftpack/tests/test_real_transforms.py	2011-01-31 21:18:43 UTC (rev 7112)
+++ branches/0.9.x/scipy/fftpack/tests/test_real_transforms.py	2011-01-31 21:18:57 UTC (rev 7113)
@@ -3,7 +3,7 @@
 
 import numpy as np
 from numpy.fft import fft as numfft
-from numpy.testing import assert_array_almost_equal, TestCase
+from numpy.testing import assert_array_almost_equal, assert_equal, TestCase
 
 from scipy.fftpack.realtransforms import dct, idct
 
@@ -47,8 +47,8 @@
             # XXX: we divide by np.max(y) because the tests fail otherwise. We
             # should really use something like assert_array_approx_equal. The
             # difference is due to fftw using a better algorithm w.r.t error
-            # propagation compared to the ones from fftpack. 
-            assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec, 
+            # propagation compared to the ones from fftpack.
+            assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec,
                     err_msg="Size %d failed" % i)
 
     def test_axis(self):
@@ -144,8 +144,8 @@
             # XXX: we divide by np.max(y) because the tests fail otherwise. We
             # should really use something like assert_array_approx_equal. The
             # difference is due to fftw using a better algorithm w.r.t error
-            # propagation compared to the ones from fftpack. 
-            assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec, 
+            # propagation compared to the ones from fftpack.
+            assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec,
                     err_msg="Size %d failed" % i)
 
 class TestIDCTIDouble(_TestIDCTBase):
@@ -184,5 +184,46 @@
         self.dec = 5
         self.type = 3
 
+class TestOverwrite(object):
+    """
+    Check input overwrite behavior
+    """
+
+    real_dtypes = [np.float32, np.float64]
+
+    def _check(self, x, routine, type, fftsize, axis, norm):
+        x2 = x.copy()
+        y = routine(x2, type, fftsize, axis, norm)
+
+        sig = "%s(%s%r, %r, axis=%r)" % (
+            routine.__name__, x.dtype, x.shape, fftsize, axis)
+        assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
+
+    def _check_1d(self, routine, dtype, shape, axis):
+        np.random.seed(1234)
+        if np.issubdtype(dtype, np.complexfloating):
+            data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
+        else:
+            data = np.random.randn(*shape)
+        data = data.astype(dtype)
+
+        for type in [1, 2, 3]:
+            for norm in [None, 'ortho']:
+                if type == 1 and norm == 'ortho':
+                    continue
+                self._check(data, routine, type, None, axis, norm)
+
+    def test_dct(self):
+        for dtype in self.real_dtypes:
+            self._check_1d(dct, dtype, (16,), -1)
+            self._check_1d(dct, dtype, (16, 2), 0)
+            self._check_1d(dct, dtype, (2, 16), 1)
+
+    def test_idct(self):
+        for dtype in self.real_dtypes:
+            self._check_1d(idct, dtype, (16,), -1)
+            self._check_1d(idct, dtype, (16, 2), 0)
+            self._check_1d(idct, dtype, (2, 16), 1)
+
 if __name__ == "__main__":
     np.testing.run_module_suite()




More information about the Scipy-svn mailing list