[py-svn] r31982 - in py/branch/distributed/py/execnet: . testing

arigo at codespeak.net arigo at codespeak.net
Mon Sep 4 13:27:25 CEST 2006


Author: arigo
Date: Mon Sep  4 13:27:23 2006
New Revision: 31982

Modified:
   py/branch/distributed/py/execnet/channel.py
   py/branch/distributed/py/execnet/gateway.py
   py/branch/distributed/py/execnet/register.py
   py/branch/distributed/py/execnet/testing/test_gateway.py
Log:
(pedronis, arigo)

- channel callbacks that are optionally called with an end marker when
  the channel closes

- a repr for gateways that shows the remote address

- never allow new channels over an already-closed connexion - these
  channels then just deadlock when we try to read from them!



Modified: py/branch/distributed/py/execnet/channel.py
==============================================================================
--- py/branch/distributed/py/execnet/channel.py	(original)
+++ py/branch/distributed/py/execnet/channel.py	Mon Sep  4 13:27:23 2006
@@ -19,6 +19,8 @@
         # XXX do this better
         print >> sys.stderr, "Warning: unhandled %r" % (self,)
 
+NO_ENDMARKER_WANTED = object()
+
 
 class Channel(object):
     """Communication channel between two possibly remote threads of code. """
@@ -33,13 +35,14 @@
         self._receiveclosed = threading.Event()
         self._remoteerrors = []
 
-    def setcallback(self, callback):
+    def setcallback(self, callback, endmarker=NO_ENDMARKER_WANTED):
         queue = self._items
         lock = self.gateway.channelfactory._receivelock
         lock.acquire()
         try:
             _callbacks = self.gateway.channelfactory._callbacks
-            if _callbacks.setdefault(self.id, callback) is not callback:
+            dictvalue = (callback, endmarker)
+            if _callbacks.setdefault(self.id, dictvalue) != dictvalue:
                 raise IOError("%r has callback already registered" %(self,))
             self._items = None
             while 1:
@@ -55,10 +58,7 @@
                         callback(olditem)
             if self._closed or self._receiveclosed.isSet():
                 # no need to keep a callback
-                try:
-                    del _callbacks[self.id]
-                except KeyError:
-                    pass
+                self.gateway.channelfactory._close_callback(self.id)
         finally:
             lock.release()
          
@@ -201,11 +201,14 @@
         self._receivelock = threading.RLock()
         self.gateway = gateway
         self.count = startcount
+        self.finished = False
 
     def new(self, id=None):
         """ create a new Channel with 'id' (or create new id if None). """
         self._writelock.acquire()
         try:
+            if self.finished:
+                raise IOError("connexion already closed: %s" % (self.gateway,))
             if id is None:
                 id = self.count
                 self.count += 2
@@ -226,10 +229,16 @@
             del self._channels[id]
         except KeyError:
             pass
+        self._close_callback(id)
+
+    def _close_callback(self, id):
         try:
-            del self._callbacks[id]
+            callback, endmarker = self._callbacks.pop(id)
         except KeyError:
             pass
+        else:
+            if endmarker is not NO_ENDMARKER_WANTED:
+                callback(endmarker)
 
     def _local_close(self, id, remoteerror=None):
         channel = self._channels.get(id)
@@ -265,23 +274,30 @@
         # executes in receiver thread
         self._receivelock.acquire()
         try:
-            callback = self._callbacks.get(id)
-            if callback is not None:
-                callback(data)   # even if channel may be already closed
-            else:
+            try:
+                callback, endmarker = self._callbacks[id]
+            except KeyError:
                 channel = self._channels.get(id)
                 queue = channel and channel._items
                 if queue is None:
                     pass    # drop data
                 else:
                     queue.put(data)
+            else:
+                callback(data)   # even if channel may be already closed
         finally:
             self._receivelock.release()
 
     def _finished_receiving(self):
+        self._writelock.acquire()
+        try:
+            self.finished = True
+        finally:
+            self._writelock.release()
         for id in self._channels.keys():
             self._local_last_message(id)
-        self._callbacks.clear()
+        for id in self._callbacks.keys():
+            self._close_callback(id)
 
 
 class ChannelFile:

Modified: py/branch/distributed/py/execnet/gateway.py
==============================================================================
--- py/branch/distributed/py/execnet/gateway.py	(original)
+++ py/branch/distributed/py/execnet/gateway.py	Mon Sep  4 13:27:23 2006
@@ -49,14 +49,22 @@
         self.pool = NamedThreadPool(receiver = self.thread_receiver, 
                                     sender = self.thread_sender)
 
-    def __repr__(self): 
+    def __repr__(self):
+        addr = self.getremoteaddress()
+        if addr:
+            addr = '[%s]' % (addr,)
+        else:
+            addr = ''
         r = (len(self.pool.getstarted('receiver'))
              and "receiving" or "not receiving")
         s = (len(self.pool.getstarted('sender')) 
              and "sending" or "not sending")
         i = len(self.channelfactory.channels())
-        return "<%s %s/%s (%d active channels)>" %(
-                self.__class__.__name__, r, s, i) 
+        return "<%s%s %s/%s (%d active channels)>" %(
+                self.__class__.__name__, addr, r, s, i)
+
+    def getremoteaddress(self):
+        return None
 
 ##    def _local_trystopexec(self):
 ##        self._execpool.shutdown() 
@@ -118,7 +126,8 @@
                 except:
                     excinfo = exc_info()
                     self.traceex(excinfo)
-                    msg.post_sent(self, excinfo)
+                    if msg is not None:
+                        msg.post_sent(self, excinfo)
                     raise
                 else:
                     self.trace('sent -> %r' % msg)

Modified: py/branch/distributed/py/execnet/register.py
==============================================================================
--- py/branch/distributed/py/execnet/register.py	(original)
+++ py/branch/distributed/py/execnet/register.py	Mon Sep  4 13:27:23 2006
@@ -110,6 +110,9 @@
         io = inputoutput.SocketIO(sock)
         InstallableGateway.__init__(self, io=io)
 
+    def getremoteaddress(self):
+        return '%s:%d' % (self.host, self.port)
+
     def remote_install(cls, gateway, hostport=None): 
         """ return a connected socket gateway through the
             given gateway. 
@@ -151,6 +154,9 @@
         cmdline.insert(0, cmd) 
         super(SshGateway, self).__init__(' '.join(cmdline))
 
+    def getremoteaddress(self):
+        return self.sshaddress
+
 class ExecGateway(PopenGateway):
     def remote_exec_sync_stdcapture(self, lines, callback):
         # hack: turn the content of the cell into

Modified: py/branch/distributed/py/execnet/testing/test_gateway.py
==============================================================================
--- py/branch/distributed/py/execnet/testing/test_gateway.py	(original)
+++ py/branch/distributed/py/execnet/testing/test_gateway.py	Mon Sep  4 13:27:23 2006
@@ -240,6 +240,21 @@
         channel = self.test_channel_callback_stays_active(False)
         channel.waitclose(1.0) # freed automatically at the end of producer()
 
+    def test_channel_endmarker_callback(self):
+        l = []
+        channel = self.gw.remote_exec(source='''
+            channel.send(42)
+            channel.send(13)
+            channel.send(channel.gateway.newchannel())
+            ''') 
+        channel.setcallback(l.append, 999)
+        py.test.raises(IOError, channel.receive)
+        channel.waitclose(1.0) 
+        assert len(l) == 4
+        assert l[:2] == [42,13]
+        assert isinstance(l[2], channel.__class__) 
+        assert l[3] == 999
+
     def test_remote_redirect_stdout(self): 
         out = py.std.StringIO.StringIO() 
         handle = self.gw.remote_redirect(stdout=out) 
@@ -391,3 +406,14 @@
     def test_sshaddress(self):
         assert self.gw.sshaddress == option.sshtarget
 
+    def test_failed_connexion(self):
+        gw = py.execnet.SshGateway('nowhere.codespeak.net')
+        try:
+            channel = gw.remote_exec("...")
+        except IOError:
+            pass      # connexion failed already
+        else:
+            # connexion did not fail yet
+            py.test.raises(EOFError, channel.receive)
+            # now it did
+            py.test.raises(IOError, gw.remote_exec, "...")



More information about the pytest-commit mailing list