[Python-checkins] python/nondist/sandbox/sio sio.py,NONE,1.1 test_sio.py,NONE,1.1

gvanrossum@users.sourceforge.net gvanrossum@users.sourceforge.net
Wed, 09 Apr 2003 14:17:44 -0700


Update of /cvsroot/python/python/nondist/sandbox/sio
In directory sc8-pr-cvs1:/tmp/cvs-serv14889

Added Files:
	sio.py test_sio.py 
Log Message:
Experimental new standard I/O library.

--- NEW FILE: sio.py ---
"""New standard I/O library.

This code is still very young and experimental!

There are fairly complete unit tests in test_sio.py.

The design is simple:

- A raw stream supports read(n), write(s), seek(offset, whence=0) and
  tell().  This is generally unbuffered.  Raw streams may support
  Unicode.

- A basis stream provides the raw stream API and builds on a much more
  low-level API, e.g. the os, mmap or socket modules.

- A filtering stream is raw stream built on top of another raw stream.
  There are filtering streams for universal newline translation and
  for unicode translation.

- A buffering stream supports the full classic Python I/O API:
  read(n=-1), readline(), readlines(sizehint=0), tell(), seek(offset,
  whence=0), write(s), writelines(lst), as well as __iter__() and
  next().  (There's also readall() but that's a synonym for read()
  without arguments.)  This is a superset of the raw stream API.  I
  haven't thought about fileno() and isatty() yet.  We really need
  only one buffering stream implementation, which is a filtering
  stream.

You typically take a basis stream, place zero or more filtering
streams on top of it, and then top it off with a buffering stream.

"""

import os
import mmap

class BufferingInputStream(object):

    """Standard buffering input stream.

    This is typically the top of the stack.
    """

    bigsize = 2**19 # Half a Meg
    bufsize = 2**13 # 8 K

    def __init__(self, base, bufsize=None):
        self.do_read = getattr(base, "read", None)
                       # function to fill buffer some more
        self.do_tell = getattr(base, "tell", None)
                       # None, or return a byte offset 
        self.do_seek = getattr(base, "seek", None)
                       # None, or seek to abyte offset
        if bufsize is None:     # Get default from the class
            bufsize = self.bufsize
        self.bufsize = bufsize  # buffer size (hint only)
        self.lines = []         # ready-made lines (sans "\n")
        self.buf = ""           # raw data (may contain "\n")
        # Invariant: readahead == "\n".join(self.lines + [self.buf])
        # self.lines contains no "\n"
        # self.buf may contain "\n"

    def tell(self):
        bytes = self.do_tell()  # This may fail
        offset = len(self.buf)
        for line in self.lines:
            offset += len(line) + 1
        assert bytes >= offset, (locals(), self.__dict__)
        return bytes - offset

    def seek(self, offset, whence=0):
        # This may fail on the do_seek() or do_tell() call.
        # But it won't call either on a relative forward seek.
        # Nor on a seek to the very end.
        if whence == 0 or (whence == 2 and self.do_seek is not None):
            self.do_seek(offset, whence)
            self.lines = []
            self.buf = ""
            return
        if whence == 2:
            # Skip relative to EOF by reading and saving only just as
            # much as needed
            assert self.do_seek is None
            data = "\n".join(self.lines + [self.buf])
            total = len(data)
            buffers = [data]
            self.lines = []
            self.buf = ""
            while 1:
                data = self.do_read(self.bufsize)
                if not data:
                    break
                buffers.append(data)
                total += len(data)
                while buffers and total >= len(buffers[0]) - offset:
                    total -= len(buffers[0])
                    del buffers[0]
            cutoff = total + offset
            if cutoff < 0:
                raise TypeError, "cannot seek back"
            if buffers:
                buffers[0] = buffers[0][cutoff:]
            self.buf = "".join(buffers)
            self.lines = []
            return
        if whence == 1:
            if offset < 0:
                self.do_seek(self.tell() + offset, 0)
                self.lines = []
                self.buf = ""
                return
            while self.lines:
                line = self.lines[0]
                if offset <= len(line):
                    self.lines[0] = line[offset:]
                    return
                offset -= len(self.lines[0]) - 1
                del self.lines[0]
            assert not self.lines
            if offset <= len(self.buf):
                self.buf = self.buf[offset:]
                return
            offset -= len(self.buf)
            self.buf = ""
            if self.do_seek is None:
                self.read(offset)
            else:
                self.do_seek(offset, 1)
            return
        raise ValueError, "whence should be 0, 1 or 2"

    def readall(self):
        self.lines.append(self.buf)
        more = ["\n".join(self.lines)]
        self.lines = []
        self.buf = ""
        bufsize = self.bufsize
        while 1:
            data = self.do_read(bufsize)
            if not data:
                break
            more.append(data)
            bufsize = max(bufsize*2, self.bigsize)
        return "".join(more)

    def read(self, n=-1):
        if n < 0:
            return self.readall()

        if self.lines:
            # See if this can be satisfied from self.lines[0]
            line = self.lines[0]
            if len(line) >= n:
                self.lines[0] = line[n:]
                return line[:n]

            # See if this can be satisfied *without exhausting* self.lines
            k = 0
            i = 0
            for line in self.lines:
                k += len(line)
                if k >= n:
                    lines = self.lines[:i]
                    data = self.lines[i]
                    cutoff = len(data) - (k-n)
                    lines.append(data[:cutoff])
                    self.lines[:i+1] = [data[cutoff:]]
                    return "\n".join(lines)
                k += 1
                i += 1

            # See if this can be satisfied from self.lines plus self.buf
            if k + len(self.buf) >= n:
                lines = self.lines
                self.lines = []
                cutoff = n - k
                lines.append(self.buf[:cutoff])
                self.buf = self.buf[cutoff:]
                return "\n".join(lines)

        else:
            # See if this can be satisfied from self.buf
            data = self.buf
            k = len(data)
            if k >= n:
                cutoff = len(data) - (k-n)
                self.buf = data[cutoff:]
                return data[:cutoff]

        lines = self.lines
        self.lines = []
        lines.append(self.buf)
        self.buf = ""
        data = "\n".join(lines)
        more = [data]
        k = len(data)
        while k < n:
            data = self.do_read(max(self.bufsize, n-k))
            k += len(data)
            more.append(data)
            if not data:
                break
        cutoff = len(data) - (k-n)
        self.buf = data[cutoff:]
        more[-1] = data[:cutoff]
        return "".join(more)

    def __iter__(self):
        return self

    def next(self):
        if self.lines:
            return self.lines.pop(0) + "\n"

        # This block is needed because read() can leave self.buf
        # containing newlines
        self.lines = self.buf.split("\n")
        self.buf = self.lines.pop()
        if self.lines:
            return self.lines.pop(0) + "\n"

        buf = self.buf and [self.buf] or []
        while 1:
            self.buf = self.do_read(self.bufsize)
            self.lines = self.buf.split("\n")
            self.buf = self.lines.pop()
            if self.lines:
                buf.append(self.lines.pop(0))
                buf.append("\n")
                break
            if not self.buf:
                break
            buf.append(self.buf)

        line = "".join(buf)
        if not line:
            raise StopIteration
        return line

    def readline(self):
        try:
            return self.next()
        except StopIteration:
            return ""

    def readlines(self, sizehint=0):
        return list(self)

class CRLFFilter(object):

    """Filtering stream for universal newlines.

    TextInputFilter is more general, but this is faster when you don't
    need tell/seek.
    """

    def __init__(self, base):
        self.do_read = base.read
        self.atcr = False

    def read(self, n):
        data = self.do_read(n)
        if self.atcr:
            if data.startswith("\n"):
                data = data[1:] # Very rare case: in the middle of "\r\n"
            self.atcr = False
        if "\r" in data:
            self.atcr = data.endswith("\r")     # Test this before removing \r
            data = data.replace("\r\n", "\n")   # Catch \r\n this first
            data = data.replace("\r", "\n")     # Remaining \r are standalone
        return data

class MMapFile(object):

    """Standard I/O basis stream using mmap."""

    def __init__(self, filename, mode="r"):
        self.filename = filename
        self.mode = mode
        if mode == "r":
            flag = os.O_RDONLY
            self.access = mmap.ACCESS_READ
        else:
            if mode == "w":
                flag = os.O_RDWR | os.O_CREAT
            elif mode == "a":
                flag = os.O_RDWR
            else:
                raise ValueError, "mode should be 'r', 'w' or 'a'"
            self.access = mmap.ACCESS_WRITE
        if hasattr(os, "O_BINARY"):
            flag |= os.O_BINARY
        self.fd = os.open(filename, flag)
        self.mm = mmap.mmap(self.fd, 0, access=self.access)
        self.pos = 0

    def __del__(self):
        self.close()

    def close(self):
        if self.mm is not None:
            self.mm.close()
            self.mm = None
        if self.fd is not None:
            os.close(self.fd)
            self.fd = None

    def tell(self):
        return self.pos

    def seek(self, offset, whence=0):
        if whence == 0:
            self.pos = max(0, offset)
        elif whence == 1:
            self.pos = max(0, self.pos + offset)
        elif whence == 2:
            self.pos = max(0, self.mm.size() + offset)
        else:
            raise ValueError, "seek(): whence must be 0, 1 or 2"

    def readall(self):
        return self.read()

    def read(self, n=-1):
        if n >= 0:
            aim = self.pos + n
        else:
            aim = self.mm.size() # Actual file size, may be more than mapped
            n = aim - self.pos
        data = self.mm[self.pos:aim]
        if len(data) < n:
            del data
            # File grew since opened; remap to get the new data
            self.mm = mmap.mmap(self.fd, 0, access=self.access)
            data = self.mm[self.pos:aim]
        self.pos += len(data)
        return data

    def __iter__(self):
        return self

    def readline(self):
        hit = self.mm.find("\n", self.pos) + 1
        if hit:
            data = self.mm[self.pos:hit]
            self.pos = hit
            return data
        # Remap the file just in case
        self.mm = mmap.mmap(self.fd, 0, access=self.access)
        hit = self.mm.find("\n", self.pos) + 1
        if hit:
            # Got a whole line after remapping
            data = self.mm[self.pos:hit]
            self.pos = hit
            return data
        # Read whatever we've got -- may be empty
        data = self.mm[self.pos:self.mm.size()]
        self.pos += len(data)
        return data

    def next(self):
        hit = self.mm.find("\n", self.pos) + 1
        if hit:
            data = self.mm[self.pos:hit]
            self.pos = hit
            return data
        # Remap the file just in case
        self.mm = mmap.mmap(self.fd, 0, access=self.access)
        hit = self.mm.find("\n", self.pos) + 1
        if hit:
            # Got a whole line after remapping
            data = self.mm[self.pos:hit]
            self.pos = hit
            return data
        # Read whatever we've got -- may be empty
        data = self.mm[self.pos:self.mm.size()]
        if not data:
            raise StopIteration
        self.pos += len(data)
        return data

    def readlines(self, sizehint=0):
        return list(iter(self.readline, ""))

    def write(self, data):
        end = self.pos + len(data)
        try:
            self.mm[self.pos:end]  = data
        except IndexError:
            self.mm.resize(end)
            self.mm[self.pos:end]  = data
        self.pos = end

    def writelines(self, lines):
        filter(self.write, lines)

class DiskFile(object):

    """Standard I/O basis stream using os.open/close/read/write/lseek"""

    def __init__(self, filename, mode="r"):
        self.filename = filename
        self.mode = mode
        if mode == "r":
            flag = os.O_RDONLY
        elif mode == "w":
            flag = os.O_RDWR | os.O_CREAT
        elif mode == "a":
            flag = os.O_RDWR
        else:
            raise ValueError, "mode should be 'r', 'w' or 'a'"
        if hasattr(os, "O_BINARY"):
            flag |= os.O_BINARY
        self.fd = os.open(filename, flag)

    def seek(self, offset, whence=0):
        os.lseek(self.fd, offset, whence)

    def tell(self):
        return os.lseek(self.fd, 0, 1)

    def read(self, n):
        return os.read(self.fd, n)

    def write(self, data):
        while data:
            n = os.write(self.fd, data)
            data = data[n:]

    def close(self):
        fd = self.fd
        if fd is not None:
            self.fd = None
            os.close(fd)

    def __del__(self):
        try:
            self.close()
        except:
            pass

class TextInputFilter(object):

    """Filtering input stream for universal newline translation."""

    def __init__(self, base):
        self.base = base   # must implement read, may implement tell, seek
        self.atcr = False  # Set when last char read was \r
        self.buf = ""      # Optional one-character read-ahead buffer

    def read(self, n):
        """Read up to n bytes."""
        if n <= 0:
            return ""
        if self.buf:
            assert not self.atcr
            data = self.buf
            self.buf = ""
            return data
        data = self.base.read(n)
        if self.atcr:
            if data.startswith("\n"):
                data = data[1:]
                if not data:
                    data = self.base.read(n)
            self.atcr = False
        if "\r" in data:
            self.atcr = data.endswith("\r")
            data = data.replace("\r\n", "\n").replace("\r", "\n")
        return data

    def seek(self, offset, whence=0):
        self.base.seek(offset, whence)
        self.atcr = False
        self.buf = ""

    def tell(self):
        pos = self.base.tell()
        if self.atcr:
            # Must read the next byte to see if it's \n,
            # because then we must report the next position.
            assert not self.buf 
            self.buf = self.base.read(1)
            pos += 1
            self.atcr = False
            if self.buf == "\n":
                self.buf = ""
        return pos - len(self.buf)

class TextOutputFilter(object):

    """Filtering output stream for universal newline translation."""

    def __init__(self, base, linesep=os.linesep):
        assert linesep in ["\n", "\r\n", "\r"]
        self.base = base    # must implement write, may implement seek, tell
        self.linesep = linesep

    def write(self, data):
        if self.linesep is not "\n" and "\n" in data:
            data = data.replace("\n", self.linesep)
        self.base.write(data)

    def seek(self, offset, whence=0):
        self.base.seek(offset, whence)

    def tell(self):
        return self.base.tell()

class DecodingInputFilter(object):

    """Filtering input stream that decodes an encoded file."""

    def __init__(self, base, encoding="utf8", errors="strict"):
        self.base = base
        self.encoding = encoding
        self.errors = errors
        self.tell = base.tell
        self.seek = base.seek

    def read(self, n):
        """Read *approximately* n bytes, then decode them.

        Under extreme circumstances,
        the return length could be longer than n!

        Always return a unicode string.

        This does *not* translate newlines;
        you can stack TextInputFilter.
        """
        data = self.base.read(n)
        try:
            return data.decode(self.encoding, self.errors)
        except ValueError:
            # XXX Sigh.  decode() doesn't handle incomplete strings well.
            # Use the retry strategy from codecs.StreamReader.
            for i in range(9):
                more = self.base.read(1)
                if not more:
                    raise
                data += more
                try:
                    return data.decode(self.encoding, self.errors)
                except ValueError:
                    pass
            raise

class EncodingOutputFilter(object):

    """Filtering output stream that writes to an encoded file."""

    def __init__(self, base, encoding="utf8", errors="strict"):
        self.base = base
        self.encoding = encoding
        self.errors = errors
        self.tell = base.tell
        self.seek = base.seek

    def write(self, chars):
        if isinstance(chars, str):
            chars = unicode(chars) # Fail if it's not ASCII
        self.base.write(chars.encode(self.encoding, self.errors))

--- NEW FILE: test_sio.py ---
"""Unit tests for sio (new standard I/O)."""

import time
import tempfile
import unittest
from test import test_support

import sio

class TestSource(object):

    def __init__(self, packets):
        for x in packets:
            assert x
        self.orig_packets = list(packets)
        self.packets = list(packets)
        self.pos = 0

    def tell(self):
        return self.pos

    def seek(self, offset, whence=0):
        if whence == 1:
            offset += self.pos
        elif whence == 2:
            for packet in self.orig_packets:
                offset += len(packet)
        else:
            assert whence == 0
        self.packets = list(self.orig_packets)
        self.pos = 0
        while self.pos < offset:
            data = self.read(offset - self.pos)
            if not data:
                break
        assert self.pos == offset

    def read(self, n):
        try:
            data = self.packets.pop(0)
        except IndexError:
            return ""
        if len(data) > n:
            data, rest = data[:n], data[n:]
            self.packets.insert(0, rest)
        self.pos += len(data)
        return data

class TestReader(object):

    def __init__(self, packets):
        for x in packets:
            assert x
        self.orig_packets = list(packets)
        self.packets = list(packets)
        self.pos = 0

    def tell(self):
        return self.pos

    def seek(self, offset, whence=0):
        if whence == 1:
            offset += self.pos
        elif whence == 2:
            for packet in self.orig_packets:
                offset += len(packet)
        else:
            assert whence == 0
        self.packets = list(self.orig_packets)
        self.pos = 0
        while self.pos < offset:
            data = self.read(offset - self.pos)
            if not data:
                break
        assert self.pos == offset

    def read(self, n):
        try:
            data = self.packets.pop(0)
        except IndexError:
            return ""
        if len(data) > n:
            data, rest = data[:n], data[n:]
            self.packets.insert(0, rest)
        self.pos += len(data)
        return data

class TestWriter(object):

    def __init__(self):
        self.buf = ""
        self.pos = 0

    def write(self, data):
        if self.pos >= len(self.buf):
            self.buf += "\0" * (self.pos - len(self.buf)) + data
            self.pos = len(self.buf)
        else:
            self.buf = (self.buf[:self.pos] + data +
                        self.buf[self.pos + len(data):])
            self.pos += len(data)

    def tell(self):
        return self.pos

    def seek(self, offset, whence=0):
        if whence == 0:
            pass
        elif whence == 1:
            offset += self.pos
        elif whence == 2:
            offset += len(self.buf)
        else:
            raise ValueError, "whence should be 0, 1 or 2"
        if offset < 0:
            offset = 0
        self.pos = offset

class BufferingInputStreamTests(unittest.TestCase):

    packets = ["a", "b", "\n", "def", "\nxy\npq\nuv", "wx"]
    lines = ["ab\n", "def\n", "xy\n", "pq\n", "uvwx"]

    def setUp(self):
        pass

    def makeStream(self, tell=False, seek=False, bufsize=None):
        base = TestSource(self.packets)
        if not tell:
            base.tell = None
        if not seek:
            base.seek = None
        return sio.BufferingInputStream(base, bufsize)

    def test_readline(self):
        file = self.makeStream()
        self.assertEqual(list(iter(file.readline, "")), self.lines)

    def test_readlines(self):
        # This also tests next() and __iter__()
        file = self.makeStream()
        self.assertEqual(file.readlines(), self.lines)

    def test_readlines_small_bufsize(self):
        file = self.makeStream(bufsize=1)
        self.assertEqual(list(file), self.lines)

    def test_readall(self):
        file = self.makeStream()
        self.assertEqual(file.readall(), "".join(self.lines))

    def test_readall_small_bufsize(self):
        file = self.makeStream(bufsize=1)
        self.assertEqual(file.readall(), "".join(self.lines))

    def test_readall_after_readline(self):
        file = self.makeStream()
        self.assertEqual(file.readline(), self.lines[0])
        self.assertEqual(file.readline(), self.lines[1])
        self.assertEqual(file.readall(), "".join(self.lines[2:]))

    def test_read_1_after_readline(self):
        file = self.makeStream()
        self.assertEqual(file.readline(), "ab\n")
        self.assertEqual(file.readline(), "def\n")
        blocks = []
        while 1:
            block = file.read(1)
            if not block:
                break
            blocks.append(block)
            self.assertEqual(file.read(0), "")
        self.assertEqual(blocks, list("".join(self.lines)[7:]))

    def test_read_1(self):
        file = self.makeStream()
        blocks = []
        while 1:
            block = file.read(1)
            if not block:
                break
            blocks.append(block)
            self.assertEqual(file.read(0), "")
        self.assertEqual(blocks, list("".join(self.lines)))

    def test_read_2(self):
        file = self.makeStream()
        blocks = []
        while 1:
            block = file.read(2)
            if not block:
                break
            blocks.append(block)
            self.assertEqual(file.read(0), "")
        self.assertEqual(blocks, ["ab", "\nd", "ef", "\nx", "y\n", "pq",
                                  "\nu", "vw", "x"])

    def test_read_4(self):
        file = self.makeStream()
        blocks = []
        while 1:
            block = file.read(4)
            if not block:
                break
            blocks.append(block)
            self.assertEqual(file.read(0), "")
        self.assertEqual(blocks, ["ab\nd", "ef\nx", "y\npq", "\nuvw", "x"])
        
    def test_read_4_after_readline(self):
        file = self.makeStream()
        self.assertEqual(file.readline(), "ab\n")
        self.assertEqual(file.readline(), "def\n")
        blocks = [file.read(4)]
        while 1:
            block = file.read(4)
            if not block:
                break
            blocks.append(block)
            self.assertEqual(file.read(0), "")
        self.assertEqual(blocks, ["xy\np", "q\nuv", "wx"])

    def test_read_4_small_bufsize(self):
        file = self.makeStream(bufsize=1)
        blocks = []
        while 1:
            block = file.read(4)
            if not block:
                break
            blocks.append(block)
        self.assertEqual(blocks, ["ab\nd", "ef\nx", "y\npq", "\nuvw", "x"])

    def test_tell_1(self):
        file = self.makeStream(tell=True)
        pos = 0
        while 1:
            self.assertEqual(file.tell(), pos)
            n = len(file.read(1))
            if not n:
                break
            pos += n

    def test_tell_1_after_readline(self):
        file = self.makeStream(tell=True)
        pos = 0
        pos += len(file.readline())
        self.assertEqual(file.tell(), pos)
        pos += len(file.readline())
        self.assertEqual(file.tell(), pos)
        while 1:
            self.assertEqual(file.tell(), pos)
            n = len(file.read(1))
            if not n:
                break
            pos += n

    def test_tell_2(self):
        file = self.makeStream(tell=True)
        pos = 0
        while 1:
            self.assertEqual(file.tell(), pos)
            n = len(file.read(2))
            if not n:
                break
            pos += n

    def test_tell_4(self):
        file = self.makeStream(tell=True)
        pos = 0
        while 1:
            self.assertEqual(file.tell(), pos)
            n = len(file.read(4))
            if not n:
                break
            pos += n

    def test_tell_readline(self):
        file = self.makeStream(tell=True)
        pos = 0
        while 1:
            self.assertEqual(file.tell(), pos)
            n = len(file.readline())
            if not n:
                break
            pos += n

    def test_seek(self):
        file = self.makeStream(tell=True, seek=True)
        all = file.readall()
        end = len(all)
        for readto in range(0, end+1):
            for seekto in range(0, end+1):
                for whence in 0, 1, 2:
                    file.seek(0)
                    self.assertEqual(file.tell(), 0)
                    head = file.read(readto)
                    self.assertEqual(head, all[:readto])
                    if whence == 1:
                        offset = seekto - readto
                    elif whence == 2:
                        offset = seekto - end
                    else:
                        offset = seekto
                    file.seek(offset, whence)
                    here = file.tell()
                    self.assertEqual(here, seekto)
                    rest = file.readall()
                    self.assertEqual(rest, all[seekto:])

    def test_seek_noseek(self):
        file = self.makeStream()
        all = file.readall()
        end = len(all)
        for readto in range(0, end+1):
            for seekto in range(readto, end+1):
                for whence in 1, 2:
                    file = self.makeStream()
                    head = file.read(readto)
                    self.assertEqual(head, all[:readto])
                    if whence == 1:
                        offset = seekto - readto
                    elif whence == 2:
                        offset = seekto - end
                    file.seek(offset, whence)
                    rest = file.readall()
                    self.assertEqual(rest, all[seekto:])

class CRLFFilterTests(unittest.TestCase):

    def test_filter(self):
        packets = ["abc\ndef\rghi\r\nxyz\r", "123\r", "\n456"]
        expected = ["abc\ndef\nghi\nxyz\n", "123\n", "456"]
        crlf = sio.CRLFFilter(TestSource(packets))
        blocks = []
        while 1:
            block = crlf.read(100)
            if not block:
                break
            blocks.append(block)
        self.assertEqual(blocks, expected)

class MMapFileTests(BufferingInputStreamTests):

    tfn = None

    def tearDown(self):
        tfn = self.tfn
        if tfn:
            self.tfn = None
            try:
                os.remove(self.tfn)
            except os.error, msg:
                print "can't remove %s: %s" % (tfn, msg)

    def makeStream(self, tell=None, seek=None, bufsize=None, mode="r"):
        self.tfn = tempfile.mktemp()
        f = open(tfn, "wb")
        f.writelines(self.packets)
        f.close()
        return sio.MMapFile(self.tfn, mode)

    def test_write(self):
        file = self.makeStream(mode="w")
        file.write("BooHoo\n")
        file.write("Barf\n")
        file.writelines(["a\n", "b\n", "c\n"])
        self.assertEqual(file.tell(), len("BooHoo\nBarf\na\nb\nc\n"))
        file.seek(0)
        self.assertEqual(file.read(), "BooHoo\nBarf\na\nb\nc\n")
        file.seek(0)
        self.assertEqual(file.readlines(),
                         ["BooHoo\n", "Barf\n", "a\n", "b\n", "c\n"])
        self.assertEqual(file.tell(), len("BooHoo\nBarf\na\nb\nc\n"))

class TextInputFilterTests(unittest.TestCase):

    packets = [
        "foo\r",
        "bar\r",
        "\nfoo\r\n",
        "abc\ndef\rghi\r\nxyz",
        "\nuvw\npqr\r",
        "\n",
        "abc\n",
        ]
    expected = [
        ("foo\n", 4),
        ("bar\n", 9),
        ("foo\n", 14),
        ("abc\ndef\nghi\nxyz", 30),
        ("\nuvw\npqr\n", 40),
        ("abc\n", 44),
        ("", 44),
        ("", 44),
        ]
    expected_with_tell = [
        ("foo\n", 4),
        ("b", 5),
        ("ar\n", 9),
        ("foo\n", 14),
        ("abc\ndef\nghi\nxyz", 30),
        ("\nuvw\npqr\n", 40),
        ("abc\n", 44),
        ("", 44),
        ("", 44),
        ]

    def test_read(self):
        base = TestReader(self.packets)
        filter = sio.TextInputFilter(base)
        for data, pos in self.expected:
            self.assertEqual(filter.read(100), data)

    def test_read_tell(self):
        base = TestReader(self.packets)
        filter = sio.TextInputFilter(base)
        for data, pos in self.expected_with_tell:
            self.assertEqual(filter.read(100), data)
            self.assertEqual(filter.tell(), pos)
            self.assertEqual(filter.tell(), pos) # Repeat the tell() !

    def test_seek(self):
        base = TestReader(self.packets)
        filter = sio.TextInputFilter(base)
        sofar = ""
        pairs = []
        while True:
            pairs.append((sofar, filter.tell()))
            c = filter.read(1)
            if not c:
                break
            self.assertEqual(len(c), 1)
            sofar += c
        all = sofar
        for i in range(len(pairs)):
            sofar, pos = pairs[i]
            filter.seek(pos)
            self.assertEqual(filter.tell(), pos)
            self.assertEqual(filter.tell(), pos)
            bufs = [sofar]
            while True:
                data = filter.read(100)
                if not data:
                    self.assertEqual(filter.read(100), "")
                    break
                bufs.append(data)
            self.assertEqual("".join(bufs), all)

class TextOutputFilterTests(unittest.TestCase):

    def test_write_nl(self):
        base = TestWriter()
        filter = sio.TextOutputFilter(base, linesep="\n")
        filter.write("abc")
        filter.write("def\npqr\nuvw")
        filter.write("\n123\n")
        self.assertEqual(base.buf, "abcdef\npqr\nuvw\n123\n")

    def test_write_cr(self):
        base = TestWriter()
        filter = sio.TextOutputFilter(base, linesep="\r")
        filter.write("abc")
        filter.write("def\npqr\nuvw")
        filter.write("\n123\n")
        self.assertEqual(base.buf, "abcdef\rpqr\ruvw\r123\r")

    def test_write_crnl(self):
        base = TestWriter()
        filter = sio.TextOutputFilter(base, linesep="\r\n")
        filter.write("abc")
        filter.write("def\npqr\nuvw")
        filter.write("\n123\n")
        self.assertEqual(base.buf, "abcdef\r\npqr\r\nuvw\r\n123\r\n")

    def test_write_tell_nl(self):
        base = TestWriter()
        filter = sio.TextOutputFilter(base, linesep="\n")
        filter.write("xxx")
        self.assertEqual(filter.tell(), 3)
        filter.write("\nabc\n")
        self.assertEqual(filter.tell(), 8)

    def test_write_tell_cr(self):
        base = TestWriter()
        filter = sio.TextOutputFilter(base, linesep="\r")
        filter.write("xxx")
        self.assertEqual(filter.tell(), 3)
        filter.write("\nabc\n")
        self.assertEqual(filter.tell(), 8)

    def test_write_tell_crnl(self):
        base = TestWriter()
        filter = sio.TextOutputFilter(base, linesep="\r\n")
        filter.write("xxx")
        self.assertEqual(filter.tell(), 3)
        filter.write("\nabc\n")
        self.assertEqual(filter.tell(), 10)

    def test_write_seek(self):
        base = TestWriter()
        filter = sio.TextOutputFilter(base, linesep="\n")
        filter.write("x"*100)
        filter.seek(50)
        filter.write("y"*10)
        self.assertEqual(base.buf, "x"*50 + "y"*10 + "x"*40)

class DecodingInputFilterTests(unittest.TestCase):

    def test_read(self):
        chars = u"abc\xff\u1234\u4321\x80xyz"
        data = chars.encode("utf8")
        base = TestReader([data])
        filter = sio.DecodingInputFilter(base)
        bufs = []
        for n in range(1, 11):
            while 1:
                c = filter.read(n)
                self.assertEqual(type(c), unicode)
                if not c:
                    break
                bufs.append(c)
            self.assertEqual(u"".join(bufs), chars)

class EncodingOutputFilterTests(unittest.TestCase):

    def test_write(self):
        chars = u"abc\xff\u1234\u4321\x80xyz"
        data = chars.encode("utf8")
        for n in range(1, 11):
            base = TestWriter()
            filter = sio.EncodingOutputFilter(base)
            pos = 0
            while 1:
                c = chars[pos:pos+n]
                if not c:
                    break
                pos += len(c)
                filter.write(c)
            self.assertEqual(base.buf, data)

# Speed test

FN = "BIG"

def timeit(fn=FN):
    f = sio.MMapFile(fn, "r")
    lines = bytes = 0
    t0 = time.clock()
    for line in f:
        lines += 1
        bytes += len(line)
    t1 = time.clock()
    print "%d lines (%d bytes) in %.3f seconds" % (lines, bytes, t1-t0)

def timeold(fn=FN):
    f = open(fn, "rb")
    lines = bytes = 0
    t0 = time.clock()
    for line in f:
        lines += 1
        bytes += len(line)
    t1 = time.clock()
    print "%d lines (%d bytes) in %.3f seconds" % (lines, bytes, t1-t0)

# Functional test

def main():
    f = sio.DiskFile("sio.py")
    f = sio.DecodingInputFilter(f)
    f = sio.TextInputFilter(f)
    f = sio.BufferingInputStream(f)
    for i in range(10):
        print repr(f.readline())

# Unit test main program

def test_main():
    suite = unittest.TestSuite()
    suite.addTest(unittest.makeSuite(BufferingInputStreamTests))
    suite.addTest(unittest.makeSuite(CRLFFilterTests))
    ##suite.addTest(unittest.makeSuite(MMapFileTests))
    suite.addTest(unittest.makeSuite(TextInputFilterTests))
    suite.addTest(unittest.makeSuite(TextOutputFilterTests))
    suite.addTest(unittest.makeSuite(DecodingInputFilterTests))
    suite.addTest(unittest.makeSuite(EncodingOutputFilterTests))
    test_support.run_suite(suite)

if __name__ == "__main__":
    test_main()