[Scipy-svn] r3126 - trunk/Lib/sandbox/pyem
scipy-svn at scipy.org
scipy-svn at scipy.org
Sun Jul 1 05:32:06 EDT 2007
Author: cdavid
Date: 2007-07-01 04:32:00 -0500 (Sun, 01 Jul 2007)
New Revision: 3126
Modified:
trunk/Lib/sandbox/pyem/TODO
trunk/Lib/sandbox/pyem/gmm_em.py
trunk/Lib/sandbox/pyem/misc.py
Log:
Add (crude) regularized EM
Modified: trunk/Lib/sandbox/pyem/TODO
===================================================================
--- trunk/Lib/sandbox/pyem/TODO 2007-07-01 09:30:56 UTC (rev 3125)
+++ trunk/Lib/sandbox/pyem/TODO 2007-07-01 09:32:00 UTC (rev 3126)
@@ -1,9 +1,8 @@
-# Last Change: Fri Jun 22 05:00 PM 2007 J
+# Last Change: Sun Jul 01 06:00 PM 2007 J
Things which must be implemented for a 1.0 version (in importante order)
- A classifier
- handle rank 1 for 1d data
- - basic regularization
- demo for pdf estimation, discriminant analysis and clustering
- scaling of data: maybe something to handle scaling internally ?
Modified: trunk/Lib/sandbox/pyem/gmm_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/gmm_em.py 2007-07-01 09:30:56 UTC (rev 3125)
+++ trunk/Lib/sandbox/pyem/gmm_em.py 2007-07-01 09:32:00 UTC (rev 3126)
@@ -1,5 +1,5 @@
# /usr/bin/python
-# Last Change: Fri Jun 22 05:00 PM 2007 J
+# Last Change: Sun Jul 01 05:00 PM 2007 J
"""Module implementing GMM, a class to estimate Gaussian mixture models using
EM, and EM, a class which use GMM instances to estimate models parameters using
@@ -23,6 +23,9 @@
#from misc import _DEF_ALPHA, _MIN_DBL_DELTA, _MIN_INV_COND
+_PRIOR_COUNT = 0.05
+_COV_PRIOR = 0.1
+
# Error classes
class GmmError(Exception):
"""Base class for exceptions in this module."""
@@ -209,6 +212,7 @@
mu[c, :] = x / ngamma[c]
va[c, :] = xx / ngamma[c] - mu[c, :] ** 2
+
w = invn * ngamma
return w, mu, va
@@ -361,6 +365,11 @@
# Initialize the data (may do nothing depending on the model)
model.init(data)
+ # Actual training
+ like = self._train_simple_em(data, model, maxiter, thresh)
+ return like
+
+ def _train_simple_em(self, data, model, maxiter, thresh):
# Likelihood is kept
like = N.zeros(maxiter)
@@ -376,8 +385,45 @@
if has_em_converged(like[i], like[i-1], thresh):
return like[0:i]
- return like
-
+class RegularizedEM:
+ # TODO: separate regularizer from EM class ?
+ def __init__(self, pcnt = _PRIOR_COUNT, pval = _COV_PRIOR):
+ """Create a regularized EM object.
+
+ Covariances matrices are regularized after the E step.
+
+ :Parameters:
+ pcnt : float
+ proportion of soft counts to be count as prior counts (e.g. if
+ you have 1000 samples and the prior_count is 0.1, than the
+ prior would "weight" 100 samples).
+ pval : float
+ value of the prior.
+ """
+ self.pcnt = pcnt
+ self.pval = pval
+
+ def train(self, data, model, maxiter = 20, thresh = 1e-5):
+ model.init(data)
+ regularize_full(model.gm.va, self.pcnt, self.pval * N.eye(model.gm.d))
+ # Likelihood is kept
+ like = N.empty(maxiter, N.float)
+
+ # Em computation, with computation of the likelihood
+ g, tgd = model.compute_log_responsabilities(data)
+ g = N.exp(g)
+ like[0] = N.sum(densities.logsumexp(tgd), axis = 0)
+ model.update_em(data, g)
+ regularize_full(model.gm.va, self.pcnt, self.pval * N.eye(model.gm.d))
+ for i in range(1, maxiter):
+ g, tgd = model.compute_log_responsabilities(data)
+ g = N.exp(g)
+ like[i] = N.sum(densities.logsumexp(tgd), axis = 0)
+ model.update_em(data, g)
+ regularize_full(model.gm.va, self.pcnt, self.pval * N.eye(model.gm.d))
+ if has_em_converged(like[i], like[i-1], thresh):
+ return like[0:i]
+
# Misc functions
def bic(lk, deg, n):
""" Expects lk to be log likelihood """
@@ -394,6 +440,16 @@
else:
return False
+def regularize_full(va, np, prior):
+ """np * n is the number of prior counts (np is a proportion, and n is the
+ number of point)."""
+ d = va.shape[1]
+ k = va.shape[0] / d
+
+ for i in range(k):
+ va[i*d:i*d+d,:] *= 1. / (1 + np)
+ va[i*d:i*d+d,:] += np / (1. + np) * prior
+
if __name__ == "__main__":
pass
## # #++++++++++++++++++
Modified: trunk/Lib/sandbox/pyem/misc.py
===================================================================
--- trunk/Lib/sandbox/pyem/misc.py 2007-07-01 09:30:56 UTC (rev 3125)
+++ trunk/Lib/sandbox/pyem/misc.py 2007-07-01 09:32:00 UTC (rev 3126)
@@ -1,4 +1,4 @@
-# Last Change: Sat Jun 09 08:00 PM 2007 J
+# Last Change: Thu Jun 28 06:00 PM 2007 J
#========================================================
# Constants used throughout the module (def args, etc...)
@@ -7,6 +7,7 @@
DEF_VIS_DIM = (0, 1)
DEF_ELL_NP = 100
DEF_LEVEL = 0.39
+
#=====================================================================
# "magic number", that is number used to control regularization and co
# Change them at your risk !
@@ -16,13 +17,13 @@
# I should actually use a number of decimals)
_MAX_DBL_DEV = 1e-10
-# max conditional number allowed
-_MAX_COND = 1e8
-_MIN_INV_COND = 1/_MAX_COND
-
-# Default alpha for regularization
-_DEF_ALPHA = 1e-1
-
-# Default min delta for regularization
-_MIN_DBL_DELTA = 1e-5
-
+## # max conditional number allowed
+## _MAX_COND = 1e8
+## _MIN_INV_COND = 1/_MAX_COND
+##
+## # Default alpha for regularization
+## _DEF_ALPHA = 1e-1
+##
+## # Default min delta for regularization
+## _MIN_DBL_DELTA = 1e-5
+##
More information about the Scipy-svn
mailing list