[Python-checkins] cpython (2.7): #7944: close files explicitly in test_tarfile (backport d560eece0857).

ezio.melotti python-checkins at python.org
Wed Jan 13 15:21:27 EST 2016


https://hg.python.org/cpython/rev/e3763d98c46e
changeset:   99883:e3763d98c46e
branch:      2.7
user:        Ezio Melotti <ezio.melotti at gmail.com>
date:        Wed Jan 13 22:21:21 2016 +0200
summary:
  #7944: close files explicitly in test_tarfile (backport d560eece0857).

files:
  Lib/test/test_tarfile.py |  633 +++++++++++++++-----------
  1 files changed, 360 insertions(+), 273 deletions(-)


diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -60,9 +60,10 @@
         self.tar.extract("ustar/regtype", TEMPDIR)
         tarinfo = self.tar.getmember("ustar/regtype")
         fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU")
+        with open(os.path.join(TEMPDIR, "ustar/regtype"), "rU") as fobj1:
+            lines1 = fobj1.readlines()
         fobj2 = self.tar.extractfile(tarinfo)
 
-        lines1 = fobj1.readlines()
         lines2 = fobj2.readlines()
         self.assertTrue(lines1 == lines2,
                 "fileobj.readlines() failed")
@@ -75,18 +76,17 @@
     def test_fileobj_iter(self):
         self.tar.extract("ustar/regtype", TEMPDIR)
         tarinfo = self.tar.getmember("ustar/regtype")
-        fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU")
+        with open(os.path.join(TEMPDIR, "ustar/regtype"), "rU") as fobj1:
+            lines1 = fobj1.readlines()
         fobj2 = self.tar.extractfile(tarinfo)
-        lines1 = fobj1.readlines()
         lines2 = [line for line in fobj2]
         self.assertTrue(lines1 == lines2,
                      "fileobj.__iter__() failed")
 
     def test_fileobj_seek(self):
         self.tar.extract("ustar/regtype", TEMPDIR)
-        fobj = open(os.path.join(TEMPDIR, "ustar/regtype"), "rb")
-        data = fobj.read()
-        fobj.close()
+        with open(os.path.join(TEMPDIR, "ustar/regtype"), "rb") as fobj:
+            data = fobj.read()
 
         tarinfo = self.tar.getmember("ustar/regtype")
         fobj = self.tar.extractfile(tarinfo)
@@ -238,19 +238,24 @@
         # This test checks if tarfile.open() is able to open an empty tar
         # archive successfully. Note that an empty tar archive is not the
         # same as an empty file!
-        tarfile.open(tmpname, self.mode.replace("r", "w")).close()
+        with tarfile.open(tmpname, self.mode.replace("r", "w")):
+            pass
         try:
             tar = tarfile.open(tmpname, self.mode)
             tar.getnames()
         except tarfile.ReadError:
             self.fail("tarfile.open() failed on empty archive")
-        self.assertListEqual(tar.getmembers(), [])
+        else:
+            self.assertListEqual(tar.getmembers(), [])
+        finally:
+            tar.close()
 
     def test_null_tarfile(self):
         # Test for issue6123: Allow opening empty archives.
         # This test guarantees that tarfile.open() does not treat an empty
         # file as an empty tar archive.
-        open(tmpname, "wb").close()
+        with open(tmpname, "wb"):
+            pass
         self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, self.mode)
         self.assertRaises(tarfile.ReadError, tarfile.open, tmpname)
 
@@ -274,15 +279,16 @@
         for char in ('\0', 'a'):
             # Test if EOFHeaderError ('\0') and InvalidHeaderError ('a')
             # are ignored correctly.
-            fobj = _open(tmpname, "wb")
-            fobj.write(char * 1024)
-            fobj.write(tarfile.TarInfo("foo").tobuf())
-            fobj.close()
+            with _open(tmpname, "wb") as fobj:
+                fobj.write(char * 1024)
+                fobj.write(tarfile.TarInfo("foo").tobuf())
 
             tar = tarfile.open(tmpname, mode="r", ignore_zeros=True)
-            self.assertListEqual(tar.getnames(), ["foo"],
+            try:
+                self.assertListEqual(tar.getnames(), ["foo"],
                     "ignore_zeros=True should have skipped the %r-blocks" % char)
-            tar.close()
+            finally:
+                tar.close()
 
     def test_premature_end_of_archive(self):
         for size in (512, 600, 1024, 1200):
@@ -313,19 +319,21 @@
     taropen = tarfile.TarFile.taropen
 
     def test_no_name_argument(self):
-        fobj = open(self.tarname, "rb")
-        tar = tarfile.open(fileobj=fobj, mode=self.mode)
-        self.assertEqual(tar.name, os.path.abspath(fobj.name))
+        with open(self.tarname, "rb") as fobj:
+            tar = tarfile.open(fileobj=fobj, mode=self.mode)
+            self.assertEqual(tar.name, os.path.abspath(fobj.name))
 
     def test_no_name_attribute(self):
-        data = open(self.tarname, "rb").read()
+        with open(self.tarname, "rb") as fobj:
+            data = fobj.read()
         fobj = StringIO.StringIO(data)
         self.assertRaises(AttributeError, getattr, fobj, "name")
         tar = tarfile.open(fileobj=fobj, mode=self.mode)
         self.assertEqual(tar.name, None)
 
     def test_empty_name_attribute(self):
-        data = open(self.tarname, "rb").read()
+        with open(self.tarname, "rb") as fobj:
+            data = fobj.read()
         fobj = StringIO.StringIO(data)
         fobj.name = ""
         tar = tarfile.open(fileobj=fobj, mode=self.mode)
@@ -346,12 +354,14 @@
         # Skip the first member and store values from the second member
         # of the testtar.
         tar = tarfile.open(self.tarname, mode=self.mode)
-        tar.next()
-        t = tar.next()
-        name = t.name
-        offset = t.offset
-        data = tar.extractfile(t).read()
-        tar.close()
+        try:
+            tar.next()
+            t = tar.next()
+            name = t.name
+            offset = t.offset
+            data = tar.extractfile(t).read()
+        finally:
+            tar.close()
 
         # Open the testtar and seek to the offset of the second member.
         if self.mode.endswith(":gz"):
@@ -361,26 +371,30 @@
         else:
             _open = open
         fobj = _open(self.tarname, "rb")
-        fobj.seek(offset)
+        try:
+            fobj.seek(offset)
 
-        # Test if the tarfile starts with the second member.
-        tar = tar.open(self.tarname, mode="r:", fileobj=fobj)
-        t = tar.next()
-        self.assertEqual(t.name, name)
-        # Read to the end of fileobj and test if seeking back to the
-        # beginning works.
-        tar.getmembers()
-        self.assertEqual(tar.extractfile(t).read(), data,
-                "seek back did not work")
-        tar.close()
+            # Test if the tarfile starts with the second member.
+            tar = tar.open(self.tarname, mode="r:", fileobj=fobj)
+            t = tar.next()
+            self.assertEqual(t.name, name)
+            # Read to the end of fileobj and test if seeking back to the
+            # beginning works.
+            tar.getmembers()
+            self.assertEqual(tar.extractfile(t).read(), data,
+                    "seek back did not work")
+            tar.close()
+        finally:
+            fobj.close()
 
     def test_fail_comp(self):
         # For Gzip and Bz2 Tests: fail with a ReadError on an uncompressed file.
         if self.mode == "r:":
             self.skipTest('needs a gz or bz2 mode')
         self.assertRaises(tarfile.ReadError, tarfile.open, tarname, self.mode)
-        fobj = open(tarname, "rb")
-        self.assertRaises(tarfile.ReadError, tarfile.open, fileobj=fobj, mode=self.mode)
+        with open(tarname, "rb") as fobj:
+            self.assertRaises(tarfile.ReadError, tarfile.open,
+                              fileobj=fobj, mode=self.mode)
 
     def test_v7_dirtype(self):
         # Test old style dirtype member (bug #1336623):
@@ -434,22 +448,25 @@
         # Test if extractall() correctly restores directory permissions
         # and times (see issue1735).
         tar = tarfile.open(tarname, encoding="iso8859-1")
-        directories = [t for t in tar if t.isdir()]
-        tar.extractall(TEMPDIR, directories)
-        for tarinfo in directories:
-            path = os.path.join(TEMPDIR, tarinfo.name)
-            if sys.platform != "win32":
-                # Win32 has no support for fine grained permissions.
-                self.assertEqual(tarinfo.mode & 0777, os.stat(path).st_mode & 0777)
-            self.assertEqual(tarinfo.mtime, os.path.getmtime(path))
-        tar.close()
+        try:
+            directories = [t for t in tar if t.isdir()]
+            tar.extractall(TEMPDIR, directories)
+            for tarinfo in directories:
+                path = os.path.join(TEMPDIR, tarinfo.name)
+                if sys.platform != "win32":
+                    # Win32 has no support for fine grained permissions.
+                    self.assertEqual(tarinfo.mode & 0777, os.stat(path).st_mode & 0777)
+                self.assertEqual(tarinfo.mtime, os.path.getmtime(path))
+        finally:
+            tar.close()
 
     def test_init_close_fobj(self):
         # Issue #7341: Close the internal file object in the TarFile
         # constructor in case of an error. For the test we rely on
         # the fact that opening an empty file raises a ReadError.
         empty = os.path.join(TEMPDIR, "empty")
-        open(empty, "wb").write("")
+        with open(empty, "wb") as fobj:
+            fobj.write("")
 
         try:
             tar = object.__new__(tarfile.TarFile)
@@ -460,7 +477,7 @@
             else:
                 self.fail("ReadError not raised")
         finally:
-            os.remove(empty)
+            support.unlink(empty)
 
     def test_parallel_iteration(self):
         # Issue #16601: Restarting iteration over tarfile continued
@@ -489,42 +506,47 @@
 
     def test_compare_members(self):
         tar1 = tarfile.open(tarname, encoding="iso8859-1")
-        tar2 = self.tar
+        try:
+            tar2 = self.tar
 
-        while True:
-            t1 = tar1.next()
-            t2 = tar2.next()
-            if t1 is None:
-                break
-            self.assertTrue(t2 is not None, "stream.next() failed.")
+            while True:
+                t1 = tar1.next()
+                t2 = tar2.next()
+                if t1 is None:
+                    break
+                self.assertTrue(t2 is not None, "stream.next() failed.")
 
-            if t2.islnk() or t2.issym():
-                self.assertRaises(tarfile.StreamError, tar2.extractfile, t2)
-                continue
+                if t2.islnk() or t2.issym():
+                    self.assertRaises(tarfile.StreamError, tar2.extractfile, t2)
+                    continue
 
-            v1 = tar1.extractfile(t1)
-            v2 = tar2.extractfile(t2)
-            if v1 is None:
-                continue
-            self.assertTrue(v2 is not None, "stream.extractfile() failed")
-            self.assertTrue(v1.read() == v2.read(), "stream extraction failed")
-
-        tar1.close()
+                v1 = tar1.extractfile(t1)
+                v2 = tar2.extractfile(t2)
+                if v1 is None:
+                    continue
+                self.assertTrue(v2 is not None, "stream.extractfile() failed")
+                self.assertTrue(v1.read() == v2.read(), "stream extraction failed")
+        finally:
+            tar1.close()
 
 
 class DetectReadTest(unittest.TestCase):
 
     def _testfunc_file(self, name, mode):
         try:
-            tarfile.open(name, mode)
+            tar = tarfile.open(name, mode)
         except tarfile.ReadError:
             self.fail()
+        else:
+            tar.close()
 
     def _testfunc_fileobj(self, name, mode):
         try:
-            tarfile.open(name, mode, fileobj=open(name, "rb"))
+            tar = tarfile.open(name, mode, fileobj=open(name, "rb"))
         except tarfile.ReadError:
             self.fail()
+        else:
+            tar.close()
 
     def _test_modes(self, testfunc):
         testfunc(tarname, "r")
@@ -715,33 +737,39 @@
 
     def test_pax_global_headers(self):
         tar = tarfile.open(tarname, encoding="iso8859-1")
+        try:
 
-        tarinfo = tar.getmember("pax/regtype1")
-        self.assertEqual(tarinfo.uname, "foo")
-        self.assertEqual(tarinfo.gname, "bar")
-        self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), u"\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
+            tarinfo = tar.getmember("pax/regtype1")
+            self.assertEqual(tarinfo.uname, "foo")
+            self.assertEqual(tarinfo.gname, "bar")
+            self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), u"\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
 
-        tarinfo = tar.getmember("pax/regtype2")
-        self.assertEqual(tarinfo.uname, "")
-        self.assertEqual(tarinfo.gname, "bar")
-        self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), u"\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
+            tarinfo = tar.getmember("pax/regtype2")
+            self.assertEqual(tarinfo.uname, "")
+            self.assertEqual(tarinfo.gname, "bar")
+            self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), u"\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
 
-        tarinfo = tar.getmember("pax/regtype3")
-        self.assertEqual(tarinfo.uname, "tarfile")
-        self.assertEqual(tarinfo.gname, "tarfile")
-        self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), u"\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
+            tarinfo = tar.getmember("pax/regtype3")
+            self.assertEqual(tarinfo.uname, "tarfile")
+            self.assertEqual(tarinfo.gname, "tarfile")
+            self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), u"\xc4\xd6\xdc\xe4\xf6\xfc\xdf")
+        finally:
+            tar.close()
 
     def test_pax_number_fields(self):
         # All following number fields are read from the pax header.
         tar = tarfile.open(tarname, encoding="iso8859-1")
-        tarinfo = tar.getmember("pax/regtype4")
-        self.assertEqual(tarinfo.size, 7011)
-        self.assertEqual(tarinfo.uid, 123)
-        self.assertEqual(tarinfo.gid, 123)
-        self.assertEqual(tarinfo.mtime, 1041808783.0)
-        self.assertEqual(type(tarinfo.mtime), float)
-        self.assertEqual(float(tarinfo.pax_headers["atime"]), 1041808783.0)
-        self.assertEqual(float(tarinfo.pax_headers["ctime"]), 1041808783.0)
+        try:
+            tarinfo = tar.getmember("pax/regtype4")
+            self.assertEqual(tarinfo.size, 7011)
+            self.assertEqual(tarinfo.uid, 123)
+            self.assertEqual(tarinfo.gid, 123)
+            self.assertEqual(tarinfo.mtime, 1041808783.0)
+            self.assertEqual(type(tarinfo.mtime), float)
+            self.assertEqual(float(tarinfo.pax_headers["atime"]), 1041808783.0)
+            self.assertEqual(float(tarinfo.pax_headers["ctime"]), 1041808783.0)
+        finally:
+            tar.close()
 
 
 class WriteTestBase(unittest.TestCase):
@@ -773,52 +801,60 @@
         # a trailing '\0'.
         name = "0123456789" * 10
         tar = tarfile.open(tmpname, self.mode)
-        t = tarfile.TarInfo(name)
-        tar.addfile(t)
-        tar.close()
+        try:
+            t = tarfile.TarInfo(name)
+            tar.addfile(t)
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname)
-        self.assertTrue(tar.getnames()[0] == name,
-                "failed to store 100 char filename")
-        tar.close()
+        try:
+            self.assertTrue(tar.getnames()[0] == name,
+                    "failed to store 100 char filename")
+        finally:
+            tar.close()
 
     def test_tar_size(self):
         # Test for bug #1013882.
         tar = tarfile.open(tmpname, self.mode)
-        path = os.path.join(TEMPDIR, "file")
-        fobj = open(path, "wb")
-        fobj.write("aaa")
-        fobj.close()
-        tar.add(path)
-        tar.close()
+        try:
+            path = os.path.join(TEMPDIR, "file")
+            with open(path, "wb") as fobj:
+                fobj.write("aaa")
+            tar.add(path)
+        finally:
+            tar.close()
         self.assertTrue(os.path.getsize(tmpname) > 0,
                 "tarfile is empty")
 
     # The test_*_size tests test for bug #1167128.
     def test_file_size(self):
         tar = tarfile.open(tmpname, self.mode)
+        try:
 
-        path = os.path.join(TEMPDIR, "file")
-        fobj = open(path, "wb")
-        fobj.close()
-        tarinfo = tar.gettarinfo(path)
-        self.assertEqual(tarinfo.size, 0)
+            path = os.path.join(TEMPDIR, "file")
+            with open(path, "wb"):
+                pass
+            tarinfo = tar.gettarinfo(path)
+            self.assertEqual(tarinfo.size, 0)
 
-        fobj = open(path, "wb")
-        fobj.write("aaa")
-        fobj.close()
-        tarinfo = tar.gettarinfo(path)
-        self.assertEqual(tarinfo.size, 3)
-
-        tar.close()
+            with open(path, "wb") as fobj:
+                fobj.write("aaa")
+            tarinfo = tar.gettarinfo(path)
+            self.assertEqual(tarinfo.size, 3)
+        finally:
+            tar.close()
 
     def test_directory_size(self):
         path = os.path.join(TEMPDIR, "directory")
         os.mkdir(path)
         try:
             tar = tarfile.open(tmpname, self.mode)
-            tarinfo = tar.gettarinfo(path)
-            self.assertEqual(tarinfo.size, 0)
+            try:
+                tarinfo = tar.gettarinfo(path)
+                self.assertEqual(tarinfo.size, 0)
+            finally:
+                tar.close()
         finally:
             os.rmdir(path)
 
@@ -826,16 +862,18 @@
         if hasattr(os, "link"):
             link = os.path.join(TEMPDIR, "link")
             target = os.path.join(TEMPDIR, "link_target")
-            fobj = open(target, "wb")
-            fobj.write("aaa")
-            fobj.close()
+            with open(target, "wb") as fobj:
+                fobj.write("aaa")
             os.link(target, link)
             try:
                 tar = tarfile.open(tmpname, self.mode)
-                # Record the link target in the inodes list.
-                tar.gettarinfo(target)
-                tarinfo = tar.gettarinfo(link)
-                self.assertEqual(tarinfo.size, 0)
+                try:
+                    # Record the link target in the inodes list.
+                    tar.gettarinfo(target)
+                    tarinfo = tar.gettarinfo(link)
+                    self.assertEqual(tarinfo.size, 0)
+                finally:
+                    tar.close()
             finally:
                 os.remove(target)
                 os.remove(link)
@@ -846,26 +884,30 @@
             os.symlink("link_target", path)
             try:
                 tar = tarfile.open(tmpname, self.mode)
-                tarinfo = tar.gettarinfo(path)
-                self.assertEqual(tarinfo.size, 0)
+                try:
+                    tarinfo = tar.gettarinfo(path)
+                    self.assertEqual(tarinfo.size, 0)
+                finally:
+                    tar.close()
             finally:
                 os.remove(path)
 
     def test_add_self(self):
         # Test for #1257255.
         dstname = os.path.abspath(tmpname)
+        tar = tarfile.open(tmpname, self.mode)
+        try:
+            self.assertTrue(tar.name == dstname, "archive name must be absolute")
+            tar.add(dstname)
+            self.assertTrue(tar.getnames() == [], "added the archive to itself")
 
-        tar = tarfile.open(tmpname, self.mode)
-        self.assertTrue(tar.name == dstname, "archive name must be absolute")
-
-        tar.add(dstname)
-        self.assertTrue(tar.getnames() == [], "added the archive to itself")
-
-        cwd = os.getcwd()
-        os.chdir(TEMPDIR)
-        tar.add(dstname)
-        os.chdir(cwd)
-        self.assertTrue(tar.getnames() == [], "added the archive to itself")
+            cwd = os.getcwd()
+            os.chdir(TEMPDIR)
+            tar.add(dstname)
+            os.chdir(cwd)
+            self.assertTrue(tar.getnames() == [], "added the archive to itself")
+        finally:
+            tar.close()
 
     def test_exclude(self):
         tempdir = os.path.join(TEMPDIR, "exclude")
@@ -878,14 +920,19 @@
             exclude = os.path.isfile
 
             tar = tarfile.open(tmpname, self.mode, encoding="iso8859-1")
-            with test_support.check_warnings(("use the filter argument",
-                                              DeprecationWarning)):
-                tar.add(tempdir, arcname="empty_dir", exclude=exclude)
-            tar.close()
+            try:
+                with test_support.check_warnings(("use the filter argument",
+                                                DeprecationWarning)):
+                    tar.add(tempdir, arcname="empty_dir", exclude=exclude)
+            finally:
+                tar.close()
 
             tar = tarfile.open(tmpname, "r")
-            self.assertEqual(len(tar.getmembers()), 1)
-            self.assertEqual(tar.getnames()[0], "empty_dir")
+            try:
+                self.assertEqual(len(tar.getmembers()), 1)
+                self.assertEqual(tar.getnames()[0], "empty_dir")
+            finally:
+                tar.close()
         finally:
             shutil.rmtree(tempdir)
 
@@ -905,15 +952,19 @@
                 return tarinfo
 
             tar = tarfile.open(tmpname, self.mode, encoding="iso8859-1")
-            tar.add(tempdir, arcname="empty_dir", filter=filter)
-            tar.close()
+            try:
+                tar.add(tempdir, arcname="empty_dir", filter=filter)
+            finally:
+                tar.close()
 
             tar = tarfile.open(tmpname, "r")
-            for tarinfo in tar:
-                self.assertEqual(tarinfo.uid, 123)
-                self.assertEqual(tarinfo.uname, "foo")
-            self.assertEqual(len(tar.getmembers()), 3)
-            tar.close()
+            try:
+                for tarinfo in tar:
+                    self.assertEqual(tarinfo.uid, 123)
+                    self.assertEqual(tarinfo.uname, "foo")
+                self.assertEqual(len(tar.getmembers()), 3)
+            finally:
+                tar.close()
         finally:
             shutil.rmtree(tempdir)
 
@@ -931,12 +982,16 @@
             os.mkdir(foo)
 
         tar = tarfile.open(tmpname, self.mode)
-        tar.add(foo, arcname=path)
-        tar.close()
+        try:
+            tar.add(foo, arcname=path)
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname, "r")
-        t = tar.next()
-        tar.close()
+        try:
+            t = tar.next()
+        finally:
+            tar.close()
 
         if not dir:
             os.remove(foo)
@@ -972,16 +1027,18 @@
     def test_cwd(self):
         # Test adding the current working directory.
         with support.change_cwd(TEMPDIR):
-            open("foo", "w").close()
-
             tar = tarfile.open(tmpname, self.mode)
-            tar.add(".")
-            tar.close()
+            try:
+                tar.add(".")
+            finally:
+                tar.close()
 
             tar = tarfile.open(tmpname, "r")
-            for t in tar:
-                self.assertTrue(t.name == "." or t.name.startswith("./"))
-            tar.close()
+            try:
+                for t in tar:
+                    self.assertTrue(t.name == "." or t.name.startswith("./"))
+            finally:
+                tar.close()
 
     @unittest.skipUnless(hasattr(os, 'symlink'), "needs os.symlink")
     def test_extractall_symlinks(self):
@@ -1098,19 +1155,18 @@
         tar.close()
 
         if self.mode.endswith("gz"):
-            fobj = gzip.GzipFile(tmpname)
-            data = fobj.read()
-            fobj.close()
+            with gzip.GzipFile(tmpname) as fobj:
+                data = fobj.read()
         elif self.mode.endswith("bz2"):
             dec = bz2.BZ2Decompressor()
-            data = open(tmpname, "rb").read()
+            with open(tmpname, "rb") as fobj:
+                data = fobj.read()
             data = dec.decompress(data)
             self.assertTrue(len(dec.unused_data) == 0,
                     "found trailing data")
         else:
-            fobj = open(tmpname, "rb")
-            data = fobj.read()
-            fobj.close()
+            with open(tmpname, "rb") as fobj:
+                data = fobj.read()
 
         self.assertTrue(data.count("\0") == tarfile.RECORDSIZE,
                          "incorrect zero padding")
@@ -1171,23 +1227,27 @@
             tarinfo.type = tarfile.LNKTYPE
 
         tar = tarfile.open(tmpname, "w")
-        tar.format = tarfile.GNU_FORMAT
-        tar.addfile(tarinfo)
+        try:
+            tar.format = tarfile.GNU_FORMAT
+            tar.addfile(tarinfo)
 
-        v1 = self._calc_size(name, link)
-        v2 = tar.offset
-        self.assertTrue(v1 == v2, "GNU longname/longlink creation failed")
-
-        tar.close()
+            v1 = self._calc_size(name, link)
+            v2 = tar.offset
+            self.assertTrue(v1 == v2, "GNU longname/longlink creation failed")
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname)
-        member = tar.next()
-        self.assertIsNotNone(member,
-                "unable to read longname member")
-        self.assertEqual(tarinfo.name, member.name,
-                "unable to read longname member")
-        self.assertEqual(tarinfo.linkname, member.linkname,
-                "unable to read longname member")
+        try:
+            member = tar.next()
+            self.assertIsNotNone(member,
+                    "unable to read longname member")
+            self.assertEqual(tarinfo.name, member.name,
+                    "unable to read longname member")
+            self.assertEqual(tarinfo.linkname, member.linkname,
+                    "unable to read longname member")
+        finally:
+            tar.close()
 
     def test_longname_1023(self):
         self._test(("longnam/" * 127) + "longnam")
@@ -1227,9 +1287,8 @@
         self.foo = os.path.join(TEMPDIR, "foo")
         self.bar = os.path.join(TEMPDIR, "bar")
 
-        fobj = open(self.foo, "wb")
-        fobj.write("foo")
-        fobj.close()
+        with open(self.foo, "wb") as fobj:
+            fobj.write("foo")
 
         os.link(self.foo, self.bar)
 
@@ -1238,8 +1297,8 @@
 
     def tearDown(self):
         self.tar.close()
-        os.remove(self.foo)
-        os.remove(self.bar)
+        support.unlink(self.foo)
+        support.unlink(self.bar)
 
     def test_add_twice(self):
         # The same name will be added as a REGTYPE every
@@ -1270,16 +1329,21 @@
             tarinfo.type = tarfile.LNKTYPE
 
         tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT)
-        tar.addfile(tarinfo)
-        tar.close()
+        try:
+            tar.addfile(tarinfo)
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname)
-        if link:
-            l = tar.getmembers()[0].linkname
-            self.assertTrue(link == l, "PAX longlink creation failed")
-        else:
-            n = tar.getmembers()[0].name
-            self.assertTrue(name == n, "PAX longname creation failed")
+        try:
+            if link:
+                l = tar.getmembers()[0].linkname
+                self.assertTrue(link == l, "PAX longlink creation failed")
+            else:
+                n = tar.getmembers()[0].name
+                self.assertTrue(name == n, "PAX longname creation failed")
+        finally:
+            tar.close()
 
     def test_pax_global_header(self):
         pax_headers = {
@@ -1291,23 +1355,28 @@
 
         tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT,
                 pax_headers=pax_headers)
-        tar.addfile(tarfile.TarInfo("test"))
-        tar.close()
+        try:
+            tar.addfile(tarfile.TarInfo("test"))
+        finally:
+            tar.close()
 
         # Test if the global header was written correctly.
         tar = tarfile.open(tmpname, encoding="iso8859-1")
-        self.assertEqual(tar.pax_headers, pax_headers)
-        self.assertEqual(tar.getmembers()[0].pax_headers, pax_headers)
+        try:
+            self.assertEqual(tar.pax_headers, pax_headers)
+            self.assertEqual(tar.getmembers()[0].pax_headers, pax_headers)
 
-        # Test if all the fields are unicode.
-        for key, val in tar.pax_headers.iteritems():
-            self.assertTrue(type(key) is unicode)
-            self.assertTrue(type(val) is unicode)
-            if key in tarfile.PAX_NUMBER_FIELDS:
-                try:
-                    tarfile.PAX_NUMBER_FIELDS[key](val)
-                except (TypeError, ValueError):
-                    self.fail("unable to convert pax header field")
+            # Test if all the fields are unicode.
+            for key, val in tar.pax_headers.iteritems():
+                self.assertTrue(type(key) is unicode)
+                self.assertTrue(type(val) is unicode)
+                if key in tarfile.PAX_NUMBER_FIELDS:
+                    try:
+                        tarfile.PAX_NUMBER_FIELDS[key](val)
+                    except (TypeError, ValueError):
+                        self.fail("unable to convert pax header field")
+        finally:
+            tar.close()
 
     def test_pax_extended_header(self):
         # The fields from the pax header have priority over the
@@ -1315,18 +1384,23 @@
         pax_headers = {u"path": u"foo", u"uid": u"123"}
 
         tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT, encoding="iso8859-1")
-        t = tarfile.TarInfo()
-        t.name = u"\xe4\xf6\xfc"     # non-ASCII
-        t.uid = 8**8        # too large
-        t.pax_headers = pax_headers
-        tar.addfile(t)
-        tar.close()
+        try:
+            t = tarfile.TarInfo()
+            t.name = u"\xe4\xf6\xfc"     # non-ASCII
+            t.uid = 8**8        # too large
+            t.pax_headers = pax_headers
+            tar.addfile(t)
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname, encoding="iso8859-1")
-        t = tar.getmembers()[0]
-        self.assertEqual(t.pax_headers, pax_headers)
-        self.assertEqual(t.name, "foo")
-        self.assertEqual(t.uid, 123)
+        try:
+            t = tar.getmembers()[0]
+            self.assertEqual(t.pax_headers, pax_headers)
+            self.assertEqual(t.name, "foo")
+            self.assertEqual(t.uid, 123)
+        finally:
+            tar.close()
 
 
 class UstarUnicodeTest(unittest.TestCase):
@@ -1345,40 +1419,49 @@
 
     def _test_unicode_filename(self, encoding):
         tar = tarfile.open(tmpname, "w", format=self.format, encoding=encoding, errors="strict")
-        name = u"\xe4\xf6\xfc"
-        tar.addfile(tarfile.TarInfo(name))
-        tar.close()
+        try:
+            name = u"\xe4\xf6\xfc"
+            tar.addfile(tarfile.TarInfo(name))
+        finally:
+            tar.close()
 
         tar = tarfile.open(tmpname, encoding=encoding)
-        self.assertTrue(type(tar.getnames()[0]) is not unicode)
-        self.assertEqual(tar.getmembers()[0].name, name.encode(encoding))
-        tar.close()
+        try:
+            self.assertTrue(type(tar.getnames()[0]) is not unicode)
+            self.assertEqual(tar.getmembers()[0].name, name.encode(encoding))
+        finally:
+            tar.close()
 
     def test_unicode_filename_error(self):
         tar = tarfile.open(tmpname, "w", format=self.format, encoding="ascii", errors="strict")
-        tarinfo = tarfile.TarInfo()
+        try:
+            tarinfo = tarfile.TarInfo()
 
-        tarinfo.name = "\xe4\xf6\xfc"
-        if self.format == tarfile.PAX_FORMAT:
+            tarinfo.name = "\xe4\xf6\xfc"
+            if self.format == tarfile.PAX_FORMAT:
+                self.assertRaises(UnicodeError, tar.addfile, tarinfo)
+            else:
+                tar.addfile(tarinfo)
+
+            tarinfo.name = u"\xe4\xf6\xfc"
             self.assertRaises(UnicodeError, tar.addfile, tarinfo)
-        else:
-            tar.addfile(tarinfo)
 
-        tarinfo.name = u"\xe4\xf6\xfc"
-        self.assertRaises(UnicodeError, tar.addfile, tarinfo)
-
-        tarinfo.name = "foo"
-        tarinfo.uname = u"\xe4\xf6\xfc"
-        self.assertRaises(UnicodeError, tar.addfile, tarinfo)
+            tarinfo.name = "foo"
+            tarinfo.uname = u"\xe4\xf6\xfc"
+            self.assertRaises(UnicodeError, tar.addfile, tarinfo)
+        finally:
+            tar.close()
 
     def test_unicode_argument(self):
         tar = tarfile.open(tarname, "r", encoding="iso8859-1", errors="strict")
-        for t in tar:
-            self.assertTrue(type(t.name) is str)
-            self.assertTrue(type(t.linkname) is str)
-            self.assertTrue(type(t.uname) is str)
-            self.assertTrue(type(t.gname) is str)
-        tar.close()
+        try:
+            for t in tar:
+                self.assertTrue(type(t.name) is str)
+                self.assertTrue(type(t.linkname) is str)
+                self.assertTrue(type(t.uname) is str)
+                self.assertTrue(type(t.gname) is str)
+        finally:
+            tar.close()
 
     def test_uname_unicode(self):
         for name in (u"\xe4\xf6\xfc", "\xe4\xf6\xfc"):
@@ -1388,8 +1471,10 @@
 
             fobj = StringIO.StringIO()
             tar = tarfile.open("foo.tar", mode="w", fileobj=fobj, format=self.format, encoding="iso8859-1")
-            tar.addfile(t)
-            tar.close()
+            try:
+                tar.addfile(t)
+            finally:
+                tar.close()
             fobj.seek(0)
 
             tar = tarfile.open("foo.tar", fileobj=fobj, encoding="iso8859-1")
@@ -1447,22 +1532,20 @@
             os.remove(self.tarname)
 
     def _add_testfile(self, fileobj=None):
-        tar = tarfile.open(self.tarname, "a", fileobj=fileobj)
-        tar.addfile(tarfile.TarInfo("bar"))
-        tar.close()
+        with tarfile.open(self.tarname, "a", fileobj=fileobj) as tar:
+            tar.addfile(tarfile.TarInfo("bar"))
 
     def _create_testtar(self, mode="w:"):
-        src = tarfile.open(tarname, encoding="iso8859-1")
-        t = src.getmember("ustar/regtype")
-        t.name = "foo"
-        f = src.extractfile(t)
-        tar = tarfile.open(self.tarname, mode)
-        tar.addfile(t, f)
-        tar.close()
+        with tarfile.open(tarname, encoding="iso8859-1") as src:
+            t = src.getmember("ustar/regtype")
+            t.name = "foo"
+            f = src.extractfile(t)
+            with tarfile.open(self.tarname, mode) as tar:
+                tar.addfile(t, f)
 
     def _test(self, names=["bar"], fileobj=None):
-        tar = tarfile.open(self.tarname, fileobj=fileobj)
-        self.assertEqual(tar.getnames(), names)
+        with tarfile.open(self.tarname, fileobj=fileobj) as tar:
+            self.assertEqual(tar.getnames(), names)
 
     def test_non_existing(self):
         self._add_testfile()
@@ -1481,7 +1564,8 @@
 
     def test_fileobj(self):
         self._create_testtar()
-        data = open(self.tarname).read()
+        with open(self.tarname) as fobj:
+            data = fobj.read()
         fobj = StringIO.StringIO(data)
         self._add_testfile(fobj)
         fobj.seek(0)
@@ -1505,7 +1589,8 @@
     # Append mode is supposed to fail if the tarfile to append to
     # does not end with a zero block.
     def _test_error(self, data):
-        open(self.tarname, "wb").write(data)
+        with open(self.tarname, "wb") as fobj:
+            fobj.write(data)
         self.assertRaises(tarfile.ReadError, self._add_testfile)
 
     def test_null(self):
@@ -1641,15 +1726,14 @@
     def test_fileobj(self):
         # Test that __exit__() did not close the external file
         # object.
-        fobj = open(tmpname, "wb")
-        try:
-            with tarfile.open(fileobj=fobj, mode="w") as tar:
-                raise Exception
-        except:
-            pass
-        self.assertFalse(fobj.closed, "external file object was closed")
-        self.assertTrue(tar.closed, "context manager failed")
-        fobj.close()
+        with open(tmpname, "wb") as fobj:
+            try:
+                with tarfile.open(fileobj=fobj, mode="w") as tar:
+                    raise Exception
+            except:
+                pass
+            self.assertFalse(fobj.closed, "external file object was closed")
+            self.assertTrue(tar.closed, "context manager failed")
 
 
 class LinkEmulationTest(ReadTest):
@@ -1737,6 +1821,7 @@
 
 
 def test_main():
+    support.unlink(TEMPDIR)
     os.makedirs(TEMPDIR)
 
     tests = [
@@ -1766,15 +1851,14 @@
     else:
         tests.append(LinkEmulationTest)
 
-    fobj = open(tarname, "rb")
-    data = fobj.read()
-    fobj.close()
+    with open(tarname, "rb") as fobj:
+        data = fobj.read()
 
     if gzip:
         # Create testtar.tar.gz and add gzip-specific tests.
-        tar = gzip.open(gzipname, "wb")
-        tar.write(data)
-        tar.close()
+        support.unlink(gzipname)
+        with gzip.open(gzipname, "wb") as tar:
+            tar.write(data)
 
         tests += [
             GzipMiscReadTest,
@@ -1787,9 +1871,12 @@
 
     if bz2:
         # Create testtar.tar.bz2 and add bz2-specific tests.
+        support.unlink(bz2name)
         tar = bz2.BZ2File(bz2name, "wb")
-        tar.write(data)
-        tar.close()
+        try:
+            tar.write(data)
+        finally:
+            tar.close()
 
         tests += [
             Bz2MiscReadTest,

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list