[Scipy-svn] r7114 - in trunk/scipy/special: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Jan 31 16:51:56 EST 2011


Author: ptvirtan
Date: 2011-01-31 15:51:56 -0600 (Mon, 31 Jan 2011)
New Revision: 7114

Added:
   trunk/scipy/special/_testutils.py
Removed:
   trunk/scipy/special/tests/testutils.py
Modified:
   trunk/scipy/special/tests/test_basic.py
   trunk/scipy/special/tests/test_data.py
   trunk/scipy/special/tests/test_lambertw.py
   trunk/scipy/special/tests/test_mpmath.py
   trunk/scipy/special/tests/test_orthogonal_eval.py
Log:
TST: move local test helper in the modules -- nose doesn't like them in tests/

Apparently, nose puts all test/ directories in sys.path, and so one
cannot have more than one "testutils.py" in total.  Better to have them
as a _testutils module in the subpackage.

Copied: trunk/scipy/special/_testutils.py (from rev 7111, trunk/scipy/special/tests/testutils.py)
===================================================================
--- trunk/scipy/special/_testutils.py	                        (rev 0)
+++ trunk/scipy/special/_testutils.py	2011-01-31 21:51:56 UTC (rev 7114)
@@ -0,0 +1,236 @@
+import os
+import warnings
+
+import numpy as np
+from numpy.testing import assert_
+from numpy.testing.noseclasses import KnownFailureTest
+
+import scipy.special as sc
+
+__all__ = ['with_special_errors', 'assert_tol_equal', 'assert_func_equal',
+           'FuncData']
+
+#------------------------------------------------------------------------------
+# Enable convergence and loss of precision warnings -- turn off one by one
+#------------------------------------------------------------------------------
+
+def with_special_errors(func):
+    """
+    Enable special function errors (such as underflow, overflow,
+    loss of precision, etc.)
+    """
+    def wrapper(*a, **kw):
+        old_filters = list(getattr(warnings, 'filters', []))
+        old_errprint = sc.errprint(1)
+        warnings.filterwarnings("error", category=sc.SpecialFunctionWarning)
+        try:
+            return func(*a, **kw)
+        finally:
+            sc.errprint(old_errprint)
+            setattr(warnings, 'filters', old_filters)
+    wrapper.__name__ = func.__name__
+    wrapper.__doc__ = func.__doc__
+    return wrapper
+
+#------------------------------------------------------------------------------
+# Comparing function values at many data points at once, with helpful
+#------------------------------------------------------------------------------
+
+def assert_tol_equal(a, b, rtol=1e-7, atol=0, err_msg='', verbose=True):
+    """Assert that `a` and `b` are equal to tolerance ``atol + rtol*abs(b)``"""
+    def compare(x, y):
+        return np.allclose(x, y, rtol=rtol, atol=atol)
+    a, b = np.asanyarray(a), np.asanyarray(b)
+    header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
+    np.testing.utils.assert_array_compare(compare, a, b, err_msg=str(err_msg),
+                                          verbose=verbose, header=header)
+
+#------------------------------------------------------------------------------
+# Comparing function values at many data points at once, with helpful
+# error reports
+#------------------------------------------------------------------------------
+
+def assert_func_equal(func, results, points, rtol=None, atol=None,
+                      param_filter=None, knownfailure=None,
+                      vectorized=True, dtype=None):
+    if hasattr(points, 'next'):
+        # it's a generator
+        points = list(points)
+
+    points = np.asarray(points)
+    if points.ndim == 1:
+        points = points[:,None]
+
+    if hasattr(results, '__name__'):
+        # function
+        if vectorized:
+            results = results(*tuple(points.T))
+        else:
+            results = np.array([results(*tuple(p)) for p in points])
+            if results.dtype == object:
+                try:
+                    results = results.astype(float)
+                except TypeError:
+                    results = results.astype(complex)
+    else:
+        results = np.asarray(results)
+
+    npoints = points.shape[1]
+
+    data = np.c_[points, results]
+    fdata = FuncData(func, data, range(npoints), range(npoints, data.shape[1]),
+                     rtol=rtol, atol=atol, param_filter=param_filter,
+                     knownfailure=knownfailure)
+    fdata.check()
+
+class FuncData(object):
+    """
+    Data set for checking a special function.
+
+    Parameters
+    ----------
+    func : function
+        Function to test
+    filename : str
+        Input file name
+    param_columns : int or tuple of ints
+        Columns indices in which the parameters to `func` lie.
+        Can be imaginary integers to indicate that the parameter
+        should be cast to complex.
+    result_columns : int or tuple of ints
+        Column indices for expected results from `func`.
+    rtol : float, optional
+        Required relative tolerance. Default is 5*eps.
+    atol : float, optional
+        Required absolute tolerance. Default is 5*tiny.
+    param_filter : function, or tuple of functions/Nones, optional
+        Filter functions to exclude some parameter ranges.
+        If omitted, no filtering is done.
+    knownfailure : str, optional
+        Known failure error message to raise when the test is run.
+        If omitted, no exception is raised.
+
+    """
+
+    def __init__(self, func, data, param_columns, result_columns,
+                 rtol=None, atol=None, param_filter=None, knownfailure=None,
+                 dataname=None):
+        self.func = func
+        self.data = data
+        self.dataname = dataname
+        if not hasattr(param_columns, '__len__'):
+            param_columns = (param_columns,)
+        if not hasattr(result_columns, '__len__'):
+            result_columns = (result_columns,)
+        self.param_columns = tuple(param_columns)
+        self.result_columns = tuple(result_columns)
+        self.rtol = rtol
+        self.atol = atol
+        if not hasattr(param_filter, '__len__'):
+            param_filter = (param_filter,)
+        self.param_filter = param_filter
+        self.knownfailure = knownfailure
+
+    def get_tolerances(self, dtype):
+        info = np.finfo(dtype)
+        rtol, atol = self.rtol, self.atol
+        if rtol is None:
+            rtol = 5*info.eps
+        if atol is None:
+            atol = 5*info.tiny
+        return rtol, atol
+
+    def check(self, data=None, dtype=None):
+        """Check the special function against the data."""
+
+        if self.knownfailure:
+            raise KnownFailureTest(self.knownfailure)
+
+        if data is None:
+            data = self.data
+
+        if dtype is None:
+            dtype = data.dtype
+        else:
+            data = data.astype(dtype)
+
+        rtol, atol = self.get_tolerances(dtype)
+
+        # Apply given filter functions
+        if self.param_filter:
+            param_mask = np.ones((data.shape[0],), np.bool_)
+            for j, filter in zip(self.param_columns, self.param_filter):
+                if filter:
+                    param_mask &= filter(data[:,j])
+            data = data[param_mask]
+
+        # Pick parameters and results from the correct columns
+        params = []
+        for j in self.param_columns:
+            if np.iscomplexobj(j):
+                j = int(j.imag)
+                params.append(data[:,j].astype(np.complex))
+            else:
+                params.append(data[:,j])
+        wanted = tuple([data[:,j] for j in self.result_columns])
+
+        # Evaluate
+        got = self.func(*params)
+        if not isinstance(got, tuple):
+            got = (got,)
+
+        # Check the validity of each output returned
+
+        assert_(len(got) == len(wanted))
+
+        for output_num, (x, y) in enumerate(zip(got, wanted)):
+            pinf_x = np.isinf(x) & (x > 0)
+            pinf_y = np.isinf(y) & (x > 0)
+            minf_x = np.isinf(x) & (x < 0)
+            minf_y = np.isinf(y) & (x < 0)
+            nan_x = np.isnan(x)
+            nan_y = np.isnan(y)
+
+            abs_y = np.absolute(y)
+            abs_y[~np.isfinite(abs_y)] = 0
+            diff = np.absolute(x - y)
+            diff[~np.isfinite(diff)] = 0
+
+            rdiff = diff / np.absolute(y)
+            rdiff[~np.isfinite(rdiff)] = 0
+
+            tol_mask = (diff < atol + rtol*abs_y)
+            pinf_mask = (pinf_x == pinf_y)
+            minf_mask = (minf_x == minf_y)
+            nan_mask = (nan_x == nan_y)
+
+            bad_j = ~(tol_mask & pinf_mask & minf_mask & nan_mask)
+
+            if np.any(bad_j):
+                # Some bad results: inform what, where, and how bad
+                msg = [""]
+                msg.append("Max |adiff|: %g" % diff.max())
+                msg.append("Max |rdiff|: %g" % rdiff.max())
+                msg.append("Bad results for the following points (in output %d):"
+                           % output_num)
+                for j in np.where(bad_j)[0]:
+                    j = int(j)
+                    fmt = lambda x: "%30s" % np.array2string(x[j], precision=18)
+                    a = "  ".join(map(fmt, params))
+                    b = "  ".join(map(fmt, got))
+                    c = "  ".join(map(fmt, wanted))
+                    d = fmt(rdiff)
+                    msg.append("%s => %s != %s  (rdiff %s)" % (a, b, c, d))
+                assert_(False, "\n".join(msg))
+
+    def __repr__(self):
+        """Pretty-printing, esp. for Nose output"""
+        if np.any(map(np.iscomplexobj, self.param_columns)):
+            is_complex = " (complex)"
+        else:
+            is_complex = ""
+        if self.dataname:
+            return "<Data for %s%s: %s>" % (self.func.__name__, is_complex,
+                                            os.path.basename(self.dataname))
+        else:
+            return "<Data for %s%s>" % (self.func.__name__, is_complex)

Modified: trunk/scipy/special/tests/test_basic.py
===================================================================
--- trunk/scipy/special/tests/test_basic.py	2011-01-31 21:18:57 UTC (rev 7113)
+++ trunk/scipy/special/tests/test_basic.py	2011-01-31 21:51:56 UTC (rev 7114)
@@ -30,7 +30,7 @@
 from scipy import special
 import scipy.special._cephes as cephes
 
-from testutils import assert_tol_equal, with_special_errors
+from scipy.special._testutils import assert_tol_equal, with_special_errors
 
 class TestCephes(TestCase):
     def test_airy(self):

Modified: trunk/scipy/special/tests/test_data.py
===================================================================
--- trunk/scipy/special/tests/test_data.py	2011-01-31 21:18:57 UTC (rev 7113)
+++ trunk/scipy/special/tests/test_data.py	2011-01-31 21:51:56 UTC (rev 7114)
@@ -8,7 +8,7 @@
     zeta, gammaincinv, lpmv
 )
 
-from testutils import FuncData
+from scipy.special._testutils import FuncData
 
 DATASETS = np.load(os.path.join(os.path.dirname(__file__),
                                 "data", "boost.npz"))

Modified: trunk/scipy/special/tests/test_lambertw.py
===================================================================
--- trunk/scipy/special/tests/test_lambertw.py	2011-01-31 21:18:57 UTC (rev 7113)
+++ trunk/scipy/special/tests/test_lambertw.py	2011-01-31 21:51:56 UTC (rev 7114)
@@ -11,7 +11,7 @@
 from scipy.special import lambertw
 from numpy import nan, inf, pi, e, isnan, log, r_, array, complex_
 
-from testutils import FuncData
+from scipy.special._testutils import FuncData
 
 
 def test_values():

Modified: trunk/scipy/special/tests/test_mpmath.py
===================================================================
--- trunk/scipy/special/tests/test_mpmath.py	2011-01-31 21:18:57 UTC (rev 7113)
+++ trunk/scipy/special/tests/test_mpmath.py	2011-01-31 21:51:56 UTC (rev 7114)
@@ -7,7 +7,7 @@
 from numpy.testing import dec
 import scipy.special as sc
 
-from testutils import FuncData, assert_func_equal
+from scipy.special._testutils import FuncData, assert_func_equal
 
 try:
     import mpmath

Modified: trunk/scipy/special/tests/test_orthogonal_eval.py
===================================================================
--- trunk/scipy/special/tests/test_orthogonal_eval.py	2011-01-31 21:18:57 UTC (rev 7113)
+++ trunk/scipy/special/tests/test_orthogonal_eval.py	2011-01-31 21:51:56 UTC (rev 7114)
@@ -2,7 +2,7 @@
 from numpy.testing import assert_
 import scipy.special.orthogonal as orth
 
-from testutils import FuncData
+from scipy.special._testutils import FuncData
 
 
 def test_eval_chebyt():

Deleted: trunk/scipy/special/tests/testutils.py
===================================================================
--- trunk/scipy/special/tests/testutils.py	2011-01-31 21:18:57 UTC (rev 7113)
+++ trunk/scipy/special/tests/testutils.py	2011-01-31 21:51:56 UTC (rev 7114)
@@ -1,236 +0,0 @@
-import os
-import warnings
-
-import numpy as np
-from numpy.testing import assert_
-from numpy.testing.noseclasses import KnownFailureTest
-
-import scipy.special as sc
-
-__all__ = ['with_special_errors', 'assert_tol_equal', 'assert_func_equal',
-           'FuncData']
-
-#------------------------------------------------------------------------------
-# Enable convergence and loss of precision warnings -- turn off one by one
-#------------------------------------------------------------------------------
-
-def with_special_errors(func):
-    """
-    Enable special function errors (such as underflow, overflow,
-    loss of precision, etc.)
-    """
-    def wrapper(*a, **kw):
-        old_filters = list(getattr(warnings, 'filters', []))
-        old_errprint = sc.errprint(1)
-        warnings.filterwarnings("error", category=sc.SpecialFunctionWarning)
-        try:
-            return func(*a, **kw)
-        finally:
-            sc.errprint(old_errprint)
-            setattr(warnings, 'filters', old_filters)
-    wrapper.__name__ = func.__name__
-    wrapper.__doc__ = func.__doc__
-    return wrapper
-
-#------------------------------------------------------------------------------
-# Comparing function values at many data points at once, with helpful
-#------------------------------------------------------------------------------
-
-def assert_tol_equal(a, b, rtol=1e-7, atol=0, err_msg='', verbose=True):
-    """Assert that `a` and `b` are equal to tolerance ``atol + rtol*abs(b)``"""
-    def compare(x, y):
-        return np.allclose(x, y, rtol=rtol, atol=atol)
-    a, b = np.asanyarray(a), np.asanyarray(b)
-    header = 'Not equal to tolerance rtol=%g, atol=%g' % (rtol, atol)
-    np.testing.utils.assert_array_compare(compare, a, b, err_msg=str(err_msg),
-                                          verbose=verbose, header=header)
-
-#------------------------------------------------------------------------------
-# Comparing function values at many data points at once, with helpful
-# error reports
-#------------------------------------------------------------------------------
-
-def assert_func_equal(func, results, points, rtol=None, atol=None,
-                      param_filter=None, knownfailure=None,
-                      vectorized=True, dtype=None):
-    if hasattr(points, 'next'):
-        # it's a generator
-        points = list(points)
-
-    points = np.asarray(points)
-    if points.ndim == 1:
-        points = points[:,None]
-
-    if hasattr(results, '__name__'):
-        # function
-        if vectorized:
-            results = results(*tuple(points.T))
-        else:
-            results = np.array([results(*tuple(p)) for p in points])
-            if results.dtype == object:
-                try:
-                    results = results.astype(float)
-                except TypeError:
-                    results = results.astype(complex)
-    else:
-        results = np.asarray(results)
-
-    npoints = points.shape[1]
-
-    data = np.c_[points, results]
-    fdata = FuncData(func, data, range(npoints), range(npoints, data.shape[1]),
-                     rtol=rtol, atol=atol, param_filter=param_filter,
-                     knownfailure=knownfailure)
-    fdata.check()
-
-class FuncData(object):
-    """
-    Data set for checking a special function.
-
-    Parameters
-    ----------
-    func : function
-        Function to test
-    filename : str
-        Input file name
-    param_columns : int or tuple of ints
-        Columns indices in which the parameters to `func` lie.
-        Can be imaginary integers to indicate that the parameter
-        should be cast to complex.
-    result_columns : int or tuple of ints
-        Column indices for expected results from `func`.
-    rtol : float, optional
-        Required relative tolerance. Default is 5*eps.
-    atol : float, optional
-        Required absolute tolerance. Default is 5*tiny.
-    param_filter : function, or tuple of functions/Nones, optional
-        Filter functions to exclude some parameter ranges.
-        If omitted, no filtering is done.
-    knownfailure : str, optional
-        Known failure error message to raise when the test is run.
-        If omitted, no exception is raised.
-
-    """
-
-    def __init__(self, func, data, param_columns, result_columns,
-                 rtol=None, atol=None, param_filter=None, knownfailure=None,
-                 dataname=None):
-        self.func = func
-        self.data = data
-        self.dataname = dataname
-        if not hasattr(param_columns, '__len__'):
-            param_columns = (param_columns,)
-        if not hasattr(result_columns, '__len__'):
-            result_columns = (result_columns,)
-        self.param_columns = tuple(param_columns)
-        self.result_columns = tuple(result_columns)
-        self.rtol = rtol
-        self.atol = atol
-        if not hasattr(param_filter, '__len__'):
-            param_filter = (param_filter,)
-        self.param_filter = param_filter
-        self.knownfailure = knownfailure
-
-    def get_tolerances(self, dtype):
-        info = np.finfo(dtype)
-        rtol, atol = self.rtol, self.atol
-        if rtol is None:
-            rtol = 5*info.eps
-        if atol is None:
-            atol = 5*info.tiny
-        return rtol, atol
-
-    def check(self, data=None, dtype=None):
-        """Check the special function against the data."""
-
-        if self.knownfailure:
-            raise KnownFailureTest(self.knownfailure)
-
-        if data is None:
-            data = self.data
-
-        if dtype is None:
-            dtype = data.dtype
-        else:
-            data = data.astype(dtype)
-
-        rtol, atol = self.get_tolerances(dtype)
-
-        # Apply given filter functions
-        if self.param_filter:
-            param_mask = np.ones((data.shape[0],), np.bool_)
-            for j, filter in zip(self.param_columns, self.param_filter):
-                if filter:
-                    param_mask &= filter(data[:,j])
-            data = data[param_mask]
-
-        # Pick parameters and results from the correct columns
-        params = []
-        for j in self.param_columns:
-            if np.iscomplexobj(j):
-                j = int(j.imag)
-                params.append(data[:,j].astype(np.complex))
-            else:
-                params.append(data[:,j])
-        wanted = tuple([data[:,j] for j in self.result_columns])
-
-        # Evaluate
-        got = self.func(*params)
-        if not isinstance(got, tuple):
-            got = (got,)
-
-        # Check the validity of each output returned
-
-        assert_(len(got) == len(wanted))
-
-        for output_num, (x, y) in enumerate(zip(got, wanted)):
-            pinf_x = np.isinf(x) & (x > 0)
-            pinf_y = np.isinf(y) & (x > 0)
-            minf_x = np.isinf(x) & (x < 0)
-            minf_y = np.isinf(y) & (x < 0)
-            nan_x = np.isnan(x)
-            nan_y = np.isnan(y)
-
-            abs_y = np.absolute(y)
-            abs_y[~np.isfinite(abs_y)] = 0
-            diff = np.absolute(x - y)
-            diff[~np.isfinite(diff)] = 0
-
-            rdiff = diff / np.absolute(y)
-            rdiff[~np.isfinite(rdiff)] = 0
-
-            tol_mask = (diff < atol + rtol*abs_y)
-            pinf_mask = (pinf_x == pinf_y)
-            minf_mask = (minf_x == minf_y)
-            nan_mask = (nan_x == nan_y)
-
-            bad_j = ~(tol_mask & pinf_mask & minf_mask & nan_mask)
-
-            if np.any(bad_j):
-                # Some bad results: inform what, where, and how bad
-                msg = [""]
-                msg.append("Max |adiff|: %g" % diff.max())
-                msg.append("Max |rdiff|: %g" % rdiff.max())
-                msg.append("Bad results for the following points (in output %d):"
-                           % output_num)
-                for j in np.where(bad_j)[0]:
-                    j = int(j)
-                    fmt = lambda x: "%30s" % np.array2string(x[j], precision=18)
-                    a = "  ".join(map(fmt, params))
-                    b = "  ".join(map(fmt, got))
-                    c = "  ".join(map(fmt, wanted))
-                    d = fmt(rdiff)
-                    msg.append("%s => %s != %s  (rdiff %s)" % (a, b, c, d))
-                assert_(False, "\n".join(msg))
-
-    def __repr__(self):
-        """Pretty-printing, esp. for Nose output"""
-        if np.any(map(np.iscomplexobj, self.param_columns)):
-            is_complex = " (complex)"
-        else:
-            is_complex = ""
-        if self.dataname:
-            return "<Data for %s%s: %s>" % (self.func.__name__, is_complex,
-                                            os.path.basename(self.dataname))
-        else:
-            return "<Data for %s%s>" % (self.func.__name__, is_complex)




More information about the Scipy-svn mailing list