[Scipy-svn] r4919 - in trunk/scipy/stats: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sun Nov 2 17:37:41 EST 2008


Author: stefan
Date: 2008-11-02 16:37:22 -0600 (Sun, 02 Nov 2008)
New Revision: 4919

Modified:
   trunk/scipy/stats/distributions.py
   trunk/scipy/stats/tests/test_distributions.py
Log:
Start refactoring rv_continuous and rv_discrete.  Fix output of
distributions to be more like NumPy (i.e. return scalar for single value).
Closes #539.

Modified: trunk/scipy/stats/distributions.py
===================================================================
--- trunk/scipy/stats/distributions.py	2008-11-02 11:32:47 UTC (rev 4918)
+++ trunk/scipy/stats/distributions.py	2008-11-02 22:37:22 UTC (rev 4919)
@@ -240,7 +240,81 @@
         newargs[k] = extract(cond,newarr*expand_arr)
     return newargs
 
-class rv_continuous(object):
+class rv_generic(object):
+    """Class which encapsulates common functionality between rv_discrete
+    and rv_continuous.
+
+    """
+    def _fix_loc_scale(self, args, loc, scale=1):
+        N = len(args)
+        if N > self.numargs:
+            if N == self.numargs + 1 and loc is None:
+                # loc is given without keyword
+                loc = args[-1]
+            if N == self.numargs + 2 and scale is None:
+                # loc and scale given without keyword
+                loc, scale = args[-2:]
+            args = args[:self.numargs]
+        if scale is None:
+            scale = 1.0
+        if loc is None:
+            loc = 0.0
+        return args, loc, scale
+
+    def _fix_loc(self, args, loc):
+        args, loc, scale = self._fix_loc_scale(args, loc)
+        return args, loc
+
+    # These are actually called, and should not be overwritten if you
+    # want to keep error checking.
+    def rvs(self,*args,**kwds):
+        """Random variates of given type.
+
+        *args
+        =====
+        The shape parameter(s) for the distribution (see docstring of the
+           instance object for more information)
+
+        **kwds
+        ======
+        size  - number of random variates (default=1)
+        loc   - location parameter (default=0)
+        scale - scale parameter (default=1)
+        """
+        kwd_names = ['loc', 'scale', 'size', 'discrete']
+        loc, scale, size, discrete = map(kwds.get, kwd_names,
+                                         [None]*len(kwd_names))
+
+        args, loc, scale = self._fix_loc_scale(args, loc, scale)
+        cond = logical_and(self._argcheck(*args),(scale >= 0))
+        if not all(cond):
+            raise ValueError, "Domain error in arguments."
+
+        # self._size is total size of all output values
+        self._size = product(size, axis=0)
+        if self._size > 1:
+            size = numpy.array(size, ndmin=1)
+
+        if scale == 0:
+            return loc*ones(size, 'd')
+
+        vals = self._rvs(*args)
+        if self._size is not None:
+            vals = reshape(vals, size)
+
+        vals = vals * scale + loc
+
+        # Cast to int if discrete
+        if discrete:
+            if numpy.isscalar(vals):
+                vals = int(vals)
+            else:
+                vals = vals.astype(int)
+
+        return vals
+
+
+class rv_continuous(rv_generic):
     """A Generic continuous random variable.
 
     Continuous random variables are defined from a standard form chosen
@@ -284,7 +358,12 @@
         - frozen RV object with the same methods but holding the
             given shape, location, and scale fixed
     """
-    def __init__(self, momtype=1, a=None, b=None, xa=-10.0, xb=10.0, xtol=1e-14, badvalue=None, name=None, longname=None, shapes=None, extradoc=None):
+    def __init__(self, momtype=1, a=None, b=None, xa=-10.0, xb=10.0,
+                 xtol=1e-14, badvalue=None, name=None, longname=None,
+                 shapes=None, extradoc=None):
+
+        rv_generic.__init__(self)
+
         if badvalue is None:
             badvalue = nan
         self.badvalue = badvalue
@@ -400,59 +479,6 @@
     def _munp(self,n,*args):
         return self.generic_moment(n,*args)
 
-    def __fix_loc_scale(self, args, loc, scale):
-        N = len(args)
-        if N > self.numargs:
-            if N == self.numargs + 1 and loc is None:
-                # loc is given without keyword
-                loc = args[-1]
-            if N == self.numargs + 2 and scale is None:
-                # loc and scale given without keyword
-                loc, scale = args[-2:]
-            args = args[:self.numargs]
-        if scale is None:
-            scale = 1.0
-        if loc is None:
-            loc = 0.0
-        return args, loc, scale
-
-    # These are actually called, but should not
-    #  be overwritten if you want to keep
-    #  the error checking.
-    def rvs(self,*args,**kwds):
-        """Random variates of given type.
-
-        *args
-        =====
-        The shape parameter(s) for the distribution (see docstring of the
-           instance object for more information)
-
-        **kwds
-        ======
-        size  - number of random variates (default=1)
-        loc   - location parameter (default=0)
-        scale - scale parameter (default=1)
-        """
-        loc,scale,size=map(kwds.get,['loc','scale','size'])
-        args, loc, scale = self.__fix_loc_scale(args, loc, scale)
-        cond = logical_and(self._argcheck(*args),(scale >= 0))
-        if not all(cond):
-            raise ValueError, "Domain error in arguments."
-
-        if size is None:
-            size = 1
-        else:
-            self._size = product(size,axis=0)
-        if numpy.isscalar(size):
-            self._size = size
-            size = (size,)
-
-        vals = reshape(self._rvs(*args),size)
-        if scale == 0:
-            return loc*ones(size,'d')
-        else:
-            return vals * scale + loc
-
     def pdf(self,x,*args,**kwds):
         """Probability density function at x of the given RV.
 
@@ -467,7 +493,7 @@
         scale - scale parameter (default=1)
         """
         loc,scale=map(kwds.get,['loc','scale'])
-        args, loc, scale = self.__fix_loc_scale(args, loc, scale)
+        args, loc, scale = self._fix_loc_scale(args, loc, scale)
         x,loc,scale = map(arr,(x,loc,scale))
         args = tuple(map(arr,args))
         x = arr((x-loc)*1.0/scale)
@@ -497,7 +523,7 @@
         scale - scale parameter (default=1)
         """
         loc,scale=map(kwds.get,['loc','scale'])
-        args, loc, scale = self.__fix_loc_scale(args, loc, scale)
+        args, loc, scale = self._fix_loc_scale(args, loc, scale)
         x,loc,scale = map(arr,(x,loc,scale))
         args = tuple(map(arr,args))
         x = (x-loc)*1.0/scale
@@ -528,7 +554,7 @@
         scale - scale parameter (default=1)
         """
         loc,scale=map(kwds.get,['loc','scale'])
-        args, loc, scale = self.__fix_loc_scale(args, loc, scale)
+        args, loc, scale = self._fix_loc_scale(args, loc, scale)
         x,loc,scale = map(arr,(x,loc,scale))
         args = tuple(map(arr,args))
         x = (x-loc)*1.0/scale
@@ -559,7 +585,7 @@
         scale - scale parameter (default=1)
         """
         loc,scale=map(kwds.get,['loc','scale'])
-        args, loc, scale = self.__fix_loc_scale(args, loc, scale)
+        args, loc, scale = self._fix_loc_scale(args, loc, scale)
         q,loc,scale = map(arr,(q,loc,scale))
         args = tuple(map(arr,args))
         cond0 = self._argcheck(*args) & (scale > 0) & (loc==loc)
@@ -590,7 +616,7 @@
         scale - scale parameter (default=1)
         """
         loc,scale=map(kwds.get,['loc','scale'])
-        args, loc, scale = self.__fix_loc_scale(args, loc, scale)
+        args, loc, scale = self._fix_loc_scale(args, loc, scale)
         q,loc,scale = map(arr,(q,loc,scale))
         args = tuple(map(arr,args))
         cond0 = self._argcheck(*args) & (scale > 0) & (loc==loc)
@@ -800,7 +826,7 @@
 
     def entropy(self, *args, **kwds):
         loc,scale=map(kwds.get,['loc','scale'])
-        args, loc, scale = self.__fix_loc_scale(args, loc, scale)
+        args, loc, scale = self._fix_loc_scale(args, loc, scale)
         args = map(arr,args)
         cond0 = self._argcheck(*args) & (scale > 0) & (loc==loc)
         output = zeros(shape(cond0),'d')
@@ -3301,7 +3327,7 @@
 # Must over-ride one of _pmf or _cdf or pass in
 #  x_k, p(x_k) lists in initialization
 
-class rv_discrete:
+class rv_discrete(rv_generic):
     """A generic discrete random variable.
 
     Discrete random variables are defined from a standard form.
@@ -3350,6 +3376,9 @@
     def __init__(self, a=0, b=inf, name=None, badvalue=None,
                  moment_tol=1e-8,values=None,inc=1,longname=None,
                  shapes=None, extradoc=None):
+
+        rv_generic.__init__(self)
+
         if badvalue is None:
             badvalue = nan
         self.badvalue = badvalue
@@ -3424,16 +3453,6 @@
     def _rvs(self, *args):
         return self._ppf(mtrand.random_sample(self._size),*args)
 
-    def __fix_loc(self, args, loc):
-        N = len(args)
-        if N > self.numargs:
-            if N == self.numargs + 1 and loc is None:  # loc is given without keyword
-                loc = args[-1]
-            args = args[:self.numargs]
-        if loc is None:
-            loc = 0
-        return args, loc
-
     def _nonzero(self, k, *args):
         return floor(k)==k
 
@@ -3470,28 +3489,10 @@
         return self.generic_moment(n)
 
 
-    def rvs(self, *args, **kwds):
-        loc,size=map(kwds.get,['loc','size'])
-        args, loc = self.__fix_loc(args, loc)
-        cond = self._argcheck(*args)
-        if not all(cond):
-            raise ValueError, "Domain error in arguments."
+    def rvs(self, *args, **kwargs):
+        kwargs['discrete'] = True
+        return rv_generic.rvs(self, *args, **kwargs)
 
-        if size is None:
-            size = 1
-        else:
-            self._size = product(size,axis=0)
-        if numpy.isscalar(size):
-            self._size = size
-            size = (size,)
-
-        vals = reshape(self._rvs(*args),size)
-        if self.return_integers:
-            vals = arr(vals)
-            if vals.dtype.char not in numpy.typecodes['AllInteger']:
-                vals = vals.astype(int)
-        return vals + loc
-
     def pmf(self, k,*args, **kwds):
         """Probability mass function at k of the given RV.
 
@@ -3505,7 +3506,7 @@
         loc   - location parameter (default=0)
         """
         loc = kwds.get('loc')
-        args, loc  = self.__fix_loc(args, loc)
+        args, loc = self._fix_loc(args, loc)
         k,loc = map(arr,(k,loc))
         args = tuple(map(arr,args))
         k = arr((k-loc))
@@ -3533,7 +3534,7 @@
         loc   - location parameter (default=0)
         """
         loc = kwds.get('loc')
-        args, loc = self.__fix_loc(args, loc)
+        args, loc = self._fix_loc(args, loc)
         k,loc = map(arr,(k,loc))
         args = tuple(map(arr,args))
         k = arr((k-loc))
@@ -3563,7 +3564,7 @@
         loc   - location parameter (default=0)
         """
         loc= kwds.get('loc')
-        args, loc = self.__fix_loc(args, loc)
+        args, loc = self._fix_loc(args, loc)
         k,loc = map(arr,(k,loc))
         args = tuple(map(arr,args))
         k = arr(k-loc)
@@ -3593,7 +3594,7 @@
         loc   - location parameter (default=0)
         """
         loc = kwds.get('loc')
-        args, loc = self.__fix_loc(args, loc)
+        args, loc = self._fix_loc(args, loc)
         q,loc  = map(arr,(q,loc))
         args = tuple(map(arr,args))
         cond0 = self._argcheck(*args) & (loc == loc)
@@ -3624,7 +3625,7 @@
         """
 
         loc = kwds.get('loc')
-        args, loc = self.__fix_loc(args, loc)
+        args, loc = self._fix_loc(args, loc)
         q,loc  = map(arr,(q,loc))
         args = tuple(map(arr,args))
         cond0 = self._argcheck(*args) & (loc == loc)
@@ -3795,7 +3796,7 @@
 
     def entropy(self, *args, **kwds):
         loc= kwds.get('loc')
-        args, loc = self.__fix_loc(args, loc)
+        args, loc = self._fix_loc(args, loc)
         loc = arr(loc)
         args = map(arr,args)
         cond0 = self._argcheck(*args) & (loc==loc)
@@ -4180,15 +4181,13 @@
         g1 = 0.0
         g2 = -6.0/5.0*(d*d+1.0)/(d-1.0)*(d+1.0)
         return mu, var, g1, g2
-    def rvs(self, min, max=None, size=None):
+    def _rvs(self, min, max=None):
         """An array of *size* random integers >= min and < max.
 
         If max is None, then range is >=0  and < min
         """
+        return mtrand.randint(min, max, self._size)
 
-        # Return an array scalar if needed.
-        return arr(mtrand.randint(min, max, size))[()]
-
     def _entropy(self, min, max):
         return log(max-min)
 randint = randint_gen(name='randint',longname='A discrete uniform '\

Modified: trunk/scipy/stats/tests/test_distributions.py
===================================================================
--- trunk/scipy/stats/tests/test_distributions.py	2008-11-02 11:32:47 UTC (rev 4918)
+++ trunk/scipy/stats/tests/test_distributions.py	2008-11-02 22:37:22 UTC (rev 4919)
@@ -83,6 +83,7 @@
         val = stats.randint.rvs(15,46)
         assert((val >= 15) & (val < 46))
         assert isinstance(val, numpy.ScalarType),`type(val)`
+        val = stats.randint(15,46).rvs(3)
         assert(val.dtype.char in typecodes['AllInteger'])
 
     def test_pdf(self):
@@ -105,6 +106,8 @@
         assert(numpy.shape(vals) == (2, 50))
         assert(vals.dtype.char in typecodes['AllInteger'])
         val = stats.binom.rvs(10, 0.75)
+        assert(isinstance(val, int))
+        val = stats.binom(10, 0.75).rvs(3)
         assert(isinstance(val, numpy.ndarray))
         assert(val.dtype.char in typecodes['AllInteger'])
 
@@ -116,6 +119,8 @@
         assert(numpy.shape(vals) == (2, 50))
         assert(vals.dtype.char in typecodes['AllInteger'])
         val = stats.bernoulli.rvs(0.75)
+        assert(isinstance(val, int))
+        val = stats.bernoulli(0.75).rvs(3)
         assert(isinstance(val, numpy.ndarray))
         assert(val.dtype.char in typecodes['AllInteger'])
 
@@ -126,6 +131,8 @@
         assert(numpy.shape(vals) == (2, 50))
         assert(vals.dtype.char in typecodes['AllInteger'])
         val = stats.nbinom.rvs(10, 0.75)
+        assert(isinstance(val, int))
+        val = stats.nbinom(10, 0.75).rvs(3)
         assert(isinstance(val, numpy.ndarray))
         assert(val.dtype.char in typecodes['AllInteger'])
 
@@ -136,6 +143,8 @@
         assert(numpy.shape(vals) == (2, 50))
         assert(vals.dtype.char in typecodes['AllInteger'])
         val = stats.geom.rvs(0.75)
+        assert(isinstance(val, int))
+        val = stats.geom(0.75).rvs(3)
         assert(isinstance(val, numpy.ndarray))
         assert(val.dtype.char in typecodes['AllInteger'])
 
@@ -159,6 +168,8 @@
         assert(numpy.shape(vals) == (2, 50))
         assert(vals.dtype.char in typecodes['AllInteger'])
         val = stats.hypergeom.rvs(20, 3, 10)
+        assert(isinstance(val, int))
+        val = stats.hypergeom(20, 3, 10).rvs(3)
         assert(isinstance(val, numpy.ndarray))
         assert(val.dtype.char in typecodes['AllInteger'])
 
@@ -169,6 +180,8 @@
         assert(numpy.shape(vals) == (2, 50))
         assert(vals.dtype.char in typecodes['AllInteger'])
         val = stats.logser.rvs(0.75)
+        assert(isinstance(val, int))
+        val = stats.logser(0.75).rvs(3)
         assert(isinstance(val, numpy.ndarray))
         assert(val.dtype.char in typecodes['AllInteger'])
 
@@ -179,6 +192,8 @@
         assert(numpy.shape(vals) == (2, 50))
         assert(vals.dtype.char in typecodes['AllInteger'])
         val = stats.poisson.rvs(0.5)
+        assert(isinstance(val, int))
+        val = stats.poisson(0.5).rvs(3)
         assert(isinstance(val, numpy.ndarray))
         assert(val.dtype.char in typecodes['AllInteger'])
 
@@ -189,6 +204,8 @@
         assert(numpy.shape(vals) == (2, 50))
         assert(vals.dtype.char in typecodes['AllInteger'])
         val = stats.zipf.rvs(1.5)
+        assert(isinstance(val, int))
+        val = stats.zipf(1.5).rvs(3)
         assert(isinstance(val, numpy.ndarray))
         assert(val.dtype.char in typecodes['AllInteger'])
 
@@ -198,6 +215,8 @@
         assert(numpy.shape(vals) == (2, 50))
         assert(vals.dtype.char in typecodes['AllInteger'])
         val = stats.dlaplace.rvs(1.5)
+        assert(isinstance(val, int))
+        val = stats.dlaplace(1.5).rvs(3)
         assert(isinstance(val, numpy.ndarray))
         assert(val.dtype.char in typecodes['AllInteger'])
 
@@ -208,10 +227,14 @@
         samples = 1000
         r = stats.rv_discrete(name='sample',values=(states,probability))
         x = r.rvs(size=samples)
+        assert(isinstance(x, numpy.ndarray))
 
         for s,p in zip(states,probability):
             assert abs(sum(x == s)/float(samples) - p) < 0.05
 
+        x = r.rvs()
+        assert(isinstance(x, int))
+
 class TestExpon(TestCase):
     def test_zero(self):
         assert_equal(stats.expon.pdf(0),1)




More information about the Scipy-svn mailing list