[Scipy-svn] r2397 - in trunk/Lib/sandbox/cdavid: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Tue Dec 12 03:22:03 EST 2006


Author: cdavid
Date: 2006-12-12 02:21:54 -0600 (Tue, 12 Dec 2006)
New Revision: 2397

Added:
   trunk/Lib/sandbox/cdavid/segmentaxis.py
   trunk/Lib/sandbox/cdavid/tests/test_segmentaxis.py
Modified:
   trunk/Lib/sandbox/cdavid/Changelog
   trunk/Lib/sandbox/cdavid/TODO
   trunk/Lib/sandbox/cdavid/__init__.py
   trunk/Lib/sandbox/cdavid/autocorr.py
   trunk/Lib/sandbox/cdavid/tests/test_autocorr.py
Log:
Add fft autocorr + code from A.M Archibald for equivalent to buffer

Modified: trunk/Lib/sandbox/cdavid/Changelog
===================================================================
--- trunk/Lib/sandbox/cdavid/Changelog	2006-12-12 07:24:01 UTC (rev 2396)
+++ trunk/Lib/sandbox/cdavid/Changelog	2006-12-12 08:21:54 UTC (rev 2397)
@@ -1,5 +1,14 @@
-pyem (0.1) Tue, 28 Nov 2006 16:56:35 +0900
+cdavid (0.2) Tue, 12 Dec 2006 17:14:18 +0900
 
+	* second release
+	* add autocorrelation based on fft (should check def for
+	complex arrays, eg use conjugate or not ?)
+	* add code segment_axis from A.M Archibald
+
+-- David Cournapeau <david at ar.media.kyoto-u.ac.jp> 
+
+cdavid (0.1) Tue, 28 Nov 2006 16:56:35 +0900
+
 	* first release
 
 -- David Cournapeau <david at ar.media.kyoto-u.ac.jp> 

Modified: trunk/Lib/sandbox/cdavid/TODO
===================================================================
--- trunk/Lib/sandbox/cdavid/TODO	2006-12-12 07:24:01 UTC (rev 2396)
+++ trunk/Lib/sandbox/cdavid/TODO	2006-12-12 08:21:54 UTC (rev 2397)
@@ -1,6 +1,9 @@
-# Last Change: Tue Nov 28 05:00 PM 2006 J
-
+# Last Change: Tue Dec 12 05:00 PM 2006 J
+Various things to do before submitting outside sandbox
     - there is no doc.
     - the handling of non contiguous arrays is not really 
     elegant, and the code is difficult to maintain
     - rank > 2: must code in C ? (yuk)
+    - for correlation: no reason to support only autocorrelation. Also, it is
+stupid to offer difference function for different implementation: should be an
+argument (fft vs no fft). Basically, the current API is not good.

Modified: trunk/Lib/sandbox/cdavid/__init__.py
===================================================================
--- trunk/Lib/sandbox/cdavid/__init__.py	2006-12-12 07:24:01 UTC (rev 2396)
+++ trunk/Lib/sandbox/cdavid/__init__.py	2006-12-12 08:21:54 UTC (rev 2397)
@@ -1,8 +1,9 @@
-# Last Change: Tue Nov 28 04:00 PM 2006 J
+# Last Change: Tue Dec 12 05:00 PM 2006 J
 from info import __doc__
 
 from lpc import lpc2 as lpc
-from autocorr import autocorr_oneside_nofft as autocorr
+from autocorr import autocorr_oneside_nofft, autocorr_fft
+from segmentaxis import segment_axis
 
 from numpy.testing import NumpyTest
 test = NumpyTest().test

Modified: trunk/Lib/sandbox/cdavid/autocorr.py
===================================================================
--- trunk/Lib/sandbox/cdavid/autocorr.py	2006-12-12 07:24:01 UTC (rev 2396)
+++ trunk/Lib/sandbox/cdavid/autocorr.py	2006-12-12 08:21:54 UTC (rev 2397)
@@ -1,5 +1,5 @@
 #! /usr/bin/env python
-# Last Change: Tue Nov 28 03:00 PM 2006 J
+# Last Change: Tue Dec 12 05:00 PM 2006 J
 
 # TODO: - proper test
 # TODO: - proper profiling
@@ -279,6 +279,40 @@
 
     return res
 
+def nextpow2(n):
+    """Returns p such as 2 ** p >= n """
+    if 2 ** N.log2(n) ==  n:
+        return int(N.log2(n))
+    else:
+        return int(N.log2(n) + 1)
+
+def autocorr_fft(signal, axis = -1):
+    """Return full autocorrelation along specified axis. Use fft
+    for computation."""
+    if N.ndim(signal) == 0:
+        return signal
+    elif signal.ndim == 1:
+        n       = signal.shape[0]
+        nfft    = int(2 ** nextpow2(2 * n - 1))
+        lag     = n - 1
+        a       = fft(signal, n = nfft, axis = -1)
+        au      = ifft(a * N.conj(a), n = nfft, axis = -1)
+        return N.require(N.concatenate((au[-lag:], au[:lag+1])), dtype = signal.dtype)
+    elif signal.ndim == 2:
+        n       = signal.shape[axis]
+        lag     = n - 1
+        nfft    = int(2 ** nextpow2(2 * n - 1))
+        a       = fft(signal, n = nfft, axis = axis)
+        au      = ifft(a * N.conj(a), n = nfft, axis = axis)
+        if axis == 0:
+            return N.require(N.concatenate( (au[-lag:], au[:lag+1]), axis = axis), \
+                    dtype = signal.dtype)
+        else:
+            return N.require(N.concatenate( (au[:, -lag:], au[:, :lag+1]), 
+                        axis = axis), dtype = signal.dtype)
+    else:
+        raise RuntimeError("rank >2 not supported yet")
+        
 def bench():
     size    = 256
     nframes = 4000

Added: trunk/Lib/sandbox/cdavid/segmentaxis.py
===================================================================
--- trunk/Lib/sandbox/cdavid/segmentaxis.py	2006-12-12 07:24:01 UTC (rev 2396)
+++ trunk/Lib/sandbox/cdavid/segmentaxis.py	2006-12-12 08:21:54 UTC (rev 2397)
@@ -0,0 +1,93 @@
+import numpy as N
+import unittest
+from numpy.testing import NumpyTestCase, assert_array_almost_equal,             assert_almost_equal, assert_equal
+import warnings
+
+def segment_axis(a, length, overlap=0, axis=None, end='cut', endvalue=0):
+    """Generate a new array that chops the given array along the given axis into overlapping frames.
+
+    example:
+    >>> segment_axis(arange(10), 4, 2)
+    array([[0, 1, 2, 3],
+           [2, 3, 4, 5],
+           [4, 5, 6, 7],
+           [6, 7, 8, 9]])
+
+    arguments:
+    a       The array to segment
+    length  The length of each frame
+    overlap The number of array elements by which the frames should overlap
+    axis    The axis to operate on; if None, act on the flattened array
+    end     What to do with the last frame, if the array is not evenly
+            divisible into pieces. Options are:
+
+            'cut'   Simply discard the extra values
+            'wrap'  Copy values from the beginning of the array
+            'pad'   Pad with a constant value
+
+    endvalue    The value to use for end='pad'
+
+    The array is not copied unless necessary (either because it is 
+    unevenly strided and being flattened or because end is set to 
+    'pad' or 'wrap').
+    """
+
+    if axis is None:
+        a = N.ravel(a) # may copy
+        axis = 0
+
+    l = a.shape[axis]
+
+    if overlap>=length:
+        raise ValueError, "frames cannot overlap by more than 100%"
+    if overlap<0 or length<=0:
+        raise ValueError, "overlap must be nonnegative and length must be positive"
+
+    if l<length or (l-length)%(length-overlap):
+        if l>length:
+            roundup = length + (1+(l-length)//(length-overlap))*(length-overlap)
+            rounddown = length + ((l-length)//(length-overlap))*(length-overlap)
+        else:
+            roundup = length
+            rounddown = 0
+        assert rounddown<l<roundup
+        assert roundup==rounddown+(length-overlap) or (roundup==length and rounddown==0)
+        a = a.swapaxes(-1,axis)
+
+        if end=='cut':
+            a = a[...,:rounddown]
+        elif end in ['pad','wrap']: # copying will be necessary
+            s = list(a.shape)
+            s[-1]=roundup
+            b = N.empty(s,dtype=a.dtype)
+            b[...,:l] = a
+            if end=='pad':
+                b[...,l:] = endvalue
+            elif end=='wrap':
+                b[...,l:] = a[...,:roundup-l]
+            a = b
+        
+        a = a.swapaxes(-1,axis)
+
+
+    l = a.shape[axis]
+    if l==0:
+        raise ValueError, "Not enough data points to segment array in 'cut' mode; try 'pad' or 'wrap'"
+    assert l>=length
+    assert (l-length)%(length-overlap) == 0
+    n = 1+(l-length)//(length-overlap)
+    s = a.strides[axis]
+    newshape = a.shape[:axis]+(n,length)+a.shape[axis+1:]
+    newstrides = a.strides[:axis]+((length-overlap)*s,s) + a.strides[axis+1:]
+
+    try: 
+        return N.ndarray.__new__(N.ndarray,strides=newstrides,shape=newshape,buffer=a,dtype=a.dtype)
+    except TypeError:
+        warnings.warn("Problem with ndarray creation forces copy.")
+        a = a.copy()
+        # Shape doesn't change but strides does
+        newstrides = a.strides[:axis]+((length-overlap)*s,s) + a.strides[axis+1:]
+        return N.ndarray.__new__(N.ndarray,strides=newstrides,shape=newshape,buffer=a,dtype=a.dtype)
+        
+
+

Modified: trunk/Lib/sandbox/cdavid/tests/test_autocorr.py
===================================================================
--- trunk/Lib/sandbox/cdavid/tests/test_autocorr.py	2006-12-12 07:24:01 UTC (rev 2396)
+++ trunk/Lib/sandbox/cdavid/tests/test_autocorr.py	2006-12-12 08:21:54 UTC (rev 2397)
@@ -1,9 +1,10 @@
 #! /usr/bin/env python
-# Last Change: Tue Nov 28 05:00 PM 2006 J
+# Last Change: Tue Dec 12 05:00 PM 2006 J
 
 from numpy.testing import *
 from numpy.random import randn, seed
-from numpy import correlate, array, concatenate, require
+from numpy import correlate, array, concatenate, require, corrcoef
+from numpy.fft import fft, ifft
 
 from numpy.ctypeslib import ndpointer, load_library
 from ctypes import c_uint
@@ -11,6 +12,7 @@
 set_package_path()
 from cdavid.autocorr import _raw_autocorr_1d, _raw_autocorr_1d_noncontiguous
 from cdavid.autocorr import autocorr_oneside_nofft as autocorr
+from cdavid.autocorr import autocorr_fft 
 from cdavid.autocorr import _autocorr_oneside_nofft_py as autocorr_py
 restore_path()
 
@@ -289,6 +291,54 @@
         yr      = autocorr_py(xt, lag, axis = axis)
         assert_array_equal(yt, yr)
 
+class test_autocorr_fft(NumpyTestCase):
+    n   = 5
+    d   = 3
+    def check_r1r(self):
+        """real case, rank 1"""
+        a   = randn(self.n)
+
+        aref    = correlate(a, a, mode = 'full')
+        atest   = autocorr_fft(a)
+        assert_array_almost_equal(atest, aref, decimal = md)
+        assert atest.dtype == a.dtype
+
+    def check_r1c(self):
+        """complex case, rank 1"""
+        a   = randn(self.n) + 1.0j * randn(self.n)
+
+        atest   = autocorr_fft(a)
+        aref    = numpy.sum(a * numpy.conj(a))
+        assert_array_almost_equal(atest[self.n - 1], aref, decimal = md)
+        assert atest.dtype == a.dtype
+
+    def check_r2c(self):
+        """complex case, rank 2"""
+        pass
+
+    def check_r2r(self):
+        """real case, rank 2"""
+
+        # axis 0
+        a       = randn(self.n, self.d)
+        axis    = 0
+
+        c       = [correlate(a[:, i], a[:, i], mode = 'full') for i in range(self.d)]
+        aref    = array(c).T
+
+        atest   = autocorr_fft(a, axis = axis)
+        assert_array_almost_equal(atest, aref, decimal = md)
+
+        # axis 1
+        a       = randn(self.n, self.d)
+        axis    = 1
+
+        c       = [correlate(a[i], a[i], mode = 'full') for i in range(self.n)]
+        aref    = array(c)
+
+        atest   = autocorr_fft(a, axis = axis)
+        assert_array_almost_equal(atest, aref, decimal = md)
+
 if __name__ == "__main__":
     ScipyTest().run()
 

Added: trunk/Lib/sandbox/cdavid/tests/test_segmentaxis.py
===================================================================
--- trunk/Lib/sandbox/cdavid/tests/test_segmentaxis.py	2006-12-12 07:24:01 UTC (rev 2396)
+++ trunk/Lib/sandbox/cdavid/tests/test_segmentaxis.py	2006-12-12 08:21:54 UTC (rev 2397)
@@ -0,0 +1,64 @@
+#! /usr/bin/env python
+# Last Change: Fri Nov 24 04:00 PM 2006 J
+
+from numpy.testing import *
+
+import numpy as N
+
+set_package_path()
+from segmentaxis import segment_axis
+restore_path()
+
+# #Optional:
+# set_local_path()
+# # import modules that are located in the same directory as this file.
+# restore_path()
+
+class test_segment(NumpyTestCase):
+    def check_simple(self):
+        assert_equal(segment_axis(N.arange(6),length=3,overlap=0),
+                         N.array([[0,1,2],[3,4,5]]))
+
+        assert_equal(segment_axis(N.arange(7),length=3,overlap=1),
+                         N.array([[0,1,2],[2,3,4],[4,5,6]]))
+
+        assert_equal(segment_axis(N.arange(7),length=3,overlap=2),
+                         N.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5],[4,5,6]]))
+
+    def check_error_checking(self):
+        self.assertRaises(ValueError,
+                lambda: segment_axis(N.arange(7),length=3,overlap=-1))
+        self.assertRaises(ValueError,
+                lambda: segment_axis(N.arange(7),length=0,overlap=0))
+        self.assertRaises(ValueError,
+                lambda: segment_axis(N.arange(7),length=3,overlap=3))
+        self.assertRaises(ValueError,
+                lambda: segment_axis(N.arange(7),length=8,overlap=3))
+
+    def check_ending(self):
+        assert_equal(segment_axis(N.arange(6),length=3,overlap=1,end='cut'),
+                         N.array([[0,1,2],[2,3,4]]))
+        assert_equal(segment_axis(N.arange(6),length=3,overlap=1,end='wrap'),
+                         N.array([[0,1,2],[2,3,4],[4,5,0]]))
+        assert_equal(segment_axis(N.arange(6),length=3,overlap=1,end='pad',endvalue=-17),
+                         N.array([[0,1,2],[2,3,4],[4,5,-17]]))
+
+    def check_multidimensional(self):
+        
+        assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=3,length=3,overlap=1).shape,
+                     (2,3,4,2,3,6))
+
+        assert_equal(segment_axis(N.ones((2,5,4,3,6)).swapaxes(1,3),axis=3,length=3,overlap=1).shape,
+                     (2,3,4,2,3,6))
+
+        assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=2,length=3,overlap=1,end='cut').shape,
+                     (2,3,1,3,5,6))
+
+        assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=2,length=3,overlap=1,end='wrap').shape,
+                     (2,3,2,3,5,6))
+
+        assert_equal(segment_axis(N.ones((2,3,4,5,6)),axis=2,length=3,overlap=1,end='pad').shape,
+                     (2,3,2,3,5,6))
+
+if __name__=='__main__':
+    NumpyTest().run()




More information about the Scipy-svn mailing list