[Python-checkins] bpo-32604: Clean up test.support.interpreters. (gh-20926)

Eric Snow webhook-mailer at python.org
Tue Jun 16 20:24:48 EDT 2020


https://github.com/python/cpython/commit/818f5b597ae93411cc44e404544247d436026a00
commit: 818f5b597ae93411cc44e404544247d436026a00
branch: master
author: Eric Snow <ericsnowcurrently at gmail.com>
committer: GitHub <noreply at github.com>
date: 2020-06-16T18:24:40-06:00
summary:

bpo-32604: Clean up test.support.interpreters. (gh-20926)

There were some minor adjustments needed and a few tests were missing.

https://bugs.python.org/issue32604

files:
M Lib/test/support/interpreters.py
M Lib/test/test__xxsubinterpreters.py
M Lib/test/test_interpreters.py

diff --git a/Lib/test/support/interpreters.py b/Lib/test/support/interpreters.py
index 09508e1bbeca0..2935708f9df1a 100644
--- a/Lib/test/support/interpreters.py
+++ b/Lib/test/support/interpreters.py
@@ -1,5 +1,6 @@
 """Subinterpreters High Level Module."""
 
+import time
 import _xxsubinterpreters as _interpreters
 
 # aliases:
@@ -19,47 +20,51 @@
 
 
 def create(*, isolated=True):
-    """
-    Initialize a new (idle) Python interpreter.
-    """
+    """Return a new (idle) Python interpreter."""
     id = _interpreters.create(isolated=isolated)
     return Interpreter(id, isolated=isolated)
 
 
 def list_all():
-    """
-    Get all existing interpreters.
-    """
-    return [Interpreter(id) for id in
-            _interpreters.list_all()]
+    """Return all existing interpreters."""
+    return [Interpreter(id) for id in _interpreters.list_all()]
 
 
 def get_current():
-    """
-    Get the currently running interpreter.
-    """
+    """Return the currently running interpreter."""
     id = _interpreters.get_current()
     return Interpreter(id)
 
 
 def get_main():
-    """
-    Get the main interpreter.
-    """
+    """Return the main interpreter."""
     id = _interpreters.get_main()
     return Interpreter(id)
 
 
 class Interpreter:
-    """
-    The Interpreter object represents
-    a single interpreter.
-    """
+    """A single Python interpreter."""
 
     def __init__(self, id, *, isolated=None):
+        if not isinstance(id, (int, _interpreters.InterpreterID)):
+            raise TypeError(f'id must be an int, got {id!r}')
         self._id = id
         self._isolated = isolated
 
+    def __repr__(self):
+        data = dict(id=int(self._id), isolated=self._isolated)
+        kwargs = (f'{k}={v!r}' for k, v in data.items())
+        return f'{type(self).__name__}({", ".join(kwargs)})'
+
+    def __hash__(self):
+        return hash(self._id)
+
+    def __eq__(self, other):
+        if not isinstance(other, Interpreter):
+            return NotImplemented
+        else:
+            return other._id == self._id
+
     @property
     def id(self):
         return self._id
@@ -67,84 +72,98 @@ def id(self):
     @property
     def isolated(self):
         if self._isolated is None:
+            # XXX The low-level function has not been added yet.
+            # See bpo-....
             self._isolated = _interpreters.is_isolated(self._id)
         return self._isolated
 
     def is_running(self):
-        """
-        Return whether or not the identified
-        interpreter is running.
-        """
+        """Return whether or not the identified interpreter is running."""
         return _interpreters.is_running(self._id)
 
     def close(self):
-        """
-        Finalize and destroy the interpreter.
+        """Finalize and destroy the interpreter.
 
-        Attempting to destroy the current
-        interpreter results in a RuntimeError.
+        Attempting to destroy the current interpreter results
+        in a RuntimeError.
         """
         return _interpreters.destroy(self._id)
 
     def run(self, src_str, /, *, channels=None):
-        """
-        Run the given source code in the interpreter.
+        """Run the given source code in the interpreter.
+
         This blocks the current Python thread until done.
         """
-        _interpreters.run_string(self._id, src_str)
+        _interpreters.run_string(self._id, src_str, channels)
 
 
 def create_channel():
-    """
-    Create a new channel for passing data between
-    interpreters.
-    """
+    """Return (recv, send) for a new cross-interpreter channel.
 
+    The channel may be used to pass data safely between interpreters.
+    """
     cid = _interpreters.channel_create()
-    return (RecvChannel(cid), SendChannel(cid))
+    recv, send = RecvChannel(cid), SendChannel(cid)
+    return recv, send
 
 
 def list_all_channels():
-    """
-    Get all open channels.
-    """
+    """Return a list of (recv, send) for all open channels."""
     return [(RecvChannel(cid), SendChannel(cid))
             for cid in _interpreters.channel_list_all()]
 
 
+class _ChannelEnd:
+    """The base class for RecvChannel and SendChannel."""
+
+    def __init__(self, id):
+        if not isinstance(id, (int, _interpreters.ChannelID)):
+            raise TypeError(f'id must be an int, got {id!r}')
+        self._id = id
+
+    def __repr__(self):
+        return f'{type(self).__name__}(id={int(self._id)})'
+
+    def __hash__(self):
+        return hash(self._id)
+
+    def __eq__(self, other):
+        if isinstance(self, RecvChannel):
+            if not isinstance(other, RecvChannel):
+                return NotImplemented
+        elif not isinstance(other, SendChannel):
+            return NotImplemented
+        return other._id == self._id
+
+    @property
+    def id(self):
+        return self._id
+
+
 _NOT_SET = object()
 
 
-class RecvChannel:
-    """
-    The RecvChannel object represents
-    a receiving channel.
-    """
+class RecvChannel(_ChannelEnd):
+    """The receiving end of a cross-interpreter channel."""
 
-    def __init__(self, id):
-        self._id = id
+    def recv(self, *, _sentinel=object(), _delay=10 / 1000):  # 10 milliseconds
+        """Return the next object from the channel.
 
-    def recv(self, *, _delay=10 / 1000):  # 10 milliseconds
-        """
-        Get the next object from the channel,
-        and wait if none have been sent.
-        Associate the interpreter with the channel.
+        This blocks until an object has been sent, if none have been
+        sent already.
         """
-        import time
-        sentinel = object()
-        obj = _interpreters.channel_recv(self._id, sentinel)
-        while obj is sentinel:
+        obj = _interpreters.channel_recv(self._id, _sentinel)
+        while obj is _sentinel:
             time.sleep(_delay)
-            obj = _interpreters.channel_recv(self._id, sentinel)
+            obj = _interpreters.channel_recv(self._id, _sentinel)
         return obj
 
     def recv_nowait(self, default=_NOT_SET):
-        """
-        Like recv(), but return the default
-        instead of waiting.
+        """Return the next object from the channel.
 
-        This function is blocked by a missing low-level
-        implementation of channel_recv_wait().
+        If none have been sent then return the default if one
+        is provided or fail with ChannelEmptyError.  Otherwise this
+        is the same as recv().
         """
         if default is _NOT_SET:
             return _interpreters.channel_recv(self._id)
@@ -152,32 +171,27 @@ def recv_nowait(self, default=_NOT_SET):
             return _interpreters.channel_recv(self._id, default)
 
 
-class SendChannel:
-    """
-    The SendChannel object represents
-    a sending channel.
-    """
-
-    def __init__(self, id):
-        self._id = id
+class SendChannel(_ChannelEnd):
+    """The sending end of a cross-interpreter channel."""
 
     def send(self, obj):
+        """Send the object (i.e. its data) to the channel's receiving end.
+
+        This blocks until the object is received.
         """
-        Send the object (i.e. its data) to the receiving
-        end of the channel and wait. Associate the interpreter
-        with the channel.
-        """
-        import time
         _interpreters.channel_send(self._id, obj)
+        # XXX We are missing a low-level channel_send_wait().
+        # See bpo-32604 and gh-19829.
+        # Until that shows up we fake it:
         time.sleep(2)
 
     def send_nowait(self, obj):
-        """
-        Like send(), but return False if not received.
+        """Send the object to the channel's receiving end.
 
-        This function is blocked by a missing low-level
-        implementation of channel_send_wait().
+        If the object is immediately received then return True
+        (else False).  Otherwise this is the same as send().
         """
-
-        _interpreters.channel_send(self._id, obj)
-        return False
+        # XXX Note that at the moment channel_send() only ever returns
+        # None.  This should be fixed when channel_send_wait() is added.
+        # See bpo-32604 and gh-19829.
+        return _interpreters.channel_send(self._id, obj)
diff --git a/Lib/test/test__xxsubinterpreters.py b/Lib/test/test__xxsubinterpreters.py
index 7aec021fb19a5..550a847616cdc 100644
--- a/Lib/test/test__xxsubinterpreters.py
+++ b/Lib/test/test__xxsubinterpreters.py
@@ -759,21 +759,9 @@ def test_still_running(self):
 
 class RunStringTests(TestBase):
 
-    SCRIPT = dedent("""
-        with open('{}', 'w') as out:
-            out.write('{}')
-        """)
-    FILENAME = 'spam'
-
     def setUp(self):
         super().setUp()
         self.id = interpreters.create()
-        self._fs = None
-
-    def tearDown(self):
-        if self._fs is not None:
-            self._fs.close()
-        super().tearDown()
 
     def test_success(self):
         script, file = _captured_script('print("it worked!", end="")')
diff --git a/Lib/test/test_interpreters.py b/Lib/test/test_interpreters.py
index 3451a4c8759d8..58258bb66af8a 100644
--- a/Lib/test/test_interpreters.py
+++ b/Lib/test/test_interpreters.py
@@ -31,10 +31,10 @@ def clean_up_interpreters():
             pass  # already destroyed
 
 
-def _run_output(interp, request, shared=None):
+def _run_output(interp, request, channels=None):
     script, rpipe = _captured_script(request)
     with rpipe:
-        interp.run(script)
+        interp.run(script, channels=channels)
         return rpipe.read()
 
 
@@ -68,25 +68,22 @@ class CreateTests(TestBase):
 
     def test_in_main(self):
         interp = interpreters.create()
-        lst = interpreters.list_all()
-        self.assertEqual(interp.id, lst[1].id)
+        self.assertIsInstance(interp, interpreters.Interpreter)
+        self.assertIn(interp, interpreters.list_all())
 
     def test_in_thread(self):
         lock = threading.Lock()
-        id = None
-        interp = interpreters.create()
-        lst = interpreters.list_all()
+        interp = None
         def f():
-            nonlocal id
-            id = interp.id
+            nonlocal interp
+            interp = interpreters.create()
             lock.acquire()
             lock.release()
-
         t = threading.Thread(target=f)
         with lock:
             t.start()
         t.join()
-        self.assertEqual(interp.id, lst[1].id)
+        self.assertIn(interp, interpreters.list_all())
 
     def test_in_subinterpreter(self):
         main, = interpreters.list_all()
@@ -94,11 +91,10 @@ def test_in_subinterpreter(self):
         out = _run_output(interp, dedent("""
             from test.support import interpreters
             interp = interpreters.create()
-            print(interp)
+            print(interp.id)
             """))
-        interp2 = out.strip()
-
-        self.assertEqual(len(set(interpreters.list_all())), len({main, interp, interp2}))
+        interp2 = interpreters.Interpreter(int(out))
+        self.assertEqual(interpreters.list_all(), [main, interp, interp2])
 
     def test_after_destroy_all(self):
         before = set(interpreters.list_all())
@@ -112,7 +108,7 @@ def test_after_destroy_all(self):
             interp.close()
         # Finally, create another.
         interp = interpreters.create()
-        self.assertEqual(len(set(interpreters.list_all())), len(before | {interp}))
+        self.assertEqual(set(interpreters.list_all()), before | {interp})
 
     def test_after_destroy_some(self):
         before = set(interpreters.list_all())
@@ -125,15 +121,15 @@ def test_after_destroy_some(self):
         interp2.close()
         # Finally, create another.
         interp = interpreters.create()
-        self.assertEqual(len(set(interpreters.list_all())), len(before | {interp3, interp}))
+        self.assertEqual(set(interpreters.list_all()), before | {interp3, interp})
 
 
 class GetCurrentTests(TestBase):
 
     def test_main(self):
-        main_interp_id = _interpreters.get_main()
-        cur_interp_id =  interpreters.get_current().id
-        self.assertEqual(cur_interp_id, main_interp_id)
+        main = interpreters.get_main()
+        current = interpreters.get_current()
+        self.assertEqual(current, main)
 
     def test_subinterpreter(self):
         main = _interpreters.get_main()
@@ -141,10 +137,10 @@ def test_subinterpreter(self):
         out = _run_output(interp, dedent("""
             from test.support import interpreters
             cur = interpreters.get_current()
-            print(cur)
+            print(cur.id)
             """))
-        cur = out.strip()
-        self.assertNotEqual(cur, main)
+        current = interpreters.Interpreter(int(out))
+        self.assertNotEqual(current, main)
 
 
 class ListAllTests(TestBase):
@@ -177,26 +173,75 @@ def test_after_destroying(self):
         self.assertEqual(ids, [main.id, second.id])
 
 
-class TestInterpreterId(TestBase):
+class TestInterpreterAttrs(TestBase):
 
-    def test_in_main(self):
-        main = interpreters.get_current()
-        self.assertEqual(0, main.id)
+    def test_id_type(self):
+        main = interpreters.get_main()
+        current = interpreters.get_current()
+        interp = interpreters.create()
+        self.assertIsInstance(main.id, _interpreters.InterpreterID)
+        self.assertIsInstance(current.id, _interpreters.InterpreterID)
+        self.assertIsInstance(interp.id, _interpreters.InterpreterID)
 
-    def test_with_custom_num(self):
+    def test_main_id(self):
+        main = interpreters.get_main()
+        self.assertEqual(main.id, 0)
+
+    def test_custom_id(self):
         interp = interpreters.Interpreter(1)
-        self.assertEqual(1, interp.id)
+        self.assertEqual(interp.id, 1)
+
+        with self.assertRaises(TypeError):
+            interpreters.Interpreter('1')
 
-    def test_for_readonly_property(self):
+    def test_id_readonly(self):
         interp = interpreters.Interpreter(1)
         with self.assertRaises(AttributeError):
             interp.id = 2
 
+    @unittest.skip('not ready yet (see bpo-32604)')
+    def test_main_isolated(self):
+        main = interpreters.get_main()
+        self.assertFalse(main.isolated)
+
+    @unittest.skip('not ready yet (see bpo-32604)')
+    def test_subinterpreter_isolated_default(self):
+        interp = interpreters.create()
+        self.assertFalse(interp.isolated)
+
+    def test_subinterpreter_isolated_explicit(self):
+        interp1 = interpreters.create(isolated=True)
+        interp2 = interpreters.create(isolated=False)
+        self.assertTrue(interp1.isolated)
+        self.assertFalse(interp2.isolated)
+
+    @unittest.skip('not ready yet (see bpo-32604)')
+    def test_custom_isolated_default(self):
+        interp = interpreters.Interpreter(1)
+        self.assertFalse(interp.isolated)
+
+    def test_custom_isolated_explicit(self):
+        interp1 = interpreters.Interpreter(1, isolated=True)
+        interp2 = interpreters.Interpreter(1, isolated=False)
+        self.assertTrue(interp1.isolated)
+        self.assertFalse(interp2.isolated)
+
+    def test_isolated_readonly(self):
+        interp = interpreters.Interpreter(1)
+        with self.assertRaises(AttributeError):
+            interp.isolated = True
+
+    def test_equality(self):
+        interp1 = interpreters.create()
+        interp2 = interpreters.create()
+        self.assertEqual(interp1, interp1)
+        self.assertNotEqual(interp1, interp2)
+
 
 class TestInterpreterIsRunning(TestBase):
 
     def test_main(self):
-        main = interpreters.get_current()
+        main = interpreters.get_main()
         self.assertTrue(main.is_running())
 
     def test_subinterpreter(self):
@@ -224,16 +269,29 @@ def test_already_destroyed(self):
         with self.assertRaises(RuntimeError):
             interp.is_running()
 
+    def test_does_not_exist(self):
+        interp = interpreters.Interpreter(1_000_000)
+        with self.assertRaises(RuntimeError):
+            interp.is_running()
+
+    def test_bad_id(self):
+        interp = interpreters.Interpreter(-1)
+        with self.assertRaises(ValueError):
+            interp.is_running()
 
-class TestInterpreterDestroy(TestBase):
+
+class TestInterpreterClose(TestBase):
 
     def test_basic(self):
+        main = interpreters.get_main()
         interp1 = interpreters.create()
         interp2 = interpreters.create()
         interp3 = interpreters.create()
-        self.assertEqual(4, len(interpreters.list_all()))
+        self.assertEqual(set(interpreters.list_all()),
+                         {main, interp1, interp2, interp3})
         interp2.close()
-        self.assertEqual(3, len(interpreters.list_all()))
+        self.assertEqual(set(interpreters.list_all()),
+                         {main, interp1, interp3})
 
     def test_all(self):
         before = set(interpreters.list_all())
@@ -241,10 +299,10 @@ def test_all(self):
         for _ in range(3):
             interp = interpreters.create()
             interps.add(interp)
-        self.assertEqual(len(set(interpreters.list_all())), len(before | interps))
+        self.assertEqual(set(interpreters.list_all()), before | interps)
         for interp in interps:
             interp.close()
-        self.assertEqual(len(set(interpreters.list_all())), len(before))
+        self.assertEqual(set(interpreters.list_all()), before)
 
     def test_main(self):
         main, = interpreters.list_all()
@@ -265,32 +323,44 @@ def test_already_destroyed(self):
         with self.assertRaises(RuntimeError):
             interp.close()
 
+    def test_does_not_exist(self):
+        interp = interpreters.Interpreter(1_000_000)
+        with self.assertRaises(RuntimeError):
+            interp.close()
+
+    def test_bad_id(self):
+        interp = interpreters.Interpreter(-1)
+        with self.assertRaises(ValueError):
+            interp.close()
+
     def test_from_current(self):
         main, = interpreters.list_all()
         interp = interpreters.create()
-        script = dedent(f"""
+        out = _run_output(interp, dedent(f"""
             from test.support import interpreters
+            interp = interpreters.Interpreter({int(interp.id)})
             try:
-                main = interpreters.get_current()
-                main.close()
+                interp.close()
             except RuntimeError:
-                pass
-            """)
-
-        interp.run(script)
-        self.assertEqual(len(set(interpreters.list_all())), len({main, interp}))
+                print('failed')
+            """))
+        self.assertEqual(out.strip(), 'failed')
+        self.assertEqual(set(interpreters.list_all()), {main, interp})
 
     def test_from_sibling(self):
         main, = interpreters.list_all()
         interp1 = interpreters.create()
-        script = dedent(f"""
+        interp2 = interpreters.create()
+        self.assertEqual(set(interpreters.list_all()),
+                         {main, interp1, interp2})
+        interp1.run(dedent(f"""
             from test.support import interpreters
-            interp2 = interpreters.create()
+            interp2 = interpreters.Interpreter(int({interp2.id}))
             interp2.close()
-            """)
-        interp1.run(script)
-
-        self.assertEqual(len(set(interpreters.list_all())), len({main, interp1}))
+            interp3 = interpreters.create()
+            interp3.close()
+            """))
+        self.assertEqual(set(interpreters.list_all()), {main, interp1})
 
     def test_from_other_thread(self):
         interp = interpreters.create()
@@ -312,41 +382,21 @@ def test_still_running(self):
 
 class TestInterpreterRun(TestBase):
 
-    SCRIPT = dedent("""
-        with open('{}', 'w') as out:
-            out.write('{}')
-        """)
-    FILENAME = 'spam'
-
-    def setUp(self):
-        super().setUp()
-        self.interp = interpreters.create()
-        self._fs = None
-
-    def tearDown(self):
-        if self._fs is not None:
-            self._fs.close()
-        super().tearDown()
-
-    @property
-    def fs(self):
-        if self._fs is None:
-            self._fs = FSFixture(self)
-        return self._fs
-
     def test_success(self):
+        interp = interpreters.create()
         script, file = _captured_script('print("it worked!", end="")')
         with file:
-            self.interp.run(script)
+            interp.run(script)
             out = file.read()
 
         self.assertEqual(out, 'it worked!')
 
     def test_in_thread(self):
+        interp = interpreters.create()
         script, file = _captured_script('print("it worked!", end="")')
         with file:
             def f():
-                self.interp.run(script)
+                interp.run(script)
 
             t = threading.Thread(target=f)
             t.start()
@@ -357,6 +407,7 @@ def f():
 
     @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
     def test_fork(self):
+        interp = interpreters.create()
         import tempfile
         with tempfile.NamedTemporaryFile('w+') as file:
             file.write('')
@@ -371,24 +422,39 @@ def test_fork(self):
                     with open('{file.name}', 'w') as out:
                         out.write('{expected}')
                 """)
-            self.interp.run(script)
+            interp.run(script)
 
             file.seek(0)
             content = file.read()
             self.assertEqual(content, expected)
 
     def test_already_running(self):
-        with _running(self.interp):
+        interp = interpreters.create()
+        with _running(interp):
             with self.assertRaises(RuntimeError):
-                self.interp.run('print("spam")')
+                interp.run('print("spam")')
+
+    def test_does_not_exist(self):
+        interp = interpreters.Interpreter(1_000_000)
+        with self.assertRaises(RuntimeError):
+            interp.run('print("spam")')
+
+    def test_bad_id(self):
+        interp = interpreters.Interpreter(-1)
+        with self.assertRaises(ValueError):
+            interp.run('print("spam")')
 
     def test_bad_script(self):
+        interp = interpreters.create()
         with self.assertRaises(TypeError):
-            self.interp.run(10)
+            interp.run(10)
 
     def test_bytes_for_script(self):
+        interp = interpreters.create()
         with self.assertRaises(TypeError):
-            self.interp.run(b'print("spam")')
+            interp.run(b'print("spam")')
+
+    # test_xxsubinterpreters covers the remaining Interpreter.run() behavior.
 
 
 class TestIsShareable(TestBase):
@@ -405,8 +471,8 @@ def test_default_shareables(self):
                 ]
         for obj in shareables:
             with self.subTest(obj):
-                self.assertTrue(
-                    interpreters.is_shareable(obj))
+                shareable = interpreters.is_shareable(obj)
+                self.assertTrue(shareable)
 
     def test_not_shareable(self):
         class Cheese:
@@ -441,22 +507,71 @@ class SubBytes(bytes):
                     interpreters.is_shareable(obj))
 
 
-class TestChannel(TestBase):
+class TestChannels(TestBase):
 
-    def test_create_cid(self):
+    def test_create(self):
         r, s = interpreters.create_channel()
         self.assertIsInstance(r, interpreters.RecvChannel)
         self.assertIsInstance(s, interpreters.SendChannel)
 
-    def test_sequential_ids(self):
-        before = interpreters.list_all_channels()
-        channels1 = interpreters.create_channel()
-        channels2 = interpreters.create_channel()
-        channels3 = interpreters.create_channel()
-        after = interpreters.list_all_channels()
+    def test_list_all(self):
+        self.assertEqual(interpreters.list_all_channels(), [])
+        created = set()
+        for _ in range(3):
+            ch = interpreters.create_channel()
+            created.add(ch)
+        after = set(interpreters.list_all_channels())
+        self.assertEqual(after, created)
+
+
+class TestRecvChannelAttrs(TestBase):
+
+    def test_id_type(self):
+        rch, _ = interpreters.create_channel()
+        self.assertIsInstance(rch.id, _interpreters.ChannelID)
+
+    def test_custom_id(self):
+        rch = interpreters.RecvChannel(1)
+        self.assertEqual(rch.id, 1)
+
+        with self.assertRaises(TypeError):
+            interpreters.RecvChannel('1')
+
+    def test_id_readonly(self):
+        rch = interpreters.RecvChannel(1)
+        with self.assertRaises(AttributeError):
+            rch.id = 2
+
+    def test_equality(self):
+        ch1, _ = interpreters.create_channel()
+        ch2, _ = interpreters.create_channel()
+        self.assertEqual(ch1, ch1)
+        self.assertNotEqual(ch1, ch2)
+
+
+class TestSendChannelAttrs(TestBase):
+
+    def test_id_type(self):
+        _, sch = interpreters.create_channel()
+        self.assertIsInstance(sch.id, _interpreters.ChannelID)
 
-        self.assertEqual(len(set(after) - set(before)),
-                         len({channels1, channels2, channels3}))
+    def test_custom_id(self):
+        sch = interpreters.SendChannel(1)
+        self.assertEqual(sch.id, 1)
+
+        with self.assertRaises(TypeError):
+            interpreters.SendChannel('1')
+
+    def test_id_readonly(self):
+        sch = interpreters.SendChannel(1)
+        with self.assertRaises(AttributeError):
+            sch.id = 2
+
+    def test_equality(self):
+        _, ch1 = interpreters.create_channel()
+        _, ch2 = interpreters.create_channel()
+        self.assertEqual(ch1, ch1)
+        self.assertNotEqual(ch1, ch2)
 
 
 class TestSendRecv(TestBase):
@@ -464,7 +579,7 @@ class TestSendRecv(TestBase):
     def test_send_recv_main(self):
         r, s = interpreters.create_channel()
         orig = b'spam'
-        s.send(orig)
+        s.send_nowait(orig)
         obj = r.recv()
 
         self.assertEqual(obj, orig)
@@ -472,16 +587,40 @@ def test_send_recv_main(self):
 
     def test_send_recv_same_interpreter(self):
         interp = interpreters.create()
-        out = _run_output(interp, dedent("""
+        interp.run(dedent("""
             from test.support import interpreters
             r, s = interpreters.create_channel()
             orig = b'spam'
-            s.send(orig)
+            s.send_nowait(orig)
             obj = r.recv()
-            assert obj is not orig
-            assert obj == orig
+            assert obj == orig, 'expected: obj == orig'
+            assert obj is not orig, 'expected: obj is not orig'
             """))
 
+    @unittest.skip('broken (see BPO-...)')
+    def test_send_recv_different_interpreters(self):
+        r1, s1 = interpreters.create_channel()
+        r2, s2 = interpreters.create_channel()
+        orig1 = b'spam'
+        s1.send_nowait(orig1)
+        out = _run_output(
+            interpreters.create(),
+            dedent(f"""
+                obj1 = r.recv()
+                assert obj1 == b'spam', 'expected: obj1 == orig1'
+                # When going to another interpreter we get a copy.
+                assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1'
+                orig2 = b'eggs'
+                print(id(orig2))
+                s.send_nowait(orig2)
+                """),
+            channels=dict(r=r1, s=s2),
+            )
+        obj2 = r2.recv()
+
+        self.assertEqual(obj2, b'eggs')
+        self.assertNotEqual(id(obj2), int(out))
+
     def test_send_recv_different_threads(self):
         r, s = interpreters.create_channel()
 
@@ -496,40 +635,108 @@ def f():
         t = threading.Thread(target=f)
         t.start()
 
-        s.send(b'spam')
+        orig = b'spam'
+        s.send(orig)
         t.join()
         obj = r.recv()
 
-        self.assertEqual(obj, b'spam')
+        self.assertEqual(obj, orig)
+        self.assertIsNot(obj, orig)
 
     def test_send_recv_nowait_main(self):
         r, s = interpreters.create_channel()
         orig = b'spam'
-        s.send(orig)
+        s.send_nowait(orig)
         obj = r.recv_nowait()
 
         self.assertEqual(obj, orig)
         self.assertIsNot(obj, orig)
 
+    def test_send_recv_nowait_main_with_default(self):
+        r, _ = interpreters.create_channel()
+        obj = r.recv_nowait(None)
+
+        self.assertIsNone(obj)
+
     def test_send_recv_nowait_same_interpreter(self):
         interp = interpreters.create()
-        out = _run_output(interp, dedent("""
+        interp.run(dedent("""
             from test.support import interpreters
             r, s = interpreters.create_channel()
             orig = b'spam'
-            s.send(orig)
+            s.send_nowait(orig)
             obj = r.recv_nowait()
-            assert obj is not orig
-            assert obj == orig
+            assert obj == orig, 'expected: obj == orig'
+            # When going back to the same interpreter we get the same object.
+            assert obj is not orig, 'expected: obj is not orig'
             """))
 
-        r, s = interpreters.create_channel()
-
-        def f():
-            while True:
-                try:
-                    obj = r.recv_nowait()
-                    break
-                except _interpreters.ChannelEmptyError:
-                    time.sleep(0.1)
-            s.send(obj)
+    @unittest.skip('broken (see BPO-...)')
+    def test_send_recv_nowait_different_interpreters(self):
+        r1, s1 = interpreters.create_channel()
+        r2, s2 = interpreters.create_channel()
+        orig1 = b'spam'
+        s1.send_nowait(orig1)
+        out = _run_output(
+            interpreters.create(),
+            dedent(f"""
+                obj1 = r.recv_nowait()
+                assert obj1 == b'spam', 'expected: obj1 == orig1'
+                # When going to another interpreter we get a copy.
+                assert id(obj1) != {id(orig1)}, 'expected: obj1 is not orig1'
+                orig2 = b'eggs'
+                print(id(orig2))
+                s.send_nowait(orig2)
+                """),
+            channels=dict(r=r1, s=s2),
+            )
+        obj2 = r2.recv_nowait()
+
+        self.assertEqual(obj2, b'eggs')
+        self.assertNotEqual(id(obj2), int(out))
+
+    def test_recv_channel_does_not_exist(self):
+        ch = interpreters.RecvChannel(1_000_000)
+        with self.assertRaises(interpreters.ChannelNotFoundError):
+            ch.recv()
+
+    def test_send_channel_does_not_exist(self):
+        ch = interpreters.SendChannel(1_000_000)
+        with self.assertRaises(interpreters.ChannelNotFoundError):
+            ch.send(b'spam')
+
+    def test_recv_nowait_channel_does_not_exist(self):
+        ch = interpreters.RecvChannel(1_000_000)
+        with self.assertRaises(interpreters.ChannelNotFoundError):
+            ch.recv_nowait()
+
+    def test_send_nowait_channel_does_not_exist(self):
+        ch = interpreters.SendChannel(1_000_000)
+        with self.assertRaises(interpreters.ChannelNotFoundError):
+            ch.send_nowait(b'spam')
+
+    def test_recv_nowait_empty(self):
+        ch, _ = interpreters.create_channel()
+        with self.assertRaises(interpreters.ChannelEmptyError):
+            ch.recv_nowait()
+
+    def test_recv_nowait_default(self):
+        default = object()
+        rch, sch = interpreters.create_channel()
+        obj1 = rch.recv_nowait(default)
+        sch.send_nowait(None)
+        sch.send_nowait(1)
+        sch.send_nowait(b'spam')
+        sch.send_nowait(b'eggs')
+        obj2 = rch.recv_nowait(default)
+        obj3 = rch.recv_nowait(default)
+        obj4 = rch.recv_nowait()
+        obj5 = rch.recv_nowait(default)
+        obj6 = rch.recv_nowait(default)
+
+        self.assertIs(obj1, default)
+        self.assertIs(obj2, None)
+        self.assertEqual(obj3, 1)
+        self.assertEqual(obj4, b'spam')
+        self.assertEqual(obj5, b'eggs')
+        self.assertIs(obj6, default)



More information about the Python-checkins mailing list