[Numpy-discussion] Vectorizing code, for loops, and all that
Tim Hochberg
tim.hochberg at ieee.org
Mon Oct 2 19:49:52 EDT 2006
Travis Oliphant wrote:
> Albert Strasheim wrote:
>
>
>> In [571]: x1 = N.random.randn(2000,39)
>>
>> In [572]: y1 = N.random.randn(64,39)
>>
>> In [574]: %timeit z1=x1[...,N.newaxis,...]-y1 10 loops, best of 3: 703 ms
>> per loop
>>
>> In [575]: z1.shape
>> Out[575]: (2000, 64, 39)
>>
>> As far as I can figure, this operation is doing 2000*64*39 subtractions.
>> Doing this straight up yields the following:
>>
>> In [576]: x2 = N.random.randn(2000,64,39)
>>
>> In [577]: y2 = N.random.randn(2000,64,39)
>>
>> In [578]: %timeit z2 = x2-y2
>> 10 loops, best of 3: 108 ms per loop
>>
>> Does anybody have any ideas on why this is so much faster? Hopefully I
>> didn't mess up somewhere...
>>
>>
>>
>
> I suspect I know why, although the difference seems rather large. There
> is code optimization that is being taken advantage of in the second
> case. If you have contiguous arrays (no broadcasting needed), then 1
> C-loop is used for the subtraction (your second case).
>
> In the first case you are using broadcasting to generate the larger
> array. This requires more complicated looping constructs under the
> covers which causes your overhead. Bascially, you will have 64*39 1-d
> loops of 2000 elements each in the first example with a bit of
> calculation over-head to reset the pointers before each loop.
>
> In the ufunc code, compare the ONE_UFUNCLOOP case with the
> NOBUFFER_UFUNCLOOP case. If you want to be sure what is running
> un-comment the fprintf statements so you can tell.
>
> I'm surprised the overhead of adjusting pointers is so high, but then
> again you are probably getting a lot of cache misses in the first case
> so there is more to it than that, the loops may run more slowly too.
>
>
I suspect that Travis is basically right about why your simple
subtraction runs much faster than your test case.
However, that doesn't mean you can't do better than at present. When
dealing with large, multidimensional arrays my experience has been that
unrolling all of the for loops is frequently counterproductive. I chalk
this up to two factors: first that you tend to end up generating large
temporary arrays and this in turn leads to cache misses. Second, you
loose flexibility in how you perform the calculation, which in turns
limits other possible explanations.
I just spent a while playing with this, and assuming I've correctly
translated your original intent I've come up with two alternative,
looping versions that run, respectively 2 and 3 times faster. I've a
feeling that kmean3, the fastest one, still has a little more room to be
sped up, but I'm out of time now. Code is below
-tim
import numpy as N
data = N.random.randn(2000, 39)
def kmean0(data):
nclusters = 64
code = data[0:nclusters,:]
return N.sum((data[...,N.newaxis,...]-code)**2, 2).argmin(axis=1)
def kmean1(data):
nclusters = 64
code = data[0:nclusters,:]
z = data[:,N.newaxis,:]
z = z-code
z = z**2
z = N.sum(z, 2)
return z.argmin(axis=1)
def kmean2(data):
nclusters = 64
naxes = data.shape[-1]
code = data[0:nclusters,:]
data = data[:, N.newaxis]
allz = N.zeros([len(data)])
for i, x in enumerate(data):
z = (x - code)
z **= 2
allz[i] = z.sum(-1).argmin(0)
return allz
def kmean3(data):
nclusters = 64
naxes = data.shape[-1]
code = data[0:nclusters]
totals = N.zeros([nclusters, len(data)], float)
transdata = data.transpose().copy()
for cluster, tot in zip(code, totals):
for di, ci in zip(transdata, cluster):
delta = di - ci
delta **=2
tot += delta
return totals.argmin(axis=0)
if __name__ == '__main__':
assert N.alltrue(kmean0(data) == kmean1(data))
assert N.alltrue(kmean0(data) == kmean2(data))
assert N.alltrue(kmean0(data) == kmean3(data))
from timeit import Timer
print Timer('kmean0(data)', 'from scratch import kmean0,
data').timeit(3)
print Timer('kmean1(data)', 'from scratch import kmean1,
data').timeit(3)
print Timer('kmean2(data)', 'from scratch import kmean2,
data').timeit(3)
print Timer('kmean3(data)', 'from scratch import kmean3,
data').timeit(3)
More information about the NumPy-Discussion
mailing list