getting a submatrix of all true

Scott David Daniels Scott.Daniels at Acm.Org
Sun Jul 6 04:19:06 EDT 2003


In article <3f07bd72$1 at nntp0.pdx.net>,
	scott.daniels at acm.org (Scott David Daniels) writes:
> Folllowing Bengt and anton at vredegoor.doge.nl (Anton Vredegoor)'s leads,
> the following code can be fast (at times).  It is quite sensitive to the
> probability of non-zeroness, (.01 works well, the .o5 is nowhere near so
> nice).
> 
> I get 2.25 min to do 25 x 25 at .05
>       2.75 min to do 30 x 30 at .05
> It gets slow _very_ fast, but gets good numbers if the probability is low
>        .25 min to do 45 x 45 at .01
>        1.5 min to do 50 x 50 at .01

OK, a bit faster: (bitcount faster, allow choose to sometimes quit early).
I get 0.25 min to do 25 x 25 at .05
       5.5 min to do 30 x 30 at .05	(?! maybe above was wrong)
and:
        1 sec to do 45 x 45 at .01
        9 sec to do 50 x 50 at .01

-Scott David Daniels
 Scott.Daniels at Acm.Org

###################################

# myopt.py

__version__ = '0.4'
try:
    assert list(enumerate('ab')) == [(0, 'a'), (1, 'b')]
except NameError:
    def enumerate(iterable):
        lst = list(iterable)
        return zip(range(len(lst)), lst)


def bitcount(row):
    '''Return the number of on bits in the integer'''
    assert row >= 0
    result = 0
    while row:
       result += 1
       lsbit = row & -row
       row ^= lsbit
    return result

bytebits = [bitcount(n) for n in range(256)] # precalculate byte counts

def bitcount(row):  # replace previous bitcount
    '''Return the number of on bits in the integer (byte at a time)'''
    assert row >= 0
    result = bytebits[row & 255]
    while row >= 256:
        row >>= 8
        result += bytebits[row & 255]
    return result


def rowencode(vector):
    '''convert from a buncha numbers to a bit-representation'''
    bit = 1L
    result = 0
    for element in vector:
        if element:
            result |= bit
        bit <<= 1
    return result


class Answer(object):
    '''An answer represents a result calculation.'''
    __slots__ = 'rows', 'colmask', '_score'
    totalrequests = 0
    totalcalcs = 0

    def __init__(self, colmask, rows):
        '''The columns in colmask are the ones we keep.'''
        self.colmask = colmask
        self.rows = rows
        self._score = None

    def score(self):
        '''Calculate the score lazily'''
        self.__class__.totalrequests += 1
        if self._score is None:
            self.__class__.totalcalcs += 1
            self._score = bitcount(self.colmask) * len(self.rows)
        return self._score

    def __repr__(self):
        return '%s(%d:%s, %x):%s' % (self.__class__.__name__,
                                      len(self.rows), self.rows,
                                      self.colmask, self._score)


totalcalls = 0

def choose(rows, keepcols, N=0, keeprows=None, best=None):
    '''Choose rows and columns to keep.  Return an Answer for the choice'''
    global totalcalls
    totalcalls += 1
    if keeprows is None:
        keeprows = []
    try:
        while 0 == rows[N] & keepcols:
            keeprows.append(N)
            N += 1
    except IndexError:
        return Answer(keepcols, keeprows)

    # a difference: some kept columns in this row are non-0
    # must drop either those columns or this row

    # Calculate result if we keep this row (drop problem columns)
    newmask = keepcols & ~rows[N]
    if best and newmask == best.colmask:
        withrow = best # Already have a calculated with this mask.  use it.
    else:
        withrow = choose(rows, keepcols & ~rows[N], N+1, keeprows + [N], best)

    # Calculate result if we drop this row
    skiprow = choose(rows, keepcols, N+1, keeprows, best or withrow)

    # Choose the better answer (keep the row if everything is equal).
    if (withrow.colmask == skiprow.colmask
     or withrow.score() >= skiprow.score()):
        return withrow
    else:
        return skiprow


# The data used from the example
X = [ [1, 0, 0, 0, 1],
      [0, 0, 0, 0, 0],
      [0, 0, 0, 1, 0],
      [0, 0, 0, 0, 0],
      [0, 0, 1, 0, 0],
      [0, 0, 1, 0, 0] ]


def printrep(row, symbols, mask=0):
    '''A row representing a single row (column-masked by mask)'''
    assert mask >= 0
    result = []
    for element in row:
        result.append(symbols[(1 & mask) * 2 + (element != 0)])
        mask >>= 1
    assert mask == 0  # mask doesn't extend beyond data.
    return ''.join(result)


def printanswer(data, rows, keepcols):
    '''Print the represented row'''
    toohuge = len(data)
    rowqueue = rows + [toohuge]
    rowqueue.reverse()
    nextrow = rowqueue.pop()
    for rownumber, row in enumerate(data):
        if nextrow > rownumber:
            # This row was zapped
            print '#', printrep(row, '01OI', keepcols)
        else:
            assert rownumber == nextrow # This row was kept
            nextrow = rowqueue.pop()
            print '_', printrep(row, '01~@', keepcols)
    assert nextrow == toohuge and not rowqueue


def getanswer(data):
    '''Calculate the best-cut for a particular matrix'''
    columns = max([len(row) for row in data])
    rowdata = [rowencode(row) for row in data]
    return choose(rowdata, (1L << columns) - 1)


def main(data=X):
    global totalcalls

    totalcalls = 0
    answer = getanswer(data)
    print 'Requested: %s, Calculated: %s,' % (
          Answer.totalrequests, Answer.totalcalcs),

    print 'answer: %r,' % answer,
    print 'Score: %s' % answer.score()
    print 'Total choose calls required: %s' % totalcalls

    printanswer(data, answer.rows, answer.colmask)



def rangen(rows, columns, probability=0.05):
    '''create a rows by columns data table with 1s at the given probability'''
    import random
    result = []
    for row in range(rows):
        result.append([int(random.random() < probability)
                       for column in range(columns)])
    return result


if __name__ == '__main__':
    import sys
    assert getanswer([[0]]).score() == 1
    assert getanswer([[0,1], [1,0]]).score() == 1
    assert getanswer([[0,1,0], [1,0,0]]).score() == 2
    if len(sys.argv) < 2:
        main(X)
    else:
        args = sys.argv[1:]
        if '.' in args[-1]:
            assert len(args) > 1
            probability = float(args.pop())
        else:
            probability = .2

        rows = int(args[0])
        if len(args) == 1:
            cols = rows
        else:
            assert len(args) == 2
            cols = int(args[1])
        main(rangen(rows, cols, probability))






More information about the Python-list mailing list