[Python-checkins] cpython (3.5): Issue #27392: Add loop.connect_accepted_socket().

yury.selivanov python-checkins at python.org
Tue Jul 12 18:24:31 EDT 2016


https://hg.python.org/cpython/rev/3e44c449433a
changeset:   102334:3e44c449433a
branch:      3.5
parent:      102331:420030a5e854
user:        Yury Selivanov <yury at magic.io>
date:        Tue Jul 12 18:23:10 2016 -0400
summary:
  Issue #27392: Add loop.connect_accepted_socket().

Patch by Jim Fulton.

files:
  Lib/asyncio/base_events.py           |  28 ++++-
  Lib/test/test_asyncio/test_events.py |  79 ++++++++++++++++
  Misc/NEWS                            |   3 +
  3 files changed, 106 insertions(+), 4 deletions(-)


diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -707,8 +707,6 @@
             raise ValueError(
                 'host and port was not specified and no sock specified')
 
-        sock.setblocking(False)
-
         transport, protocol = yield from self._create_connection_transport(
             sock, protocol_factory, ssl, server_hostname)
         if self._debug:
@@ -721,14 +719,17 @@
 
     @coroutine
     def _create_connection_transport(self, sock, protocol_factory, ssl,
-                                     server_hostname):
+                                     server_hostname, server_side=False):
+
+        sock.setblocking(False)
+
         protocol = protocol_factory()
         waiter = self.create_future()
         if ssl:
             sslcontext = None if isinstance(ssl, bool) else ssl
             transport = self._make_ssl_transport(
                 sock, protocol, sslcontext, waiter,
-                server_side=False, server_hostname=server_hostname)
+                server_side=server_side, server_hostname=server_hostname)
         else:
             transport = self._make_socket_transport(sock, protocol, waiter)
 
@@ -980,6 +981,25 @@
         return server
 
     @coroutine
+    def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None):
+        """Handle an accepted connection.
+
+        This is used by servers that accept connections outside of
+        asyncio but that use asyncio to handle connections.
+
+        This method is a coroutine.  When completed, the coroutine
+        returns a (transport, protocol) pair.
+        """
+        transport, protocol = yield from self._create_connection_transport(
+            sock, protocol_factory, ssl, '', server_side=True)
+        if self._debug:
+            # Get the socket from the transport because SSL transport closes
+            # the old socket and creates a new SSL socket
+            sock = transport.get_extra_info('socket')
+            logger.debug("%r handled: (%r, %r)", sock, transport, protocol)
+        return transport, protocol
+
+    @coroutine
     def connect_read_pipe(self, protocol_factory, pipe):
         protocol = protocol_factory()
         waiter = self.create_future()
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -744,6 +744,85 @@
             self.assertEqual(cm.exception.errno, errno.EADDRINUSE)
             self.assertIn(str(httpd.address), cm.exception.strerror)
 
+    def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
+        loop = self.loop
+
+        class MyProto(MyBaseProto):
+
+            def connection_lost(self, exc):
+                super().connection_lost(exc)
+                loop.call_soon(loop.stop)
+
+            def data_received(self, data):
+                super().data_received(data)
+                self.transport.write(expected_response)
+
+        lsock = socket.socket()
+        lsock.bind(('127.0.0.1', 0))
+        lsock.listen(1)
+        addr = lsock.getsockname()
+
+        message = b'test data'
+        reponse = None
+        expected_response = b'roger'
+
+        def client():
+            global response
+            try:
+                csock = socket.socket()
+                if client_ssl is not None:
+                    csock = client_ssl.wrap_socket(csock)
+                csock.connect(addr)
+                csock.sendall(message)
+                response = csock.recv(99)
+                csock.close()
+            except Exception as exc:
+                print(
+                    "Failure in client thread in test_connect_accepted_socket",
+                    exc)
+
+        thread = threading.Thread(target=client, daemon=True)
+        thread.start()
+
+        conn, _ = lsock.accept()
+        proto = MyProto(loop=loop)
+        proto.loop = loop
+        f = loop.create_task(
+            loop.connect_accepted_socket(
+                (lambda : proto), conn, ssl=server_ssl))
+        loop.run_forever()
+        conn.close()
+        lsock.close()
+
+        thread.join(1)
+        self.assertFalse(thread.is_alive())
+        self.assertEqual(proto.state, 'CLOSED')
+        self.assertEqual(proto.nbytes, len(message))
+        self.assertEqual(response, expected_response)
+
+    @unittest.skipIf(ssl is None, 'No ssl module')
+    def test_ssl_connect_accepted_socket(self):
+        if (sys.platform == 'win32' and
+            sys.version_info < (3, 5) and
+            isinstance(self.loop, proactor_events.BaseProactorEventLoop)
+            ):
+            raise unittest.SkipTest(
+                'SSL not supported with proactor event loops before Python 3.5'
+                )
+
+        server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        server_context.load_cert_chain(ONLYCERT, ONLYKEY)
+        if hasattr(server_context, 'check_hostname'):
+            server_context.check_hostname = False
+        server_context.verify_mode = ssl.CERT_NONE
+
+        client_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+        if hasattr(server_context, 'check_hostname'):
+            client_context.check_hostname = False
+        client_context.verify_mode = ssl.CERT_NONE
+
+        self.test_connect_accepted_socket(server_context, client_context)
+
     @mock.patch('asyncio.base_events.socket')
     def create_server_multiple_hosts(self, family, hosts, mock_sock):
         @asyncio.coroutine
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -72,6 +72,9 @@
 
 - Issue #26930: Update Windows builds to use OpenSSL 1.0.2h.
 
+- Issue #27392: Add loop.connect_accepted_socket().
+  Patch by Jim Fulton.
+
 IDLE
 ----
 

-- 
Repository URL: https://hg.python.org/cpython


More information about the Python-checkins mailing list