[Python-checkins] r65032 - in python/trunk: Lib/test/test_threading.py Lib/threading.py Python/ceval.c

jesse.noller python-checkins at python.org
Wed Jul 16 22:03:47 CEST 2008


Author: jesse.noller
Date: Wed Jul 16 22:03:47 2008
New Revision: 65032

Log:
Apply patch for 874900: threading module can deadlock after fork


Modified:
   python/trunk/Lib/test/test_threading.py
   python/trunk/Lib/threading.py
   python/trunk/Python/ceval.c

Modified: python/trunk/Lib/test/test_threading.py
==============================================================================
--- python/trunk/Lib/test/test_threading.py	(original)
+++ python/trunk/Lib/test/test_threading.py	Wed Jul 16 22:03:47 2008
@@ -323,6 +323,82 @@
                                sys.getrefcount(weak_raising_cyclic_object())))
 
 
+class ThreadJoinOnShutdown(unittest.TestCase):
+
+    def _run_and_join(self, script):
+        script = """if 1:
+            import sys, os, time, threading
+
+            # a thread, which waits for the main program to terminate
+            def joiningfunc(mainthread):
+                mainthread.join()
+                print 'end of thread'
+        \n""" + script
+
+        import subprocess
+        p = subprocess.Popen([sys.executable, "-c", script], stdout=subprocess.PIPE)
+        rc = p.wait()
+        self.assertEqual(p.stdout.read(), "end of main\nend of thread\n")
+        self.failIf(rc == 2, "interpreter was blocked")
+        self.failUnless(rc == 0, "Unexpected error")
+
+    def test_1_join_on_shutdown(self):
+        # The usual case: on exit, wait for a non-daemon thread
+        script = """if 1:
+            import os
+            t = threading.Thread(target=joiningfunc,
+                                 args=(threading.current_thread(),))
+            t.start()
+            time.sleep(0.1)
+            print 'end of main'
+            """
+        self._run_and_join(script)
+
+
+    def test_2_join_in_forked_process(self):
+        # Like the test above, but from a forked interpreter
+        import os
+        if not hasattr(os, 'fork'):
+            return
+        script = """if 1:
+            childpid = os.fork()
+            if childpid != 0:
+                os.waitpid(childpid, 0)
+                sys.exit(0)
+
+            t = threading.Thread(target=joiningfunc,
+                                 args=(threading.current_thread(),))
+            t.start()
+            print 'end of main'
+            """
+        self._run_and_join(script)
+
+    def test_3_join_in_forked_from_thread(self):
+        # Like the test above, but fork() was called from a worker thread
+        # In the forked process, the main Thread object must be marked as stopped.
+        import os
+        if not hasattr(os, 'fork'):
+            return
+        script = """if 1:
+            main_thread = threading.current_thread()
+            def worker():
+                childpid = os.fork()
+                if childpid != 0:
+                    os.waitpid(childpid, 0)
+                    sys.exit(0)
+
+                t = threading.Thread(target=joiningfunc,
+                                     args=(main_thread,))
+                print 'end of main'
+                t.start()
+                t.join() # Should not block: main_thread is already stopped
+
+            w = threading.Thread(target=worker)
+            w.start()
+            """
+        self._run_and_join(script)
+
+
 class ThreadingExceptionTests(unittest.TestCase):
     # A RuntimeError should be raised if Thread.start() is called
     # multiple times.
@@ -363,7 +439,9 @@
 
 def test_main():
     test.test_support.run_unittest(ThreadTests,
-                                   ThreadingExceptionTests)
+                                   ThreadJoinOnShutdown,
+                                   ThreadingExceptionTests,
+                                   )
 
 if __name__ == "__main__":
     test_main()

Modified: python/trunk/Lib/threading.py
==============================================================================
--- python/trunk/Lib/threading.py	(original)
+++ python/trunk/Lib/threading.py	Wed Jul 16 22:03:47 2008
@@ -825,6 +825,37 @@
     from _threading_local import local
 
 
+def _after_fork():
+    # This function is called by Python/ceval.c:PyEval_ReInitThreads which
+    # is called from PyOS_AfterFork.  Here we cleanup threading module state
+    # that should not exist after a fork.
+
+    # Reset _active_limbo_lock, in case we forked while the lock was held
+    # by another (non-forked) thread.  http://bugs.python.org/issue874900
+    global _active_limbo_lock
+    _active_limbo_lock = _allocate_lock()
+
+    # fork() only copied the current thread; clear references to others.
+    new_active = {}
+    current = current_thread()
+    with _active_limbo_lock:
+        for ident, thread in _active.iteritems():
+            if thread is current:
+                # There is only one active thread.
+                new_active[ident] = thread
+            else:
+                # All the others are already stopped.
+                # We don't call _Thread__stop() because it tries to acquire
+                # thread._Thread__block which could also have been held while
+                # we forked.
+                thread._Thread__stopped = True
+
+        _limbo.clear()
+        _active.clear()
+        _active.update(new_active)
+        assert len(_active) == 1
+
+
 # Self-test code
 
 def _test():

Modified: python/trunk/Python/ceval.c
==============================================================================
--- python/trunk/Python/ceval.c	(original)
+++ python/trunk/Python/ceval.c	Wed Jul 16 22:03:47 2008
@@ -274,6 +274,9 @@
 void
 PyEval_ReInitThreads(void)
 {
+	PyObject *threading, *result;
+	PyThreadState *tstate;
+
 	if (!interpreter_lock)
 		return;
 	/*XXX Can't use PyThread_free_lock here because it does too
@@ -283,6 +286,23 @@
 	interpreter_lock = PyThread_allocate_lock();
 	PyThread_acquire_lock(interpreter_lock, 1);
 	main_thread = PyThread_get_thread_ident();
+
+	/* Update the threading module with the new state.
+	 */
+	tstate = PyThreadState_GET();
+	threading = PyMapping_GetItemString(tstate->interp->modules,
+					    "threading");
+	if (threading == NULL) {
+		/* threading not imported */
+		PyErr_Clear();
+		return;
+	}
+	result = PyObject_CallMethod(threading, "_after_fork", NULL);
+	if (result == NULL)
+		PyErr_WriteUnraisable(threading);
+	else
+		Py_DECREF(result);
+	Py_DECREF(threading);
 }
 #endif
 


More information about the Python-checkins mailing list