[Scipy-svn] r4256 - branches/refactor_fft/scipy/fftpack/src/mkl

scipy-svn at scipy.org scipy-svn at scipy.org
Sun May 11 05:59:40 EDT 2008


Author: cdavid
Date: 2008-05-11 04:59:35 -0500 (Sun, 11 May 2008)
New Revision: 4256

Modified:
   branches/refactor_fft/scipy/fftpack/src/mkl/zfftnd.cxx
Log:
mkl backend for zfftnd now uses c++ cycle cache.

Modified: branches/refactor_fft/scipy/fftpack/src/mkl/zfftnd.cxx
===================================================================
--- branches/refactor_fft/scipy/fftpack/src/mkl/zfftnd.cxx	2008-05-11 09:20:47 UTC (rev 4255)
+++ branches/refactor_fft/scipy/fftpack/src/mkl/zfftnd.cxx	2008-05-11 09:59:35 UTC (rev 4256)
@@ -5,9 +5,146 @@
  *
  * Last Change: Sun May 11 06:00 PM 2008 J
  */
+#include <new>
 
-static long *convert_dims(int n, int *dims)
+#include "cycliccache.h"
+
+using namespace fft;
+
+class NDMKLCacheId {
+        public:
+                NDMKLCacheId(int rank, int *dim);
+                virtual ~NDMKLCacheId();
+
+                NDMKLCacheId(const NDMKLCacheId &);
+
+                virtual bool operator==(const NDMKLCacheId & other) const {
+                        return is_equal(other);
+                };
+
+                virtual bool is_equal(const NDMKLCacheId & other) const;
+
+        public:
+                int m_rank;
+                int *m_dims;
+
+        private:
+                int init(int rank, int *dims);
+};
+
+int NDMKLCacheId::init(int rank, int *dims)
 {
+	m_dims = (int *) malloc(sizeof(int) * rank);
+	if (m_dims == NULL) {
+		return -1;
+	}
+	memcpy(m_dims, dims, rank * sizeof(*m_dims));
+
+	return 0;
+
+}
+
+NDMKLCacheId::NDMKLCacheId(int rank, int *dims) :
+        m_rank(rank)
+{
+        if (init(rank, dims)) {
+                goto fail;
+        }
+
+fail:
+        std::bad_alloc();
+}
+
+NDMKLCacheId::NDMKLCacheId(const NDMKLCacheId & copy) :
+        m_rank(copy.m_rank)
+{
+	if (init(copy.m_rank, copy.m_dims)) {
+		goto fail;
+	}
+
+fail:
+	std::bad_alloc();
+}
+
+NDMKLCacheId::~NDMKLCacheId()
+{
+	free(m_dims);
+}
+
+bool NDMKLCacheId::is_equal(const NDMKLCacheId & other) const
+{
+	bool res;
+
+	if (m_rank == other.m_rank) {
+                res = equal_dims(m_rank, m_dims, other.m_dims);
+	} else {
+		return false;
+	}
+
+	return res;
+}
+
+/*
+ * Cache class for nd-MKL
+ */
+class NDMKLCache:public Cache < NDMKLCacheId > {
+        public:
+                NDMKLCache(const NDMKLCacheId & id);
+                virtual ~ NDMKLCache();
+
+                int compute_forward(double * inout) const
+                {
+                        DftiComputeForward(m_hdl, inout);
+                        return 0;
+                };
+
+                int compute_backward(double * inout) const
+                {
+                        DftiComputeBackward(m_hdl, inout);
+                        return 0;
+                };
+
+        protected:
+                int m_rank;
+                int *m_dims;
+                long *m_ndims;
+                DFTI_DESCRIPTOR_HANDLE m_hdl;
+
+        private:
+                long *convert_dims(int n, int *dims) const;
+
+};
+
+NDMKLCache::NDMKLCache(const NDMKLCacheId & id)
+:  Cache < NDMKLCacheId > (id)
+{
+        m_rank = id.m_rank;
+        m_ndims = convert_dims(id.m_rank, id.m_dims);
+        m_dims = (int *) malloc(sizeof(int) * m_rank);
+        if (m_dims == NULL) {
+                goto fail;
+        }
+
+	memcpy(m_dims, id.m_dims, sizeof(int) * m_rank);
+        DftiCreateDescriptor(&m_hdl, DFTI_DOUBLE, DFTI_COMPLEX, (long) m_rank,
+                             m_ndims);
+        DftiCommitDescriptor(m_hdl);
+
+        return;
+
+fail:
+        throw std::bad_alloc();
+}
+
+NDMKLCache::~NDMKLCache()
+{
+        DftiFreeDescriptor(&m_hdl);
+        free(m_dims);
+        free(m_ndims);
+}
+
+long* NDMKLCache::convert_dims(int n, int *dims) const
+{
         long *ndim;
         int i;
 
@@ -18,44 +155,42 @@
         return ndim;
 }
 
-GEN_CACHE(zfftnd_mkl, (int n, int *dims)
-	  , DFTI_DESCRIPTOR_HANDLE desc_handle;
-	  int *dims;
-	  long *ndims;, ((caches_zfftnd_mkl[i].n == n) &&
-			 (equal_dims(n, caches_zfftnd_mkl[i].dims, dims)))
-	  , caches_zfftnd_mkl[id].ndims = convert_dims(n, dims);
-	  caches_zfftnd_mkl[id].n = n;
-	  caches_zfftnd_mkl[id].dims = (int *) malloc(sizeof(int) * n);
-	  memcpy(caches_zfftnd_mkl[id].dims, dims, sizeof(int) * n);
-	  DftiCreateDescriptor(&caches_zfftnd_mkl[id].desc_handle,
-			       DFTI_DOUBLE, DFTI_COMPLEX, (long) n,
-			       caches_zfftnd_mkl[id].ndims);
-	  DftiCommitDescriptor(caches_zfftnd_mkl[id].desc_handle);,
-	  DftiFreeDescriptor(&caches_zfftnd_mkl[id].desc_handle);
-	  free(caches_zfftnd_mkl[id].dims);
-	  free(caches_zfftnd_mkl[id].ndims);, 10)
+static CacheManager < NDMKLCacheId, NDMKLCache > mkl_cmgr(10);
 
+/* stub to make GEN_CACHE happy */
+static void destroy_zfftnd_mkl_caches()
+{
+}
+
 extern void zfftnd_mkl(complex_double * inout, int rank,
 		       int *dims, int direction, int howmany,
 		       int normalize)
 {
-        int i, sz, id;
+        int i, sz;
         complex_double *ptr = inout;
+	NDMKLCache *cache;
 
-        DFTI_DESCRIPTOR_HANDLE desc_handle;
         sz = 1;
         for (i = 0; i < rank; ++i) {
                 sz *= dims[i];
         }
 
-        id = get_cache_id_zfftnd_mkl(rank, dims);
-        desc_handle = caches_zfftnd_mkl[id].desc_handle;
-        for (i = 0; i < howmany; ++i, ptr += sz) {
-                if (direction == 1) {
-                        DftiComputeForward(desc_handle, (double *) ptr);
-                } else if (direction == -1) {
-                        DftiComputeBackward(desc_handle, (double *) ptr);
-                }
+        cache = mkl_cmgr.get_cache(NDMKLCacheId(rank, dims));
+        switch(direction) {
+                case 1:
+                        for (i = 0; i < howmany; ++i, ptr += sz) {
+                                cache->compute_forward((double*)ptr);
+                        }
+                        break;
+                case -1:
+                        for (i = 0; i < howmany; ++i, ptr += sz) {
+                                cache->compute_backward((double*)ptr);
+                        }
+                        break;
+                default:
+                        fprintf(stderr,
+                                "nd mkl:Wrong direction (this is a bug)\n");
+                        return;
         }
         if (normalize) {
                 ptr = inout;




More information about the Scipy-svn mailing list