[py-svn] r6909 - in py/dist/py/execnet: . bin

hpk at codespeak.net hpk at codespeak.net
Wed Oct 13 05:01:04 CEST 2004


Author: hpk
Date: Wed Oct 13 05:01:02 2004
New Revision: 6909

Modified:
   py/dist/py/execnet/bin/startserver.py
   py/dist/py/execnet/gateway.py
   py/dist/py/execnet/gateway_test.py
   py/dist/py/execnet/register.py
Log:
- the first implementation of the "channel interface" 
  as described in doc/execnet 

- added lots of tests 

- gateway.py should be split up into multiple files, probably 



Modified: py/dist/py/execnet/bin/startserver.py
==============================================================================
--- py/dist/py/execnet/bin/startserver.py	(original)
+++ py/dist/py/execnet/bin/startserver.py	Wed Oct 13 05:01:02 2004
@@ -42,7 +42,7 @@
         #    import traceback
         #    traceback.print_exc()
 
-def listen(hostport):
+def bind_and_listen(hostport):
     if isinstance(hostport, str):
         host, port = hostport.split(':')
         hostport = (host, int(port))
@@ -64,6 +64,6 @@
         hostport = sys.argv[1]
     else:
         hostport = ':8888'
-    serversock = listen(hostport) 
+    serversock = bind_and_listen(hostport) 
     startserver(serversock)
 

Modified: py/dist/py/execnet/gateway.py
==============================================================================
--- py/dist/py/execnet/gateway.py	(original)
+++ py/dist/py/execnet/gateway.py	Wed Oct 13 05:01:02 2004
@@ -1,35 +1,60 @@
 import sys, os, threading, struct, Queue, traceback
 import atexit
+import time
 
 from py.__impl__.execnet.source import Source
 
 debug = 0
 sysex = (KeyboardInterrupt, SystemExit) 
 
-class Gateway:
-    def __init__(self, io, ns = None):
-        if ns is None:
-            ns = {}
-        ns['gateway'] = self
+class Gateway(object):
+    num_worker_threads = 2
+
+    def __init__(self, io, startcount=2): 
         self.io = io
-        self.ns = ns
-        self._incoming = Queue.Queue()
+        self._execqueue = Queue.Queue()
         self._outgoing = Queue.Queue()
-        self._execevent = threading.Event() 
-
-        self.running = 1
-        self._replylock = threading.RLock()
-        self._replies = {}
-
-        for method in self._receiver, self._executor, self._sender:
-            t = threading.Thread(target=method, name=method.func_name)
-            setattr(self, 'thread'+method.func_name, t)
-            t.setDaemon(0) 
-            t.start() 
+        self.channelfactory = ChannelFactory(self, startcount) 
+        self.iothreads = [
+            threading.Thread(target=self.thread_receiver, name='receiver'),
+            threading.Thread(target=self.thread_sender, name='sender'),
+        ]
+        for x in self.iothreads:
+            x.start()
+        self.workerthreads = w = [] 
+        for x in range(self.num_worker_threads):
+            w.append(threading.Thread(target=self.thread_executor, 
+                                      name='executor %d' % x))
+        for x in w:
+            x.start()
         if not _gateways:
             atexit.register(cleanup_atexit) 
         _gateways.append(self) 
 
+    def _stopexec(self):
+        if self.workerthreads: 
+            for x in self.workerthreads: 
+                self._execqueue.put(None) 
+            for x in self.workerthreads:
+                if x.isAlive():
+                    self.trace("joining %r" % x)
+                    x.join()
+            self.workerthreads[:] = []
+
+    def exit(self): 
+        if self.workerthreads:
+            self._stopexec()
+            self._outgoing.put(Message.EXIT_GATEWAY()) 
+        else:
+            self.trace("exit() called, but gateway has not threads anymore!") 
+
+    def join(self):
+        current = threading.currentThread()
+        for x in self.iothreads: 
+            if x != current and x.isAlive():
+                print "joining", x
+                x.join()
+
     def trace(self, *args):
         if debug:
             try:
@@ -47,179 +72,258 @@
         errortext = "".join(l)
         self.trace(errortext)
 
-    def _receiver(self):
-        """ read sourcecode strings into the incoming queue. """ 
+    def thread_receiver(self):
+        """ thread to read and handle Messages half-sync-half-async. """ 
         try:
             while 1: 
                 try:
-                    header = self.io.read(struct.calcsize("!i"))
-                    stringlen, = struct.unpack("!i", header)
-                    string = self.io.read(stringlen)
-                except self.io.error: 
+                    msg = Message.readfrom(self.io) 
+                    self.trace("received <- %r" % msg) 
+                    msg.handle(self) 
+                except SystemExit: 
+                    self.io.close_read()
+                    return
+                except:
                     self.traceex(sys.exc_info()) 
-                    break 
-                else:
-                    if not string:  # empty string is signal to shutdown 
-                        break 
-                    self._incoming.put(string) 
-        finally:
-            self.io.close_read()
-            self.trace("receiver leaving")
-            self.exit()
+                    break
+        finally:
+            self.trace('leaving %r' % threading.currentThread())
 
-    def _sender(self):
-        """ send source objects from the outgoing queue to the remote side. """ 
+    def thread_sender(self):
+        """ thread to send Messages over the wire. """ 
         try:
             while 1: 
-                obj = self._outgoing.get()
-                if obj is None: 
-                    # if we are finished then we try to send an 
-                    # empty string #to the other side in order to signal 
-                    # it to shutdown its receiver loop 
-                    obj = ''
-                else:
-                    obj = str(obj) 
-                data = struct.pack("!i", len(obj)) + obj
+                msg = self._outgoing.get()
                 try:
-                    self.io.write(data)
+                    msg.writeto(self.io) 
                 except self.io.error:
                     self.traceex(sys.exc_info())
-                    break
                 else:
-                    self.trace('sent -> %d' % len(obj))
-                if not obj:
-                    break
+                    self.trace('sent -> %r' % msg) 
+                if isinstance(msg, (Message.STOP_RECEIVING, Message.EXIT_GATEWAY)): 
+                    self.io.close_write()
+                    break 
         finally: 
-            self.trace('sender leaving')
-            self.io.close_write()
+            self.trace('leaving %r' % threading.currentThread())
 
-    def _executor(self):
-        """ execute source objects from the incoming queue. """ 
+    def thread_executor(self):
+        """ worker thread to execute source objects from the execution queue. """ 
         try:
             while 1: 
-                source = self._incoming.get()
-                if source is None: 
-                    self._execevent.set() 
+                task = self._execqueue.get()
+                if task is None: 
                     break
+                channel, source = task 
                 try:
-                    self.trace("executing starts:", repr(source)[:50]) 
+                    loc = { 'channel' : channel } 
+                    self.trace("execution starts:", repr(source)[:50]) 
                     try: 
                         co = compile(source, '', 'exec')
-                        exec co in self.ns # globals(), self.ns
+                        exec co in loc 
                     finally: 
-                        self._execevent.set() 
                         self.trace("execution finished:", repr(source)[:50]) 
                 except (KeyboardInterrupt, SystemExit):
                     raise
                 except:
-                    self.traceex(sys.exc_info())
-        finally:
-           self.trace("executor leaving")
-
-    def wait_exec(self, timeout=10.0):
-        """ wait until the next execution event (i.e. sourcecode is executed). """ 
-        self._execevent.wait(timeout)
-        if not self._execevent.isSet(): 
-            raise IOError, "timeout waiting for execevent" 
-        self.clear_exec() 
-
-    def clear_exec(self):
-        """ clear execution event. """ 
-        self._execevent.clear()
-    
-    def exit(self): 
-        if self.running:
-            self.running = 0 
-            self._incoming.put(None)
-            self._outgoing.put(None)
-            for thread in (self.thread_receiver,
-                           self.thread_sender, 
-                           self.thread_executor): 
-                if thread != threading.currentThread():
-                    self.trace("joining %s" % str(thread)) 
-                    thread.join(10.0) 
+                    excinfo = sys.exc_info()
+                    l = traceback.format_exception(*excinfo) 
+                    errortext = "".join(l)
+                    self._outgoing.put(Message.CHANNEL_CLOSE_ERROR(channel.id, errortext)) 
+                    self.trace(errortext) 
                 else:
-                    self.trace("joining %s which is current" % str(thread)) 
-            self.trace("exit finished successfully") 
-        else:
-            self.trace("exit() called, but gateway was not running anymore!") 
-
-    def __nonzero__(self):
-        return self.running
+                    self._outgoing.put(Message.CHANNEL_CLOSE(channel.id))
+        finally:
+            self.trace('leaving %r' % threading.currentThread())
 
-    def _get_status(self, id=None):
-        self._replylock.acquire()
-        try:
-            if id is None:
-                id = len(self._replies)
-                self._replies[id] = x = Reply(self, id)
-                return x 
-            return self._replies[id]
-        finally:    
-            self._replylock.release()
     # _____________________________________________________________________
     #
     # High Level Interface 
     # _____________________________________________________________________
     
-    def remote_exec_oneshot(self, source):
-        """ execute source on the other side of the gateway disregarding feedback. """ 
-        if source is not None:
-            source = Source(source)
-        self._outgoing.put(source) 
-
     def remote_exec_async(self, source):
-        """ return reply object for the asynchornous execution of the 
-            given sourcecode. 
+        """ return channel object for communicating with the asynchronously 
+            executing source code which can use a corresponding 'channel' 
+            object in turn. 
         """ 
-        source = Source(source)
-        reply = self._get_status() 
-        source.putaround(
-            "try:",
-            Source("""
-               except:
-                   msg = gateway._format_exception(sys.exc_info())
-               else:
-                   msg = '' 
-               gateway.remote_exec_oneshot(
-                    'gateway._get_status(%r).set(%%r)' %% (msg, ))
-            """ % reply.id))
-        self.trace(str(source))
-        self.remote_exec_oneshot(source) 
-        return reply 
+        source = str(Source(source))
+        if debug:
+            import parser
+            parser.suite(source)
+        channel = self.channelfactory.new()
+        self._outgoing.put(Message.CHANNEL_OPEN(channel.id, source))
+        return channel
 
     def remote_exec_sync(self, source, timeout=10):
         """ synchronously execute source on the other side of the gateway
             return after execution of the source finishes on the other 
             side. 
         """ 
-        reply = self.remote_exec_async(source)
-        reply.wait(timeout=timeout) 
+        channel = self.remote_exec_async(source)
+        channel.waitclose(timeout=timeout) 
         return 
 
-class Reply:
-    """ Reply objects let you determine if a remote_exec_async method 
-        finished on the other site and possibly retrieve error information. 
-    """
+class Channel(object):
     def __init__(self, gateway, id):
+        assert isinstance(id, int)
         self.gateway = gateway
         self.id = id 
-        self._setevent = threading.Event()
-    
-    def set(self, msg):
-        self.msg = msg 
-        self._setevent.set()
-    
-    def wait(self, timeout=10.0):
-        self._setevent.wait(timeout=timeout)
-        if not self._setevent.isSet():
-            raise IOError, "timeout waiting to status reply" 
-        self._setevent.clear() 
+        self._items = Queue.Queue()
+        self._closeevent = threading.Event()
+
+    def _close(self, error=None):
+        self._error = error
+        self._closeevent.set()
+
+    def __repr__(self):
+        flag = self._closeevent.isSet() and "closed" or "open"
+        return "<Channel id=%d %s>" % (self.id, flag)
+
+    def waitclose(self, timeout=None): 
+        """ return error message (None if no error) after waiting for close event. """
+        self._closeevent.wait(timeout) 
+        if not self._closeevent.isSet():
+            raise IOError, "Timeout"
+        return self._error 
         
+    def send(self, item): 
+        """sends the given item to the other side of the channel, 
+        possibly blocking if the sender queue is full. 
+        Note that each value V of the items needs to have the
+        following property (all basic types in python have it):
+        eval(repr(V)) == V."""
+        self.gateway._outgoing.put(Message.CHANNEL_DATA(self.id, repr(item)))
+
+    def receive(self, timeout=None):
+        """receives an item that was sent from the other side, 
+        possibly blocking if there is none."""
+        return self._items.get(timeout=timeout) 
 
 # 
-# helper functions
+# helpers 
+#
+
+class ChannelFactory(object):
+    def __init__(self, gateway, startcount=1): 
+        self._dict = dict()
+        self._lock = threading.RLock()
+        self.gateway = gateway
+        self.count = startcount
+
+    def new(self): 
+        self._lock.acquire()
+        try:
+            channel = Channel(self.gateway, self.count)
+            self._dict[self.count] = channel
+            return channel
+        finally:
+            self.count += 2
+            self._lock.release()
+        
+    def __getitem__(self, key):
+        self._lock.acquire()
+        try:
+            return self._dict[key]
+        finally:
+            self._lock.release()
+    def __setitem__(self, key, value):
+        self._lock.acquire()
+        try:
+            self._dict[key] = value
+        finally:
+            self._lock.release()
+    def __delitem__(self, key):
+        self._lock.acquire()
+        try:
+            del self._dict[key]
+        finally:
+            self._lock.release()
+
+# ___________________________________________________________________________
 #
+# Messages 
+# ___________________________________________________________________________
+# the size of a number on the wire 
+numsize = struct.calcsize("!i")
+# header of a packet 
+# int message_type: 0==exitgateway,         
+#                   1==channelfinished_ok,
+#                   2==channelfinished_err, 
+#                   3==channelopen # executes source code 
+#                   4==channelsend # marshals obj
+class Message:
+    """ encapsulates Messages and their wire protocol. """
+    _types = {}
+    def __init__(self, channelid=0, data=''): 
+        self.channelid = channelid 
+        self.data = str(data)
+       
+    def writeto(self, io):
+        data = str(self.data)
+        header = struct.pack("!iii", self.msgtype, self.channelid, len(data))
+        io.write(header)
+        io.write(data) 
+
+    def readfrom(cls, io): 
+        header = io.read(numsize*3)  
+        msgtype, senderid, stringlen = struct.unpack("!iii", header)
+        if stringlen: 
+            string = io.read(stringlen)
+        else:
+            string = '' 
+        msg = cls._types[msgtype](senderid, string)
+        return msg 
+    readfrom = classmethod(readfrom) 
+
+    def __repr__(self):
+        if len(self.data) > 50:
+            return "<Message.%s channelid=%d len=%d>" %(self.__class__.__name__, 
+                        self.channelid, len(self.data))
+        else: 
+            return "<Message.%s channelid=%d %r>" %(self.__class__.__name__, 
+                        self.channelid, self.data)
+
+def _setupmessages():
+    class EXIT_GATEWAY(Message):
+        def handle(self, gateway):
+            gateway._stopexec()
+            gateway._outgoing.put(self.STOP_RECEIVING()) 
+            raise SystemExit 
+    class STOP_RECEIVING(Message):
+        def handle(self, gateway):
+            raise SystemExit 
+    class CHANNEL_OPEN(Message):
+        def handle(self, gateway):
+            channel = Channel(gateway, self.channelid) 
+            gateway.channelfactory[self.channelid] = channel 
+            gateway._execqueue.put((channel, self.data)) 
+    class CHANNEL_DATA(Message):
+        def handle(self, gateway):
+            channel = gateway.channelfactory[self.channelid]
+            channel._items.put(eval(self.data)) 
+    class CHANNEL_CLOSE(Message):
+        def handle(self, gateway):
+            channel = gateway.channelfactory[self.channelid]
+            channel._close()
+            del gateway.channelfactory[channel.id]
+    class CHANNEL_CLOSE_ERROR(Message):
+        def handle(self, gateway):
+            channel = gateway.channelfactory[self.channelid]
+            channel._close(self.data)
+            if debug:
+                for line in self.data.split('\n'):
+                    gateway.trace("remote error: " + line)
+    classes = [x for x in locals().values() if hasattr(x, '__bases__')]
+    classes.sort(lambda x,y : cmp(x.__name__, y.__name__))
+    i = 0
+    for cls in classes: 
+        Message._types[i] = cls  
+        cls.msgtype = i
+        setattr(Message, cls.__name__, cls) 
+        i+=1
+
+_setupmessages()
+
+                    
 def getid(gw, cache={}):
     name = gw.__class__.__name__ 
     try:
@@ -230,7 +334,7 @@
 
 _gateways = []
 def cleanup_atexit():
+    print "="*20 + "cleaning up" + "=" * 20
     for x in _gateways: 
-        if x.running:
+        if x.workerthreads:
             x.exit()
-

Modified: py/dist/py/execnet/gateway_test.py
==============================================================================
--- py/dist/py/execnet/gateway_test.py	(original)
+++ py/dist/py/execnet/gateway_test.py	Wed Oct 13 05:01:02 2004
@@ -1,8 +1,66 @@
 import os, sys
 import py 
 from py.__impl__.execnet.source import Source
-autopath = py.magic.autopath() 
+from py.__impl__.execnet import gateway 
+mypath = py.magic.autopath() 
 
+from StringIO import StringIO
+
+class TestMessage:
+    def test_wire_protocol(self):
+        for cls in gateway.Message._types.values():
+            one = StringIO()
+            cls(42, '23').writeto(one) 
+            two = StringIO(one.getvalue())
+            msg = gateway.Message.readfrom(two)
+            assert isinstance(msg, cls) 
+            assert msg.channelid == 42 
+            assert msg.data == '23'
+            assert isinstance(repr(msg), str)
+            # == "<Message.%s channelid=42 '23'>" %(msg.__class__.__name__, )
+
+class TestChannel:
+    def setup_method(self, method):
+        self.fac = gateway.ChannelFactory(None)
+
+    def test_factory_create(self):
+        chan1 = self.fac.new()
+        assert chan1.id == 1
+        chan2 = self.fac.new()
+        assert chan2.id == 3
+
+    def test_factory_getitem(self):
+        chan1 = self.fac.new()
+        assert self.fac[chan1.id] == chan1 
+        chan2 = self.fac.new()
+        assert self.fac[chan2.id] == chan2
+        
+    def test_factory_delitem(self):
+        chan1 = self.fac.new()
+        assert self.fac[chan1.id] == chan1 
+        del self.fac[chan1.id]
+        py.test.raises(KeyError, self.fac.__getitem__, chan1.id)
+
+    def test_factory_setitem(self):
+        channel = gateway.Channel(None, 12)
+        self.fac[channel.id] = channel
+        assert self.fac[channel.id] == channel 
+
+    def test_channel_timeouterror(self):
+        channel = self.fac.new() 
+        py.test.raises(IOError, channel.waitclose, timeout=0.1)
+
+    def test_channel_close(self):
+        channel = self.fac.new()
+        channel._close() 
+        channel.waitclose(0.1)
+
+    def test_channel_close_error(self):
+        channel = self.fac.new()
+        channel._close("error") 
+        err = channel.waitclose(0.1)
+        assert err == "error"
+    
 class PopenGatewayTestSetup: 
     disabled = True 
     def setup_class(cls):
@@ -11,90 +69,112 @@
     def teardown_class(cls):
         cls.gw.exit()  
 
+
+class BasicRemoteExecution: 
+    disabled = True 
+
+    def test_correct_setup(self):
+        assert self.gw.workerthreads and self.gw.iothreads 
+
+    def test_syntax_error_in_debug_mode(self):
+        debug = gateway.debug 
+        try:
+            gateway.debug = 1 
+            py.test.raises(SyntaxError, self.gw.remote_exec_async, "a='")
+        finally:
+            gateway.debug = debug
+
+    def test_remote_exec_async_waitclose(self): 
+        channel = self.gw.remote_exec_async('pass') 
+        channel.waitclose(timeout=3.0) 
+
+    def test_remote_exec_async_channel_anonymous(self):
+        channel = self.gw.remote_exec_async('''
+                    obj = channel.receive()
+                    channel.send(obj)
+                  ''')
+        channel.send(42)
+        result = channel.receive(timeout=3.0)
+        assert result == 42
+
+class TestBasicPopenGateway(PopenGatewayTestSetup, BasicRemoteExecution): 
+    disabled = False 
+    def test_many_popen(self):
+        num = 4
+        l = []
+        for i in range(num):
+            l.append(py.execnet.PopenGateway())
+        channels = []
+        for gw in l: 
+            channel = gw.remote_exec_async("""channel.send(42)""")
+            channels.append(channel)
+        try:
+            while channels: 
+                channel = channels.pop()
+                try:
+                    ret = channel.receive()
+                    assert ret == 42
+                finally:
+                    channel.gateway.exit()
+        finally:
+            for x in channels: 
+                x.gateway.exit()
+
 class SocketGatewayTestSetup:
     disabled = True
 
     def setup_class(cls):
         portrange = (7770, 7800)
         cls.proxygw = py.execnet.PopenGateway() 
-        s = Source(
-                autopath.dirpath('bin', 'startserver.py').read(), 
-                """
-                import socket
-                for i in range%(portrange)r: 
-                    try:
-                        sock = listen(("localhost", i))
-                    except socket.error: 
-                        continue
-                    else:
-                        gateway.remote_exec_oneshot("gateway._listenport=" + str(i))
-                        startserver(sock)
-                        print "started server with socket"
-                        break
+        socketserverbootstrap = Source( 
+            mypath.dirpath('bin', 'startserver.py').read(), 
+            """
+            import socket 
+            portrange = channel.receive() 
+            for i in portrange: 
+                try:
+                    sock = bind_and_listen(("localhost", i))
+                except socket.error: 
+                    continue
                 else:
-                    gateway.remote_exec_oneshot("gateway._listenport=None")""" % locals(), 
-                )
-                
-        cls.proxygw.remote_exec_oneshot(s) 
-        cls.proxygw.wait_exec(timeout=5)
-        if not cls.proxygw._listenport:
+                    channel.send(i) 
+                    startserver(sock)
+                    print "started server with socket"
+                    break
+            else:
+                channel.send(None) 
+    """)
+        # open a gateway to a fresh child process 
+        cls.proxygw = py.execnet.PopenGateway()
+
+        # execute asynchronously the above socketserverbootstrap on the other
+        channel = cls.proxygw.remote_exec_async(socketserverbootstrap) 
+
+        # send parameters for the for-loop
+        channel.send((7770, 7800)) 
+        #
+        # the other side should start the for loop now, we
+        # wait for the result
+        #
+        cls.listenport = channel.receive() 
+        if cls.listenport is None: 
             raise IOError, "could not setup remote SocketServer"
-        cls.gw = py.execnet.SocketGateway('localhost', cls.proxygw._listenport) 
-        print "initialized socket gateway on port", cls.proxygw._listenport 
+        cls.gw = py.execnet.SocketGateway('localhost', cls.listenport) 
+        print "initialized socket gateway on port", cls.listenport 
 
     def teardown_class(cls):
         print "trying to tear down remote socket gateway" 
         cls.gw.exit() 
-        if cls.proxygw._listenport:
+        if cls.gw.port: 
             print "trying to tear down remote socket loop" 
             import socket
             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-            sock.connect(('localhost', cls.proxygw._listenport)) 
+            sock.connect(('localhost', cls.listenport))
             sock.sendall('"raise KeyboardInterrupt"') 
             sock.shutdown(2) 
         print "trying to tear proxy gateway" 
         cls.proxygw.exit() 
 
-class BasicRemoteExecution: 
-    disabled = True 
-    def test_remote_exec_oneshot(self):
-        self.gw.clear_exec() 
-        self.gw.remote_exec_oneshot('gateway.remote_exec_oneshot("test_1_back = 42")')
-        self.gw.wait_exec() 
-        assert self.gw.ns['test_1_back'] == 42 
-
-    def test_remote_exec_async_with_handle(self):
-        handle = self.gw.remote_exec_async('gateway.remote_exec_oneshot("test_2_back = 17")')
-        handle.wait(timeout=10.0) 
-        assert self.gw.ns['test_2_back'] == 17 
-
-    def test_remote_exec_sync(self):
-        self.gw.remote_exec_sync('gateway.remote_exec_oneshot("test_3_back = 43")') 
-        assert self.gw.ns['test_3_back'] == 43 
-
-class BasicPopenGatewayTest(PopenGatewayTestSetup, BasicRemoteExecution): 
-    disabled = False 
-    def test_many_popen(self):
-        num = 4
-        l = []
-        for i in range(num):
-            l.append(py.execnet.PopenGateway())
-        stati = []
-        for gw in l: 
-            status = gw.remote_exec_async("gateway.remote_exec_oneshot('back=42')")
-            stati.append(status)
-        try:
-            while stati: 
-                status = stati.pop()
-                try:
-                    status.wait(timeout=3.0) 
-                    assert status.gateway.ns['back'] == 42
-                finally:
-                    status.gateway.exit()
-        finally:
-            for x in stati: 
-                x.gateway.exit()
-
 class BasicSocketGatewayTest(SocketGatewayTestSetup, BasicRemoteExecution): 
     disabled = False 
 

Modified: py/dist/py/execnet/register.py
==============================================================================
--- py/dist/py/execnet/register.py	(original)
+++ py/dist/py/execnet/register.py	Wed Oct 13 05:01:02 2004
@@ -1,5 +1,6 @@
 
 from py.magic import autopath ; autopath = autopath()
+import Queue
 
 import os, inspect, socket
 
@@ -8,40 +9,56 @@
 py.magic.invoke(dyncode=True) 
 
 class InstallableGateway(gateway.Gateway):
-    def __init__(self, io, ns = None):
+    """ initialize gateways on both sides of a inputoutput object. """
+    def __init__(self, io): 
         self.remote_bootstrap_gateway(io) 
-        gateway.Gateway.__init__(self, io, ns)
+        gateway.Gateway.__init__(self, io=io, startcount=1)
 
     def remote_bootstrap_gateway(self, io): 
         """ return Gateway with a asynchronously remotely 
             initialized counterpart Gateway (which may or may not succeed). 
+            Note that the other sides gateways starts enumerating 
+            its channels with even numbers while the sender
+            gateway starts with odd numbers.  This allows to 
+            uniquely identify channels across both sides. 
         """
         bootstrap = [ 
             inspect.getsource(inputoutput), 
-            io.server_stmt, 
             inspect.getsource(gateway), 
-            "gateway = Gateway(io)",
-            "gateway.thread_executor.join()"
+            io.server_stmt, 
+            "Gateway(io=io, startcount=2).join()", 
+            "print 'exiting gateway'", 
         ]
         source = "\n".join(bootstrap)
         self.trace("sending gateway bootstrap code")
         io.write('%r\n' % source)
 
 class PopenGateway(InstallableGateway):
-    def __init__(self, python="python", ns=None):
+    def __init__(self, python="python"):
         cmd = '%s -u -c "exec input()"' % python
         infile, outfile = os.popen2(cmd)
         io = inputoutput.Popen2IO(infile, outfile) 
-        InstallableGateway.__init__(self, io, ns) 
+        InstallableGateway.__init__(self, io=io) 
+        self._pidchannel = self.remote_exec_async("import os ; channel.send(os.getpid())")
+
+    def exit(self):
+        super(PopenGateway, self).exit()
+        try:
+            pid = self._pidchannel.receive(0.5)
+        except Queue.Empty():
+            self.trace("could not receive child PID")
+        else:
+            self.trace("waiting for pid %s" % pid) 
+            os.waitpid(pid, 0) 
 
 class SocketGateway(InstallableGateway):
-    def __init__(self, host, port, ns=None): 
+    def __init__(self, host, port): 
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        host = str(host) 
-        port = int(port) 
+        self.host = host = str(host) 
+        self.port = port = int(port) 
         sock.connect((host, port))
         io = inputoutput.SocketIO(sock) 
-        InstallableGateway.__init__(self, io, ns) 
+        InstallableGateway.__init__(self, io=io) 
         
 class ExecGateway(PopenGateway):
     def remote_exec_sync_stdcapture(self, lines, callback):



More information about the pytest-commit mailing list