[Scipy-svn] r4256 - branches/refactor_fft/scipy/fftpack/src/mkl
scipy-svn at scipy.org
scipy-svn at scipy.org
Sun May 11 05:59:40 EDT 2008
Author: cdavid
Date: 2008-05-11 04:59:35 -0500 (Sun, 11 May 2008)
New Revision: 4256
Modified:
branches/refactor_fft/scipy/fftpack/src/mkl/zfftnd.cxx
Log:
mkl backend for zfftnd now uses c++ cycle cache.
Modified: branches/refactor_fft/scipy/fftpack/src/mkl/zfftnd.cxx
===================================================================
--- branches/refactor_fft/scipy/fftpack/src/mkl/zfftnd.cxx 2008-05-11 09:20:47 UTC (rev 4255)
+++ branches/refactor_fft/scipy/fftpack/src/mkl/zfftnd.cxx 2008-05-11 09:59:35 UTC (rev 4256)
@@ -5,9 +5,146 @@
*
* Last Change: Sun May 11 06:00 PM 2008 J
*/
+#include <new>
-static long *convert_dims(int n, int *dims)
+#include "cycliccache.h"
+
+using namespace fft;
+
+class NDMKLCacheId {
+ public:
+ NDMKLCacheId(int rank, int *dim);
+ virtual ~NDMKLCacheId();
+
+ NDMKLCacheId(const NDMKLCacheId &);
+
+ virtual bool operator==(const NDMKLCacheId & other) const {
+ return is_equal(other);
+ };
+
+ virtual bool is_equal(const NDMKLCacheId & other) const;
+
+ public:
+ int m_rank;
+ int *m_dims;
+
+ private:
+ int init(int rank, int *dims);
+};
+
+int NDMKLCacheId::init(int rank, int *dims)
{
+ m_dims = (int *) malloc(sizeof(int) * rank);
+ if (m_dims == NULL) {
+ return -1;
+ }
+ memcpy(m_dims, dims, rank * sizeof(*m_dims));
+
+ return 0;
+
+}
+
+NDMKLCacheId::NDMKLCacheId(int rank, int *dims) :
+ m_rank(rank)
+{
+ if (init(rank, dims)) {
+ goto fail;
+ }
+
+fail:
+ std::bad_alloc();
+}
+
+NDMKLCacheId::NDMKLCacheId(const NDMKLCacheId & copy) :
+ m_rank(copy.m_rank)
+{
+ if (init(copy.m_rank, copy.m_dims)) {
+ goto fail;
+ }
+
+fail:
+ std::bad_alloc();
+}
+
+NDMKLCacheId::~NDMKLCacheId()
+{
+ free(m_dims);
+}
+
+bool NDMKLCacheId::is_equal(const NDMKLCacheId & other) const
+{
+ bool res;
+
+ if (m_rank == other.m_rank) {
+ res = equal_dims(m_rank, m_dims, other.m_dims);
+ } else {
+ return false;
+ }
+
+ return res;
+}
+
+/*
+ * Cache class for nd-MKL
+ */
+class NDMKLCache:public Cache < NDMKLCacheId > {
+ public:
+ NDMKLCache(const NDMKLCacheId & id);
+ virtual ~ NDMKLCache();
+
+ int compute_forward(double * inout) const
+ {
+ DftiComputeForward(m_hdl, inout);
+ return 0;
+ };
+
+ int compute_backward(double * inout) const
+ {
+ DftiComputeBackward(m_hdl, inout);
+ return 0;
+ };
+
+ protected:
+ int m_rank;
+ int *m_dims;
+ long *m_ndims;
+ DFTI_DESCRIPTOR_HANDLE m_hdl;
+
+ private:
+ long *convert_dims(int n, int *dims) const;
+
+};
+
+NDMKLCache::NDMKLCache(const NDMKLCacheId & id)
+: Cache < NDMKLCacheId > (id)
+{
+ m_rank = id.m_rank;
+ m_ndims = convert_dims(id.m_rank, id.m_dims);
+ m_dims = (int *) malloc(sizeof(int) * m_rank);
+ if (m_dims == NULL) {
+ goto fail;
+ }
+
+ memcpy(m_dims, id.m_dims, sizeof(int) * m_rank);
+ DftiCreateDescriptor(&m_hdl, DFTI_DOUBLE, DFTI_COMPLEX, (long) m_rank,
+ m_ndims);
+ DftiCommitDescriptor(m_hdl);
+
+ return;
+
+fail:
+ throw std::bad_alloc();
+}
+
+NDMKLCache::~NDMKLCache()
+{
+ DftiFreeDescriptor(&m_hdl);
+ free(m_dims);
+ free(m_ndims);
+}
+
+long* NDMKLCache::convert_dims(int n, int *dims) const
+{
long *ndim;
int i;
@@ -18,44 +155,42 @@
return ndim;
}
-GEN_CACHE(zfftnd_mkl, (int n, int *dims)
- , DFTI_DESCRIPTOR_HANDLE desc_handle;
- int *dims;
- long *ndims;, ((caches_zfftnd_mkl[i].n == n) &&
- (equal_dims(n, caches_zfftnd_mkl[i].dims, dims)))
- , caches_zfftnd_mkl[id].ndims = convert_dims(n, dims);
- caches_zfftnd_mkl[id].n = n;
- caches_zfftnd_mkl[id].dims = (int *) malloc(sizeof(int) * n);
- memcpy(caches_zfftnd_mkl[id].dims, dims, sizeof(int) * n);
- DftiCreateDescriptor(&caches_zfftnd_mkl[id].desc_handle,
- DFTI_DOUBLE, DFTI_COMPLEX, (long) n,
- caches_zfftnd_mkl[id].ndims);
- DftiCommitDescriptor(caches_zfftnd_mkl[id].desc_handle);,
- DftiFreeDescriptor(&caches_zfftnd_mkl[id].desc_handle);
- free(caches_zfftnd_mkl[id].dims);
- free(caches_zfftnd_mkl[id].ndims);, 10)
+static CacheManager < NDMKLCacheId, NDMKLCache > mkl_cmgr(10);
+/* stub to make GEN_CACHE happy */
+static void destroy_zfftnd_mkl_caches()
+{
+}
+
extern void zfftnd_mkl(complex_double * inout, int rank,
int *dims, int direction, int howmany,
int normalize)
{
- int i, sz, id;
+ int i, sz;
complex_double *ptr = inout;
+ NDMKLCache *cache;
- DFTI_DESCRIPTOR_HANDLE desc_handle;
sz = 1;
for (i = 0; i < rank; ++i) {
sz *= dims[i];
}
- id = get_cache_id_zfftnd_mkl(rank, dims);
- desc_handle = caches_zfftnd_mkl[id].desc_handle;
- for (i = 0; i < howmany; ++i, ptr += sz) {
- if (direction == 1) {
- DftiComputeForward(desc_handle, (double *) ptr);
- } else if (direction == -1) {
- DftiComputeBackward(desc_handle, (double *) ptr);
- }
+ cache = mkl_cmgr.get_cache(NDMKLCacheId(rank, dims));
+ switch(direction) {
+ case 1:
+ for (i = 0; i < howmany; ++i, ptr += sz) {
+ cache->compute_forward((double*)ptr);
+ }
+ break;
+ case -1:
+ for (i = 0; i < howmany; ++i, ptr += sz) {
+ cache->compute_backward((double*)ptr);
+ }
+ break;
+ default:
+ fprintf(stderr,
+ "nd mkl:Wrong direction (this is a bug)\n");
+ return;
}
if (normalize) {
ptr = inout;
More information about the Scipy-svn
mailing list