[Numpy-discussion] Permutation in Numpy

David M. Cooke cookedm at physics.mcmaster.ca
Wed Jul 28 16:06:03 EDT 2004


On Sun, Jul 25, 2004 at 07:24:49AM -0400, Hee-Seng Kye wrote:
> #perm.py
> def perm(k):
>     # Compute the list of all permutations of k
>     if len(k) <= 1:
>         return [k]
>     r = []
>     for i in range(len(k)):
>         s =  k[:i] + k[i+1:]
>         p = perm(s)
>         for x in p:
>             r.append(k[i:i+1] + x)
>     return r
> 
> Does anyone know if there is a built-in function in Numpy (or Numarray) 
> that does the above task faster (computes the list of all permutations 
> of a list, k)?  Or is there a way to make the above function run faster 
> using Numpy?
> 
> I'm asking because I need to create a very large list which contains 
> all permutations of range(12), in which case there would be 12! 
> permutations.  I created a file test.py:

Do you really need a *list* of all those permutations? Think about it:
12! is about 0.5 billion, which is about as much RAM as your machine
has. Each permutation is going to be a list taking 20 bytes of overhead
plus 4 bytes per entry, so 68 bytes per permutation. You need 32 GB of
RAM to store that.

You probably want to just be able to access them in order, so a
generator is a better bet. That way, you're only storing the current
permutation instead of all of them. Something like

def perm(k):
    k = tuple(k)
    lk = len(k)
    if lk <= 1:
        yield k
    else:
        for i in range(lk):
            s = k[:i] + k[i+1:]
            t = (k[i],)
            for x in perm(s):
                yield t + x

Then:

for p in perm(range(12):
    print p

(I'm using tuples instead of lists as that gives a better performance
here.)

For n = 9, your code takes 9.4 s on my machine. The above take 3 s, and
will scale with n (n=12 should take 3s * 10*11*12= 1.1 h). Your original
code won't scale with n, as more and more time will be taken up
reallocated the list of permutations.

We can get fancier and unroll it a bit more:
def perm(k):
    k = tuple(k)
    lk = len(k)
    if lk <= 1:
        yield k
    elif lk == 2:
        yield k
        yield (k[1], k[0])
    elif lk == 3:
        k0, k1, k2 = k
        yield k
        yield (k0, k2, k1)
        yield (k1, k0, k2)
        yield (k1, k2, k0)
        yield (k2, k0, k1)
        yield (k2, k1, k0)
    else:
        for i in range(lk):
            s = k[:i] + k[i+1:]
            t = (k[i],)
            for x in perm(s):
                yield t + x

This takes 1.3 s for n = 9 on my machine.

Hope this helps.

-- 
|>|\/|<
/--------------------------------------------------------------------------\
|David M. Cooke                      http://arbutus.physics.mcmaster.ca/dmc/
|cookedm at physics.mcmaster.ca




More information about the NumPy-Discussion mailing list