[Numpy-discussion] Permutation in Numpy
David M. Cooke
cookedm at physics.mcmaster.ca
Wed Jul 28 16:06:03 EDT 2004
On Sun, Jul 25, 2004 at 07:24:49AM -0400, Hee-Seng Kye wrote:
> #perm.py
> def perm(k):
> # Compute the list of all permutations of k
> if len(k) <= 1:
> return [k]
> r = []
> for i in range(len(k)):
> s = k[:i] + k[i+1:]
> p = perm(s)
> for x in p:
> r.append(k[i:i+1] + x)
> return r
>
> Does anyone know if there is a built-in function in Numpy (or Numarray)
> that does the above task faster (computes the list of all permutations
> of a list, k)? Or is there a way to make the above function run faster
> using Numpy?
>
> I'm asking because I need to create a very large list which contains
> all permutations of range(12), in which case there would be 12!
> permutations. I created a file test.py:
Do you really need a *list* of all those permutations? Think about it:
12! is about 0.5 billion, which is about as much RAM as your machine
has. Each permutation is going to be a list taking 20 bytes of overhead
plus 4 bytes per entry, so 68 bytes per permutation. You need 32 GB of
RAM to store that.
You probably want to just be able to access them in order, so a
generator is a better bet. That way, you're only storing the current
permutation instead of all of them. Something like
def perm(k):
k = tuple(k)
lk = len(k)
if lk <= 1:
yield k
else:
for i in range(lk):
s = k[:i] + k[i+1:]
t = (k[i],)
for x in perm(s):
yield t + x
Then:
for p in perm(range(12):
print p
(I'm using tuples instead of lists as that gives a better performance
here.)
For n = 9, your code takes 9.4 s on my machine. The above take 3 s, and
will scale with n (n=12 should take 3s * 10*11*12= 1.1 h). Your original
code won't scale with n, as more and more time will be taken up
reallocated the list of permutations.
We can get fancier and unroll it a bit more:
def perm(k):
k = tuple(k)
lk = len(k)
if lk <= 1:
yield k
elif lk == 2:
yield k
yield (k[1], k[0])
elif lk == 3:
k0, k1, k2 = k
yield k
yield (k0, k2, k1)
yield (k1, k0, k2)
yield (k1, k2, k0)
yield (k2, k0, k1)
yield (k2, k1, k0)
else:
for i in range(lk):
s = k[:i] + k[i+1:]
t = (k[i],)
for x in perm(s):
yield t + x
This takes 1.3 s for n = 9 on my machine.
Hope this helps.
--
|>|\/|<
/--------------------------------------------------------------------------\
|David M. Cooke http://arbutus.physics.mcmaster.ca/dmc/
|cookedm at physics.mcmaster.ca
More information about the NumPy-Discussion
mailing list