[Jython-checkins] jython: Fix bugs in select.poll, threadpool group closing, and SSL handshaking.
jim.baker
jython-checkins at python.org
Mon Jul 14 22:19:48 CEST 2014
http://hg.python.org/jython/rev/2c45f75a5406
changeset: 7348:2c45f75a5406
user: Jim Baker <jim.baker at rackspace.com>
date: Mon Jul 14 14:19:57 2014 -0600
summary:
Fix bugs in select.poll, threadpool group closing, and SSL handshaking.
Does not yet support Start TLS for the server side of a SSL connection,
which has been deferred indefinitely.
Merged from https://bitbucket.org/jimbaker/jython-fix-socket-bugs
Fixes http://bugs.jython.org/issue2094,
http://bugs.jython.org/issue2147, http://bugs.jython.org/issue2174
files:
Lib/_socket.py | 280 ++++++++++++++++++---------
Lib/select.py | 2 +-
Lib/ssl.py | 211 ++++++++++++++++----
Lib/test/test_socket.py | 28 +-
4 files changed, 369 insertions(+), 152 deletions(-)
diff --git a/Lib/_socket.py b/Lib/_socket.py
--- a/Lib/_socket.py
+++ b/Lib/_socket.py
@@ -3,7 +3,9 @@
import errno
import jarray
import logging
+import numbers
import pprint
+import struct
import sys
import time
import _google_ipaddr_r234
@@ -15,18 +17,19 @@
from StringIO import StringIO
from threading import Condition, Lock
from types import MethodType, NoneType
+from weakref import WeakKeyDictionary
import java
from java.io import IOException, InterruptedIOException
-from java.lang import Thread
+from java.lang import Thread, IllegalStateException
from java.net import InetAddress, InetSocketAddress
from java.nio.channels import ClosedChannelException
from java.util import NoSuchElementException
from java.util.concurrent import (
ArrayBlockingQueue, CopyOnWriteArrayList, CountDownLatch, LinkedBlockingQueue,
RejectedExecutionException, ThreadFactory, TimeUnit)
-from java.util.concurrent.atomic import AtomicBoolean
-from javax.net.ssl import SSLPeerUnverifiedException
+from java.util.concurrent.atomic import AtomicBoolean, AtomicLong
+from javax.net.ssl import SSLPeerUnverifiedException, SSLException
try:
# jarjar-ed version
@@ -53,6 +56,8 @@
FORMAT = '%(asctime)-15s %(threadName)s %(levelname)s %(funcName)s %(message)s %(sock)s'
logging.basicConfig(format=FORMAT, level=logging.DEBUG)
+# _debug() # UNCOMMENT to get logging of socket activity
+
# Constants
###########
@@ -185,21 +190,27 @@
# because these threads only handle ephemeral data, such as performing
# SSL wrap/unwrap.
+
class DaemonThreadFactory(ThreadFactory):
+
+ thread_count = AtomicLong()
+
+ def __init__(self, label):
+ self.label = label
+
def newThread(self, runnable):
t = Thread(runnable)
t.daemon = True
+ t.name = self.label % (self.thread_count.getAndIncrement())
return t
-# This number should be configurable by the user. 10 is the default
-# number as of 4.0.17 of Netty. FIXME this default may be based on core count.
+NIO_GROUP = NioEventLoopGroup(10, DaemonThreadFactory("Jython-Netty-Client-%s"))
-NIO_GROUP = NioEventLoopGroup(10, DaemonThreadFactory())
-def _check_threadpool_for_pending_threads():
+def _check_threadpool_for_pending_threads(group):
pending_threads = []
- for t in NIO_GROUP:
+ for t in group:
pending_count = t.pendingTasks()
if pending_count > 0:
pending_threads.append((t, pending_count))
@@ -272,6 +283,7 @@
IOException : lambda x: error(errno.ECONNRESET, 'Software caused connection abort'),
InterruptedIOException : lambda x: timeout(None, 'timed out'),
+ IllegalStateException : lambda x: error(errno.EPIPE, 'Illegal state exception'),
java.net.BindException : lambda x: error(errno.EADDRINUSE, 'Address already in use'),
java.net.ConnectException : lambda x: error(errno.ECONNREFUSED, 'Connection refused'),
@@ -299,7 +311,8 @@
java.nio.channels.UnresolvedAddressException : lambda x: gaierror(errno.EGETADDRINFOFAILED, 'getaddrinfo failed'),
java.nio.channels.UnsupportedAddressTypeException : None,
- SSLPeerUnverifiedException: lambda x: SSLError(SSL_ERROR_SSL, "FIXME"),
+ SSLPeerUnverifiedException: lambda x: SSLError(SSL_ERROR_SSL, x.message),
+ SSLException: lambda x: SSLError(SSL_ERROR_SSL, x.message),
}
@@ -391,7 +404,9 @@
# shortcircuiting if the socket was in fact ready for
# reading/writing/exception before the select call
if selected_rlist or selected_wlist:
- return sorted(selected_rlist), sorted(selected_wlist), sorted(selected_xlist)
+ completed = sorted(selected_rlist), sorted(selected_wlist), sorted(selected_xlist)
+ log.debug("Completed select %s", completed, extra={"sock": "*"})
+ return completed
elif timeout is not None and time.time() - started >= timeout:
return [], [], []
self.cv.wait(timeout)
@@ -400,7 +415,12 @@
# poll support
##############
-_PollNotification = namedtuple("_PollNotification", ["sock", "fd", "exception", "hangup"])
+_PollNotification = namedtuple(
+ "_PollNotification",
+ ["sock", # the real socket
+ "fd", # could be the real socket (as returned by fileno) or a wrapping socket object
+ "exception",
+ "hangup"])
class poll(object):
@@ -408,30 +428,41 @@
def __init__(self):
self.queue = LinkedBlockingQueue()
self.registered = dict() # fd -> eventmask
+ self.socks2fd = WeakKeyDictionary() # sock -> fd
def notify(self, sock, exception=None, hangup=False):
notification = _PollNotification(
sock=sock,
- fd=sock.fileno(),
+ fd=self.socks2fd.get(sock),
exception=exception,
hangup=hangup)
+ log.debug("Notify %s", notification, extra={"sock": "*"})
+
self.queue.put(notification)
def register(self, fd, eventmask=POLLIN|POLLPRI|POLLOUT):
+ if not hasattr(fd, "fileno"):
+ raise TypeError("argument must have a fileno() method")
+ sock = fd.fileno()
+ log.debug("Register fd=%s eventmask=%s", fd, eventmask, extra={"sock": sock})
self.registered[fd] = eventmask
- # NOTE in case fd != sock in a future release, modifiy accordingly
- sock = fd
+ self.socks2fd[sock] = fd
sock._register_selector(self)
self.notify(sock) # Ensure we get an initial notification
def modify(self, fd, eventmask):
+ if not hasattr(fd, "fileno"):
+ raise TypeError("argument must have a fileno() method")
if fd not in self.registered:
raise error(errno.ENOENT, "No such file or directory")
self.registered[fd] = eventmask
def unregister(self, fd):
+ if not hasattr(fd, "fileno"):
+ raise TypeError("argument must have a fileno() method")
+ log.debug("Unregister socket fd=%s", fd, extra={"sock": fd.fileno()})
del self.registered[fd]
- sock = fd
+ sock = fd.fileno()
sock._unregister_selector(self)
def _event_test(self, notification):
@@ -439,7 +470,8 @@
# edges around errors and hangup
if notification is None:
return None, 0
- mask = self.registered.get(notification.sock, 0) # handle if concurrently removed, by simply ignoring
+ mask = self.registered.get(notification.fd, 0) # handle if concurrently removed, by simply ignoring
+ log.debug("Testing notification=%s mask=%s", notification, mask, extra={"sock": "*"})
event = 0
if mask & POLLIN and notification.sock._readable():
event |= POLLIN
@@ -451,53 +483,58 @@
event |= POLLHUP
if mask & POLLNVAL and not notification.sock.peer_closed:
event |= POLLNVAL
+ log.debug("Tested notification=%s event=%s", notification, event, extra={"sock": "*"})
return notification.fd, event
+ def _handle_poll(self, poller):
+ notification = poller()
+ if notification is None:
+ return []
+
+ # Pull as many outstanding notifications as possible out
+ # of the queue
+ notifications = [notification]
+ self.queue.drainTo(notifications)
+ log.debug("Got notification(s) %s", notifications, extra={"sock": "MODULE"})
+ result = []
+ socks = set()
+
+ # But given how we notify, it's possible to see possible
+ # multiple notifications. Just return one (fd, event) for a
+ # given socket
+ for notification in notifications:
+ if notification.sock not in socks:
+ fd, event = self._event_test(notification)
+ if event:
+ result.append((fd, event))
+ socks.add(notification.sock)
+
+ # Repump sockets to pick up a subsequent level change
+ for sock in socks:
+ self.notify(sock)
+
+ return result
+
def poll(self, timeout=None):
- if not timeout or timeout < 0:
- # Simplify logic around timeout resets
+ if not (timeout is None or isinstance(timeout, numbers.Real)):
+ raise TypeError("timeout must be a number or None, got %r" % (timeout,))
+ if timeout < 0:
timeout = None
+ log.debug("Polling timeout=%s", timeout, extra={"sock": "*"})
+ if timeout is None:
+ return self._handle_poll(self.queue.take)
+ elif timeout == 0:
+ return self._handle_poll(self.queue.poll)
else:
- timeout /= 1000. # convert from milliseconds to seconds
-
- while True:
- if timeout is None:
- notification = self.queue.take()
- elif timeout > 0:
+ timeout = float(timeout) / 1000. # convert from milliseconds to seconds
+ while timeout > 0:
started = time.time()
timeout_in_ns = int(timeout * _TO_NANOSECONDS)
- notification = self.queue.poll(timeout_in_ns, TimeUnit.NANOSECONDS)
- # Need to reset the timeout, because this notification
- # may not be of interest when masked out
+ result = self._handle_poll(partial(self.queue.poll, timeout_in_ns, TimeUnit.NANOSECONDS))
+ if result:
+ return result
timeout = timeout - (time.time() - started)
- else:
- return []
-
- if notification is None:
- continue
-
- # Pull as many outstanding notifications as possible out
- # of the queue
- notifications = [notification]
- self.queue.drainTo(notifications)
- log.debug("Got notification(s) %s", notifications, extra={"sock": "MODULE"})
- result = []
- socks = set()
-
- # But given how we notify, it's possible to see possible
- # multiple notifications. Just return one (fd, event) for a
- # given socket
- for notification in notifications:
- if notification.sock not in socks:
- fd, event = self._event_test(notification)
- if event:
- result.append((fd, event))
- socks.add(notification.sock)
- # Repump sockets to pick up a subsequent level change
- for sock in socks:
- self.notify(sock)
- if result:
- return result
+ return []
# integration with Netty
@@ -538,7 +575,7 @@
self.parent_socket = parent_socket
def initChannel(self, child_channel):
- child = ChildSocket()
+ child = ChildSocket(self.parent_socket)
child.proto = IPPROTO_TCP
child._init_client_mode(child_channel)
@@ -551,7 +588,7 @@
log.debug("Setting inherited options %s", child.options, extra={"sock": child})
config = child_channel.config()
for option, value in child.options.iteritems():
- config.setOption(option, value)
+ _set_option(config.setOption, option, value)
log.debug("Notifing listeners of parent socket %s", self.parent_socket, extra={"sock": child})
self.parent_socket.child_queue.put(child)
@@ -578,6 +615,35 @@
# FIXME raise exceptions for ops not permitted on client socket, server socket
UNKNOWN_SOCKET, CLIENT_SOCKET, SERVER_SOCKET, DATAGRAM_SOCKET = range(4)
+_socket_types = {
+ UNKNOWN_SOCKET: "unknown",
+ CLIENT_SOCKET: "client",
+ SERVER_SOCKET: "server",
+ DATAGRAM_SOCKET: "datagram"
+}
+
+
+
+
+def _identity(value):
+ return value
+
+
+def _set_option(setter, option, value):
+ if option in (ChannelOption.SO_LINGER, ChannelOption.SO_TIMEOUT):
+ # FIXME consider implementing these options. Note these are not settable
+ # via config.setOption in any event:
+ #
+ # * SO_TIMEOUT does not work for NIO sockets, need to use
+ # IdleStateHandler instead
+ #
+ # * SO_LINGER does not work for nonblocking sockets, so need
+ # to emulate in calling close on the socket by attempting to
+ # send any unsent data (it's not clear this actually is
+ # needed in Netty however...)
+ return
+ else:
+ setter(option, value)
# These are the only socket protocols we currently support, so it's easy to map as follows:
@@ -585,19 +651,15 @@
_socket_options = {
IPPROTO_TCP: {
(SOL_SOCKET, SO_KEEPALIVE): (ChannelOption.SO_KEEPALIVE, bool),
- (SOL_SOCKET, SO_LINGER): (ChannelOption.SO_LINGER, int),
+ (SOL_SOCKET, SO_LINGER): (ChannelOption.SO_LINGER, _identity),
(SOL_SOCKET, SO_RCVBUF): (ChannelOption.SO_RCVBUF, int),
(SOL_SOCKET, SO_REUSEADDR): (ChannelOption.SO_REUSEADDR, bool),
(SOL_SOCKET, SO_SNDBUF): (ChannelOption.SO_SNDBUF, int),
- # FIXME SO_TIMEOUT needs to be handled by an IdleStateHandler -
- # ChannelOption.SO_TIMEOUT really only applies to OIO (old) socket channels,
- # we want to use NIO ones
(SOL_SOCKET, SO_TIMEOUT): (ChannelOption.SO_TIMEOUT, int),
(IPPROTO_TCP, TCP_NODELAY): (ChannelOption.TCP_NODELAY, bool),
},
IPPROTO_UDP: {
(SOL_SOCKET, SO_BROADCAST): (ChannelOption.SO_BROADCAST, bool),
- (SOL_SOCKET, SO_LINGER): (ChannelOption.SO_LINGER, int),
(SOL_SOCKET, SO_RCVBUF): (ChannelOption.SO_RCVBUF, int),
(SOL_SOCKET, SO_REUSEADDR): (ChannelOption.SO_REUSEADDR, bool),
(SOL_SOCKET, SO_SNDBUF): (ChannelOption.SO_SNDBUF, int),
@@ -629,6 +691,7 @@
proto = IPPROTO_UDP
self.proto = proto
+ self._sock = self # some Python code wants to see a socket
self._last_error = 0 # supports SO_ERROR
self.connected = False
self.timeout = _defaulttimeout
@@ -654,7 +717,7 @@
def __repr__(self):
return "<_realsocket at {:#x} type={} open_count={} channel={} timeout={}>".format(
- id(self), self.socket_type, self.open_count, self.channel, self.timeout)
+ id(self), _socket_types[self.socket_type], self.open_count, self.channel, self.timeout)
def _unlatch(self):
pass # no-op once mutated from ChildSocket to normal _socketobject
@@ -689,6 +752,7 @@
elif self.timeout:
self._handle_timeout(future.await, reason)
if not future.isSuccess():
+ log.exception("Got this failure %s during %s", future.cause(), reason, extra={"sock": self})
raise future.cause()
return future
else:
@@ -757,7 +821,7 @@
self.python_inbound_handler = PythonInboundHandler(self)
bootstrap = Bootstrap().group(NIO_GROUP).channel(NioSocketChannel)
for option, value in self.options.iteritems():
- bootstrap.option(option, value)
+ _set_option(bootstrap.option, option, value)
# FIXME really this is just for SSL handling, so make more
# specific than a list of connect_handlers
@@ -772,13 +836,12 @@
bind_future = bootstrap.bind(self.bind_addr)
self._handle_channel_future(bind_future, "local bind")
self.channel = bind_future.channel()
- future = self.channel.connect(addr)
else:
log.debug("Connect to %s", addr, extra={"sock": self})
- future = bootstrap.connect(addr)
- self.channel = future.channel()
-
- self._handle_channel_future(future, "connect")
+ self.channel = bootstrap.channel()
+
+ connect_future = self.channel.connect(addr)
+ self._handle_channel_future(connect_future, "connect")
self.bind_timestamp = time.time()
def _post_connect(self):
@@ -818,19 +881,21 @@
def listen(self, backlog):
self.socket_type = SERVER_SOCKET
+ self.child_queue = ArrayBlockingQueue(backlog)
+ self.accepted_children = 1 # include the parent as well to simplify close logic
b = ServerBootstrap()
- self.group = NioEventLoopGroup(10, DaemonThreadFactory())
- b.group(self.group)
+ self.parent_group = NioEventLoopGroup(2, DaemonThreadFactory("Jython-Netty-Parent-%s"))
+ self.child_group = NioEventLoopGroup(2, DaemonThreadFactory("Jython-Netty-Child-%s"))
+ b.group(self.parent_group, self.child_group)
b.channel(NioServerSocketChannel)
b.option(ChannelOption.SO_BACKLOG, backlog)
for option, value in self.options.iteritems():
- b.option(option, value)
+ _set_option(b.option, option, value)
# Note that child options are set in the child handler so
# that they can take into account any subsequent changes,
# plus have shadow support
- self.child_queue = ArrayBlockingQueue(backlog)
self.child_handler = ChildSocketHandler(self)
b.childHandler(self.child_handler)
@@ -854,6 +919,9 @@
raise error(errno.EWOULDBLOCK, "Resource temporarily unavailable")
peername = child.getpeername() if child else None
log.debug("Got child %s connected to %s", child, peername, extra={"sock": self})
+ child.accepted = True
+ with self.open_lock:
+ self.accepted_children += 1
return child, peername
# DATAGRAM METHODS
@@ -870,7 +938,7 @@
bootstrap = Bootstrap().group(NIO_GROUP).channel(NioDatagramChannel)
bootstrap.handler(self.python_inbound_handler)
for option, value in self.options.iteritems():
- bootstrap.option(option, value)
+ _set_option(bootstrap.option, option, value)
future = bootstrap.register()
self._handle_channel_future(future, "register")
@@ -903,14 +971,19 @@
self._handle_channel_future(future, "sendto")
return len(string)
-
- # FIXME implement these methods
-
def recvfrom_into(self, buffer, nbytes=0, flags=0):
- raise NotImplementedError()
+ if nbytes == 0:
+ nbytes = len(buffer)
+ data, remote_addr = self.recvfrom(nbytes, flags)
+ buffer[0:len(data)] = data
+ return len(data), remote_addr
def recv_into(self, buffer, nbytes=0, flags=0):
- raise NotImplementedError()
+ if nbytes == 0:
+ nbytes = len(buffer)
+ data = self.recv(nbytes, flags)
+ buffer[0:len(data)] = data
+ return len(data)
# GENERAL METHODS
@@ -930,7 +1003,9 @@
# Do not care about tasks that attempt to schedule after close
pass
if self.socket_type == SERVER_SOCKET:
- self.group.shutdownGracefully(0, 100, TimeUnit.MILLISECONDS)
+ log.debug("Shutting down server socket parent group", extra={"sock": self})
+ self.parent_group.shutdownGracefully(0, 100, TimeUnit.MILLISECONDS)
+ self.accepted_children -= 1
while True:
child = self.child_queue.poll()
if child is None:
@@ -941,6 +1016,7 @@
log.debug("Closed socket", extra={"sock": self})
def shutdown(self, how):
+ log.debug("Got request to shutdown socket how=%s", how, extra={"sock": self})
self._verify_channel()
if how & SHUT_RD:
try:
@@ -952,8 +1028,10 @@
def _readable(self):
if self.socket_type == CLIENT_SOCKET or self.socket_type == DATAGRAM_SOCKET:
- return ((self.incoming_head is not None and self.incoming_head.readableBytes()) or
- self.incoming.peek())
+ log.debug("Incoming head=%s queue=%s", self.incoming_head, self.incoming, extra={"sock": self})
+ return (
+ (self.incoming_head is not None and self.incoming_head.readableBytes()) or
+ self.incoming.peek())
elif self.socket_type == SERVER_SOCKET:
return bool(self.child_queue.peek())
else:
@@ -986,6 +1064,7 @@
raise error(errno.ENOTCONN, 'Socket not connected')
future = self.channel.writeAndFlush(Unpooled.wrappedBuffer(data))
self._handle_channel_future(future, "send")
+ log.debug("Sent data <<<{!r:.20}>>>".format(data), extra={"sock": self})
# FIXME are we sure we are going to be able to send this much data, especially async?
return len(data)
@@ -1008,9 +1087,9 @@
log.debug("No data yet for socket", extra={"sock": self})
raise error(errno.EAGAIN, "Resource temporarily unavailable")
- # Only return _PEER_CLOSED once
msg = self.incoming_head
if msg is _PEER_CLOSED:
+ # Only return _PEER_CLOSED once
self.incoming_head = None
self.peer_closed = True
return msg
@@ -1073,13 +1152,11 @@
except KeyError:
raise error(errno.ENOPROTOOPT, "Protocol not available")
- # FIXME for NIO sockets, SO_TIMEOUT doesn't work - should use
- # IdleStateHandler instead
cast_value = cast(value)
self.options[option] = cast_value
log.debug("Setting option %s to %s", optname, value, extra={"sock": self})
if self.channel:
- self.channel.config().setOption(option, cast(value))
+ _set_option(self.channel.config().setOption, option, cast_value)
def getsockopt(self, level, optname, buflen=None):
# Pseudo options for interrogating the status of this socket
@@ -1247,10 +1324,12 @@
class ChildSocket(_realsocket):
- def __init__(self):
+ def __init__(self, parent_socket):
super(ChildSocket, self).__init__()
+ self.parent_socket = parent_socket
self.active = AtomicBoolean()
self.active_latch = CountDownLatch(1)
+ self.accepted = False
def _ensure_post_connect(self):
do_post_connect = not self.active.getAndSet(True)
@@ -1293,6 +1372,14 @@
def close(self):
self._ensure_post_connect()
super(ChildSocket, self).close()
+ if self.open_count > 0:
+ return
+ if self.accepted:
+ with self.parent_socket.open_lock:
+ self.parent_socket.accepted_children -= 1
+ if self.parent_socket.accepted_children == 0:
+ log.debug("Shutting down child group for parent socket=%s", self.parent_socket, extra={"sock": self})
+ self.parent_socket.child_group.shutdownGracefully(0, 100, TimeUnit.MILLISECONDS)
def shutdown(self, how):
self._ensure_post_connect()
@@ -1413,6 +1500,12 @@
def _get_jsockaddr(address_object, family, sock_type, proto, flags):
+ if family is None:
+ family = AF_UNSPEC
+ if sock_type is None:
+ sock_type = 0
+ if proto is None:
+ proto = 0
addr = _get_jsockaddr2(address_object, family, sock_type, proto, flags)
log.debug("Address %s for %s", addr, address_object, extra={"sock": "*"})
return addr
@@ -1427,9 +1520,10 @@
address_object = ("", 0)
error_message = "Address must be a 2-tuple (ipv4: (host, port)) or a 4-tuple (ipv6: (host, port, flow, scope))"
if not isinstance(address_object, tuple) or \
- ((family == AF_INET and len(address_object) != 2) or (family == AF_INET6 and len(address_object) not in [2,4] )) or \
- not isinstance(address_object[0], (basestring, NoneType)) or \
- not isinstance(address_object[1], (int, long)):
+ ((family == AF_INET and len(address_object) != 2) or \
+ (family == AF_INET6 and len(address_object) not in [2,4] )) or \
+ not isinstance(address_object[0], (basestring, NoneType)) or \
+ not isinstance(address_object[1], (int, long)):
raise TypeError(error_message)
if len(address_object) == 4 and not isinstance(address_object[3], (int, long)):
raise TypeError(error_message)
@@ -1520,8 +1614,10 @@
@raises_java_exception
def getaddrinfo(host, port, family=AF_UNSPEC, socktype=0, proto=0, flags=0):
- if _ipv4_addresses_only:
- family = AF_INET
+ if family is None:
+ family = AF_UNSPEC
+ if socktype is None:
+ socktype = 0
if not family in [AF_INET, AF_INET6, AF_UNSPEC]:
raise gaierror(errno.EIO, 'ai_family not supported')
host = _getaddrinfo_get_host(host, family, flags)
diff --git a/Lib/select.py b/Lib/select.py
--- a/Lib/select.py
+++ b/Lib/select.py
@@ -8,5 +8,5 @@
POLLHUP,
POLLNVAL,
error,
- #poll,
+ poll,
select)
diff --git a/Lib/ssl.py b/Lib/ssl.py
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -1,5 +1,10 @@
+import base64
+import errno
import logging
+import os.path
+import textwrap
import time
+import threading
try:
# jarjar-ed version
@@ -20,27 +25,33 @@
SSL_ERROR_ZERO_RETURN,
SSL_ERROR_WANT_CONNECT,
SSL_ERROR_EOF,
- SSL_ERROR_INVALID_ERROR_CODE)
+ SSL_ERROR_INVALID_ERROR_CODE,
+ error as socket_error)
from _sslcerts import _get_ssl_context
from java.text import SimpleDateFormat
-from java.util import Locale, TimeZone
+from java.util import ArrayList, Locale, TimeZone
+from java.util.concurrent import CountDownLatch
from javax.naming.ldap import LdapName
from javax.security.auth.x500 import X500Principal
-log = logging.getLogger("socket")
+log = logging.getLogger("_socket")
+# Pretend to be OpenSSL
+OPENSSL_VERSION = "OpenSSL 1.0.0 (as emulated by Java SSL)"
+OPENSSL_VERSION_NUMBER = 0x1000000L
+OPENSSL_VERSION_INFO = (1, 0, 0, 0, 0)
+
CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED = range(3)
-# FIXME need to map to java names as well; there's also possibility some difference between
-# SSLv2 (Java) and PROTOCOL_SSLv23 (Python) but reading the docs suggest not
-# http://docs.oracle.com/javase/7/docs/technotes/guides/security/StandardNames.html#SSLContext
-
-# Currently ignored, since we just use the default in Java. FIXME
-PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 = range(4)
-_PROTOCOL_NAMES = {PROTOCOL_SSLv2: 'SSLv2', PROTOCOL_SSLv3: 'SSLv3', PROTOCOL_SSLv23: 'SSLv23', PROTOCOL_TLSv1: 'TLSv1'}
+# Do not support PROTOCOL_SSLv2, it is highly insecure and it is optional
+_, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 = range(4)
+_PROTOCOL_NAMES = {
+ PROTOCOL_SSLv3: 'SSLv3',
+ PROTOCOL_SSLv23: 'SSLv23',
+ PROTOCOL_TLSv1: 'TLSv1'}
_rfc2822_date_format = SimpleDateFormat("MMM dd HH:mm:ss yyyy z", Locale.US)
_rfc2822_date_format.setTimeZone(TimeZone.getTimeZone("GMT"))
@@ -59,8 +70,7 @@
}
_cert_name_types = [
- # FIXME only entry 2 - DNS - has been confirmed w/ cpython;
- # everything else is coming from this doc:
+ # Fields documented in
# http://docs.oracle.com/javase/7/docs/api/java/security/cert/X509Certificate.html#getSubjectAlternativeNames()
"other",
"rfc822",
@@ -80,7 +90,7 @@
def initChannel(self, ch):
pipeline = ch.pipeline()
- pipeline.addLast("ssl", self.ssl_handler)
+ pipeline.addFirst("ssl", self.ssl_handler)
class SSLSocket(object):
@@ -89,47 +99,61 @@
keyfile, certfile, ca_certs,
do_handshake_on_connect, server_side):
self.sock = sock
+ self.do_handshake_on_connect = do_handshake_on_connect
self._sock = sock._sock # the real underlying socket
self.context = _get_ssl_context(keyfile, certfile, ca_certs)
self.engine = self.context.createSSLEngine()
+ self.server_side = server_side
self.engine.setUseClientMode(not server_side)
- self.ssl_handler = SslHandler(self.engine)
- self.already_handshaked = False
- self.do_handshake_on_connect = do_handshake_on_connect
+ self.ssl_handler = None
+ # _sslobj is used to follow CPython convention that an object
+ # means we have handshaked, as used by existing code that
+ # looks at this internal
+ self._sslobj = None
+ self.handshake_count = 0
- if self.do_handshake_on_connect and hasattr(self._sock, "connected") and self._sock.connected:
- self.already_handshaked = True
- log.debug("Adding SSL handler to pipeline after connection", extra={"sock": self._sock})
- self._sock.channel.pipeline().addFirst("ssl", self.ssl_handler)
- self._sock._post_connect()
- self._sock._notify_selectors()
- self._sock._unlatch()
-
- def handshake_step(result):
- log.debug("SSL handshaking %s", result, extra={"sock": self._sock})
- if not hasattr(self._sock, "activity_latch"): # need a better discriminant
- self._sock._post_connect()
- self._sock._notify_selectors()
-
- self.ssl_handler.handshakeFuture().addListener(handshake_step)
- if self.do_handshake_on_connect and self.already_handshaked:
- time.sleep(0.1) # FIXME do we need this sleep
- self.ssl_handler.handshakeFuture().sync()
- log.debug("SSL handshaking completed", extra={"sock": self._sock})
+ if self.do_handshake_on_connect and self.sock._sock.connected:
+ self.do_handshake()
def connect(self, addr):
log.debug("Connect SSL with handshaking %s", self.do_handshake_on_connect, extra={"sock": self._sock})
self._sock._connect(addr)
if self.do_handshake_on_connect:
- self.already_handshaked = True
- if self._sock.connected:
- log.debug("Already connected, adding SSL handler to pipeline...", extra={"sock": self._sock})
+ self.do_handshake()
+
+ def unwrap(self):
+ self._sock.channel.pipeline().remove("ssl")
+ self.ssl_handler.close()
+ return self._sock
+
+ def do_handshake(self):
+ log.debug("SSL handshaking", extra={"sock": self._sock})
+
+ def handshake_step(result):
+ log.debug("SSL handshaking completed %s", result, extra={"sock": self._sock})
+ if not hasattr(self._sock, "active_latch"):
+ log.debug("Post connect step", extra={"sock": self._sock})
+ self._sock._post_connect()
+ self._sock._unlatch()
+ self._sslobj = object() # we have now handshaked
+ self._notify_selectors()
+
+ if self.ssl_handler is None:
+ self.ssl_handler = SslHandler(self.engine)
+ self.ssl_handler.handshakeFuture().addListener(handshake_step)
+
+ if hasattr(self._sock, "connected") and self._sock.connected:
+ # The underlying socket is already connected, so some extra work to manage
+ log.debug("Adding SSL handler to pipeline after connection", extra={"sock": self._sock})
self._sock.channel.pipeline().addFirst("ssl", self.ssl_handler)
else:
log.debug("Not connected, adding SSL initializer...", extra={"sock": self._sock})
self._sock.connect_handlers.append(SSLInitializer(self.ssl_handler))
- # Various pass through methods to the wrapper socket
+ handshake = self.ssl_handler.handshakeFuture()
+ self._sock._handle_channel_future(handshake, "SSL handshake")
+
+ # Various pass through methods to the wrapped socket
def send(self, data):
return self.sock.send(data)
@@ -140,6 +164,18 @@
def recv(self, bufsize, flags=0):
return self.sock.recv(bufsize, flags)
+ def recvfrom(self, bufsize, flags=0):
+ return self.sock.recvfrom(bufsize, flags)
+
+ def recvfrom_into(self, buffer, nbytes=0, flags=0):
+ return self.sock.recvfrom_into(buffer, nbytes, flags)
+
+ def recv_into(self, buffer, nbytes=0, flags=0):
+ return self.sock.recv_into(buffer, nbytes, flags)
+
+ def sendto(self, string, arg1, arg2=None):
+ raise socket_error(errno.EPROTO)
+
def close(self):
self.sock.close()
@@ -175,12 +211,6 @@
def _notify_selectors(self):
self._sock._notify_selectors()
- def do_handshake(self):
- if not self.already_handshaked:
- log.debug("Not handshaked, so adding SSL handler", extra={"sock": self._sock})
- self.already_handshaked = True
- self._sock.channel.pipeline().addFirst("ssl", self.ssl_handler)
-
def getpeername(self):
return self.sock.getpeername()
@@ -240,9 +270,91 @@
do_handshake_on_connect=do_handshake_on_connect)
-def unwrap_socket(sock):
- # FIXME removing SSL handler from pipeline should suffice, but low pri for now
- raise NotImplemented()
+# some utility functions
+
+def cert_time_to_seconds(cert_time):
+
+ """Takes a date-time string in standard ASN1_print form
+ ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
+ a Python time value in seconds past the epoch."""
+
+ import time
+ return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
+
+PEM_HEADER = "-----BEGIN CERTIFICATE-----"
+PEM_FOOTER = "-----END CERTIFICATE-----"
+
+def DER_cert_to_PEM_cert(der_cert_bytes):
+
+ """Takes a certificate in binary DER format and returns the
+ PEM version of it as a string."""
+
+ if hasattr(base64, 'standard_b64encode'):
+ # preferred because older API gets line-length wrong
+ f = base64.standard_b64encode(der_cert_bytes)
+ return (PEM_HEADER + '\n' +
+ textwrap.fill(f, 64) + '\n' +
+ PEM_FOOTER + '\n')
+ else:
+ return (PEM_HEADER + '\n' +
+ base64.encodestring(der_cert_bytes) +
+ PEM_FOOTER + '\n')
+
+def PEM_cert_to_DER_cert(pem_cert_string):
+
+ """Takes a certificate in ASCII PEM format and returns the
+ DER-encoded version of it as a byte sequence"""
+
+ if not pem_cert_string.startswith(PEM_HEADER):
+ raise ValueError("Invalid PEM encoding; must start with %s"
+ % PEM_HEADER)
+ if not pem_cert_string.strip().endswith(PEM_FOOTER):
+ raise ValueError("Invalid PEM encoding; must end with %s"
+ % PEM_FOOTER)
+ d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
+ return base64.decodestring(d)
+
+def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
+
+ """Retrieve the certificate from the server at the specified address,
+ and return it as a PEM-encoded string.
+ If 'ca_certs' is specified, validate the server cert against it.
+ If 'ssl_version' is specified, use it in the connection attempt."""
+
+ host, port = addr
+ if (ca_certs is not None):
+ cert_reqs = CERT_REQUIRED
+ else:
+ cert_reqs = CERT_NONE
+ s = wrap_socket(socket(), ssl_version=ssl_version,
+ cert_reqs=cert_reqs, ca_certs=ca_certs)
+ s.connect(addr)
+ dercert = s.getpeercert(True)
+ s.close()
+ return DER_cert_to_PEM_cert(dercert)
+
+def get_protocol_name(protocol_code):
+ return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
+
+# a replacement for the old socket.ssl function
+
+def sslwrap_simple(sock, keyfile=None, certfile=None):
+
+ """A replacement for the old socket.ssl function. Designed
+ for compability with Python 2.5 and earlier. Will disappear in
+ Python 3.0."""
+
+ ssl_sock = wrap_socket(sock, keyfile=keyfile, certfile=certfile, ssl_version=PROTOCOL_SSLv23)
+ try:
+ sock.getpeername()
+ except socket_error:
+ # no, no connection yet
+ pass
+ else:
+ # yes, do the handshake
+ ssl_sock.do_handshake()
+
+ return ssl_sock
# Underlying Java does a good job of managing entropy, so these are just no-ops
@@ -251,7 +363,8 @@
return True
def RAND_egd(path):
- pass
+ if os.path.abspath(str(path)) != path:
+ raise TypeError("Must be an absolute path, but ignoring it regardless")
def RAND_add(bytes, entropy):
pass
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -16,7 +16,7 @@
import thread, threading
from weakref import proxy
from StringIO import StringIO
-from _socket import _check_threadpool_for_pending_threads
+from _socket import _check_threadpool_for_pending_threads, NIO_GROUP
PORT = 50100
HOST = 'localhost'
@@ -126,6 +126,17 @@
if not self.server_ready.isSet():
self.server_ready.set()
self.client_ready.wait()
+
+ def _assert_no_pending_threads(self, group, msg):
+ # Wait up to one second for there not to be pending threads
+ for i in xrange(10):
+ pending_threads = _check_threadpool_for_pending_threads(group)
+ if len(pending_threads) == 0:
+ break
+ time.sleep(0.1)
+
+ if pending_threads:
+ self.fail("Pending threads in Netty msg={} pool={}".format(msg, pprint.pformat(pending_threads)))
def _tearDown(self):
self.done.wait() # wait for the client to exit
@@ -134,16 +145,13 @@
msg = None
if not self.queue.empty():
msg = self.queue.get()
+
+ self._assert_no_pending_threads(NIO_GROUP, "Client thread pool")
+ if hasattr(self, "srv"):
+ self._assert_no_pending_threads(self.srv.group, "Server thread pool")
- # Wait up to one second for there not to be pending threads
- for i in xrange(10):
- pending_threads = _check_threadpool_for_pending_threads()
- if len(pending_threads) == 0:
- break
- time.sleep(0.1)
-
- if pending_threads or msg:
- self.fail("msg={} Pending threads in Netty pool={}".format(msg, pprint.pformat(pending_threads)))
+ if msg:
+ self.fail("msg={}".format(msg))
def clientRun(self, test_func):
--
Repository URL: http://hg.python.org/jython
More information about the Jython-checkins
mailing list