Sudoku solver: reduction + brute force

Pavol.Severa at gmail.com Pavol.Severa at gmail.com
Sat Jan 14 09:30:51 EST 2006


ago wrote:
> Inspired by some recent readings on LinuxJournal and an ASPN recipe, I
> decided to revamp my old python hack... The new code is a combination
> of (2) reduction methods and brute force and it is quite faster than
> the
> ASPN program. If anyone is interested I attached the code in
> http://agolb.blogspot.com/2006/01/sudoku-solver-in-python.html

I suggest trying

input="""
0,0,0,0,9,6,8,0,0
0,0,1,0,0,0,0,7,0
0,2,0,0,0,0,0,0,3
0,3,0,0,0,8,0,0,6
0,0,4,0,2,0,3,0,0
6,0,0,5,0,0,0,8,0
9,0,0,0,0,0,0,5,0
0,7,0,0,0,0,1,0,0
0,0,5,9,4,0,0,0,0"""

your program seems to take too long to solve it.

I think the hard part is not to solve, but rather to create *difficult*
sudoku grids.
But to inflate my ego beyond the known universe, here is my solver
(that solves the avove mentioned grid reasonably fast). I suppose the
only difference is that is uses 3, rather than 2, rules to simplify
before starting tree-like search.



#########
#if a copyryght is needed:
#this is pulbic domain, do with it whatever you want
#i.e. most probably nothing
#########

class DeadEnd(Exception):
    pass

class sudoku(object):

    def __init__(self,*args):
        self.changed=True
        self.possible=[]
        if len(args) != 81:
            raise ValueError, "need 81 numbers"
        for i in args:
            if i==0:
                self.possible.append(range(1,10))
            else:
                self.possible.append([i])

    def __getitem__(self,(x,y)):
        return self.possible[9*x+y]

    def __setitem__(self,(x,y),what):
        self.possible[9*x+y]=what

    def copy(self):
        result=sudoku(*(81*[1]))
        for i in range(9):
            for j in range(9):
                result[i,j]=list(self[i,j])
        return result

    def solved(self):
        for i in range(9):
            for j in range(9):
                if len(self[i,j]) != 1:
                    return False
        return True

    def trials(self):
        for i,j in ((i,j) for ln in range(2,10)
                    for i in range(9) for j in range(9)
                    if len(self[i,j])==ln):
            for k in self[i,j]:
                new=self.copy()
                new[i,j]=[k]
                yield new

    def clean1(self,x,y):
        self.changed=False
        if len(self[x,y]) == 1:
            return
        remove=set()
        for places in self.regions(x,y):
            missing=set(range(1,10))
            for xx,yy in places:
                if xx==x and yy==y:
                    continue
                if len(self[xx,yy])==1:
                    remove.add(self[xx,yy][0])
                missing-=set(self[xx,yy])
            if missing:
                a=missing.pop()
                self[x,y]=[a]
                self.changed=True
        for a in remove:
            try:
                self[x,y].remove(a)
                if not self[x,y]:
                    raise DeadEnd
                self.changed=True
            except ValueError:
                pass

    def clean3(self,out1,out2):
        for (o1, o2) in ((out1,out2), (out2,out1)):
            remove=set(range(1,10))
            for x,y in o1:
                remove-=set(self[x,y])
            for x,y in o2:
                for n in remove:
                    try:
                        self[x,y].remove(n)
                        if not self[x,y]:
                            raise DeadEnd
                        self.changed=True
                    except ValueError:
                        pass

    @staticmethod
    def regions(x,y):
        return  (((xx,y) for xx in range(9)),
                 ((x,yy) for yy in range(9)),
                 ((xx,yy) for xx in range(3*(x//3),3*(x//3)+3)
                      for yy in range(3*(y//3),3*(y//3)+3)))


    @staticmethod
    def outs():
        for i in range(3):
            for j in range(3):
                for k in range(3):
                    out1=[(a+3*i,b+3*j) for a in range(3)
                            if a is not k for b in range(3)]
                    out2=[(k+3*i,n) for n in range(9) if n//3!=j]
                    yield out1, out2
                for k in range(3):
                    out1=[(a+3*i,b+3*j) for a in range(3)
                            for b in range(3) if b is not k]
                    out2=[(n,k+3*j) for n in range(9) if n//3!=i]
                    yield out1, out2

    def clean_all(self):
        while self.changed:
            self.changed=False
            for x in range(9):
                for y in range(9):
                    self.clean1(x,y)
            for out1,out2 in self.outs():
                self.clean3(out1,out2)

    def __repr__(self):
        result=""
        for x in range(9):
            for y in range(9):
                if len(self[x,y])==1:
                    haf=self[x,y][0]
                else:
                    haf=self[x,y]
                result+=str(haf)+' '
            result+='\n'
        return result




from collections import deque

class liter(object):

    def __init__(self, *iterables):
        self.iters=deque(iter(x) for x in iterables)

    def __iter__(self):
        while self.iters:
            it=self.iters.popleft()
            try:
                result=it.next()
            except StopIteration:
                continue
            self.iters.append(it)
            yield result

    def append(self,what):
        self.iters.append(iter(what))



def solve(me):
    tree=liter([me])
    for you in tree:
        try:
            you.clean_all()
        except DeadEnd:
            continue
        if you.solved():
            return you
        tree.append(you.trials())

######

input=(
0,0,0,0,9,6,8,0,0,
0,0,1,0,0,0,0,7,0,
0,2,0,0,0,0,0,0,3,
0,3,0,0,0,8,0,0,6,
0,0,4,0,2,0,3,0,0,
6,0,0,5,0,0,0,8,0,
9,0,0,0,0,0,0,5,0,
0,7,0,0,0,0,1,0,0,
0,0,5,9,4,0,0,0,0)

result=solve(sudoku(*input))
print result




More information about the Python-list mailing list