[Python-checkins] bpo-23749: Implement loop.start_tls() (#5039)

Yury Selivanov webhook-mailer at python.org
Sat Dec 30 00:35:39 EST 2017


https://github.com/python/cpython/commit/f111b3dcb414093a4efb9d74b69925e535ddc470
commit: f111b3dcb414093a4efb9d74b69925e535ddc470
branch: master
author: Yury Selivanov <yury at magic.io>
committer: GitHub <noreply at github.com>
date: 2017-12-30T00:35:36-05:00
summary:

bpo-23749: Implement loop.start_tls() (#5039)

files:
A Lib/test/test_asyncio/functional.py
A Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst
M Doc/library/asyncio-eventloop.rst
M Lib/asyncio/base_events.py
M Lib/asyncio/events.py
M Lib/asyncio/proactor_events.py
M Lib/asyncio/selector_events.py
M Lib/test/test_asyncio/test_events.py
M Lib/test/test_asyncio/test_sslproto.py
M Lib/test/test_asyncio/utils.py

diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst
index 5dd258df312..33b86d6f033 100644
--- a/Doc/library/asyncio-eventloop.rst
+++ b/Doc/library/asyncio-eventloop.rst
@@ -537,6 +537,38 @@ Creating listening connections
    .. versionadded:: 3.5.3
 
 
+TLS Upgrade
+-----------
+
+.. coroutinemethod:: AbstractEventLoop.start_tls(transport, protocol, sslcontext, \*, server_side=False, server_hostname=None, ssl_handshake_timeout=None)
+
+   Upgrades an existing connection to TLS.
+
+   Returns a new transport instance, that the *protocol* must start using
+   immediately after the *await*.  The *transport* instance passed to
+   the *start_tls* method should never be used again.
+
+   Parameters:
+
+   * *transport* and *protocol* instances that methods like
+     :meth:`~AbstractEventLoop.create_server` and
+     :meth:`~AbstractEventLoop.create_connection` return.
+
+   * *sslcontext*: a configured instance of :class:`~ssl.SSLContext`.
+
+   * *server_side* pass ``True`` when a server-side connection is being
+     upgraded (like the one created by :meth:`~AbstractEventLoop.create_server`).
+
+   * *server_hostname*: sets or overrides the host name that the target
+     server's certificate will be matched against.
+
+   * *ssl_handshake_timeout* is (for an SSL connection) the time in seconds to
+     wait for the SSL handshake to complete before aborting the connection.
+     ``10.0`` seconds if ``None`` (default).
+
+   .. versionadded:: 3.7
+
+
 Watch file descriptors
 ----------------------
 
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 96cc4f02588..00831b39853 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -29,9 +29,15 @@
 import warnings
 import weakref
 
+try:
+    import ssl
+except ImportError:  # pragma: no cover
+    ssl = None
+
 from . import coroutines
 from . import events
 from . import futures
+from . import sslproto
 from . import tasks
 from .log import logger
 
@@ -279,7 +285,8 @@ def _make_ssl_transport(
             self, rawsock, protocol, sslcontext, waiter=None,
             *, server_side=False, server_hostname=None,
             extra=None, server=None,
-            ssl_handshake_timeout=None):
+            ssl_handshake_timeout=None,
+            call_connection_made=True):
         """Create SSL transport."""
         raise NotImplementedError
 
@@ -795,6 +802,42 @@ def _getaddrinfo_debug(self, host, port, family, type, proto, flags):
 
         return transport, protocol
 
+    async def start_tls(self, transport, protocol, sslcontext, *,
+                        server_side=False,
+                        server_hostname=None,
+                        ssl_handshake_timeout=None):
+        """Upgrade transport to TLS.
+
+        Return a new transport that *protocol* should start using
+        immediately.
+        """
+        if ssl is None:
+            raise RuntimeError('Python ssl module is not available')
+
+        if not isinstance(sslcontext, ssl.SSLContext):
+            raise TypeError(
+                f'sslcontext is expected to be an instance of ssl.SSLContext, '
+                f'got {sslcontext!r}')
+
+        if not getattr(transport, '_start_tls_compatible', False):
+            raise TypeError(
+                f'transport {self!r} is not supported by start_tls()')
+
+        waiter = self.create_future()
+        ssl_protocol = sslproto.SSLProtocol(
+            self, protocol, sslcontext, waiter,
+            server_side, server_hostname,
+            ssl_handshake_timeout=ssl_handshake_timeout,
+            call_connection_made=False)
+
+        transport.set_protocol(ssl_protocol)
+        self.call_soon(ssl_protocol.connection_made, transport)
+        if not transport.is_reading():
+            self.call_soon(transport.resume_reading)
+
+        await waiter
+        return ssl_protocol._app_transport
+
     async def create_datagram_endpoint(self, protocol_factory,
                                        local_addr=None, remote_addr=None, *,
                                        family=0, proto=0, flags=0,
diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py
index 3a5dbadbb10..9496d5c765f 100644
--- a/Lib/asyncio/events.py
+++ b/Lib/asyncio/events.py
@@ -305,6 +305,17 @@ def set_default_executor(self, executor):
         """
         raise NotImplementedError
 
+    async def start_tls(self, transport, protocol, sslcontext, *,
+                        server_side=False,
+                        server_hostname=None,
+                        ssl_handshake_timeout=None):
+        """Upgrade a transport to TLS.
+
+        Return a new transport that *protocol* should start using
+        immediately.
+        """
+        raise NotImplementedError
+
     async def create_unix_connection(
             self, protocol_factory, path=None, *,
             ssl=None, sock=None,
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index 2661cddef73..ab1285b7999 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -223,6 +223,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
                                       transports.WriteTransport):
     """Transport for write pipes."""
 
+    _start_tls_compatible = True
+
     def write(self, data):
         if not isinstance(data, (bytes, bytearray, memoryview)):
             raise TypeError(
diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py
index 1e4bd83a1b1..5692e38486a 100644
--- a/Lib/asyncio/selector_events.py
+++ b/Lib/asyncio/selector_events.py
@@ -694,6 +694,8 @@ def get_write_buffer_size(self):
 
 class _SelectorSocketTransport(_SelectorTransport):
 
+    _start_tls_compatible = True
+
     def __init__(self, loop, sock, protocol, waiter=None,
                  extra=None, server=None):
         super().__init__(loop, sock, protocol, extra, server)
diff --git a/Lib/test/test_asyncio/functional.py b/Lib/test/test_asyncio/functional.py
new file mode 100644
index 00000000000..5fd174b6f43
--- /dev/null
+++ b/Lib/test/test_asyncio/functional.py
@@ -0,0 +1,279 @@
+import asyncio
+import asyncio.events
+import contextlib
+import os
+import pprint
+import select
+import socket
+import ssl
+import tempfile
+import threading
+
+
+class FunctionalTestCaseMixin:
+
+    def new_loop(self):
+        return asyncio.new_event_loop()
+
+    def run_loop_briefly(self, *, delay=0.01):
+        self.loop.run_until_complete(asyncio.sleep(delay, loop=self.loop))
+
+    def loop_exception_handler(self, loop, context):
+        self.__unhandled_exceptions.append(context)
+        self.loop.default_exception_handler(context)
+
+    def setUp(self):
+        self.loop = self.new_loop()
+        asyncio.set_event_loop(None)
+
+        self.loop.set_exception_handler(self.loop_exception_handler)
+        self.__unhandled_exceptions = []
+
+        # Disable `_get_running_loop`.
+        self._old_get_running_loop = asyncio.events._get_running_loop
+        asyncio.events._get_running_loop = lambda: None
+
+    def tearDown(self):
+        try:
+            self.loop.close()
+
+            if self.__unhandled_exceptions:
+                print('Unexpected calls to loop.call_exception_handler():')
+                pprint.pprint(self.__unhandled_exceptions)
+                self.fail('unexpected calls to loop.call_exception_handler()')
+
+        finally:
+            asyncio.events._get_running_loop = self._old_get_running_loop
+            asyncio.set_event_loop(None)
+            self.loop = None
+
+    def tcp_server(self, server_prog, *,
+                   family=socket.AF_INET,
+                   addr=None,
+                   timeout=5,
+                   backlog=1,
+                   max_clients=10):
+
+        if addr is None:
+            if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
+                with tempfile.NamedTemporaryFile() as tmp:
+                    addr = tmp.name
+            else:
+                addr = ('127.0.0.1', 0)
+
+        sock = socket.socket(family, socket.SOCK_STREAM)
+
+        if timeout is None:
+            raise RuntimeError('timeout is required')
+        if timeout <= 0:
+            raise RuntimeError('only blocking sockets are supported')
+        sock.settimeout(timeout)
+
+        try:
+            sock.bind(addr)
+            sock.listen(backlog)
+        except OSError as ex:
+            sock.close()
+            raise ex
+
+        return TestThreadedServer(
+            self, sock, server_prog, timeout, max_clients)
+
+    def tcp_client(self, client_prog,
+                   family=socket.AF_INET,
+                   timeout=10):
+
+        sock = socket.socket(family, socket.SOCK_STREAM)
+
+        if timeout is None:
+            raise RuntimeError('timeout is required')
+        if timeout <= 0:
+            raise RuntimeError('only blocking sockets are supported')
+        sock.settimeout(timeout)
+
+        return TestThreadedClient(
+            self, sock, client_prog, timeout)
+
+    def unix_server(self, *args, **kwargs):
+        if not hasattr(socket, 'AF_UNIX'):
+            raise NotImplementedError
+        return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
+
+    def unix_client(self, *args, **kwargs):
+        if not hasattr(socket, 'AF_UNIX'):
+            raise NotImplementedError
+        return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
+
+    @contextlib.contextmanager
+    def unix_sock_name(self):
+        with tempfile.TemporaryDirectory() as td:
+            fn = os.path.join(td, 'sock')
+            try:
+                yield fn
+            finally:
+                try:
+                    os.unlink(fn)
+                except OSError:
+                    pass
+
+    def _abort_socket_test(self, ex):
+        try:
+            self.loop.stop()
+        finally:
+            self.fail(ex)
+
+
+##############################################################################
+# Socket Testing Utilities
+##############################################################################
+
+
+class TestSocketWrapper:
+
+    def __init__(self, sock):
+        self.__sock = sock
+
+    def recv_all(self, n):
+        buf = b''
+        while len(buf) < n:
+            data = self.recv(n - len(buf))
+            if data == b'':
+                raise ConnectionAbortedError
+            buf += data
+        return buf
+
+    def start_tls(self, ssl_context, *,
+                  server_side=False,
+                  server_hostname=None):
+
+        assert isinstance(ssl_context, ssl.SSLContext)
+
+        ssl_sock = ssl_context.wrap_socket(
+            self.__sock, server_side=server_side,
+            server_hostname=server_hostname,
+            do_handshake_on_connect=False)
+
+        ssl_sock.do_handshake()
+
+        self.__sock.close()
+        self.__sock = ssl_sock
+
+    def __getattr__(self, name):
+        return getattr(self.__sock, name)
+
+    def __repr__(self):
+        return '<{} {!r}>'.format(type(self).__name__, self.__sock)
+
+
+class SocketThread(threading.Thread):
+
+    def stop(self):
+        self._active = False
+        self.join()
+
+    def __enter__(self):
+        self.start()
+        return self
+
+    def __exit__(self, *exc):
+        self.stop()
+
+
+class TestThreadedClient(SocketThread):
+
+    def __init__(self, test, sock, prog, timeout):
+        threading.Thread.__init__(self, None, None, 'test-client')
+        self.daemon = True
+
+        self._timeout = timeout
+        self._sock = sock
+        self._active = True
+        self._prog = prog
+        self._test = test
+
+    def run(self):
+        try:
+            self._prog(TestSocketWrapper(self._sock))
+        except Exception as ex:
+            self._test._abort_socket_test(ex)
+
+
+class TestThreadedServer(SocketThread):
+
+    def __init__(self, test, sock, prog, timeout, max_clients):
+        threading.Thread.__init__(self, None, None, 'test-server')
+        self.daemon = True
+
+        self._clients = 0
+        self._finished_clients = 0
+        self._max_clients = max_clients
+        self._timeout = timeout
+        self._sock = sock
+        self._active = True
+
+        self._prog = prog
+
+        self._s1, self._s2 = socket.socketpair()
+        self._s1.setblocking(False)
+
+        self._test = test
+
+    def stop(self):
+        try:
+            if self._s2 and self._s2.fileno() != -1:
+                try:
+                    self._s2.send(b'stop')
+                except OSError:
+                    pass
+        finally:
+            super().stop()
+
+    def run(self):
+        try:
+            with self._sock:
+                self._sock.setblocking(0)
+                self._run()
+        finally:
+            self._s1.close()
+            self._s2.close()
+
+    def _run(self):
+        while self._active:
+            if self._clients >= self._max_clients:
+                return
+
+            r, w, x = select.select(
+                [self._sock, self._s1], [], [], self._timeout)
+
+            if self._s1 in r:
+                return
+
+            if self._sock in r:
+                try:
+                    conn, addr = self._sock.accept()
+                except BlockingIOError:
+                    continue
+                except socket.timeout:
+                    if not self._active:
+                        return
+                    else:
+                        raise
+                else:
+                    self._clients += 1
+                    conn.settimeout(self._timeout)
+                    try:
+                        with conn:
+                            self._handle_client(conn)
+                    except Exception as ex:
+                        self._active = False
+                        try:
+                            raise
+                        finally:
+                            self._test._abort_socket_test(ex)
+
+    def _handle_client(self, sock):
+        self._prog(TestSocketWrapper(sock))
+
+    @property
+    def addr(self):
+        return self._sock.getsockname()
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index 79e8d79e6b1..da2e036648b 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -31,21 +31,7 @@
 from asyncio import proactor_events
 from asyncio import selector_events
 from test.test_asyncio import utils as test_utils
-try:
-    from test import support
-except ImportError:
-    from asyncio import test_support as support
-
-
-def data_file(filename):
-    if hasattr(support, 'TEST_HOME_DIR'):
-        fullname = os.path.join(support.TEST_HOME_DIR, filename)
-        if os.path.isfile(fullname):
-            return fullname
-    fullname = os.path.join(os.path.dirname(__file__), filename)
-    if os.path.isfile(fullname):
-        return fullname
-    raise FileNotFoundError(filename)
+from test import support
 
 
 def osx_tiger():
@@ -80,23 +66,6 @@ def __await__(self):
         pass
 
 
-ONLYCERT = data_file('ssl_cert.pem')
-ONLYKEY = data_file('ssl_key.pem')
-SIGNED_CERTFILE = data_file('keycert3.pem')
-SIGNING_CA = data_file('pycacert.pem')
-PEERCERT = {'serialNumber': 'B09264B1F2DA21D1',
-            'version': 1,
-            'subject': ((('countryName', 'XY'),),
-                    (('localityName', 'Castle Anthrax'),),
-                    (('organizationName', 'Python Software Foundation'),),
-                    (('commonName', 'localhost'),)),
-            'issuer': ((('countryName', 'XY'),),
-                    (('organizationName', 'Python Software Foundation CA'),),
-                    (('commonName', 'our-ca-server'),)),
-            'notAfter': 'Nov 13 19:47:07 2022 GMT',
-            'notBefore': 'Jan  4 19:47:07 2013 GMT'}
-
-
 class MyBaseProto(asyncio.Protocol):
     connected = None
     done = None
@@ -853,16 +822,8 @@ def test_ssl_connect_accepted_socket(self):
                 'SSL not supported with proactor event loops before Python 3.5'
                 )
 
-        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
-        server_context.load_cert_chain(ONLYCERT, ONLYKEY)
-        if hasattr(server_context, 'check_hostname'):
-            server_context.check_hostname = False
-        server_context.verify_mode = ssl.CERT_NONE
-
-        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
-        if hasattr(server_context, 'check_hostname'):
-            client_context.check_hostname = False
-        client_context.verify_mode = ssl.CERT_NONE
+        server_context = test_utils.simple_server_sslcontext()
+        client_context = test_utils.simple_client_sslcontext()
 
         self.test_connect_accepted_socket(server_context, client_context)
 
@@ -1048,7 +1009,7 @@ def _make_ssl_unix_server(self, factory, certfile, keyfile=None):
     def test_create_server_ssl(self):
         proto = MyProto(loop=self.loop)
         server, host, port = self._make_ssl_server(
-            lambda: proto, ONLYCERT, ONLYKEY)
+            lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY)
 
         f_c = self.loop.create_connection(MyBaseProto, host, port,
                                           ssl=test_utils.dummy_ssl_context())
@@ -1081,7 +1042,7 @@ def test_create_server_ssl(self):
     def test_create_unix_server_ssl(self):
         proto = MyProto(loop=self.loop)
         server, path = self._make_ssl_unix_server(
-            lambda: proto, ONLYCERT, ONLYKEY)
+            lambda: proto, test_utils.ONLYCERT, test_utils.ONLYKEY)
 
         f_c = self.loop.create_unix_connection(
             MyBaseProto, path, ssl=test_utils.dummy_ssl_context(),
@@ -1111,7 +1072,7 @@ def test_create_unix_server_ssl(self):
     def test_create_server_ssl_verify_failed(self):
         proto = MyProto(loop=self.loop)
         server, host, port = self._make_ssl_server(
-            lambda: proto, SIGNED_CERTFILE)
+            lambda: proto, test_utils.SIGNED_CERTFILE)
 
         sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         sslcontext_client.options |= ssl.OP_NO_SSLv2
@@ -1141,7 +1102,7 @@ def test_create_server_ssl_verify_failed(self):
     def test_create_unix_server_ssl_verify_failed(self):
         proto = MyProto(loop=self.loop)
         server, path = self._make_ssl_unix_server(
-            lambda: proto, SIGNED_CERTFILE)
+            lambda: proto, test_utils.SIGNED_CERTFILE)
 
         sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         sslcontext_client.options |= ssl.OP_NO_SSLv2
@@ -1170,13 +1131,13 @@ def test_create_unix_server_ssl_verify_failed(self):
     def test_create_server_ssl_match_failed(self):
         proto = MyProto(loop=self.loop)
         server, host, port = self._make_ssl_server(
-            lambda: proto, SIGNED_CERTFILE)
+            lambda: proto, test_utils.SIGNED_CERTFILE)
 
         sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         sslcontext_client.options |= ssl.OP_NO_SSLv2
         sslcontext_client.verify_mode = ssl.CERT_REQUIRED
         sslcontext_client.load_verify_locations(
-            cafile=SIGNING_CA)
+            cafile=test_utils.SIGNING_CA)
         if hasattr(sslcontext_client, 'check_hostname'):
             sslcontext_client.check_hostname = True
 
@@ -1199,12 +1160,12 @@ def test_create_server_ssl_match_failed(self):
     def test_create_unix_server_ssl_verified(self):
         proto = MyProto(loop=self.loop)
         server, path = self._make_ssl_unix_server(
-            lambda: proto, SIGNED_CERTFILE)
+            lambda: proto, test_utils.SIGNED_CERTFILE)
 
         sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         sslcontext_client.options |= ssl.OP_NO_SSLv2
         sslcontext_client.verify_mode = ssl.CERT_REQUIRED
-        sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
+        sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA)
         if hasattr(sslcontext_client, 'check_hostname'):
             sslcontext_client.check_hostname = True
 
@@ -1224,12 +1185,12 @@ def test_create_unix_server_ssl_verified(self):
     def test_create_server_ssl_verified(self):
         proto = MyProto(loop=self.loop)
         server, host, port = self._make_ssl_server(
-            lambda: proto, SIGNED_CERTFILE)
+            lambda: proto, test_utils.SIGNED_CERTFILE)
 
         sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         sslcontext_client.options |= ssl.OP_NO_SSLv2
         sslcontext_client.verify_mode = ssl.CERT_REQUIRED
-        sslcontext_client.load_verify_locations(cafile=SIGNING_CA)
+        sslcontext_client.load_verify_locations(cafile=test_utils.SIGNING_CA)
         if hasattr(sslcontext_client, 'check_hostname'):
             sslcontext_client.check_hostname = True
 
@@ -1241,7 +1202,7 @@ def test_create_server_ssl_verified(self):
 
         # extra info is available
         self.check_ssl_extra_info(client,peername=(host, port),
-                                  peercert=PEERCERT)
+                                  peercert=test_utils.PEERCERT)
 
         # close connection
         proto.transport.close()
diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py
index a7498e85c25..886c5cf3626 100644
--- a/Lib/test/test_asyncio/test_sslproto.py
+++ b/Lib/test/test_asyncio/test_sslproto.py
@@ -13,6 +13,7 @@
 from asyncio import sslproto
 from asyncio import tasks
 from test.test_asyncio import utils as test_utils
+from test.test_asyncio import functional as func_tests
 
 
 @unittest.skipIf(ssl is None, 'No ssl module')
@@ -158,5 +159,156 @@ def test_set_new_app_protocol(self):
         self.assertIs(ssl_proto._app_protocol, new_app_proto)
 
 
+##############################################################################
+# Start TLS Tests
+##############################################################################
+
+
+class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
+
+    def new_loop(self):
+        raise NotImplementedError
+
+    def test_start_tls_client_1(self):
+        HELLO_MSG = b'1' * 1024 * 1024 * 5
+
+        server_context = test_utils.simple_server_sslcontext()
+        client_context = test_utils.simple_client_sslcontext()
+
+        def serve(sock):
+            data = sock.recv_all(len(HELLO_MSG))
+            self.assertEqual(len(data), len(HELLO_MSG))
+
+            sock.start_tls(server_context, server_side=True)
+
+            sock.sendall(b'O')
+            data = sock.recv_all(len(HELLO_MSG))
+            self.assertEqual(len(data), len(HELLO_MSG))
+            sock.close()
+
+        class ClientProto(asyncio.Protocol):
+            def __init__(self, on_data, on_eof):
+                self.on_data = on_data
+                self.on_eof = on_eof
+                self.con_made_cnt = 0
+
+            def connection_made(proto, tr):
+                proto.con_made_cnt += 1
+                # Ensure connection_made gets called only once.
+                self.assertEqual(proto.con_made_cnt, 1)
+
+            def data_received(self, data):
+                self.on_data.set_result(data)
+
+            def eof_received(self):
+                self.on_eof.set_result(True)
+
+        async def client(addr):
+            on_data = self.loop.create_future()
+            on_eof = self.loop.create_future()
+
+            tr, proto = await self.loop.create_connection(
+                lambda: ClientProto(on_data, on_eof), *addr)
+
+            tr.write(HELLO_MSG)
+            new_tr = await self.loop.start_tls(tr, proto, client_context)
+
+            self.assertEqual(await on_data, b'O')
+            new_tr.write(HELLO_MSG)
+            await on_eof
+
+            new_tr.close()
+
+        with self.tcp_server(serve) as srv:
+            self.loop.run_until_complete(
+                asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
+
+    def test_start_tls_server_1(self):
+        HELLO_MSG = b'1' * 1024 * 1024 * 5
+
+        server_context = test_utils.simple_server_sslcontext()
+        client_context = test_utils.simple_client_sslcontext()
+
+        def client(sock, addr):
+            sock.connect(addr)
+            data = sock.recv_all(len(HELLO_MSG))
+            self.assertEqual(len(data), len(HELLO_MSG))
+
+            sock.start_tls(client_context)
+            sock.sendall(HELLO_MSG)
+            sock.close()
+
+        class ServerProto(asyncio.Protocol):
+            def __init__(self, on_con, on_eof):
+                self.on_con = on_con
+                self.on_eof = on_eof
+                self.data = b''
+
+            def connection_made(self, tr):
+                self.on_con.set_result(tr)
+
+            def data_received(self, data):
+                self.data += data
+
+            def eof_received(self):
+                self.on_eof.set_result(1)
+
+        async def main():
+            tr = await on_con
+            tr.write(HELLO_MSG)
+
+            self.assertEqual(proto.data, b'')
+
+            new_tr = await self.loop.start_tls(
+                tr, proto, server_context,
+                server_side=True)
+
+            await on_eof
+            self.assertEqual(proto.data, HELLO_MSG)
+            new_tr.close()
+
+            server.close()
+            await server.wait_closed()
+
+        on_con = self.loop.create_future()
+        on_eof = self.loop.create_future()
+        proto = ServerProto(on_con, on_eof)
+
+        server = self.loop.run_until_complete(
+            self.loop.create_server(
+                lambda: proto, '127.0.0.1', 0))
+        addr = server.sockets[0].getsockname()
+
+        with self.tcp_client(lambda sock: client(sock, addr)):
+            self.loop.run_until_complete(
+                asyncio.wait_for(main(), loop=self.loop, timeout=10))
+
+    def test_start_tls_wrong_args(self):
+        async def main():
+            with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
+                await self.loop.start_tls(None, None, None)
+
+            sslctx = test_utils.simple_server_sslcontext()
+            with self.assertRaisesRegex(TypeError, 'is not supported'):
+                await self.loop.start_tls(None, None, sslctx)
+
+        self.loop.run_until_complete(main())
+
+
+ at unittest.skipIf(ssl is None, 'No ssl module')
+class SelectorStartTLS(BaseStartTLS, unittest.TestCase):
+
+    def new_loop(self):
+        return asyncio.SelectorEventLoop()
+
+
+ at unittest.skipIf(ssl is None, 'No ssl module')
+ at unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
+class ProactorStartTLS(BaseStartTLS, unittest.TestCase):
+
+    def new_loop(self):
+        return asyncio.ProactorEventLoop()
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/Lib/test/test_asyncio/utils.py b/Lib/test/test_asyncio/utils.py
index eaafe3af8b8..a78e01988d8 100644
--- a/Lib/test/test_asyncio/utils.py
+++ b/Lib/test/test_asyncio/utils.py
@@ -35,6 +35,49 @@
 from test import support
 
 
+def data_file(filename):
+    if hasattr(support, 'TEST_HOME_DIR'):
+        fullname = os.path.join(support.TEST_HOME_DIR, filename)
+        if os.path.isfile(fullname):
+            return fullname
+    fullname = os.path.join(os.path.dirname(__file__), filename)
+    if os.path.isfile(fullname):
+        return fullname
+    raise FileNotFoundError(filename)
+
+
+ONLYCERT = data_file('ssl_cert.pem')
+ONLYKEY = data_file('ssl_key.pem')
+SIGNED_CERTFILE = data_file('keycert3.pem')
+SIGNING_CA = data_file('pycacert.pem')
+PEERCERT = {'serialNumber': 'B09264B1F2DA21D1',
+            'version': 1,
+            'subject': ((('countryName', 'XY'),),
+                    (('localityName', 'Castle Anthrax'),),
+                    (('organizationName', 'Python Software Foundation'),),
+                    (('commonName', 'localhost'),)),
+            'issuer': ((('countryName', 'XY'),),
+                    (('organizationName', 'Python Software Foundation CA'),),
+                    (('commonName', 'our-ca-server'),)),
+            'notAfter': 'Nov 13 19:47:07 2022 GMT',
+            'notBefore': 'Jan  4 19:47:07 2013 GMT'}
+
+
+def simple_server_sslcontext():
+    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+    server_context.load_cert_chain(ONLYCERT, ONLYKEY)
+    server_context.check_hostname = False
+    server_context.verify_mode = ssl.CERT_NONE
+    return server_context
+
+
+def simple_client_sslcontext():
+    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+    client_context.check_hostname = False
+    client_context.verify_mode = ssl.CERT_NONE
+    return client_context
+
+
 def dummy_ssl_context():
     if ssl is None:
         return None
diff --git a/Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst b/Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst
new file mode 100644
index 00000000000..d6de1fef901
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2017-12-29-00-44-42.bpo-23749.QL1Cxd.rst
@@ -0,0 +1 @@
+asyncio: Implement loop.start_tls()



More information about the Python-checkins mailing list