[Scipy-svn] r3094 - in trunk/Lib/sandbox/pyem: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Jun 11 05:19:13 EDT 2007


Author: cdavid
Date: 2007-06-11 04:18:57 -0500 (Mon, 11 Jun 2007)
New Revision: 3094

Added:
   trunk/Lib/sandbox/pyem/tests/diag_1d_3k.mat
   trunk/Lib/sandbox/pyem/tests/diag_1d_4k.mat
   trunk/Lib/sandbox/pyem/tests/diag_2d_3k.mat
   trunk/Lib/sandbox/pyem/tests/full_2d_3k.mat
   trunk/Lib/sandbox/pyem/tests/generate_tests_data.py
Modified:
   trunk/Lib/sandbox/pyem/gauss_mix.py
   trunk/Lib/sandbox/pyem/gmm_em.py
   trunk/Lib/sandbox/pyem/tests/test_gmm_em.py
Log:
Add basic tests for EM, 1d, 2d, full and diag mode

Modified: trunk/Lib/sandbox/pyem/gauss_mix.py
===================================================================
--- trunk/Lib/sandbox/pyem/gauss_mix.py	2007-06-11 09:18:25 UTC (rev 3093)
+++ trunk/Lib/sandbox/pyem/gauss_mix.py	2007-06-11 09:18:57 UTC (rev 3094)
@@ -1,5 +1,5 @@
 # /usr/bin/python
-# Last Change: Mon Jun 11 03:00 PM 2007 J
+# Last Change: Mon Jun 11 06:00 PM 2007 J
 
 """Module implementing GM, a class which represents Gaussian mixtures.
 
@@ -132,6 +132,7 @@
         :SeeAlso:
             If you know already the parameters when creating the model, you can
             simply use the method class GM.fromvalues."""
+        #XXX: when fromvalues is called, parameters are called twice...
         k, d, mode  = check_gmm_param(weights, mu, sigma)
         if not k == self.k:
             raise GmParamError("Number of given components is %d, expected %d" 
@@ -664,14 +665,14 @@
     """
         
     # Check that w is valid
-    if N.fabs(N.sum(w, 0)  - 1) > misc._MAX_DBL_DEV:
+    if not len(w.shape) == 1:
+        raise GmParamError('weight should be a rank 1 array')
+
+    if N.fabs(N.sum(w)  - 1) > misc._MAX_DBL_DEV:
         raise GmParamError('weight does not sum to 1')
     
-    if not len(w.shape) == 1:
-        raise GmParamError('weight is not a vector')
-
     # Check that mean and va have the same number of components
-    K           = len(w)
+    K = len(w)
 
     if N.ndim(mu) < 2:
         msg = "mu should be a K,d matrix, and a row vector if only 1 comp"

Modified: trunk/Lib/sandbox/pyem/gmm_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/gmm_em.py	2007-06-11 09:18:25 UTC (rev 3093)
+++ trunk/Lib/sandbox/pyem/gmm_em.py	2007-06-11 09:18:57 UTC (rev 3094)
@@ -1,5 +1,5 @@
 # /usr/bin/python
-# Last Change: Mon Jun 11 01:00 PM 2007 J
+# Last Change: Mon Jun 11 04:00 PM 2007 J
 
 """Module implementing GMM, a class to estimate Gaussian mixture models using
 EM, and EM, a class which use GMM instances to estimate models parameters using

Added: trunk/Lib/sandbox/pyem/tests/diag_1d_3k.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/sandbox/pyem/tests/diag_1d_3k.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/sandbox/pyem/tests/diag_1d_4k.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/sandbox/pyem/tests/diag_1d_4k.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/sandbox/pyem/tests/diag_2d_3k.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/sandbox/pyem/tests/diag_2d_3k.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/sandbox/pyem/tests/full_2d_3k.mat
===================================================================
(Binary files differ)


Property changes on: trunk/Lib/sandbox/pyem/tests/full_2d_3k.mat
___________________________________________________________________
Name: svn:mime-type
   + application/octet-stream

Added: trunk/Lib/sandbox/pyem/tests/generate_tests_data.py
===================================================================
--- trunk/Lib/sandbox/pyem/tests/generate_tests_data.py	2007-06-11 09:18:25 UTC (rev 3093)
+++ trunk/Lib/sandbox/pyem/tests/generate_tests_data.py	2007-06-11 09:18:57 UTC (rev 3094)
@@ -0,0 +1,103 @@
+#! /usr/bin/env python
+# Last Change: Mon Jun 11 05:00 PM 2007 J
+
+# This script generates some random data used for testing EM implementations.
+import copy
+import numpy as N
+from numpy.testing import set_package_path, restore_path
+from scipy.io import savemat, loadmat
+
+set_package_path()
+import pyem
+restore_path()
+
+from pyem import GM, GMM, EM
+
+def generate_dataset(d, k, mode, nframes):
+    """Generate a dataset useful for EM anf GMM testing.
+    
+    returns:
+        data : ndarray
+            data from the true model.
+        tgm : GM
+            the true model (randomly generated)
+        gm0 : GM
+            the initial model
+        gm : GM
+            the trained model
+    """
+    # Generate a model
+    w, mu, va = GM.gen_param(d, k, mode, spread = 2.0)
+    tgm = GM.fromvalues(w, mu, va)
+
+    # Generate data from the model
+    data = tgm.sample(nframes)
+
+    # Run EM on the model, by running the initialization separetely.
+    gmm = GMM(GM(d, k, mode), 'test')
+    gmm.init_random(data)
+    gm0 = copy.copy(gmm.gm)
+
+    gmm = GMM(copy.copy(gmm.gm), 'test')
+    em = EM()
+    em.train(data, gmm)
+
+    return data, tgm, gm0, gmm.gm
+
+def save_dataset(filename, data, tgm, gm0, gm):
+    dic = {'tw': tgm.w, 'tmu': tgm.mu, 'tva': tgm.va,
+            'w0': gm0.w, 'mu0' : gm0.mu, 'va0': gm0.va,
+            'w': gm.w, 'mu': gm.mu, 'va': gm.va,
+            'data': data}
+    savemat(filename, dic)
+
+def doall(d, k, mode):
+    import pylab as P
+
+    data, tgm, gm0, gm = generate_dataset(d, k, mode, 500)
+    filename = mode + '_%dd' % d + '_%dk.mat' % k
+    save_dataset(filename, data, tgm, gm0, gm)
+
+    if d == 1:
+        P.subplot(2, 1, 1)
+        gm0.plot1d()
+        h = tgm.plot1d(gpdf = True)
+        P.hist(data[:, 0], 20, normed = 1, fill = False)
+
+        P.subplot(2, 1, 2)
+        gm.plot1d()
+        tgm.plot1d(gpdf = True)
+        P.hist(data[:, 0], 20, normed = 1, fill = False)
+    else:
+        P.subplot(2, 1, 1)
+        gm0.plot()
+        h = tgm.plot()
+        [i.set_color('g') for i in h]
+        P.plot(data[:, 0], data[:, 1], '.')
+
+        P.subplot(2, 1, 2)
+        gm.plot()
+        h = tgm.plot()
+        [i.set_color('g') for i in h]
+        P.plot(data[:, 0], data[:, 1], '.')
+
+    P.show()
+
+if __name__ == '__main__':
+    N.random.seed(0)
+    d = 2
+    k = 3
+    mode = 'full'
+    doall(d, k, mode)
+
+    N.random.seed(0)
+    d = 2
+    k = 3
+    mode = 'diag'
+    doall(d, k, mode)
+
+    N.random.seed(0)
+    d = 1
+    k = 4
+    mode = 'diag'
+    doall(d, k, mode)

Modified: trunk/Lib/sandbox/pyem/tests/test_gmm_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/tests/test_gmm_em.py	2007-06-11 09:18:25 UTC (rev 3093)
+++ trunk/Lib/sandbox/pyem/tests/test_gmm_em.py	2007-06-11 09:18:57 UTC (rev 3094)
@@ -1,5 +1,5 @@
 #! /usr/bin/env python
-# Last Change: Sat Jun 09 03:00 PM 2007 J
+# Last Change: Mon Jun 11 06:00 PM 2007 J
 
 # For now, just test that all mode/dim execute correctly
 
@@ -12,6 +12,14 @@
 from pyem import GMM, GM, EM
 restore_path()
 
+def load_dataset(filename):
+    from scipy.io import loadmat
+    dic = loadmat(filename, squeeze_me = False)
+    dic['w0'] = dic['w0'].squeeze()
+    dic['w'] = dic['w'].squeeze()
+    dic['tw'] = dic['tw'].squeeze()
+    return dic
+
 class EmTest(NumpyTestCase):
     def _create_model_and_run_em(self, d, k, mode, nframes):
         #+++++++++++++++++++++++++++++++++++++++++++++++++
@@ -32,61 +40,127 @@
         em  = EM()
         lk  = em.train(data, gmm)
 
-class test_full(EmTest):
-    def check_1d(self, level = 1):
-        d       = 1
-        k       = 2
-        mode    = 'full'
-        nframes = int(1e2)
+#class test_full_run(EmTest):
+#    """This class only tests whether the algorithms runs. Do not check the
+#    results."""
+#    def check_1d(self, level = 1):
+#        d       = 1
+#        k       = 2
+#        mode    = 'full'
+#        nframes = int(1e2)
+#
+#        #seed(1)
+#        self._create_model_and_run_em(d, k, mode, nframes)
+#
+#    def check_2d(self, level = 1):
+#        d       = 2
+#        k       = 2
+#        mode    = 'full'
+#        nframes = int(1e2)
+#
+#        #seed(1)
+#        self._create_model_and_run_em(d, k, mode, nframes)
+#
+#    def check_5d(self, level = 1):
+#        d       = 5
+#        k       = 3
+#        mode    = 'full'
+#        nframes = int(1e2)
+#
+#        #seed(1)
+#        self._create_model_and_run_em(d, k, mode, nframes)
+#
+#class test_diag_run(EmTest):
+#    """This class only tests whether the algorithms runs. Do not check the
+#    results."""
+#    def check_1d(self, level = 1):
+#        d       = 1
+#        k       = 2
+#        mode    = 'diag'
+#        nframes = int(1e2)
+#
+#        #seed(1)
+#        self._create_model_and_run_em(d, k, mode, nframes)
+#
+#    def check_2d(self, level = 1):
+#        d       = 2
+#        k       = 2
+#        mode    = 'diag'
+#        nframes = int(1e2)
+#
+#        #seed(1)
+#        self._create_model_and_run_em(d, k, mode, nframes)
+#
+#    def check_5d(self, level = 1):
+#        d       = 5
+#        k       = 3
+#        mode    = 'diag'
+#        nframes = int(1e2)
+#
+#        #seed(1)
+#        self._create_model_and_run_em(d, k, mode, nframes)
 
-        #seed(1)
-        self._create_model_and_run_em(d, k, mode, nframes)
+class test_datasets(EmTest):
+    """This class tests whether the EM algorithms works using pre-computed
+    datasets."""
+    def check_1d_full(self, level = 1):
+        d = 1
+        k = 4
+        mode = 'full'
+        # Data are exactly the same than in diagonal mode, just check that
+        # calling full mode works even in 1d, even if it is kind of stupid to
+        # do so
+        dic = load_dataset('diag_1d_4k.mat')
 
-    def check_2d(self, level = 1):
-        d       = 2
-        k       = 2
-        mode    = 'full'
-        nframes = int(1e2)
+        gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
+        gmm = GMM(gm, 'test')
+        EM().train(dic['data'], gmm)
 
-        #seed(1)
-        self._create_model_and_run_em(d, k, mode, nframes)
+        assert_array_equal(gmm.gm.w, dic['w'])
+        assert_array_equal(gmm.gm.mu, dic['mu'])
+        assert_array_equal(gmm.gm.va, dic['va'])
 
-    def check_5d(self, level = 1):
-        d       = 5
-        k       = 3
-        mode    = 'full'
-        nframes = int(1e2)
+    def check_1d_diag(self, level = 1):
+        d = 1
+        k = 4
+        mode = 'diag'
+        dic = load_dataset('diag_1d_4k.mat')
 
-        #seed(1)
-        self._create_model_and_run_em(d, k, mode, nframes)
+        gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
+        gmm = GMM(gm, 'test')
+        EM().train(dic['data'], gmm)
 
-class test_diag(EmTest):
-    def check_1d(self, level = 1):
-        d       = 1
-        k       = 2
-        mode    = 'diag'
-        nframes = int(1e2)
+        assert_array_equal(gmm.gm.w, dic['w'])
+        assert_array_equal(gmm.gm.mu, dic['mu'])
+        assert_array_equal(gmm.gm.va, dic['va'])
 
-        #seed(1)
-        self._create_model_and_run_em(d, k, mode, nframes)
+    def check_2d_full(self, level = 1):
+        d = 2
+        k = 3
+        mode = 'full'
+        dic = load_dataset('full_2d_3k.mat')
 
-    def check_2d(self, level = 1):
-        d       = 2
-        k       = 2
-        mode    = 'diag'
-        nframes = int(1e2)
+        gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
+        gmm = GMM(gm, 'test')
+        EM().train(dic['data'], gmm)
 
-        #seed(1)
-        self._create_model_and_run_em(d, k, mode, nframes)
+        assert_array_equal(gmm.gm.w, dic['w'])
+        assert_array_equal(gmm.gm.mu, dic['mu'])
+        assert_array_equal(gmm.gm.va, dic['va'])
 
-    def check_5d(self, level = 1):
-        d       = 5
-        k       = 3
-        mode    = 'diag'
-        nframes = int(1e2)
+    def check_2d_diag(self, level = 1):
+        d = 2
+        k = 3
+        mode = 'diag'
+        dic = load_dataset('diag_2d_3k.mat')
 
-        #seed(1)
-        self._create_model_and_run_em(d, k, mode, nframes)
+        gm = GM.fromvalues(dic['w0'], dic['mu0'], dic['va0'])
+        gmm = GMM(gm, 'test')
+        EM().train(dic['data'], gmm)
 
+        assert_array_equal(gmm.gm.w, dic['w'])
+        assert_array_equal(gmm.gm.mu, dic['mu'])
+        assert_array_equal(gmm.gm.va, dic['va'])
+
 if __name__ == "__main__":
     NumpyTest().run()




More information about the Scipy-svn mailing list