[Jython-checkins] jython: Fix bugs in select.poll, threadpool group closing, and SSL handshaking.

jim.baker jython-checkins at python.org
Mon Jul 14 22:19:48 CEST 2014


http://hg.python.org/jython/rev/2c45f75a5406
changeset:   7348:2c45f75a5406
user:        Jim Baker <jim.baker at rackspace.com>
date:        Mon Jul 14 14:19:57 2014 -0600
summary:
  Fix bugs in select.poll, threadpool group closing, and SSL handshaking.
Does not yet support Start TLS for the server side of a SSL connection,
which has been deferred indefinitely.

Merged from https://bitbucket.org/jimbaker/jython-fix-socket-bugs
Fixes http://bugs.jython.org/issue2094,
http://bugs.jython.org/issue2147, http://bugs.jython.org/issue2174

files:
  Lib/_socket.py          |  280 ++++++++++++++++++---------
  Lib/select.py           |    2 +-
  Lib/ssl.py              |  211 ++++++++++++++++----
  Lib/test/test_socket.py |   28 +-
  4 files changed, 369 insertions(+), 152 deletions(-)


diff --git a/Lib/_socket.py b/Lib/_socket.py
--- a/Lib/_socket.py
+++ b/Lib/_socket.py
@@ -3,7 +3,9 @@
 import errno
 import jarray
 import logging
+import numbers
 import pprint
+import struct
 import sys
 import time
 import _google_ipaddr_r234
@@ -15,18 +17,19 @@
 from StringIO import StringIO
 from threading import Condition, Lock
 from types import MethodType, NoneType
+from weakref import WeakKeyDictionary
 
 import java
 from java.io import IOException, InterruptedIOException
-from java.lang import Thread
+from java.lang import Thread, IllegalStateException
 from java.net import InetAddress, InetSocketAddress
 from java.nio.channels import ClosedChannelException
 from java.util import NoSuchElementException
 from java.util.concurrent import (
     ArrayBlockingQueue, CopyOnWriteArrayList, CountDownLatch, LinkedBlockingQueue,
     RejectedExecutionException, ThreadFactory, TimeUnit)
-from java.util.concurrent.atomic import AtomicBoolean
-from javax.net.ssl import SSLPeerUnverifiedException
+from java.util.concurrent.atomic import AtomicBoolean, AtomicLong
+from javax.net.ssl import SSLPeerUnverifiedException, SSLException
 
 try:
     # jarjar-ed version
@@ -53,6 +56,8 @@
     FORMAT = '%(asctime)-15s %(threadName)s %(levelname)s %(funcName)s %(message)s %(sock)s'
     logging.basicConfig(format=FORMAT, level=logging.DEBUG)
 
+# _debug()  # UNCOMMENT to get logging of socket activity
+
 
 # Constants
 ###########
@@ -185,21 +190,27 @@
 # because these threads only handle ephemeral data, such as performing
 # SSL wrap/unwrap.
 
+
 class DaemonThreadFactory(ThreadFactory):
+
+    thread_count = AtomicLong()
+
+    def __init__(self, label):
+        self.label = label
+
     def newThread(self, runnable):
         t = Thread(runnable)
         t.daemon = True
+        t.name = self.label % (self.thread_count.getAndIncrement())
         return t
 
 
-# This number should be configurable by the user. 10 is the default
-# number as of 4.0.17 of Netty. FIXME this default may be based on core count.
+NIO_GROUP = NioEventLoopGroup(10, DaemonThreadFactory("Jython-Netty-Client-%s"))
 
-NIO_GROUP = NioEventLoopGroup(10, DaemonThreadFactory())
 
-def _check_threadpool_for_pending_threads():
+def _check_threadpool_for_pending_threads(group):
     pending_threads = []
-    for t in NIO_GROUP:
+    for t in group:
         pending_count = t.pendingTasks()
         if pending_count > 0:
             pending_threads.append((t, pending_count))
@@ -272,6 +283,7 @@
 
     IOException            : lambda x: error(errno.ECONNRESET, 'Software caused connection abort'),
     InterruptedIOException : lambda x: timeout(None, 'timed out'),
+    IllegalStateException  : lambda x: error(errno.EPIPE, 'Illegal state exception'),
     
     java.net.BindException            : lambda x: error(errno.EADDRINUSE, 'Address already in use'),
     java.net.ConnectException         : lambda x: error(errno.ECONNREFUSED, 'Connection refused'),
@@ -299,7 +311,8 @@
     java.nio.channels.UnresolvedAddressException      : lambda x: gaierror(errno.EGETADDRINFOFAILED, 'getaddrinfo failed'),
     java.nio.channels.UnsupportedAddressTypeException : None,
 
-    SSLPeerUnverifiedException: lambda x: SSLError(SSL_ERROR_SSL, "FIXME"),
+    SSLPeerUnverifiedException: lambda x: SSLError(SSL_ERROR_SSL, x.message),
+    SSLException: lambda x: SSLError(SSL_ERROR_SSL, x.message),
 }
 
 
@@ -391,7 +404,9 @@
                 # shortcircuiting if the socket was in fact ready for
                 # reading/writing/exception before the select call
                 if selected_rlist or selected_wlist:
-                    return sorted(selected_rlist), sorted(selected_wlist), sorted(selected_xlist)
+                    completed = sorted(selected_rlist), sorted(selected_wlist), sorted(selected_xlist)
+                    log.debug("Completed select %s", completed, extra={"sock": "*"})
+                    return completed
                 elif timeout is not None and time.time() - started >= timeout:
                     return [], [], []
                 self.cv.wait(timeout)
@@ -400,7 +415,12 @@
 # poll support
 ##############
 
-_PollNotification = namedtuple("_PollNotification", ["sock", "fd", "exception", "hangup"])
+_PollNotification = namedtuple(
+    "_PollNotification",
+    ["sock",  # the real socket
+     "fd",    # could be the real socket (as returned by fileno) or a wrapping socket object
+     "exception",
+     "hangup"])
 
 
 class poll(object):
@@ -408,30 +428,41 @@
     def __init__(self):
         self.queue = LinkedBlockingQueue()
         self.registered = dict()  # fd -> eventmask
+        self.socks2fd = WeakKeyDictionary()  # sock -> fd
 
     def notify(self, sock, exception=None, hangup=False):
         notification = _PollNotification(
             sock=sock,
-            fd=sock.fileno(),
+            fd=self.socks2fd.get(sock),
             exception=exception,
             hangup=hangup)
+        log.debug("Notify %s", notification, extra={"sock": "*"})
+
         self.queue.put(notification)
 
     def register(self, fd, eventmask=POLLIN|POLLPRI|POLLOUT):
+        if not hasattr(fd, "fileno"):
+            raise TypeError("argument must have a fileno() method")
+        sock = fd.fileno()
+        log.debug("Register fd=%s eventmask=%s", fd, eventmask, extra={"sock": sock})
         self.registered[fd] = eventmask
-        # NOTE in case fd != sock in a future release, modifiy accordingly
-        sock = fd
+        self.socks2fd[sock] = fd
         sock._register_selector(self)
         self.notify(sock)  # Ensure we get an initial notification
 
     def modify(self, fd, eventmask):
+        if not hasattr(fd, "fileno"):
+            raise TypeError("argument must have a fileno() method")
         if fd not in self.registered:
             raise error(errno.ENOENT, "No such file or directory")
         self.registered[fd] = eventmask
 
     def unregister(self, fd):
+        if not hasattr(fd, "fileno"):
+            raise TypeError("argument must have a fileno() method")
+        log.debug("Unregister socket fd=%s", fd, extra={"sock": fd.fileno()})
         del self.registered[fd]
-        sock = fd
+        sock = fd.fileno()
         sock._unregister_selector(self)
 
     def _event_test(self, notification):
@@ -439,7 +470,8 @@
         # edges around errors and hangup
         if notification is None:
             return None, 0
-        mask = self.registered.get(notification.sock, 0)   # handle if concurrently removed, by simply ignoring
+        mask = self.registered.get(notification.fd, 0)   # handle if concurrently removed, by simply ignoring
+        log.debug("Testing notification=%s mask=%s", notification, mask, extra={"sock": "*"}) 
         event = 0
         if mask & POLLIN and notification.sock._readable():
             event |= POLLIN
@@ -451,53 +483,58 @@
             event |= POLLHUP
         if mask & POLLNVAL and not notification.sock.peer_closed:
             event |= POLLNVAL
+        log.debug("Tested notification=%s event=%s", notification, event, extra={"sock": "*"}) 
         return notification.fd, event
 
+    def _handle_poll(self, poller):
+        notification = poller()
+        if notification is None:
+            return []
+            
+        # Pull as many outstanding notifications as possible out
+        # of the queue
+        notifications = [notification]
+        self.queue.drainTo(notifications)
+        log.debug("Got notification(s) %s", notifications, extra={"sock": "MODULE"})
+        result = []
+        socks = set()
+
+        # But given how we notify, it's possible to see possible
+        # multiple notifications. Just return one (fd, event) for a
+        # given socket
+        for notification in notifications:
+            if notification.sock not in socks:
+                fd, event = self._event_test(notification)
+                if event:
+                    result.append((fd, event))
+                    socks.add(notification.sock)
+
+        # Repump sockets to pick up a subsequent level change
+        for sock in socks:
+            self.notify(sock)
+
+        return result
+
     def poll(self, timeout=None):
-        if not timeout or timeout < 0:
-            # Simplify logic around timeout resets
+        if not (timeout is None or isinstance(timeout, numbers.Real)):
+            raise TypeError("timeout must be a number or None, got %r" % (timeout,))
+        if timeout < 0:
             timeout = None
+        log.debug("Polling timeout=%s", timeout, extra={"sock": "*"})
+        if timeout is None:
+            return self._handle_poll(self.queue.take)
+        elif timeout == 0:
+            return self._handle_poll(self.queue.poll)
         else:
-            timeout /= 1000.  # convert from milliseconds to seconds
-
-        while True:
-            if timeout is None:
-                notification = self.queue.take()
-            elif timeout > 0:
+            timeout = float(timeout) / 1000.  # convert from milliseconds to seconds
+            while timeout > 0:
                 started = time.time()
                 timeout_in_ns = int(timeout * _TO_NANOSECONDS)
-                notification = self.queue.poll(timeout_in_ns, TimeUnit.NANOSECONDS)
-                # Need to reset the timeout, because this notification
-                # may not be of interest when masked out
+                result = self._handle_poll(partial(self.queue.poll, timeout_in_ns, TimeUnit.NANOSECONDS))
+                if result:
+                    return result
                 timeout = timeout - (time.time() - started)
-            else:
-                return []
-
-            if notification is None:
-                continue
-            
-            # Pull as many outstanding notifications as possible out
-            # of the queue
-            notifications = [notification]
-            self.queue.drainTo(notifications)
-            log.debug("Got notification(s) %s", notifications, extra={"sock": "MODULE"})
-            result = []
-            socks = set()
-
-            # But given how we notify, it's possible to see possible
-            # multiple notifications. Just return one (fd, event) for a
-            # given socket
-            for notification in notifications:
-                if notification.sock not in socks:
-                    fd, event = self._event_test(notification)
-                    if event:
-                        result.append((fd, event))
-                        socks.add(notification.sock)
-            # Repump sockets to pick up a subsequent level change
-            for sock in socks:
-                self.notify(sock)
-            if result:
-                return result
+            return []
 
 
 # integration with Netty
@@ -538,7 +575,7 @@
         self.parent_socket = parent_socket
 
     def initChannel(self, child_channel):
-        child = ChildSocket()
+        child = ChildSocket(self.parent_socket)
         child.proto = IPPROTO_TCP
         child._init_client_mode(child_channel)
 
@@ -551,7 +588,7 @@
             log.debug("Setting inherited options %s", child.options, extra={"sock": child})
             config = child_channel.config()
             for option, value in child.options.iteritems():
-                config.setOption(option, value)
+                _set_option(config.setOption, option, value)
 
         log.debug("Notifing listeners of parent socket %s", self.parent_socket, extra={"sock": child})
         self.parent_socket.child_queue.put(child)
@@ -578,6 +615,35 @@
 
 # FIXME raise exceptions for ops not permitted on client socket, server socket
 UNKNOWN_SOCKET, CLIENT_SOCKET, SERVER_SOCKET, DATAGRAM_SOCKET = range(4)
+_socket_types = {
+    UNKNOWN_SOCKET:  "unknown",
+    CLIENT_SOCKET:   "client", 
+    SERVER_SOCKET:   "server",
+    DATAGRAM_SOCKET: "datagram"
+}
+
+
+
+
+def _identity(value):
+    return value
+
+
+def _set_option(setter, option, value):
+    if option in (ChannelOption.SO_LINGER, ChannelOption.SO_TIMEOUT):
+        # FIXME consider implementing these options. Note these are not settable
+        # via config.setOption in any event:
+        #
+        # * SO_TIMEOUT does not work for NIO sockets, need to use
+        #   IdleStateHandler instead
+        #
+        # * SO_LINGER does not work for nonblocking sockets, so need
+        #   to emulate in calling close on the socket by attempting to
+        #   send any unsent data (it's not clear this actually is
+        #   needed in Netty however...)
+        return
+    else:
+        setter(option, value)
 
 
 # These are the only socket protocols we currently support, so it's easy to map as follows:
@@ -585,19 +651,15 @@
 _socket_options = {
     IPPROTO_TCP: {
         (SOL_SOCKET,  SO_KEEPALIVE):   (ChannelOption.SO_KEEPALIVE, bool),
-        (SOL_SOCKET,  SO_LINGER):      (ChannelOption.SO_LINGER, int),
+        (SOL_SOCKET,  SO_LINGER):      (ChannelOption.SO_LINGER, _identity),
         (SOL_SOCKET,  SO_RCVBUF):      (ChannelOption.SO_RCVBUF, int),
         (SOL_SOCKET,  SO_REUSEADDR):   (ChannelOption.SO_REUSEADDR, bool),
         (SOL_SOCKET,  SO_SNDBUF):      (ChannelOption.SO_SNDBUF, int),
-        # FIXME SO_TIMEOUT needs to be handled by an IdleStateHandler -
-        # ChannelOption.SO_TIMEOUT really only applies to OIO (old) socket channels,
-        # we want to use NIO ones
         (SOL_SOCKET,  SO_TIMEOUT):     (ChannelOption.SO_TIMEOUT, int),
         (IPPROTO_TCP, TCP_NODELAY):    (ChannelOption.TCP_NODELAY, bool),
     },
     IPPROTO_UDP: {
         (SOL_SOCKET,  SO_BROADCAST):   (ChannelOption.SO_BROADCAST, bool),
-        (SOL_SOCKET,  SO_LINGER):      (ChannelOption.SO_LINGER, int),
         (SOL_SOCKET,  SO_RCVBUF):      (ChannelOption.SO_RCVBUF, int),
         (SOL_SOCKET,  SO_REUSEADDR):   (ChannelOption.SO_REUSEADDR, bool),
         (SOL_SOCKET,  SO_SNDBUF):      (ChannelOption.SO_SNDBUF, int),
@@ -629,6 +691,7 @@
                 proto = IPPROTO_UDP
         self.proto = proto
 
+        self._sock = self  # some Python code wants to see a socket
         self._last_error = 0  # supports SO_ERROR
         self.connected = False
         self.timeout = _defaulttimeout
@@ -654,7 +717,7 @@
 
     def __repr__(self):
         return "<_realsocket at {:#x} type={} open_count={} channel={} timeout={}>".format(
-            id(self), self.socket_type, self.open_count, self.channel, self.timeout)
+            id(self), _socket_types[self.socket_type], self.open_count, self.channel, self.timeout)
 
     def _unlatch(self):
         pass  # no-op once mutated from ChildSocket to normal _socketobject
@@ -689,6 +752,7 @@
         elif self.timeout:
             self._handle_timeout(future.await, reason)
             if not future.isSuccess():
+                log.exception("Got this failure %s during %s", future.cause(), reason, extra={"sock": self})
                 raise future.cause()
             return future
         else:
@@ -757,7 +821,7 @@
         self.python_inbound_handler = PythonInboundHandler(self)
         bootstrap = Bootstrap().group(NIO_GROUP).channel(NioSocketChannel)
         for option, value in self.options.iteritems():
-            bootstrap.option(option, value)
+            _set_option(bootstrap.option, option, value)
 
         # FIXME really this is just for SSL handling, so make more
         # specific than a list of connect_handlers
@@ -772,13 +836,12 @@
             bind_future = bootstrap.bind(self.bind_addr)
             self._handle_channel_future(bind_future, "local bind")
             self.channel = bind_future.channel()
-            future = self.channel.connect(addr)
         else:
             log.debug("Connect to %s", addr, extra={"sock": self})
-            future = bootstrap.connect(addr)
-            self.channel = future.channel()
-            
-        self._handle_channel_future(future, "connect")
+            self.channel = bootstrap.channel()
+
+        connect_future = self.channel.connect(addr)
+        self._handle_channel_future(connect_future, "connect")
         self.bind_timestamp = time.time()
 
     def _post_connect(self):
@@ -818,19 +881,21 @@
 
     def listen(self, backlog):
         self.socket_type = SERVER_SOCKET
+        self.child_queue = ArrayBlockingQueue(backlog)
+        self.accepted_children = 1  # include the parent as well to simplify close logic
 
         b = ServerBootstrap()
-        self.group = NioEventLoopGroup(10, DaemonThreadFactory())
-        b.group(self.group)
+        self.parent_group = NioEventLoopGroup(2, DaemonThreadFactory("Jython-Netty-Parent-%s"))
+        self.child_group = NioEventLoopGroup(2, DaemonThreadFactory("Jython-Netty-Child-%s"))
+        b.group(self.parent_group, self.child_group)
         b.channel(NioServerSocketChannel)
         b.option(ChannelOption.SO_BACKLOG, backlog)
         for option, value in self.options.iteritems():
-            b.option(option, value)
+            _set_option(b.option, option, value)
             # Note that child options are set in the child handler so
             # that they can take into account any subsequent changes,
             # plus have shadow support
 
-        self.child_queue = ArrayBlockingQueue(backlog)
         self.child_handler = ChildSocketHandler(self)
         b.childHandler(self.child_handler)
 
@@ -854,6 +919,9 @@
                 raise error(errno.EWOULDBLOCK, "Resource temporarily unavailable")
         peername = child.getpeername() if child else None
         log.debug("Got child %s connected to %s", child, peername, extra={"sock": self})
+        child.accepted = True
+        with self.open_lock:
+            self.accepted_children += 1
         return child, peername
 
     # DATAGRAM METHODS
@@ -870,7 +938,7 @@
             bootstrap = Bootstrap().group(NIO_GROUP).channel(NioDatagramChannel)
             bootstrap.handler(self.python_inbound_handler)
             for option, value in self.options.iteritems():
-                bootstrap.option(option, value)
+                _set_option(bootstrap.option, option, value)
 
             future = bootstrap.register()
             self._handle_channel_future(future, "register")
@@ -903,14 +971,19 @@
         self._handle_channel_future(future, "sendto")
         return len(string)
 
-
-    # FIXME implement these methods
-
     def recvfrom_into(self, buffer, nbytes=0, flags=0):
-        raise NotImplementedError()
+        if nbytes == 0:
+            nbytes = len(buffer)
+        data, remote_addr = self.recvfrom(nbytes, flags)
+        buffer[0:len(data)] = data
+        return len(data), remote_addr
 
     def recv_into(self, buffer, nbytes=0, flags=0):
-        raise NotImplementedError()
+        if nbytes == 0:
+            nbytes = len(buffer)
+        data = self.recv(nbytes, flags)
+        buffer[0:len(data)] = data
+        return len(data)
 
     # GENERAL METHODS
                                              
@@ -930,7 +1003,9 @@
                 # Do not care about tasks that attempt to schedule after close
                 pass
             if self.socket_type == SERVER_SOCKET:
-                self.group.shutdownGracefully(0, 100, TimeUnit.MILLISECONDS)
+                log.debug("Shutting down server socket parent group", extra={"sock": self})
+                self.parent_group.shutdownGracefully(0, 100, TimeUnit.MILLISECONDS)
+                self.accepted_children -= 1
                 while True:
                     child = self.child_queue.poll()
                     if child is None:
@@ -941,6 +1016,7 @@
             log.debug("Closed socket", extra={"sock": self})
 
     def shutdown(self, how):
+        log.debug("Got request to shutdown socket how=%s", how, extra={"sock": self})
         self._verify_channel()
         if how & SHUT_RD:
             try:
@@ -952,8 +1028,10 @@
             
     def _readable(self):
         if self.socket_type == CLIENT_SOCKET or self.socket_type == DATAGRAM_SOCKET:
-            return ((self.incoming_head is not None and self.incoming_head.readableBytes()) or
-                    self.incoming.peek())
+            log.debug("Incoming head=%s queue=%s", self.incoming_head, self.incoming, extra={"sock": self})
+            return (
+                (self.incoming_head is not None and self.incoming_head.readableBytes()) or
+                self.incoming.peek())
         elif self.socket_type == SERVER_SOCKET:
             return bool(self.child_queue.peek())
         else:
@@ -986,6 +1064,7 @@
             raise error(errno.ENOTCONN, 'Socket not connected')
         future = self.channel.writeAndFlush(Unpooled.wrappedBuffer(data))
         self._handle_channel_future(future, "send")
+        log.debug("Sent data <<<{!r:.20}>>>".format(data), extra={"sock": self})
         # FIXME are we sure we are going to be able to send this much data, especially async?
         return len(data)
     
@@ -1008,9 +1087,9 @@
                     log.debug("No data yet for socket", extra={"sock": self})
                     raise error(errno.EAGAIN, "Resource temporarily unavailable")
 
-        # Only return _PEER_CLOSED once
         msg = self.incoming_head
         if msg is _PEER_CLOSED:
+            # Only return _PEER_CLOSED once
             self.incoming_head = None
             self.peer_closed = True
         return msg
@@ -1073,13 +1152,11 @@
         except KeyError:
             raise error(errno.ENOPROTOOPT, "Protocol not available")
 
-        # FIXME for NIO sockets, SO_TIMEOUT doesn't work - should use
-        # IdleStateHandler instead
         cast_value = cast(value)
         self.options[option] = cast_value
         log.debug("Setting option %s to %s", optname, value, extra={"sock": self})
         if self.channel:
-            self.channel.config().setOption(option, cast(value))
+            _set_option(self.channel.config().setOption, option, cast_value)
 
     def getsockopt(self, level, optname, buflen=None):
         # Pseudo options for interrogating the status of this socket
@@ -1247,10 +1324,12 @@
 
 class ChildSocket(_realsocket):
     
-    def __init__(self):
+    def __init__(self, parent_socket):
         super(ChildSocket, self).__init__()
+        self.parent_socket = parent_socket
         self.active = AtomicBoolean()
         self.active_latch = CountDownLatch(1)
+        self.accepted = False
 
     def _ensure_post_connect(self):
         do_post_connect = not self.active.getAndSet(True)
@@ -1293,6 +1372,14 @@
     def close(self):
         self._ensure_post_connect()
         super(ChildSocket, self).close()
+        if self.open_count > 0:
+            return
+        if self.accepted:
+            with self.parent_socket.open_lock:
+                self.parent_socket.accepted_children -= 1
+                if self.parent_socket.accepted_children == 0:
+                    log.debug("Shutting down child group for parent socket=%s", self.parent_socket, extra={"sock": self})
+                    self.parent_socket.child_group.shutdownGracefully(0, 100, TimeUnit.MILLISECONDS)
 
     def shutdown(self, how):
         self._ensure_post_connect()
@@ -1413,6 +1500,12 @@
 
 
 def _get_jsockaddr(address_object, family, sock_type, proto, flags):
+    if family is None:
+        family = AF_UNSPEC
+    if sock_type is None:
+        sock_type = 0
+    if proto is None:
+        proto = 0
     addr = _get_jsockaddr2(address_object, family, sock_type, proto, flags)
     log.debug("Address %s for %s", addr, address_object, extra={"sock": "*"})
     return addr
@@ -1427,9 +1520,10 @@
         address_object = ("", 0)
     error_message = "Address must be a 2-tuple (ipv4: (host, port)) or a 4-tuple (ipv6: (host, port, flow, scope))"
     if not isinstance(address_object, tuple) or \
-            ((family == AF_INET and len(address_object) != 2) or (family == AF_INET6 and len(address_object) not in [2,4] )) or \
-            not isinstance(address_object[0], (basestring, NoneType)) or \
-            not isinstance(address_object[1], (int, long)):
+       ((family == AF_INET and len(address_object) != 2) or \
+        (family == AF_INET6 and len(address_object) not in [2,4] )) or \
+       not isinstance(address_object[0], (basestring, NoneType)) or \
+       not isinstance(address_object[1], (int, long)):
         raise TypeError(error_message)
     if len(address_object) == 4 and not isinstance(address_object[3], (int, long)):
         raise TypeError(error_message)
@@ -1520,8 +1614,10 @@
 
 @raises_java_exception
 def getaddrinfo(host, port, family=AF_UNSPEC, socktype=0, proto=0, flags=0):
-    if _ipv4_addresses_only:
-        family = AF_INET
+    if family is None:
+        family = AF_UNSPEC
+    if socktype is None:
+        socktype = 0
     if not family in [AF_INET, AF_INET6, AF_UNSPEC]:
         raise gaierror(errno.EIO, 'ai_family not supported')
     host = _getaddrinfo_get_host(host, family, flags)
diff --git a/Lib/select.py b/Lib/select.py
--- a/Lib/select.py
+++ b/Lib/select.py
@@ -8,5 +8,5 @@
     POLLHUP,
     POLLNVAL,
     error,
-    #poll,
+    poll,
     select)
diff --git a/Lib/ssl.py b/Lib/ssl.py
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -1,5 +1,10 @@
+import base64
+import errno
 import logging
+import os.path
+import textwrap
 import time
+import threading
 
 try:
     # jarjar-ed version
@@ -20,27 +25,33 @@
     SSL_ERROR_ZERO_RETURN,
     SSL_ERROR_WANT_CONNECT,
     SSL_ERROR_EOF,
-    SSL_ERROR_INVALID_ERROR_CODE)
+    SSL_ERROR_INVALID_ERROR_CODE,
+    error as socket_error)
 from _sslcerts import _get_ssl_context
 
 from java.text import SimpleDateFormat
-from java.util import Locale, TimeZone
+from java.util import ArrayList, Locale, TimeZone
+from java.util.concurrent import CountDownLatch
 from javax.naming.ldap import LdapName
 from javax.security.auth.x500 import X500Principal
 
 
-log = logging.getLogger("socket")
+log = logging.getLogger("_socket")
 
 
+# Pretend to be OpenSSL
+OPENSSL_VERSION = "OpenSSL 1.0.0 (as emulated by Java SSL)"
+OPENSSL_VERSION_NUMBER = 0x1000000L
+OPENSSL_VERSION_INFO = (1, 0, 0, 0, 0)
+
 CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED = range(3)
 
-# FIXME need to map to java names as well; there's also possibility some difference between 
-# SSLv2 (Java) and PROTOCOL_SSLv23 (Python) but reading the docs suggest not
-# http://docs.oracle.com/javase/7/docs/technotes/guides/security/StandardNames.html#SSLContext
-
-# Currently ignored, since we just use the default in Java. FIXME
-PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 = range(4)
-_PROTOCOL_NAMES = {PROTOCOL_SSLv2: 'SSLv2', PROTOCOL_SSLv3: 'SSLv3', PROTOCOL_SSLv23: 'SSLv23', PROTOCOL_TLSv1: 'TLSv1'}
+# Do not support PROTOCOL_SSLv2, it is highly insecure and it is optional
+_, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 = range(4)
+_PROTOCOL_NAMES = {
+    PROTOCOL_SSLv3: 'SSLv3', 
+    PROTOCOL_SSLv23: 'SSLv23',
+    PROTOCOL_TLSv1: 'TLSv1'}
 
 _rfc2822_date_format = SimpleDateFormat("MMM dd HH:mm:ss yyyy z", Locale.US)
 _rfc2822_date_format.setTimeZone(TimeZone.getTimeZone("GMT"))
@@ -59,8 +70,7 @@
 }
 
 _cert_name_types = [
-    # FIXME only entry 2 - DNS - has been confirmed w/ cpython;
-    # everything else is coming from this doc:
+    # Fields documented in 
     # http://docs.oracle.com/javase/7/docs/api/java/security/cert/X509Certificate.html#getSubjectAlternativeNames()
     "other",
     "rfc822",
@@ -80,7 +90,7 @@
 
     def initChannel(self, ch):
         pipeline = ch.pipeline()
-        pipeline.addLast("ssl", self.ssl_handler) 
+        pipeline.addFirst("ssl", self.ssl_handler)
 
 
 class SSLSocket(object):
@@ -89,47 +99,61 @@
                  keyfile, certfile, ca_certs,
                  do_handshake_on_connect, server_side):
         self.sock = sock
+        self.do_handshake_on_connect = do_handshake_on_connect
         self._sock = sock._sock  # the real underlying socket
         self.context = _get_ssl_context(keyfile, certfile, ca_certs)
         self.engine = self.context.createSSLEngine()
+        self.server_side = server_side
         self.engine.setUseClientMode(not server_side)
-        self.ssl_handler = SslHandler(self.engine)
-        self.already_handshaked = False
-        self.do_handshake_on_connect = do_handshake_on_connect
+        self.ssl_handler = None
+        # _sslobj is used to follow CPython convention that an object
+        # means we have handshaked, as used by existing code that
+        # looks at this internal
+        self._sslobj = None
+        self.handshake_count = 0
 
-        if self.do_handshake_on_connect and hasattr(self._sock, "connected") and self._sock.connected:
-            self.already_handshaked = True
-            log.debug("Adding SSL handler to pipeline after connection", extra={"sock": self._sock})
-            self._sock.channel.pipeline().addFirst("ssl", self.ssl_handler)
-            self._sock._post_connect()
-            self._sock._notify_selectors()
-            self._sock._unlatch()
-
-        def handshake_step(result):
-            log.debug("SSL handshaking %s", result, extra={"sock": self._sock})
-            if not hasattr(self._sock, "activity_latch"):  # need a better discriminant
-                self._sock._post_connect()
-            self._sock._notify_selectors()
-
-        self.ssl_handler.handshakeFuture().addListener(handshake_step)
-        if self.do_handshake_on_connect and self.already_handshaked:
-            time.sleep(0.1)  # FIXME do we need this sleep
-            self.ssl_handler.handshakeFuture().sync()
-            log.debug("SSL handshaking completed", extra={"sock": self._sock})
+        if self.do_handshake_on_connect and self.sock._sock.connected:
+            self.do_handshake()
 
     def connect(self, addr):
         log.debug("Connect SSL with handshaking %s", self.do_handshake_on_connect, extra={"sock": self._sock})
         self._sock._connect(addr)
         if self.do_handshake_on_connect:
-            self.already_handshaked = True
-            if self._sock.connected:
-                log.debug("Already connected, adding SSL handler to pipeline...", extra={"sock": self._sock})
+            self.do_handshake()
+
+    def unwrap(self):
+        self._sock.channel.pipeline().remove("ssl")
+        self.ssl_handler.close()
+        return self._sock
+
+    def do_handshake(self):
+        log.debug("SSL handshaking", extra={"sock": self._sock})
+
+        def handshake_step(result):
+            log.debug("SSL handshaking completed %s", result, extra={"sock": self._sock})
+            if not hasattr(self._sock, "active_latch"):
+                log.debug("Post connect step", extra={"sock": self._sock})
+                self._sock._post_connect()
+                self._sock._unlatch()
+            self._sslobj = object()  # we have now handshaked
+            self._notify_selectors()
+
+        if self.ssl_handler is None:
+            self.ssl_handler = SslHandler(self.engine)
+            self.ssl_handler.handshakeFuture().addListener(handshake_step)
+
+            if hasattr(self._sock, "connected") and self._sock.connected:
+                # The underlying socket is already connected, so some extra work to manage
+                log.debug("Adding SSL handler to pipeline after connection", extra={"sock": self._sock})
                 self._sock.channel.pipeline().addFirst("ssl", self.ssl_handler)
             else:
                 log.debug("Not connected, adding SSL initializer...", extra={"sock": self._sock})
                 self._sock.connect_handlers.append(SSLInitializer(self.ssl_handler))
 
-    # Various pass through methods to the wrapper socket
+        handshake = self.ssl_handler.handshakeFuture()
+        self._sock._handle_channel_future(handshake, "SSL handshake")
+
+    # Various pass through methods to the wrapped socket
 
     def send(self, data):
         return self.sock.send(data)
@@ -140,6 +164,18 @@
     def recv(self, bufsize, flags=0):
         return self.sock.recv(bufsize, flags)
 
+    def recvfrom(self, bufsize, flags=0):
+        return self.sock.recvfrom(bufsize, flags)
+
+    def recvfrom_into(self, buffer, nbytes=0, flags=0):
+        return self.sock.recvfrom_into(buffer, nbytes, flags)
+
+    def recv_into(self, buffer, nbytes=0, flags=0):
+        return self.sock.recv_into(buffer, nbytes, flags)
+
+    def sendto(self, string, arg1, arg2=None):
+        raise socket_error(errno.EPROTO)
+
     def close(self):
         self.sock.close()
 
@@ -175,12 +211,6 @@
     def _notify_selectors(self):
         self._sock._notify_selectors()
 
-    def do_handshake(self):
-        if not self.already_handshaked:
-            log.debug("Not handshaked, so adding SSL handler", extra={"sock": self._sock})
-            self.already_handshaked = True
-            self._sock.channel.pipeline().addFirst("ssl", self.ssl_handler)
-
     def getpeername(self):
         return self.sock.getpeername()
 
@@ -240,9 +270,91 @@
         do_handshake_on_connect=do_handshake_on_connect)
 
 
-def unwrap_socket(sock):
-    # FIXME removing SSL handler from pipeline should suffice, but low pri for now
-    raise NotImplemented()
+# some utility functions
+
+def cert_time_to_seconds(cert_time):
+
+    """Takes a date-time string in standard ASN1_print form
+    ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
+    a Python time value in seconds past the epoch."""
+
+    import time
+    return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
+
+PEM_HEADER = "-----BEGIN CERTIFICATE-----"
+PEM_FOOTER = "-----END CERTIFICATE-----"
+
+def DER_cert_to_PEM_cert(der_cert_bytes):
+
+    """Takes a certificate in binary DER format and returns the
+    PEM version of it as a string."""
+
+    if hasattr(base64, 'standard_b64encode'):
+        # preferred because older API gets line-length wrong
+        f = base64.standard_b64encode(der_cert_bytes)
+        return (PEM_HEADER + '\n' +
+                textwrap.fill(f, 64) + '\n' +
+                PEM_FOOTER + '\n')
+    else:
+        return (PEM_HEADER + '\n' +
+                base64.encodestring(der_cert_bytes) +
+                PEM_FOOTER + '\n')
+
+def PEM_cert_to_DER_cert(pem_cert_string):
+
+    """Takes a certificate in ASCII PEM format and returns the
+    DER-encoded version of it as a byte sequence"""
+
+    if not pem_cert_string.startswith(PEM_HEADER):
+        raise ValueError("Invalid PEM encoding; must start with %s"
+                         % PEM_HEADER)
+    if not pem_cert_string.strip().endswith(PEM_FOOTER):
+        raise ValueError("Invalid PEM encoding; must end with %s"
+                         % PEM_FOOTER)
+    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
+    return base64.decodestring(d)
+
+def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
+
+    """Retrieve the certificate from the server at the specified address,
+    and return it as a PEM-encoded string.
+    If 'ca_certs' is specified, validate the server cert against it.
+    If 'ssl_version' is specified, use it in the connection attempt."""
+
+    host, port = addr
+    if (ca_certs is not None):
+        cert_reqs = CERT_REQUIRED
+    else:
+        cert_reqs = CERT_NONE
+    s = wrap_socket(socket(), ssl_version=ssl_version,
+                    cert_reqs=cert_reqs, ca_certs=ca_certs)
+    s.connect(addr)
+    dercert = s.getpeercert(True)
+    s.close()
+    return DER_cert_to_PEM_cert(dercert)
+
+def get_protocol_name(protocol_code):
+    return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
+
+# a replacement for the old socket.ssl function
+
+def sslwrap_simple(sock, keyfile=None, certfile=None):
+
+    """A replacement for the old socket.ssl function.  Designed
+    for compability with Python 2.5 and earlier.  Will disappear in
+    Python 3.0."""
+
+    ssl_sock = wrap_socket(sock, keyfile=keyfile, certfile=certfile, ssl_version=PROTOCOL_SSLv23)
+    try:
+        sock.getpeername()
+    except socket_error:
+        # no, no connection yet
+        pass
+    else:
+        # yes, do the handshake
+        ssl_sock.do_handshake()
+
+    return ssl_sock
 
 
 # Underlying Java does a good job of managing entropy, so these are just no-ops
@@ -251,7 +363,8 @@
     return True
 
 def RAND_egd(path):
-    pass
+    if os.path.abspath(str(path)) != path:
+        raise TypeError("Must be an absolute path, but ignoring it regardless")
 
 def RAND_add(bytes, entropy):
     pass
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -16,7 +16,7 @@
 import thread, threading
 from weakref import proxy
 from StringIO import StringIO
-from _socket import _check_threadpool_for_pending_threads
+from _socket import _check_threadpool_for_pending_threads, NIO_GROUP
 
 PORT = 50100
 HOST = 'localhost'
@@ -126,6 +126,17 @@
         if not self.server_ready.isSet():
             self.server_ready.set()
         self.client_ready.wait()
+
+    def _assert_no_pending_threads(self, group, msg):
+        # Wait up to one second for there not to be pending threads
+        for i in xrange(10):
+            pending_threads = _check_threadpool_for_pending_threads(group)
+            if len(pending_threads) == 0:
+                break
+            time.sleep(0.1)
+            
+        if pending_threads:
+            self.fail("Pending threads in Netty msg={} pool={}".format(msg, pprint.pformat(pending_threads)))
         
     def _tearDown(self):
         self.done.wait()   # wait for the client to exit
@@ -134,16 +145,13 @@
         msg = None
         if not self.queue.empty():
             msg = self.queue.get()
+        
+        self._assert_no_pending_threads(NIO_GROUP, "Client thread pool")
+        if hasattr(self, "srv"):
+            self._assert_no_pending_threads(self.srv.group, "Server thread pool")
 
-        # Wait up to one second for there not to be pending threads
-        for i in xrange(10):
-            pending_threads = _check_threadpool_for_pending_threads()
-            if len(pending_threads) == 0:
-                break
-            time.sleep(0.1)
-
-        if pending_threads or msg:
-            self.fail("msg={} Pending threads in Netty pool={}".format(msg, pprint.pformat(pending_threads)))
+        if msg:
+            self.fail("msg={}".format(msg))
 
             
     def clientRun(self, test_func):

-- 
Repository URL: http://hg.python.org/jython


More information about the Jython-checkins mailing list