[Scipy-svn] r2238 - in trunk/Lib/stats: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Sat Sep 30 19:39:37 EDT 2006


Author: stefan
Date: 2006-09-30 18:39:22 -0500 (Sat, 30 Sep 2006)
New Revision: 2238

Modified:
   trunk/Lib/stats/stats.py
   trunk/Lib/stats/tests/test_stats.py
Log:
For stats.threshold: update doc, add test, optimise.


Modified: trunk/Lib/stats/stats.py
===================================================================
--- trunk/Lib/stats/stats.py	2006-09-29 16:47:46 UTC (rev 2237)
+++ trunk/Lib/stats/stats.py	2006-09-30 23:39:22 UTC (rev 2238)
@@ -1203,19 +1203,24 @@
 #####################################
 
 def threshold(a, threshmin=None, threshmax=None, newval=0):
-    """
-Like numpy.clip() except that values <threshmid or >threshmax are replaced
-by newval instead of by threshmin/threshmax (respectively).
+    """Clip array to a given value.
+    
+Similar to numpy.clip(), except that values less than threshmin or
+greater than threshmax are replaced by newval, instead of by
+threshmin and threshmax respectively.
 
-Returns: a, with values <threshmin or >threshmax replaced with newval
+Returns: a, with values less than threshmin or greater than threshmax
+         replaced with newval
+
 """
-    a = asarray(a)
+    a = asarray(a).copy()
     mask = zeros(a.shape, dtype=bool)
-    if threshmin != None:
-        mask |= (a < threshmin)
-    if threshmax != None:
+    if threshmin is not None:
+        mask = (a < threshmin)
+    if threshmax is not None:
         mask |= (a > threshmax)
-    return np.where(mask, newval, a)
+    a[mask] = newval
+    return a
 
 
 def trimboth(a, proportiontocut):

Modified: trunk/Lib/stats/tests/test_stats.py
===================================================================
--- trunk/Lib/stats/tests/test_stats.py	2006-09-29 16:47:46 UTC (rev 2237)
+++ trunk/Lib/stats/tests/test_stats.py	2006-09-30 23:39:22 UTC (rev 2238)
@@ -759,5 +759,16 @@
         y = scipy.stats.kurtosis(self.testcase,0,0)
         assert_approx_equal(y,1.64)
 
+class test_threshold(ScipyTestCase):
+    def check_basic(self):
+        a = [-1,2,3,4,5,-1,-2]
+        assert_array_equal(stats.threshold(a),a)
+        assert_array_equal(stats.threshold(a,3,None,0),
+                           [0,0,3,4,5,0,0])
+        assert_array_equal(stats.threshold(a,None,3,0),
+                           [-1,2,3,0,0,-1,-2])
+        assert_array_equal(stats.threshold(a,2,4,0),
+                           [0,2,3,4,0,0,0])
+        
 if __name__ == "__main__":
     ScipyTest().run()




More information about the Scipy-svn mailing list