[SciPy-Dev] Slow moment function in scipy.stats
Ralf Gommers
ralf.gommers at gmail.com
Tue Feb 24 02:17:48 EST 2015
Hi Stefan,
On Thu, Feb 19, 2015 at 12:55 PM, stefan <stefan.peterson at rubico.com> wrote:
> Hello list,
>
> First time poster here. Anyway, some time ago I noticed that the scipy
> skewness function was a major bottleneck in an algorithm of mine. Back
> then, I typed up my own replacement and thought no more about it. Today,
> for some unknown reason, I decided to dig a little deeper in this and found
> the major culprit to be the way moments are computed, specifically the use
> of np.power.
>
np.power is indeed slow, see for explanations:
http://stackoverflow.com/questions/25254541/why-is-numpy-power-60x-slower-than-in-lining
http://stackoverflow.com/questions/26770996/why-is-numpy-power-slower-for-integer-exponents
>
> Testing an alternative approach (missing some safeties, obviously) in
> IPython:
>
> In [1]: import numpy as np
>
> In [2]: import scipy.stats as spstat
>
> In [3]: def moment(x, mom=1, axis=0):
> ...: if mom == 1:
> ...: return np.float64(0.0)
> ...: else:
> ...: x_zero_mean = x - np.expand_dims(np.mean(x, axis), axis)
> ...: x_zero_mean_2 = x_zero_mean**2
> ...: s = x_zero_mean_2.copy()
> ...: for k in range(1, mom // 2):
> ...: s *= x_zero_mean_2
> ...: if mom % 2:
> ...: s *= x_zero_mean
> ...: return np.mean(s, axis)
> ...:
>
> In [4]: a = np.random.randn(25,1e5)
>
> In [5]: for ax in range(2):
> ...: for k in range(1, 8):
> ...: %timeit spstat.moment(a, k, ax)
> ...:
> 10000 loops, best of 3: 41.9 µs per loop
> 10 loops, best of 3: 44 ms per loop
> 1 loops, best of 3: 233 ms per loop
> 1 loops, best of 3: 229 ms per loop
> 1 loops, best of 3: 232 ms per loop
> 1 loops, best of 3: 232 ms per loop
> 1 loops, best of 3: 236 ms per loop
> 100000 loops, best of 3: 4.14 µs per loop
> 10 loops, best of 3: 43.3 ms per loop
> 1 loops, best of 3: 227 ms per loop
> 1 loops, best of 3: 225 ms per loop
> 1 loops, best of 3: 227 ms per loop
> 1 loops, best of 3: 232 ms per loop
> 1 loops, best of 3: 232 ms per loop
>
> In [6]: for ax in range(2):
> for k in range(1, 8):
> %timeit moment(a, k, ax)
> ...:
> 1000000 loops, best of 3: 458 ns per loop
> 10 loops, best of 3: 21.7 ms per loop
> 10 loops, best of 3: 26 ms per loop
> 10 loops, best of 3: 25.9 ms per loop
> 10 loops, best of 3: 30.4 ms per loop
> 10 loops, best of 3: 30.3 ms per loop
> 10 loops, best of 3: 33.1 ms per loop
> 1000000 loops, best of 3: 463 ns per loop
> 10 loops, best of 3: 21.2 ms per loop
> 10 loops, best of 3: 25.5 ms per loop
> 10 loops, best of 3: 25.5 ms per loop
> 10 loops, best of 3: 30.1 ms per loop
> 10 loops, best of 3: 30.1 ms per loop
> 10 loops, best of 3: 33.2 ms per loop
>
> In [7]: for ax in range(2):
> for k in range(1, 8):
> print(np.sum((spstat.moment(a, k, ax) - moment(a, k, ax))**2))
> ...:
> 0.0
> 0.0
> 6.87146461986e-28
> 1.49841810127e-26
> 6.84527222501e-26
> 1.26529038747e-24
> 1.35165136907e-23
> 0.0
> 0.0
> 1.34532977463e-33
> 3.94430452611e-30
> 1.5467173476e-31
> 1.95637504495e-28
> 2.48413854543e-29
>
> So there are some rounding errors, but they're hardly alarming. Is there
> another reason not to do it this way?
>
I'd say that replacing one call to np.power with 6 lines of code to achieve
a ~10x speedup is a good tradeoff. Pull request is welcome:)
Cheers,
Ralf
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/scipy-dev/attachments/20150224/10c6acf0/attachment.html>
More information about the SciPy-Dev
mailing list