[Scipy-svn] r4269 - branches/refactor_fft/scipy/fftpack/src/fftpack
scipy-svn at scipy.org
scipy-svn at scipy.org
Sun May 11 08:49:13 EDT 2008
Author: cdavid
Date: 2008-05-11 07:49:09 -0500 (Sun, 11 May 2008)
New Revision: 4269
Modified:
branches/refactor_fft/scipy/fftpack/src/fftpack/zfftnd.cxx
Log:
fftpack backend for zfftnd now uses c++ cyclic cache.
Modified: branches/refactor_fft/scipy/fftpack/src/fftpack/zfftnd.cxx
===================================================================
--- branches/refactor_fft/scipy/fftpack/src/fftpack/zfftnd.cxx 2008-05-11 12:07:10 UTC (rev 4268)
+++ branches/refactor_fft/scipy/fftpack/src/fftpack/zfftnd.cxx 2008-05-11 12:49:09 UTC (rev 4269)
@@ -3,9 +3,146 @@
*
* Original code by Pearu Peaterson
*
- * Last Change: Wed Aug 08 02:00 PM 2007 J
+ * Last Change: Sun May 11 09:00 PM 2008 J
*/
+#include <new>
+#include "cycliccache.h"
+
+extern "C" {
+extern void zfft(complex_double * inout,
+ int n, int direction, int howmany, int normalize);
+};
+
+static int next_comb(int *ia, int *da, int m);
+static void flatten(complex_double * dest, complex_double * src,
+ int rank, int strides_axis, int dims_axis, int unflat,
+ int *tmp);
+
+
+using namespace fft;
+
+class NDFFTPackCacheId : public CacheId {
+ public:
+ NDFFTPackCacheId(int n, int rank) :
+ CacheId(n),
+ m_rank(rank)
+ {
+ };
+
+ virtual bool operator==(const NDFFTPackCacheId &other) const
+ {
+ return is_equal(other);
+ };
+
+ virtual bool is_equal(const NDFFTPackCacheId &other) const;
+
+ public:
+ int m_rank;
+};
+
+bool NDFFTPackCacheId::is_equal(const NDFFTPackCacheId & other) const
+{
+ return m_n == other.m_n && m_rank == other.m_rank;
+}
+
+class NDFFTPackCache: public Cache<NDFFTPackCacheId> {
+ public:
+ NDFFTPackCache(const NDFFTPackCacheId& id);
+ virtual ~NDFFTPackCache();
+
+ int compute(complex_double * inout, int sz, int* dims,
+ int direction, int normalize, int howmany) const;
+
+ protected:
+ complex_double* m_wsave;
+ int* m_iptr;
+
+ private:
+ int prepare(int *dims) const;
+};
+
+NDFFTPackCache::NDFFTPackCache(const NDFFTPackCacheId& id)
+: Cache<NDFFTPackCacheId>(id)
+{
+ int n = id.m_n;
+ int rank = id.m_rank;
+
+ m_wsave = (complex_double *)malloc(sizeof(*m_wsave) * (2 * n));
+ if (m_wsave == NULL) {
+ goto fail;
+ }
+
+ m_iptr = (int*)malloc(4 * rank * sizeof(*m_iptr));
+ if (m_iptr == NULL) {
+ goto clean_wsave;
+ }
+
+ return;
+
+clean_wsave:
+ free(m_wsave);
+fail:
+ throw std::bad_alloc();
+}
+
+NDFFTPackCache::~NDFFTPackCache()
+{
+ free(m_iptr);
+ free(m_wsave);
+}
+
+int NDFFTPackCache::compute(complex_double *inout, int sz, int *dims,
+ int direction, int normalize,
+ int howmany) const
+{
+ int rank = m_id.m_rank;
+ int i, axis, k, j;
+ complex_double *tmp = m_wsave;
+ complex_double *ptr = inout;
+
+ zfft(inout, dims[rank - 1], direction, howmany * sz / dims[rank - 1],
+ normalize);
+ prepare(dims);
+
+ for (i = 0; i < howmany; ++i, ptr += sz) {
+ for (axis = 0; axis < rank - 1; ++axis) {
+ for (k = j = 0; k < rank; ++k) {
+ if (k != axis) {
+ *(m_iptr + rank + j) = m_iptr[k];
+ *(m_iptr + 2 * rank + j++) = dims[k] - 1;
+ }
+ }
+ flatten(tmp, ptr, rank, m_iptr[axis], dims[axis], 0, m_iptr);
+ zfft(tmp, dims[axis], direction, sz / dims[axis], normalize);
+ flatten(ptr, tmp, rank, m_iptr[axis], dims[axis], 1, m_iptr);
+ }
+ }
+ return 0;
+}
+
+int NDFFTPackCache::prepare(int *dims) const
+{
+ int rank = m_id.m_rank;
+ int i;
+
+ m_iptr[rank - 1] = 1;
+ for (i = 2; i <= rank; ++i) {
+ m_iptr[rank - i] = m_iptr[rank - i + 1] * dims[rank - i + 1];
+ }
+
+ return 0;
+}
+
+static CacheManager<NDFFTPackCacheId, NDFFTPackCache> ndfftpack_cmgr(10);
+
+#if 0
+/* stub to make PUBLIC_GEN_API happy */
+static void destroy_zfftnd_fftpack_caches()
+{
+}
+#endif
+
GEN_CACHE(zfftnd_fftpack, (int n, int rank)
, complex_double * ptr; int *iptr; int rank;
, ((caches_zfftnd_fftpack[i].n == n)
@@ -24,14 +161,14 @@
/*inline : disabled because MSVC6.0 fails to compile it. */
int next_comb(int *ia, int *da, int m)
{
- while (m >= 0 && ia[m] == da[m]) {
- ia[m--] = 0;
- }
- if (m < 0) {
- return 0;
- }
- ia[m]++;
- return 1;
+ while (m >= 0 && ia[m] == da[m]) {
+ ia[m--] = 0;
+ }
+ if (m < 0) {
+ return 0;
+ }
+ ia[m]++;
+ return 1;
}
static
@@ -39,82 +176,55 @@
int rank, int strides_axis, int dims_axis, int unflat,
int *tmp)
{
- int *new_strides = tmp + rank;
- int *new_dims = tmp + 2 * rank;
- int *ia = tmp + 3 * rank;
- int rm1 = rank - 1, rm2 = rank - 2;
- int i, j, k;
- for (i = 0; i < rm2; ++i)
- ia[i] = 0;
- ia[rm2] = -1;
- j = 0;
- if (unflat) {
- while (next_comb(ia, new_dims, rm2)) {
- k = 0;
- for (i = 0; i < rm1; ++i) {
- k += ia[i] * new_strides[i];
- }
- for (i = 0; i < dims_axis; ++i) {
- *(dest + k + i * strides_axis) = *(src + j++);
- }
+ int *new_strides = tmp + rank;
+ int *new_dims = tmp + 2 * rank;
+ int *ia = tmp + 3 * rank;
+ int rm1 = rank - 1, rm2 = rank - 2;
+ int i, j, k;
+
+ for (i = 0; i < rm2; ++i) {
+ ia[i] = 0;
}
- } else {
- while (next_comb(ia, new_dims, rm2)) {
- k = 0;
- for (i = 0; i < rm1; ++i) {
- k += ia[i] * new_strides[i];
- }
- for (i = 0; i < dims_axis; ++i) {
- *(dest + j++) = *(src + k + i * strides_axis);
- }
+
+ ia[rm2] = -1;
+ j = 0;
+ if (unflat) {
+ while (next_comb(ia, new_dims, rm2)) {
+ k = 0;
+ for (i = 0; i < rm1; ++i) {
+ k += ia[i] * new_strides[i];
+ }
+ for (i = 0; i < dims_axis; ++i) {
+ *(dest + k + i * strides_axis) = *(src + j++);
+ }
+ }
+ } else {
+ while (next_comb(ia, new_dims, rm2)) {
+ k = 0;
+ for (i = 0; i < rm1; ++i) {
+ k += ia[i] * new_strides[i];
+ }
+ for (i = 0; i < dims_axis; ++i) {
+ *(dest + j++) = *(src + k + i * strides_axis);
+ }
+ }
}
- }
}
-extern "C" {
-extern void zfft(complex_double * inout,
- int n, int direction, int howmany, int normalize);
-};
-
extern void zfftnd_fftpack(complex_double * inout, int rank,
int *dims, int direction, int howmany,
int normalize)
{
- int i, sz;
- complex_double *ptr = inout;
- int axis;
- complex_double *tmp;
- int *itmp;
- int k, j;
+ int i, sz;
+ complex_double *ptr = inout;
+ NDFFTPackCache* cache;
- sz = 1;
- for (i = 0; i < rank; ++i) {
- sz *= dims[i];
- }
- zfft(ptr, dims[rank - 1], direction, howmany * sz / dims[rank - 1],
- normalize);
+ sz = 1;
+ for (i = 0; i < rank; ++i) {
+ sz *= dims[i];
+ }
- i = get_cache_id_zfftnd_fftpack(sz, rank);
- tmp = caches_zfftnd_fftpack[i].ptr;
- itmp = caches_zfftnd_fftpack[i].iptr;
+ cache = ndfftpack_cmgr.get_cache(NDFFTPackCacheId(sz, rank));
+ cache->compute(ptr, sz, dims, direction, normalize, howmany);
- itmp[rank - 1] = 1;
- for (i = 2; i <= rank; ++i) {
- itmp[rank - i] = itmp[rank - i + 1] * dims[rank - i + 1];
- }
-
- for (i = 0; i < howmany; ++i, ptr += sz) {
- for (axis = 0; axis < rank - 1; ++axis) {
- for (k = j = 0; k < rank; ++k) {
- if (k != axis) {
- *(itmp + rank + j) = itmp[k];
- *(itmp + 2 * rank + j++) = dims[k] - 1;
- }
- }
- flatten(tmp, ptr, rank, itmp[axis], dims[axis], 0, itmp);
- zfft(tmp, dims[axis], direction, sz / dims[axis], normalize);
- flatten(ptr, tmp, rank, itmp[axis], dims[axis], 1, itmp);
- }
- }
-
}
More information about the Scipy-svn
mailing list