[Python-checkins] cpython (2.7): Backport Fix for Issue #7776: Fix ``Host:'' header and reconnection when using

senthil.kumaran python-checkins at python.org
Sat May 17 03:52:01 CEST 2014


http://hg.python.org/cpython/rev/568041fd8090
changeset:   90728:568041fd8090
branch:      2.7
parent:      90722:0a6d51ccff54
user:        Senthil Kumaran <senthil at uthcode.com>
date:        Fri May 16 18:51:46 2014 -0700
summary:
  Backport Fix for Issue #7776: Fix ``Host:'' header and reconnection when using http.client.HTTPConnection.set_tunnel().

Patch by Nikolaus Rath.

files:
  Lib/httplib.py           |  54 +++++++++++++++++++--------
  Lib/test/test_httplib.py |  48 +++++++++++++++++++++++-
  Misc/NEWS                |   4 ++
  3 files changed, 88 insertions(+), 18 deletions(-)


diff --git a/Lib/httplib.py b/Lib/httplib.py
--- a/Lib/httplib.py
+++ b/Lib/httplib.py
@@ -700,17 +700,33 @@
         self._tunnel_host = None
         self._tunnel_port = None
         self._tunnel_headers = {}
-
-        self._set_hostport(host, port)
         if strict is not None:
             self.strict = strict
 
+        (self.host, self.port) = self._get_hostport(host, port)
+
+        # This is stored as an instance variable to allow unittests
+        # to replace with a suitable mock
+        self._create_connection = socket.create_connection
+
     def set_tunnel(self, host, port=None, headers=None):
-        """ Sets up the host and the port for the HTTP CONNECT Tunnelling.
+        """ Set up host and port for HTTP CONNECT tunnelling.
+
+        In a connection that uses HTTP Connect tunneling, the host passed to the
+        constructor is used as proxy server that relays all communication to the
+        endpoint passed to set_tunnel. This is done by sending a HTTP CONNECT
+        request to the proxy server when the connection is established.
+
+        This method must be called before the HTML connection has been
+        established.
 
         The headers argument should be a mapping of extra HTTP headers
         to send with the CONNECT request.
         """
+        # Verify if this is required.
+        if self.sock:
+            raise RuntimeError("Can't setup tunnel for established connection.")
+
         self._tunnel_host = host
         self._tunnel_port = port
         if headers:
@@ -718,7 +734,7 @@
         else:
             self._tunnel_headers.clear()
 
-    def _set_hostport(self, host, port):
+    def _get_hostport(self, host, port):
         if port is None:
             i = host.rfind(':')
             j = host.rfind(']')         # ipv6 addresses have [...]
@@ -735,15 +751,14 @@
                 port = self.default_port
             if host and host[0] == '[' and host[-1] == ']':
                 host = host[1:-1]
-        self.host = host
-        self.port = port
+        return (host, port)
 
     def set_debuglevel(self, level):
         self.debuglevel = level
 
     def _tunnel(self):
-        self._set_hostport(self._tunnel_host, self._tunnel_port)
-        self.send("CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port))
+        (host, port) = self._get_hostport(self._tunnel_host, self._tunnel_port)
+        self.send("CONNECT %s:%d HTTP/1.0\r\n" % (host, port))
         for header, value in self._tunnel_headers.iteritems():
             self.send("%s: %s\r\n" % (header, value))
         self.send("\r\n")
@@ -768,8 +783,8 @@
 
     def connect(self):
         """Connect to the host and port specified in __init__."""
-        self.sock = socket.create_connection((self.host,self.port),
-                                             self.timeout, self.source_address)
+        self.sock = self._create_connection((self.host,self.port),
+                                           self.timeout, self.source_address)
 
         if self._tunnel_host:
             self._tunnel()
@@ -907,17 +922,24 @@
                         netloc_enc = netloc.encode("idna")
                     self.putheader('Host', netloc_enc)
                 else:
+                    if self._tunnel_host:
+                        host = self._tunnel_host
+                        port = self._tunnel_port
+                    else:
+                        host = self.host
+                        port = self.port
+
                     try:
-                        host_enc = self.host.encode("ascii")
+                        host_enc = host.encode("ascii")
                     except UnicodeEncodeError:
-                        host_enc = self.host.encode("idna")
+                        host_enc = host.encode("idna")
                     # Wrap the IPv6 Host Header with [] (RFC 2732)
                     if host_enc.find(':') >= 0:
                         host_enc = "[" + host_enc + "]"
-                    if self.port == self.default_port:
+                    if port == self.default_port:
                         self.putheader('Host', host_enc)
                     else:
-                        self.putheader('Host', "%s:%s" % (host_enc, self.port))
+                        self.putheader('Host', "%s:%s" % (host_enc, port))
 
             # note: we are assuming that clients will not attempt to set these
             #       headers since *this* library must deal with the
@@ -1168,8 +1190,8 @@
         def connect(self):
             "Connect to a host on a given (SSL) port."
 
-            sock = socket.create_connection((self.host, self.port),
-                                            self.timeout, self.source_address)
+            sock = self._create_connection((self.host, self.port),
+                                          self.timeout, self.source_address)
             if self._tunnel_host:
                 self.sock = sock
                 self._tunnel()
diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py
--- a/Lib/test/test_httplib.py
+++ b/Lib/test/test_httplib.py
@@ -13,10 +13,12 @@
 HOST = test_support.HOST
 
 class FakeSocket:
-    def __init__(self, text, fileclass=StringIO.StringIO):
+    def __init__(self, text, fileclass=StringIO.StringIO, host=None, port=None):
         self.text = text
         self.fileclass = fileclass
         self.data = ''
+        self.host = host
+        self.port = port
 
     def sendall(self, data):
         self.data += ''.join(data)
@@ -26,6 +28,9 @@
             raise httplib.UnimplementedFileMode()
         return self.fileclass(self.text)
 
+    def close(self):
+        pass
+
 class EPipeSocket(FakeSocket):
 
     def __init__(self, text, pipe_trigger):
@@ -526,9 +531,48 @@
                 self.fail("Port incorrectly parsed: %s != %s" % (p, c.host))
 
 
+class TunnelTests(TestCase):
+    def test_connect(self):
+        response_text = (
+            'HTTP/1.0 200 OK\r\n\r\n'   # Reply to CONNECT
+            'HTTP/1.1 200 OK\r\n'       # Reply to HEAD
+            'Content-Length: 42\r\n\r\n'
+        )
+
+        def create_connection(address, timeout=None, source_address=None):
+            return FakeSocket(response_text, host=address[0], port=address[1])
+
+        conn = httplib.HTTPConnection('proxy.com')
+        conn._create_connection = create_connection
+
+        # Once connected, we should not be able to tunnel anymore
+        conn.connect()
+        self.assertRaises(RuntimeError, conn.set_tunnel, 'destination.com')
+
+        # But if close the connection, we are good.
+        conn.close()
+        conn.set_tunnel('destination.com')
+        conn.request('HEAD', '/', '')
+
+        self.assertEqual(conn.sock.host, 'proxy.com')
+        self.assertEqual(conn.sock.port, 80)
+        self.assertTrue('CONNECT destination.com' in conn.sock.data)
+        self.assertTrue('Host: destination.com' in conn.sock.data)
+
+        self.assertTrue('Host: proxy.com' not in conn.sock.data)
+
+        conn.close()
+
+        conn.request('PUT', '/', '')
+        self.assertEqual(conn.sock.host, 'proxy.com')
+        self.assertEqual(conn.sock.port, 80)
+        self.assertTrue('CONNECT destination.com' in conn.sock.data)
+        self.assertTrue('Host: destination.com' in conn.sock.data)
+
+
 def test_main(verbose=None):
     test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
-                              HTTPSTimeoutTest, SourceAddressTest)
+                              HTTPSTimeoutTest, SourceAddressTest, TunnelTests)
 
 if __name__ == '__main__':
     test_main()
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -49,6 +49,10 @@
 Library
 -------
 
+-  Issue #7776: Backport Fix ``Host:'' header and reconnection when using
+   http.client.HTTPConnection.set_tunnel() from Python 3.
+   Patch by Nikolaus Rath.
+
 - Issue #21306: Backport hmac.compare_digest from Python 3. This is part of PEP
   466.
 

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list