[Scipy-svn] r5093 - in trunk/scipy/interpolate: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Thu Nov 13 16:21:05 EST 2008
Author: ptvirtan
Date: 2008-11-13 15:20:50 -0600 (Thu, 13 Nov 2008)
New Revision: 5093
Modified:
trunk/scipy/interpolate/fitpack2.py
trunk/scipy/interpolate/tests/test_fitpack.py
Log:
interpolate: don't change the __class__ of a live UnivariateSpline object, if the object is an instance of user-defined subclass. (Fixes #660)
Modified: trunk/scipy/interpolate/fitpack2.py
===================================================================
--- trunk/scipy/interpolate/fitpack2.py 2008-11-13 21:03:45 UTC (rev 5092)
+++ trunk/scipy/interpolate/fitpack2.py 2008-11-13 21:20:50 UTC (rev 5093)
@@ -103,19 +103,28 @@
pass
elif ier==-1:
# the spline returned is an interpolating spline
- self.__class__ = InterpolatedUnivariateSpline
+ self._set_class(InterpolatedUnivariateSpline)
elif ier==-2:
# the spline returned is the weighted least-squares
# polynomial of degree k. In this extreme case fp gives
# the upper bound fp0 for the smoothing factor s.
- self.__class__ = LSQUnivariateSpline
+ self._set_class(LSQUnivariateSpline)
else:
# error
if ier==1:
- self.__class__ = LSQUnivariateSpline
+ self._set_class(LSQUnivariateSpline)
message = _curfit_messages.get(ier,'ier=%s' % (ier))
warnings.warn(message)
+ def _set_class(self, cls):
+ self._spline_class = cls
+ if self.__class__ in (UnivariateSpline, InterpolatedUnivariateSpline,
+ LSQUnivariateSpline):
+ self.__class__ = cls
+ else:
+ # It's an unknown subclass -- don't change class. cf. #660
+ pass
+
def _reset_nest(self, data, nest=None):
n = data[10]
if nest is None:
Modified: trunk/scipy/interpolate/tests/test_fitpack.py
===================================================================
--- trunk/scipy/interpolate/tests/test_fitpack.py 2008-11-13 21:03:45 UTC (rev 5092)
+++ trunk/scipy/interpolate/tests/test_fitpack.py 2008-11-13 21:20:50 UTC (rev 5093)
@@ -36,6 +36,15 @@
assert_almost_equal(lut.get_residual(),0.0)
assert_array_almost_equal(lut([1,1.5,2]),[0,1,2])
+ def test_subclassing(self):
+
+ class ZeroSpline(UnivariateSpline):
+ def __call__(self, x):
+ return 0*array(x)
+
+ sp = ZeroSpline([1,2,3,4,5], [3,2,3,2,3], k=2)
+ assert_array_equal(sp([1.5, 2.5]), [0., 0.])
+
class TestLSQBivariateSpline(TestCase):
def test_linear_constant(self):
x = [1,1,1,2,2,2,3,3,3]
More information about the Scipy-svn
mailing list