[pypy-commit] pypy default: Fix socket.recvfrom() so that it takes advantage of the fact that,

arigo pypy.commits at gmail.com
Tue Aug 9 04:52:19 EDT 2016


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r86098:e53ea5c9c384
Date: 2016-08-09 10:51 +0200
http://bitbucket.org/pypy/pypy/changeset/e53ea5c9c384/

Log:	Fix socket.recvfrom() so that it takes advantage of the fact that,
	nowadays, a lot of buffers have a get_raw_address()

diff --git a/rpython/rlib/rsocket.py b/rpython/rlib/rsocket.py
--- a/rpython/rlib/rsocket.py
+++ b/rpython/rlib/rsocket.py
@@ -862,23 +862,30 @@
         string."""
         self.wait_for_data(False)
         with rffi.scoped_alloc_buffer(buffersize) as buf:
-            read_bytes = _c.socketrecv(self.fd,
-                                       rffi.cast(rffi.VOIDP, buf.raw),
-                                       buffersize, flags)
+            read_bytes = _c.socketrecv(self.fd, buf.raw, buffersize, flags)
             if read_bytes >= 0:
                 return buf.str(read_bytes)
         raise self.error_handler()
 
     def recvinto(self, rwbuffer, nbytes, flags=0):
-        buf = self.recv(nbytes, flags)
-        rwbuffer.setslice(0, buf)
-        return len(buf)
+        try:
+            rwbuffer.get_raw_address()
+        except ValueError:
+            buf = self.recv(nbytes, flags)
+            rwbuffer.setslice(0, buf)
+            return len(buf)
+        else:
+            self.wait_for_data(False)
+            raw = rwbuffer.get_raw_address()
+            read_bytes = _c.socketrecv(self.fd, raw, nbytes, flags)
+            if read_bytes >= 0:
+                return read_bytes
+            raise self.error_handler()
 
     @jit.dont_look_inside
     def recvfrom(self, buffersize, flags=0):
         """Like recv(buffersize, flags) but also return the sender's
         address."""
-        read_bytes = -1
         self.wait_for_data(False)
         with rffi.scoped_alloc_buffer(buffersize) as buf:
             address, addr_p, addrlen_p = self._addrbuf()
@@ -899,9 +906,30 @@
         raise self.error_handler()
 
     def recvfrom_into(self, rwbuffer, nbytes, flags=0):
-        buf, addr = self.recvfrom(nbytes, flags)
-        rwbuffer.setslice(0, buf)
-        return len(buf), addr
+        try:
+            rwbuffer.get_raw_address()
+        except ValueError:
+            buf, addr = self.recvfrom(nbytes, flags)
+            rwbuffer.setslice(0, buf)
+            return len(buf), addr
+        else:
+            self.wait_for_data(False)
+            address, addr_p, addrlen_p = self._addrbuf()
+            try:
+                raw = rwbuffer.get_raw_address()
+                read_bytes = _c.recvfrom(self.fd, raw, nbytes, flags,
+                                         addr_p, addrlen_p)
+                addrlen = rffi.cast(lltype.Signed, addrlen_p[0])
+            finally:
+                lltype.free(addrlen_p, flavor='raw')
+                address.unlock()
+            if read_bytes >= 0:
+                if addrlen:
+                    address.addrlen = addrlen
+                else:
+                    address = None
+                return (read_bytes, address)
+            raise self.error_handler()
 
     def send_raw(self, dataptr, length, flags=0):
         """Send data from a CCHARP buffer."""
diff --git a/rpython/rlib/test/test_rsocket.py b/rpython/rlib/test/test_rsocket.py
--- a/rpython/rlib/test/test_rsocket.py
+++ b/rpython/rlib/test/test_rsocket.py
@@ -119,25 +119,111 @@
     s1.close()
     s2.close()
 
-def test_socketpair_recvinto():
+def test_socketpair_recvinto_1():
     class Buffer:
         def setslice(self, start, string):
             self.x = string
 
-        def as_str(self):
-            return self.x
+        def get_raw_address(self):
+            raise ValueError
 
     if sys.platform == "win32":
         py.test.skip('No socketpair on Windows')
     s1, s2 = socketpair()
     buf = Buffer()
     s1.sendall('?')
-    s2.recvinto(buf, 1)
-    assert buf.as_str() == '?'
+    n = s2.recvinto(buf, 1)
+    assert n == 1
+    assert buf.x == '?'
     count = s2.send('x'*99)
     assert 1 <= count <= 99
-    s1.recvinto(buf, 100)
-    assert buf.as_str() == 'x'*count
+    n = s1.recvinto(buf, 100)
+    assert n == count
+    assert buf.x == 'x'*count
+    s1.close()
+    s2.close()
+
+def test_socketpair_recvinto_2():
+    class Buffer:
+        def __init__(self):
+            self._p = lltype.malloc(rffi.CCHARP.TO, 100, flavor='raw',
+                                    track_allocation=False)
+
+        def _as_str(self, count):
+            return rffi.charpsize2str(self._p, count)
+
+        def get_raw_address(self):
+            return self._p
+
+    if sys.platform == "win32":
+        py.test.skip('No socketpair on Windows')
+    s1, s2 = socketpair()
+    buf = Buffer()
+    s1.sendall('?')
+    n = s2.recvinto(buf, 1)
+    assert n == 1
+    assert buf._as_str(1) == '?'
+    count = s2.send('x'*99)
+    assert 1 <= count <= 99
+    n = s1.recvinto(buf, 100)
+    assert n == count
+    assert buf._as_str(n) == 'x'*count
+    s1.close()
+    s2.close()
+
+def test_socketpair_recvfrom_into_1():
+    class Buffer:
+        def setslice(self, start, string):
+            self.x = string
+
+        def get_raw_address(self):
+            raise ValueError
+
+    if sys.platform == "win32":
+        py.test.skip('No socketpair on Windows')
+    s1, s2 = socketpair()
+    buf = Buffer()
+    s1.sendall('?')
+    n, addr = s2.recvfrom_into(buf, 1)
+    assert n == 1
+    assert addr is None
+    assert buf.x == '?'
+    count = s2.send('x'*99)
+    assert 1 <= count <= 99
+    n, addr = s1.recvfrom_into(buf, 100)
+    assert n == count
+    assert addr is None
+    assert buf.x == 'x'*count
+    s1.close()
+    s2.close()
+
+def test_socketpair_recvfrom_into_2():
+    class Buffer:
+        def __init__(self):
+            self._p = lltype.malloc(rffi.CCHARP.TO, 100, flavor='raw',
+                                    track_allocation=False)
+
+        def _as_str(self, count):
+            return rffi.charpsize2str(self._p, count)
+
+        def get_raw_address(self):
+            return self._p
+
+    if sys.platform == "win32":
+        py.test.skip('No socketpair on Windows')
+    s1, s2 = socketpair()
+    buf = Buffer()
+    s1.sendall('?')
+    n, addr = s2.recvfrom_into(buf, 1)
+    assert n == 1
+    assert addr is None
+    assert buf._as_str(1) == '?'
+    count = s2.send('x'*99)
+    assert 1 <= count <= 99
+    n, addr = s1.recvfrom_into(buf, 100)
+    assert n == count
+    assert addr is None
+    assert buf._as_str(n) == 'x'*count
     s1.close()
     s2.close()
 


More information about the pypy-commit mailing list