[Scipy-svn] r3159 - in trunk/Lib/sandbox/pyem: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Thu Jul 12 03:24:11 EDT 2007


Author: cdavid
Date: 2007-07-12 02:23:55 -0500 (Thu, 12 Jul 2007)
New Revision: 3159

Modified:
   trunk/Lib/sandbox/pyem/densities.py
   trunk/Lib/sandbox/pyem/tests/test_densities.py
Log:
Significantly improve speed of gauss_den

Modified: trunk/Lib/sandbox/pyem/densities.py
===================================================================
--- trunk/Lib/sandbox/pyem/densities.py	2007-07-11 05:14:36 UTC (rev 3158)
+++ trunk/Lib/sandbox/pyem/densities.py	2007-07-12 07:23:55 UTC (rev 3159)
@@ -1,7 +1,7 @@
 #! /usr/bin/python
 #
 # Copyrighted David Cournapeau
-# Last Change: Mon Jul 02 06:00 PM 2007 J
+# Last Change: Thu Jul 12 04:00 PM 2007 J
 """This module implements various basic functions related to multivariate
 gaussian, such as pdf estimation, confidence interval/ellipsoids, etc..."""
 
@@ -115,7 +115,8 @@
     d       = mu.size
     inva    = 1/va
     fac     = (2*N.pi) ** (-d/2.0) * N.sqrt(inva)
-    y       = ((x-mu) ** 2) * -0.5 * inva
+    inva    *= -0.5
+    y       = ((x-mu) ** 2) * inva
     if not log:
         y   = fac * N.exp(y)
     else:
@@ -123,12 +124,6 @@
 
     return y
     
-#from ctypes import cdll, c_uint, c_int, c_double, POINTER
-#_gden   = cdll.LoadLibrary('src/libgden.so')
-#_gden.gden_diag.restype     = c_int
-#_gden.gden_diag.argtypes    = [POINTER(c_double), c_uint, c_uint,
-#        POINTER(c_double), POINTER(c_double), POINTER(c_double)]
-
 def _diag_gauss_den(x, mu, va, log):
     """ This function is the actual implementation
     of gaussian pdf in scalar case. It assumes all args
@@ -139,15 +134,14 @@
     d   = mu.size
     #n   = x.shape[0]
     if not log:
-        inva = 1/va[0, 0]
-        fac = (2*N.pi) ** (-d/2.0) * N.sqrt(inva)
-        y =  (x[:, 0] - mu[0, 0]) ** 2 * inva * -0.5
-        for i in range(1, d):
-            inva = 1/va[0, i]
-            fac *= N.sqrt(inva)
-            y += (x[:, i] - mu[0, i]) ** 2 * inva * -0.5
-        y = fac * N.exp(y)
+        inva = 1/va[0]
+        fac = (2*N.pi) ** (-d/2.0) * N.prod(N.sqrt(inva))
+        inva *= -0.5
+        x = x - mu
+        x **= 2
+        y = fac * N.exp(N.dot(x, inva))
     else:
+        # XXX optimize log case as non log case above
         y = _scalar_gauss_den(x[:, 0], mu[0, 0], va[0, 0], log)
         for i in range(1, d):
             y +=  _scalar_gauss_den(x[:, i], mu[0, i], va[0, i], log)

Modified: trunk/Lib/sandbox/pyem/tests/test_densities.py
===================================================================
--- trunk/Lib/sandbox/pyem/tests/test_densities.py	2007-07-11 05:14:36 UTC (rev 3158)
+++ trunk/Lib/sandbox/pyem/tests/test_densities.py	2007-07-12 07:23:55 UTC (rev 3159)
@@ -1,5 +1,5 @@
 #! /usr/bin/env python
-# Last Change: Tue Jun 12 08:00 PM 2007 J
+# Last Change: Thu Jul 12 04:00 PM 2007 J
 
 # TODO:
 #   - having "fake tests" to check that all mode (scalar, diag and full) are
@@ -66,6 +66,9 @@
             0.00378789836599, 0.00015915297541, 
             0.00000253261067, 0.00000001526368])
 
+#=====================
+# Basic accuracy tests
+#=====================
 class test_py_implementation(TestDensities):
     def _test(self, level, decimal = DEF_DEC):
         Y   = gauss_den(self.X, self.mu, self.va)
@@ -99,6 +102,37 @@
         self._generate_test_data_1d()
         self._test_log(level)
 
+#=====================
+# Basic speed tests
+#=====================
+class test_speed(NumpyTestCase):
+    n = 1e5
+    niter = 10
+    cpud = 3.2e9
+    def _prepare(self, n, d, mode):
+        cls = self.__class__
+        x = 0.1 * N.random.randn(n, d)
+        mu = 0.1 * N.random.randn(d)
+        if mode == 'diag':
+            va = 0.1 * N.random.randn(d) ** 2
+        elif mode == 'full':
+            a = N.random.randn(d, d)
+            va = 0.1 * N.dot(a.T, a)
+        st = self.measure("gauss_den(x, mu, va)", cls.niter)
+        return st / cls.niter #* cls.cpud / n / d
+
+    def _bench(self, n, d, mode):
+        st = self._prepare(n, d, mode)
+        print "%d dimension, %d samples, %s mode: %8.2f " % (d, n, mode, st)
+
+    def test1(self, level = 5):
+        cls = self.__class__
+        for i in [1, 5, 10, 30]:
+            self._bench(cls.n, i, 'diag')
+
+#================
+# Logsumexp tests
+#================
 class test_py_logsumexp(TestDensities):
     """Class to compare logsumexp vs naive implementation."""
     def test_underlow(self):
@@ -110,17 +144,20 @@
             try:
                 a = N.array([[-1000]])
                 self.naive_logsumexp(a)
-                raise AssertionError("expected to catch underflow, we should not be here")
+                raise AssertionError("expected to catch underflow, we should"\
+                                     "not be here")
             except FloatingPointError, e:
                 print "Catching underflow, as expected"
             assert pyem.densities.logsumexp(a) == -1000.
             try:
                 a = N.array([[-1000, -1000, -1000]])
                 self.naive_logsumexp(a)
-                raise AssertionError("expected to catch underflow, we should not be here")
+                raise AssertionError("expected to catch underflow, we should"\
+                                     "not be here")
             except FloatingPointError, e:
                 print "Catching underflow, as expected"
-            assert_array_almost_equal(pyem.densities.logsumexp(a), -998.90138771)
+            assert_array_almost_equal(pyem.densities.logsumexp(a), 
+                                      -998.90138771)
         finally:
             N.seterr(under=errst['under'])
 
@@ -154,13 +191,16 @@
         a2 = self.naive_logsumexp(y)
         assert_array_almost_equal(a1, a2, DEF_DEC)
 
+#=======================
+# Test C implementation
+#=======================
 class test_c_implementation(TestDensities):
     def _test(self, level, decimal = DEF_DEC):
         try:
             from pyem._c_densities import gauss_den as c_gauss_den
             Y   = c_gauss_den(self.X, self.mu, self.va)
             assert_array_almost_equal(Y, self.Yt, decimal)
-        except ImportError, inst:
+        except Exception, inst:
             print "Error while importing C implementation, not tested"
             print " -> (Import error was %s)" % inst 
 




More information about the Scipy-svn mailing list