[Python-checkins] [2.7] bpo-31234: Join threads explicitly in tests (#7406)

Victor Stinner webhook-mailer at python.org
Mon Jun 4 17:53:55 EDT 2018

commit: 146351860a34b3cde387930a360e57391e7b99f3
branch: 2.7
author: Victor Stinner <vstinner at redhat.com>
committer: GitHub <noreply at github.com>
date: 2018-06-04T23:53:52+02:00

[2.7] bpo-31234: Join threads explicitly in tests (#7406)

* Add support.wait_threads_exit(): context manager looping at exit
  until the number of threads decreases to its original number.
* Add some missing thread.join()
* test_asyncore.test_send(): call explicitly t.join() because the cleanup
  function is only called outside the test method, whereas the method
  has a @test_support.reap_threads decorator
* test_hashlib: replace threading.Event with thread.join()
* test_thread:

  * Use wait_threads_exit() context manager
  * Replace test_support with support
  * test_forkinthread(): check child process exit status in the
    main thread to better handle error.

M Lib/test/support/__init__.py
M Lib/test/test_asyncore.py
M Lib/test/test_hashlib.py
M Lib/test/test_httpservers.py
M Lib/test/test_smtplib.py
M Lib/test/test_thread.py

diff --git a/Lib/test/support/__init__.py b/Lib/test/support/__init__.py
index 3e44c5a35d92..47af6b39465f 100644
--- a/Lib/test/support/__init__.py
+++ b/Lib/test/support/__init__.py
@@ -1722,6 +1722,43 @@ def decorator(*args):
     return decorator
+ at contextlib.contextmanager
+def wait_threads_exit(timeout=60.0):
+    """
+    bpo-31234: Context manager to wait until all threads created in the with
+    statement exit.
+    Use thread.count() to check if threads exited. Indirectly, wait until
+    threads exit the internal t_bootstrap() C function of the thread module.
+    threading_setup() and threading_cleanup() are designed to emit a warning
+    if a test leaves running threads in the background. This context manager
+    is designed to cleanup threads started by the thread.start_new_thread()
+    which doesn't allow to wait for thread exit, whereas thread.Thread has a
+    join() method.
+    """
+    old_count = thread._count()
+    try:
+        yield
+    finally:
+        start_time = time.time()
+        deadline = start_time + timeout
+        while True:
+            count = thread._count()
+            if count <= old_count:
+                break
+            if time.time() > deadline:
+                dt = time.time() - start_time
+                msg = ("wait_threads() failed to cleanup %s "
+                       "threads after %.1f seconds "
+                       "(count: %s, old count: %s)"
+                       % (count - old_count, dt, count, old_count))
+                raise AssertionError(msg)
+            time.sleep(0.010)
+            gc_collect()
 def reap_children():
     """Use this function at the end of test_main() whenever sub-processes
     are started.  This will help ensure that no extra children (zombies)
diff --git a/Lib/test/test_asyncore.py b/Lib/test/test_asyncore.py
index 693d67cd8a3d..4b347a3a6dd6 100644
--- a/Lib/test/test_asyncore.py
+++ b/Lib/test/test_asyncore.py
@@ -727,19 +727,20 @@ def test_quick_connect(self):
         server = TCPServer()
         t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=500))
-        self.addCleanup(t.join)
-        for x in xrange(20):
-            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-            s.settimeout(.2)
-            s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
-                         struct.pack('ii', 1, 0))
-            try:
-                s.connect(server.address)
-            except socket.error:
-                pass
-            finally:
-                s.close()
+        try:
+            for x in xrange(20):
+                s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+                s.settimeout(.2)
+                s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER,
+                             struct.pack('ii', 1, 0))
+                try:
+                    s.connect(server.address)
+                except socket.error:
+                    pass
+                finally:
+                    s.close()
+        finally:
+            t.join()
 class TestAPI_UseSelect(BaseTestAPI):
diff --git a/Lib/test/test_hashlib.py b/Lib/test/test_hashlib.py
index 471ebb4dd17b..b8d6388feaf9 100644
--- a/Lib/test/test_hashlib.py
+++ b/Lib/test/test_hashlib.py
@@ -371,25 +371,25 @@ def test_threaded_hashing(self):
         data = smallest_data*200000
         expected_hash = hashlib.sha1(data*num_threads).hexdigest()
-        def hash_in_chunks(chunk_size, event):
+        def hash_in_chunks(chunk_size):
             index = 0
             while index < len(data):
                 index += chunk_size
-            event.set()
-        events = []
+        threads = []
         for threadnum in xrange(num_threads):
             chunk_size = len(data) // (10**threadnum)
             assert chunk_size > 0
             assert chunk_size % len(smallest_data) == 0
-            event = threading.Event()
-            events.append(event)
-            threading.Thread(target=hash_in_chunks,
-                             args=(chunk_size, event)).start()
-        for event in events:
-            event.wait()
+            thread = threading.Thread(target=hash_in_chunks,
+                                      args=(chunk_size,))
+            threads.append(thread)
+        for thread in threads:
+            thread.start()
+        for thread in threads:
+            thread.join()
         self.assertEqual(expected_hash, hasher.hexdigest())
diff --git a/Lib/test/test_httpservers.py b/Lib/test/test_httpservers.py
index 11f0d5d61439..93807c1959bb 100644
--- a/Lib/test/test_httpservers.py
+++ b/Lib/test/test_httpservers.py
@@ -66,6 +66,7 @@ def run(self):
     def stop(self):
+        self.join()
 class BaseTestCase(unittest.TestCase):
diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py
index 1bb669018807..703b631c175b 100644
--- a/Lib/test/test_smtplib.py
+++ b/Lib/test/test_smtplib.py
@@ -306,12 +306,14 @@ def setUp(self):
         self.port = test_support.bind_port(self.sock)
         servargs = (self.evt, self.respdata, self.sock)
-        threading.Thread(target=server, args=servargs).start()
+        self.thread = threading.Thread(target=server, args=servargs)
+        self.thread.start()
     def tearDown(self):
+        self.thread.join()
         sys.stdout = self.old_stdout
     def testLineTooLong(self):
diff --git a/Lib/test/test_thread.py b/Lib/test/test_thread.py
index c8caa5ddff03..93690a60b2ff 100644
--- a/Lib/test/test_thread.py
+++ b/Lib/test/test_thread.py
@@ -1,8 +1,8 @@
 import os
 import unittest
 import random
-from test import test_support
-thread = test_support.import_module('thread')
+from test import support
+thread = support.import_module('thread')
 import time
 import sys
 import weakref
@@ -17,7 +17,7 @@
 def verbose_print(arg):
     """Helper function for printing out debugging output."""
-    if test_support.verbose:
+    if support.verbose:
         with _print_mutex:
             print arg
@@ -34,8 +34,8 @@ def setUp(self):
         self.running = 0
         self.next_ident = 0
-        key = test_support.threading_setup()
-        self.addCleanup(test_support.threading_cleanup, *key)
+        key = support.threading_setup()
+        self.addCleanup(support.threading_cleanup, *key)
 class ThreadRunningTests(BasicThreadTest):
@@ -60,12 +60,13 @@ def task(self, ident):
     def test_starting_threads(self):
-        # Basic test for thread creation.
-        for i in range(NUMTASKS):
-            self.newtask()
-        verbose_print("waiting for tasks to complete...")
-        self.done_mutex.acquire()
-        verbose_print("all tasks done")
+        with support.wait_threads_exit():
+            # Basic test for thread creation.
+            for i in range(NUMTASKS):
+                self.newtask()
+            verbose_print("waiting for tasks to complete...")
+            self.done_mutex.acquire()
+            verbose_print("all tasks done")
     def test_stack_size(self):
         # Various stack size tests.
@@ -95,12 +96,13 @@ def test_nt_and_posix_stack_size(self):
             verbose_print("trying stack_size = (%d)" % tss)
             self.next_ident = 0
             self.created = 0
-            for i in range(NUMTASKS):
-                self.newtask()
+            with support.wait_threads_exit():
+                for i in range(NUMTASKS):
+                    self.newtask()
-            verbose_print("waiting for all tasks to complete")
-            self.done_mutex.acquire()
-            verbose_print("all tasks done")
+                verbose_print("waiting for all tasks to complete")
+                self.done_mutex.acquire()
+                verbose_print("all tasks done")
@@ -110,25 +112,28 @@ def test__count(self):
         mut = thread.allocate_lock()
         started = []
         def task():
-        thread.start_new_thread(task, ())
-        while not started:
-            time.sleep(0.01)
-        self.assertEqual(thread._count(), orig + 1)
-        # Allow the task to finish.
-        mut.release()
-        # The only reliable way to be sure that the thread ended from the
-        # interpreter's point of view is to wait for the function object to be
-        # destroyed.
-        done = []
-        wr = weakref.ref(task, lambda _: done.append(None))
-        del task
-        while not done:
-            time.sleep(0.01)
-        self.assertEqual(thread._count(), orig)
+        with support.wait_threads_exit():
+            thread.start_new_thread(task, ())
+            while not started:
+                time.sleep(0.01)
+            self.assertEqual(thread._count(), orig + 1)
+            # Allow the task to finish.
+            mut.release()
+            # The only reliable way to be sure that the thread ended from the
+            # interpreter's point of view is to wait for the function object to be
+            # destroyed.
+            done = []
+            wr = weakref.ref(task, lambda _: done.append(None))
+            del task
+            while not done:
+                time.sleep(0.01)
+            self.assertEqual(thread._count(), orig)
     def test_save_exception_state_on_error(self):
         # See issue #14474
@@ -143,14 +148,13 @@ def mywrite(self, *args):
             real_write(self, *args)
         c = thread._count()
         started = thread.allocate_lock()
-        with test_support.captured_output("stderr") as stderr:
+        with support.captured_output("stderr") as stderr:
             real_write = stderr.write
             stderr.write = mywrite
-            thread.start_new_thread(task, ())
-            started.acquire()
-            while thread._count() > c:
-                time.sleep(0.01)
+            with support.wait_threads_exit():
+                thread.start_new_thread(task, ())
+                started.acquire()
         self.assertIn("Traceback", stderr.getvalue())
@@ -182,13 +186,14 @@ def enter(self):
 class BarrierTest(BasicThreadTest):
     def test_barrier(self):
-        self.bar = Barrier(NUMTASKS)
-        self.running = NUMTASKS
-        for i in range(NUMTASKS):
-            thread.start_new_thread(self.task2, (i,))
-        verbose_print("waiting for tasks to end")
-        self.done_mutex.acquire()
-        verbose_print("tasks done")
+        with support.wait_threads_exit():
+            self.bar = Barrier(NUMTASKS)
+            self.running = NUMTASKS
+            for i in range(NUMTASKS):
+                thread.start_new_thread(self.task2, (i,))
+            verbose_print("waiting for tasks to end")
+            self.done_mutex.acquire()
+            verbose_print("tasks done")
     def task2(self, ident):
         for i in range(NUMTRIPS):
@@ -226,8 +231,9 @@ def setUp(self):
                      "This test is only appropriate for POSIX-like systems.")
-    @test_support.reap_threads
+    @support.reap_threads
     def test_forkinthread(self):
+        non_local = {'status': None}
         def thread1():
                 pid = os.fork() # fork in a thread
@@ -246,11 +252,13 @@ def thread1():
             else: # parent
                 pid, status = os.waitpid(pid, 0)
-                self.assertEqual(status, 0)
+                non_local['status'] = status
-        thread.start_new_thread(thread1, ())
-        self.assertEqual(os.read(self.read_fd, 2), "OK",
-                         "Unable to fork() in thread")
+        with support.wait_threads_exit():
+            thread.start_new_thread(thread1, ())
+            self.assertEqual(os.read(self.read_fd, 2), "OK",
+                             "Unable to fork() in thread")
+        self.assertEqual(non_local['status'], 0)
     def tearDown(self):
@@ -265,7 +273,7 @@ def tearDown(self):
 def test_main():
-    test_support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
+    support.run_unittest(ThreadRunningTests, BarrierTest, LockTests,
 if __name__ == "__main__":

More information about the Python-checkins mailing list