[pypy-commit] pypy win32-stdlib: rework zipfile and test_zipfile to respect open files

mattip noreply at buildbot.pypy.org
Thu Apr 12 16:35:50 CEST 2012


Author: Matti Picus <matti.picus at gmail.com>
Branch: win32-stdlib
Changeset: r54314:ba500a555707
Date: 2012-04-12 17:35 +0300
http://bitbucket.org/pypy/pypy/changeset/ba500a555707/

Log:	rework zipfile and test_zipfile to respect open files

diff --git a/lib-python/2.7/test/test_zipfile.py b/lib-python/modified-2.7/test/test_zipfile.py
copy from lib-python/2.7/test/test_zipfile.py
copy to lib-python/modified-2.7/test/test_zipfile.py
--- a/lib-python/2.7/test/test_zipfile.py
+++ b/lib-python/modified-2.7/test/test_zipfile.py
@@ -234,8 +234,9 @@
 
         # Read the ZIP archive
         with zipfile.ZipFile(f, "r") as zipfp:
-            for line, zipline in zip(self.line_gen, zipfp.open(TESTFN)):
-                self.assertEqual(zipline, line + '\n')
+            with zipfp.open(TESTFN) as f:
+                for line, zipline in zip(self.line_gen, f):
+                    self.assertEqual(zipline, line + '\n')
 
     def test_readline_read_stored(self):
         # Issue #7610: calls to readline() interleaved with calls to read().
@@ -340,7 +341,8 @@
         produces the expected result."""
         with zipfile.ZipFile(TESTFN2, "w") as zipfp:
             zipfp.write(TESTFN)
-            self.assertEqual(zipfp.read(TESTFN), open(TESTFN).read())
+            with open(TESTFN) as f:
+                self.assertEqual(zipfp.read(TESTFN), f.read())
 
     @skipUnless(zlib, "requires zlib")
     def test_per_file_compression(self):
@@ -382,7 +384,8 @@
                 self.assertEqual(writtenfile, correctfile)
 
                 # make sure correct data is in correct file
-                self.assertEqual(fdata, open(writtenfile, "rb").read())
+                with open(writtenfile, "rb") as fid:
+                    self.assertEqual(fdata, fid.read())
                 os.remove(writtenfile)
 
         # remove the test file subdirectories
@@ -401,24 +404,25 @@
                 else:
                     outfile = os.path.join(os.getcwd(), fpath)
 
-                self.assertEqual(fdata, open(outfile, "rb").read())
+                with  open(outfile, "rb") as fid:
+                    self.assertEqual(fdata, fid.read())
                 os.remove(outfile)
 
         # remove the test file subdirectories
         shutil.rmtree(os.path.join(os.getcwd(), 'ziptest2dir'))
 
     def test_writestr_compression(self):
-        zipfp = zipfile.ZipFile(TESTFN2, "w")
-        zipfp.writestr("a.txt", "hello world", compress_type=zipfile.ZIP_STORED)
-        if zlib:
-            zipfp.writestr("b.txt", "hello world", compress_type=zipfile.ZIP_DEFLATED)
+        with zipfile.ZipFile(TESTFN2, "w") as zipfp:
+            zipfp.writestr("a.txt", "hello world", compress_type=zipfile.ZIP_STORED)
+            if zlib:
+                zipfp.writestr("b.txt", "hello world", compress_type=zipfile.ZIP_DEFLATED)
 
-        info = zipfp.getinfo('a.txt')
-        self.assertEqual(info.compress_type, zipfile.ZIP_STORED)
+            info = zipfp.getinfo('a.txt')
+            self.assertEqual(info.compress_type, zipfile.ZIP_STORED)
 
-        if zlib:
-            info = zipfp.getinfo('b.txt')
-            self.assertEqual(info.compress_type, zipfile.ZIP_DEFLATED)
+            if zlib:
+                info = zipfp.getinfo('b.txt')
+                self.assertEqual(info.compress_type, zipfile.ZIP_DEFLATED)
 
 
     def zip_test_writestr_permissions(self, f, compression):
@@ -646,7 +650,8 @@
 
     def test_write_non_pyfile(self):
         with zipfile.PyZipFile(TemporaryFile(), "w") as zipfp:
-            open(TESTFN, 'w').write('most definitely not a python file')
+            with open(TESTFN, 'w') as f:
+                f.write('most definitely not a python file')
             self.assertRaises(RuntimeError, zipfp.writepy, TESTFN)
             os.remove(TESTFN)
 
@@ -795,7 +800,8 @@
         self.assertRaises(RuntimeError, zipf.open, "foo.txt")
         self.assertRaises(RuntimeError, zipf.testzip)
         self.assertRaises(RuntimeError, zipf.writestr, "bogus.txt", "bogus")
-        open(TESTFN, 'w').write('zipfile test data')
+        with open(TESTFN, 'w') as fp:
+            fp.write('zipfile test data')
         self.assertRaises(RuntimeError, zipf.write, TESTFN)
 
     def test_bad_constructor_mode(self):
@@ -803,7 +809,6 @@
         self.assertRaises(RuntimeError, zipfile.ZipFile, TESTFN, "q")
 
     def test_bad_open_mode(self):
-        """Check that bad modes passed to ZipFile.open are caught."""
         with zipfile.ZipFile(TESTFN, mode="w") as zipf:
             zipf.writestr("foo.txt", "O, for a Muse of Fire!")
 
@@ -851,7 +856,6 @@
 
     def test_comments(self):
         """Check that comments on the archive are handled properly."""
-
         # check default comment is empty
         with zipfile.ZipFile(TESTFN, mode="w") as zipf:
             self.assertEqual(zipf.comment, '')
@@ -953,14 +957,16 @@
         with zipfile.ZipFile(TESTFN, mode="w") as zipf:
             pass
         try:
-            zipf = zipfile.ZipFile(TESTFN, mode="r")
+            with zipfile.ZipFile(TESTFN, mode="r") as zipf:
+                pass
         except zipfile.BadZipfile:
             self.fail("Unable to create empty ZIP file in 'w' mode")
 
         with zipfile.ZipFile(TESTFN, mode="a") as zipf:
             pass
         try:
-            zipf = zipfile.ZipFile(TESTFN, mode="r")
+            with zipfile.ZipFile(TESTFN, mode="r") as zipf:
+                pass
         except:
             self.fail("Unable to create empty ZIP file in 'a' mode")
 
@@ -1160,6 +1166,8 @@
             data1 += zopen1.read(500)
             data2 += zopen2.read(500)
             self.assertEqual(data1, data2)
+            zopen1.close()
+            zopen2.close()
 
     def test_different_file(self):
         # Verify that (when the ZipFile is in control of creating file objects)
@@ -1207,9 +1215,9 @@
 
     def test_store_dir(self):
         os.mkdir(os.path.join(TESTFN2, "x"))
-        zipf = zipfile.ZipFile(TESTFN, "w")
-        zipf.write(os.path.join(TESTFN2, "x"), "x")
-        self.assertTrue(zipf.filelist[0].filename.endswith("x/"))
+        with zipfile.ZipFile(TESTFN, "w") as zipf:
+            zipf.write(os.path.join(TESTFN2, "x"), "x")
+            self.assertTrue(zipf.filelist[0].filename.endswith("x/"))
 
     def tearDown(self):
         shutil.rmtree(TESTFN2)
@@ -1226,7 +1234,8 @@
         for n, s in enumerate(self.seps):
             self.arcdata[s] = s.join(self.line_gen) + s
             self.arcfiles[s] = '%s-%d' % (TESTFN, n)
-            open(self.arcfiles[s], "wb").write(self.arcdata[s])
+            with open(self.arcfiles[s], "wb") as f:
+                f.write(self.arcdata[s])
 
     def make_test_archive(self, f, compression):
         # Create the ZIP archive
@@ -1295,8 +1304,9 @@
         # Read the ZIP archive
         with zipfile.ZipFile(f, "r") as zipfp:
             for sep, fn in self.arcfiles.items():
-                for line, zipline in zip(self.line_gen, zipfp.open(fn, "rU")):
-                    self.assertEqual(zipline, line + '\n')
+                with zipfp.open(fn, "rU") as f:
+                    for line, zipline in zip(self.line_gen, f):
+                        self.assertEqual(zipline, line + '\n')
 
     def test_read_stored(self):
         for f in (TESTFN2, TemporaryFile(), StringIO()):
diff --git a/lib-python/2.7/zipfile.py b/lib-python/modified-2.7/zipfile.py
copy from lib-python/2.7/zipfile.py
copy to lib-python/modified-2.7/zipfile.py
--- a/lib-python/2.7/zipfile.py
+++ b/lib-python/modified-2.7/zipfile.py
@@ -648,6 +648,10 @@
         return data
 
 
+class ZipExtFileWithClose(ZipExtFile):
+    def close(self):
+        self._fileobj.close()
+
 
 class ZipFile:
     """ Class with methods to open, read, write, close, list zip files.
@@ -843,9 +847,9 @@
             try:
                 # Read by chunks, to avoid an OverflowError or a
                 # MemoryError with very large embedded files.
-                f = self.open(zinfo.filename, "r")
-                while f.read(chunk_size):     # Check CRC-32
-                    pass
+                with self.open(zinfo.filename, "r") as f:
+                    while f.read(chunk_size):     # Check CRC-32
+                        pass
             except BadZipfile:
                 return zinfo.filename
 
@@ -864,7 +868,9 @@
 
     def read(self, name, pwd=None):
         """Return file bytes (as a string) for name."""
-        return self.open(name, "r", pwd).read()
+        with self.open(name, "r", pwd) as f:
+            retval = f.read()
+            return retval
 
     def open(self, name, mode="r", pwd=None):
         """Return file-like object for 'name'."""
@@ -881,59 +887,66 @@
         else:
             zef_file = open(self.filename, 'rb')
 
-        # Make sure we have an info object
-        if isinstance(name, ZipInfo):
-            # 'name' is already an info object
-            zinfo = name
+        try:
+            # Make sure we have an info object
+            if isinstance(name, ZipInfo):
+                # 'name' is already an info object
+                zinfo = name
+            else:
+                # Get info object for name
+                zinfo = self.getinfo(name)
+
+            zef_file.seek(zinfo.header_offset, 0)
+
+            # Skip the file header:
+            fheader = zef_file.read(sizeFileHeader)
+            if fheader[0:4] != stringFileHeader:
+                raise BadZipfile, "Bad magic number for file header"
+
+            fheader = struct.unpack(structFileHeader, fheader)
+            fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
+            if fheader[_FH_EXTRA_FIELD_LENGTH]:
+                zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
+
+            if fname != zinfo.orig_filename:
+                raise BadZipfile, \
+                          'File name in directory "%s" and header "%s" differ.' % (
+                              zinfo.orig_filename, fname)
+
+            # check for encrypted flag & handle password
+            is_encrypted = zinfo.flag_bits & 0x1
+            zd = None
+            if is_encrypted:
+                if not pwd:
+                    pwd = self.pwd
+                if not pwd:
+                    raise RuntimeError, "File %s is encrypted, " \
+                          "password required for extraction" % name
+
+                zd = _ZipDecrypter(pwd)
+                # The first 12 bytes in the cypher stream is an encryption header
+                #  used to strengthen the algorithm. The first 11 bytes are
+                #  completely random, while the 12th contains the MSB of the CRC,
+                #  or the MSB of the file time depending on the header type
+                #  and is used to check the correctness of the password.
+                bytes = zef_file.read(12)
+                h = map(zd, bytes[0:12])
+                if zinfo.flag_bits & 0x8:
+                    # compare against the file type from extended local headers
+                    check_byte = (zinfo._raw_time >> 8) & 0xff
+                else:
+                    # compare against the CRC otherwise
+                    check_byte = (zinfo.CRC >> 24) & 0xff
+                if ord(h[11]) != check_byte:
+                    raise RuntimeError("Bad password for file", name)
+        except:
+            if not self._filePassed:
+                zef_file.close()
+            raise    
+        if self._filePassed:
+            return  ZipExtFile(zef_file, mode, zinfo, zd)
         else:
-            # Get info object for name
-            zinfo = self.getinfo(name)
-
-        zef_file.seek(zinfo.header_offset, 0)
-
-        # Skip the file header:
-        fheader = zef_file.read(sizeFileHeader)
-        if fheader[0:4] != stringFileHeader:
-            raise BadZipfile, "Bad magic number for file header"
-
-        fheader = struct.unpack(structFileHeader, fheader)
-        fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
-        if fheader[_FH_EXTRA_FIELD_LENGTH]:
-            zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
-
-        if fname != zinfo.orig_filename:
-            raise BadZipfile, \
-                      'File name in directory "%s" and header "%s" differ.' % (
-                          zinfo.orig_filename, fname)
-
-        # check for encrypted flag & handle password
-        is_encrypted = zinfo.flag_bits & 0x1
-        zd = None
-        if is_encrypted:
-            if not pwd:
-                pwd = self.pwd
-            if not pwd:
-                raise RuntimeError, "File %s is encrypted, " \
-                      "password required for extraction" % name
-
-            zd = _ZipDecrypter(pwd)
-            # The first 12 bytes in the cypher stream is an encryption header
-            #  used to strengthen the algorithm. The first 11 bytes are
-            #  completely random, while the 12th contains the MSB of the CRC,
-            #  or the MSB of the file time depending on the header type
-            #  and is used to check the correctness of the password.
-            bytes = zef_file.read(12)
-            h = map(zd, bytes[0:12])
-            if zinfo.flag_bits & 0x8:
-                # compare against the file type from extended local headers
-                check_byte = (zinfo._raw_time >> 8) & 0xff
-            else:
-                # compare against the CRC otherwise
-                check_byte = (zinfo.CRC >> 24) & 0xff
-            if ord(h[11]) != check_byte:
-                raise RuntimeError("Bad password for file", name)
-
-        return  ZipExtFile(zef_file, mode, zinfo, zd)
+            return  ZipExtFileWithClose(zef_file, mode, zinfo, zd)
 
     def extract(self, member, path=None, pwd=None):
         """Extract a member from the archive to the current working directory,
@@ -989,7 +1002,6 @@
             if not os.path.isdir(targetpath):
                 os.mkdir(targetpath)
             return targetpath
-
         source = self.open(member, pwd=pwd)
         target = file(targetpath, "wb")
         shutil.copyfileobj(source, target)


More information about the pypy-commit mailing list