[py-svn] r10656 - in py/branch/py-collect/execnet: . testing

hpk at codespeak.net hpk at codespeak.net
Fri Apr 15 02:00:53 CEST 2005


Author: hpk
Date: Fri Apr 15 02:00:53 2005
New Revision: 10656

Modified:
   py/branch/py-collect/execnet/channel.py
   py/branch/py-collect/execnet/gateway.py
   py/branch/py-collect/execnet/testing/test_gateway.py
   py/branch/py-collect/execnet/testing/test_threadpool.py
Log:
- add stdout/stderr redirection per remote_exec() 

  introduce a ThreadOut object which encapsulates 
  e.g.  sys.stdout to demultiplex to different writers 
  according to threads. 

- adjust remote_redirect-implementatation accordingly 



Modified: py/branch/py-collect/execnet/channel.py
==============================================================================
--- py/branch/py-collect/execnet/channel.py	(original)
+++ py/branch/py-collect/execnet/channel.py	Fri Apr 15 02:00:53 2005
@@ -43,8 +43,8 @@
                 put(Message.CHANNEL_CLOSE(self.id))
             self._close()
 
-    def remote_exec(self, source): 
-        self.gateway._remote_exec(self, source) 
+    def remote_exec(self, source, stdout=None, stderr=None): 
+        self.gateway._remote_exec(self, source, stdout, stderr) 
 
     def _close(self, finalitem=EOFError()):
         if self.id in self.gateway.channelfactory:

Modified: py/branch/py-collect/execnet/gateway.py
==============================================================================
--- py/branch/py-collect/execnet/gateway.py	(original)
+++ py/branch/py-collect/execnet/gateway.py	Fri Apr 15 02:00:53 2005
@@ -1,6 +1,6 @@
 import sys
 import os
-import threading
+import thread, threading
 import Queue
 import traceback
 import atexit
@@ -34,7 +34,6 @@
         return "%s: %s" %(self.__class__.__name__, self.formatted)
 
 class WorkerThread(threading.Thread): 
-
     def __init__(self, pool): 
         super(WorkerThread, self).__init__() 
         self._queue = Queue.Queue() 
@@ -155,9 +154,79 @@
         thread.start() 
         l.append(thread) 
 
+class ThreadOut(object): 
+    def __new__(cls, obj, attrname): 
+        """ Divert file output to per-thread writefuncs. 
+            the given obj and attrname describe the destination 
+            of the file.  
+        """ 
+        current = getattr(obj, attrname)
+        if isinstance(current, cls): 
+            current._used += 1
+            return current 
+        self = object.__new__(cls) 
+        self._tid2out = {}
+        self._used = 1 
+        self._oldout = getattr(obj, attrname) 
+        self._defaultwriter = self._oldout.write 
+        self._address = (obj, attrname) 
+        setattr(obj, attrname, self) 
+        return self 
+
+    def setdefaultwriter(self, writefunc): 
+        self._defaultwriter = writefunc 
+
+    def resetdefault(self): 
+        self._defaultwriter = self._oldout.write
+
+    def softspace(): 
+        def fget(self): 
+            return self._get()[0]
+        def fset(self, value): 
+            self._get()[0] = value 
+        return property(fget, fset, None, "software attribute") 
+    softspace = softspace()
+
+    def deinstall(self): 
+        self._used -= 1 
+        x = self._used 
+        if x <= 0: 
+            obj, attrname = self._address 
+            setattr(obj, attrname, self._oldout) 
+        
+    def setwritefunc(self, writefunc, tid=None): 
+        assert callable(writefunc)
+        if tid is None: 
+            tid = thread.get_ident() 
+        self._tid2out[tid] = [0, writefunc]
+
+    def delwritefunc(self, tid=None, ignoremissing=True): 
+        if tid is None: 
+            tid = thread.get_ident() 
+        try: 
+            del self._tid2out[tid] 
+        except KeyError: 
+            if not ignoremissing: 
+                raise 
+
+    def _get(self): 
+        tid = thread.get_ident() 
+        try: 
+            return self._tid2out[tid]
+        except KeyError: 
+            return getattr(self._defaultwriter, 'softspace', 0), self._defaultwriter 
+
+    def write(self, data): 
+        softspace, out = self._get() 
+        out(data) 
+
+    def flush(self): 
+        pass 
+   
 class Gateway(object):
     num_worker_threads = 2
     RemoteError = RemoteError
+    ThreadOut = ThreadOut
 
     def __init__(self, io, startcount=2, maxthreads=None):
         self._execpool = WorkerPool() 
@@ -259,15 +328,36 @@
         finally:
             self.trace('leaving %r' % threading.currentThread())
 
-    def thread_executor(self, channel, source): 
+    def _redirect_thread_output(self, outid, errid): 
+        l = []
+        for name, id in ('stdout', outid), ('stderr', errid): 
+            if id: 
+                channel = self._makechannel(outid) 
+                out = ThreadOut(sys, name)
+                out.setwritefunc(channel.send) 
+                l.append((out, channel))
+        def close(): 
+            for out, channel in l: 
+                out.delwritefunc() 
+                channel.close() 
+        return close 
+
+    def _makechannel(self, newid): 
+        newchannel = Channel(self, newid) 
+        self.channelfactory[newid] = newchannel
+        return newchannel 
+
+    def thread_executor(self, channel, (source, outid, errid)): 
         """ worker thread to execute source objects from the execution queue. """
         try:
             loc = { 'channel' : channel }
             self.trace("execution starts:", repr(source)[:50])
+            close = self._redirect_thread_output(outid, errid) 
             try:
                 co = compile(source+'\n', '', 'exec', 4096)
                 exec co in loc
             finally:
+                close() 
                 self.trace("execution finished:", repr(source)[:50])
         except (KeyboardInterrupt, SystemExit):
             raise
@@ -280,9 +370,9 @@
         else:
             channel.close()
 
-    def _scheduleexec(self, channel, source): 
+    def _scheduleexec(self, channel, sourcetask): 
         self.trace("dispatching exec")
-        self._execpool.dispatch(self.thread_executor, channel, source) 
+        self._execpool.dispatch(self.thread_executor, channel, sourcetask) 
 
     def _dispatchcallback(self, callback, data): 
         # XXX this should run in a separate thread because
@@ -290,7 +380,7 @@
         #     where we get called from 
         callback(data) 
 
-    def _remote_exec(self, channel, source): 
+    def _remote_exec(self, channel, source, stdout=None, stderr=None): 
         try:
             source = str(Source(source))
         except NameError: 
@@ -299,7 +389,20 @@
                 source = str(py.code.Source(source))
             except ImportError: 
                 pass 
-        self._outgoing.put(Message.CHANNEL_OPEN(channel.id, source))
+        outid = self._redirectchannelid(stdout) 
+        errid = self._redirectchannelid(stderr) 
+        self._outgoing.put(Message.CHANNEL_OPEN(channel.id, 
+                                                (source, outid, errid))) 
+
+    def _redirectchannelid(self, callback): 
+        if callback is None: 
+            return  
+        if hasattr(callback, 'write'): 
+            callback = callback.write 
+        assert callable(callback) 
+        chan = self.newchannel() 
+        chan.setcallback(callback) 
+        return chan.id 
 
     # _____________________________________________________________________
     #
@@ -310,7 +413,7 @@
         """ return new channel object. """ 
         return self.channelfactory.new() 
 
-    def remote_exec(self, source): 
+    def remote_exec(self, source, stdout=None, stderr=None): 
         """ return channel object for communicating with the asynchronously
             executing 'source' code which will have a corresponding 'channel'
             object in its executing namespace. If a channel object is not
@@ -318,55 +421,39 @@
             is will be returned as well. 
         """
         channel = self.newchannel() 
-        channel.remote_exec(source) 
+        channel.remote_exec(source, stdout=stdout, stderr=stderr) 
         return channel 
 
-
-    def remote_redirect(self, stdout): 
+    def remote_redirect(self, stdout=None, stderr=None): 
         """ return a handle representing a redirection of of remote 
             end's stdout to a local file object.  with handle.close() 
             the redirection will be reverted.   
         """ 
-        handle = RedirectHandle(self, stdout) 
-        handle.open() 
-        return handle 
-
-class RedirectHandle(object): 
-    def __init__(self, gateway, stdout): 
-        self.gateway = gateway 
-        self.stdout = stdout 
-
-    def open(self): 
-        self.outchannel = self.gateway.newchannel() 
-        self.outchannel.setcallback(self.stdout.write) 
-        channel = self.gateway.remote_exec(""" 
-            import sys
-            outchannel = channel.receive() 
-            sys.__dict__.setdefault('_stdoutsubst', []).append(sys.stdout)
-            sys.stdout = outchannel.open('w') 
-        """)
-        channel.send(self.outchannel) 
-        channel.waitclose(1.0)
-
-    def __del__(self): 
-        self.close() 
+        clist = []
+        for name, out in ('stdout', stdout), ('stderr', stderr): 
+            if out: 
+                outchannel = self.newchannel() 
+                outchannel.setcallback(getattr(out, 'write', out))
+                channel = self.remote_exec(""" 
+                    import sys
+                    outchannel = channel.receive() 
+                    outchannel.gateway.ThreadOut(sys, %r).setdefaultwriter(outchannel.send)
+                """ % name) 
+                channel.send(outchannel)
+                clist.append(channel)
+        for c in clist: 
+            c.waitclose(1.0) 
+        class Handle: 
+            def close(_): 
+                for name, out in ('stdout', stdout), ('stderr', stderr): 
+                    if out: 
+                        c = self.remote_exec("""
+                            import sys
+                            channel.gateway.ThreadOut(sys, %r).resetdefault()
+                        """ % name) 
+                        c.waitclose(1.0) 
+        return Handle()
 
-    def close(self, timeout=1.0): 
-        """ close redirection on remote side and wait 
-            for closing. If timeout==0 we will not 
-            wait for the remote side to finish 
-            resetting the redirection pipe. 
-        """ 
-        c = self.gateway.remote_exec(""" 
-            import sys
-            outchannel = sys.stdout.channel 
-            sys.stdout = sys._stdoutsubst.pop() 
-            outchannel.close() 
-        """) 
-        if timeout != 0: 
-            self.outchannel.waitclose(timeout=1.0) 
-            c.waitclose(1.0)
-        
 def getid(gw, cache={}):
     name = gw.__class__.__name__
     try:

Modified: py/branch/py-collect/execnet/testing/test_gateway.py
==============================================================================
--- py/branch/py-collect/execnet/testing/test_gateway.py	(original)
+++ py/branch/py-collect/execnet/testing/test_gateway.py	Fri Apr 15 02:00:53 2005
@@ -169,14 +169,26 @@
     def test_remote_redirect_stdout(self): 
         out = py.std.StringIO.StringIO() 
         handle = self.gw.remote_redirect(stdout=out) 
-        try: 
-            c = self.gw.remote_exec("print 42")
-        finally: 
-            handle.close() 
+        c = self.gw.remote_exec("print 42")
         c.waitclose(1.0)
+        handle.close() 
         s = out.getvalue() 
         assert s.strip() == "42" 
 
+    def test_remote_exec_redirect_multi(self): 
+        num = 3
+        l = [[] for x in range(num)]
+        channels = [self.gw.remote_exec("print %d" % i, stdout=l[i].append)
+                        for i in range(num)]
+        for x in channels: 
+            x.waitclose(1.0) 
+
+        for i in range(num): 
+            subl = l[i] 
+            assert subl 
+            s = subl[0]
+            assert s.strip() == str(i)
+
 class TestBasicPopenGateway(PopenGatewayTestSetup, BasicRemoteExecution):
     #disabled = True
     def test_many_popen(self):

Modified: py/branch/py-collect/execnet/testing/test_threadpool.py
==============================================================================
--- py/branch/py-collect/execnet/testing/test_threadpool.py	(original)
+++ py/branch/py-collect/execnet/testing/test_threadpool.py	Fri Apr 15 02:00:53 2005
@@ -1,6 +1,7 @@
 
-from py.__impl__.execnet.gateway import WorkerPool 
+from py.__impl__.execnet.gateway import WorkerPool, ThreadOut
 import py
+import sys 
 
 def test_some(): 
     pool = WorkerPool() 
@@ -11,7 +12,6 @@
     for i in range(num): 
         q.get() 
     assert len(pool._alive) == 4 
-    assert len(pool._ready) == 4 
     pool.shutdown() 
     assert len(pool._alive) == 0 
     assert len(pool._ready) == 0 
@@ -42,3 +42,63 @@
     pool.shutdown()
     assert not pool._alive  
     assert not pool._ready  
+
+def test_threadout_install_deinstall(): 
+    old = sys.stdout 
+    out = ThreadOut(sys, 'stdout') 
+    out.deinstall() 
+    assert old == sys.stdout 
+
+class TestThreadOut: 
+    def setup_method(self, method): 
+        self.out = ThreadOut(sys, 'stdout') 
+    def teardown_method(self, method): 
+        self.out.deinstall() 
+        
+    def test_threadout_one(self): 
+        l = []
+        self.out.setwritefunc(l.append) 
+        print 42,13,
+        x = l.pop(0) 
+        assert x == '42' 
+        x = l.pop(0) 
+        assert x == ' '
+        x = l.pop(0) 
+        assert x == '13' 
+
+
+    def test_threadout_multi_and_default(self): 
+        num = 3 
+        defaults = []
+        def f(l): 
+            self.out.setwritefunc(l.append) 
+            print id(l),
+            self.out.delwritefunc() 
+            print 1 
+
+        self.out.setdefaultwriter(defaults.append) 
+        pool = WorkerPool() 
+        listlist = []
+        for x in range(num): 
+            l = []
+            listlist.append(l) 
+            pool.dispatch(f, l) 
+        pool.shutdown() 
+        for name, value in self.out.__dict__.items(): 
+            print >>sys.stderr, "%s: %s" %(name, value) 
+        for i in range(num): 
+            item = listlist[i]
+            assert item ==[str(id(item))]
+        assert not self.out._tid2out 
+        assert defaults 
+        expect = ['1' for x in range(num)]
+        defaults = [x for x in defaults if x.strip()]
+        assert defaults == expect 
+
+    def test_threadout_nested(self): 
+        # we want ThreadOuts to coexist 
+        last = sys.stdout
+        out = ThreadOut(sys, 'stdout') 
+        assert last == sys.stdout 
+        out.deinstall() 
+        assert last == sys.stdout 



More information about the pytest-commit mailing list