[Scipy-svn] r7145 - in trunk/scipy/misc: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Wed Feb 16 18:43:55 EST 2011
Author: warren.weckesser
Date: 2011-02-16 17:43:54 -0600 (Wed, 16 Feb 2011)
New Revision: 7145
Added:
trunk/scipy/misc/tests/test_common.py
Modified:
trunk/scipy/misc/common.py
Log:
ENH: misc: In pade(), use 'linalg.solve' instead of multiplying by the inverse (ticket #1387)
Modified: trunk/scipy/misc/common.py
===================================================================
--- trunk/scipy/misc/common.py 2011-02-15 19:22:32 UTC (rev 7144)
+++ trunk/scipy/misc/common.py 2011-02-16 23:43:54 UTC (rev 7145)
@@ -323,19 +323,19 @@
from scipy import linalg
an = asarray(an)
N = len(an) - 1
- n = N-m
- if (n < 0):
+ n = N - m
+ if n < 0:
raise ValueError("Order of q <m> must be smaller than len(an)-1.")
- Akj = eye(N+1,n+1)
- Bkj = zeros((N+1,m),'d')
- for row in range(1,m+1):
+ Akj = eye(N+1, n+1)
+ Bkj = zeros((N+1, m), 'd')
+ for row in range(1, m+1):
Bkj[row,:row] = -(an[:row])[::-1]
- for row in range(m+1,N+1):
+ for row in range(m+1, N+1):
Bkj[row,:] = -(an[row-m:row])[::-1]
- C = hstack((Akj,Bkj))
- pq = dot(linalg.inv(C),an)
+ C = hstack((Akj, Bkj))
+ pq = linalg.solve(C, an)
p = pq[:n+1]
- q = r_[1.0,pq[n+1:]]
+ q = r_[1.0, pq[n+1:]]
return poly1d(p[::-1]), poly1d(q[::-1])
def lena():
Added: trunk/scipy/misc/tests/test_common.py
===================================================================
--- trunk/scipy/misc/tests/test_common.py (rev 0)
+++ trunk/scipy/misc/tests/test_common.py 2011-02-16 23:43:54 UTC (rev 7145)
@@ -0,0 +1,32 @@
+
+from numpy.testing import assert_array_equal, assert_array_almost_equal
+
+from scipy.misc import pade
+
+
+def test_pade_trivial():
+ nump, denomp = pade([1.0], 0)
+ assert_array_equal(nump.c, [1.0])
+ assert_array_equal(denomp.c, [1.0])
+
+def test_pade_4term_exp():
+ # First four Taylor coefficients of exp(x).
+ # Unlike poly1d, the first array element is the zero-order term.
+ an = [1.0, 1.0, 0.5, 1.0/6]
+
+ nump, denomp = pade(an, 0)
+ assert_array_almost_equal(nump.c, [1.0/6, 0.5, 1.0, 1.0])
+ assert_array_almost_equal(denomp.c, [1.0])
+
+ nump, denomp = pade(an, 1)
+ assert_array_almost_equal(nump.c, [1.0/6, 2.0/3, 1.0])
+ assert_array_almost_equal(denomp.c, [-1.0/3, 1.0])
+
+ nump, denomp = pade(an, 2)
+ assert_array_almost_equal(nump.c, [1.0/3, 1.0])
+ assert_array_almost_equal(denomp.c, [1.0/6, -2.0/3, 1.0])
+
+ nump, denomp = pade(an, 3)
+ assert_array_almost_equal(nump.c, [1.0])
+ assert_array_almost_equal(denomp.c, [-1.0/6, 0.5, -1.0, 1.0])
+
More information about the Scipy-svn
mailing list