[py-svn] r10607 - in py/branch/py-collect: . execnet execnet/testing

hpk at codespeak.net hpk at codespeak.net
Thu Apr 14 12:33:11 CEST 2005


Author: hpk
Date: Thu Apr 14 12:33:11 2005
New Revision: 10607

Added:
   py/branch/py-collect/execnet/testing/test_threadpool.py
Removed:
   py/branch/py-collect/execnet/testing/sshtesting.py
Modified:
   py/branch/py-collect/conftest.py
   py/branch/py-collect/execnet/gateway.py
   py/branch/py-collect/execnet/message.py
   py/branch/py-collect/execnet/testing/test_gateway.py
Log:
- first implementation of: allow arbitrary many worker threads 
  (some more unification regarding thread handling is outstanding) 

- made a py lib specific test option "-S" to allow to set 
  ssh targets.  If given, additional tests are run with
  respect to SshGateways. 



Modified: py/branch/py-collect/conftest.py
==============================================================================
--- py/branch/py-collect/conftest.py	(original)
+++ py/branch/py-collect/conftest.py	Thu Apr 14 12:33:11 2005
@@ -15,3 +15,12 @@
 fulltrace = False
 showlocals = False
 nomagic = False
+
+import py 
+Option = py.test.Config.Option 
+
+option = py.test.Config.addoptions("execnet options", 
+        Option('-S', '',
+               action="store", dest="sshtarget", default=None,
+               help="target to run tests requiring ssh, e.g. user at codespeak.net"), 
+    )

Modified: py/branch/py-collect/execnet/gateway.py
==============================================================================
--- py/branch/py-collect/execnet/gateway.py	(original)
+++ py/branch/py-collect/execnet/gateway.py	Thu Apr 14 12:33:11 2005
@@ -4,6 +4,7 @@
 import Queue
 import traceback
 import atexit
+import time
 
 
 # XXX the following line should not be here
@@ -32,6 +33,78 @@
     def __repr__(self):
         return "%s: %s" %(self.__class__.__name__, self.formatted)
 
+class WorkerThread(threading.Thread): 
+
+    def __init__(self, pool): 
+        super(WorkerThread, self).__init__() 
+        self._queue = Queue.Queue() 
+        self._pool = pool 
+        self.setDaemon(1) 
+
+    def run(self): 
+        while 1: 
+            task = self._queue.get() 
+            assert self not in self._pool._ready 
+            if task is None: 
+                break 
+            try: 
+                func, args, kwargs = task 
+                func(*args, **kwargs) 
+            except: 
+                import traceback
+                traceback.print_exc() 
+            self._pool._ready.append(self) 
+
+    def handle(self, task): 
+        self._queue.put(task) 
+
+    def stop(self): 
+        self._queue.put(None) 
+
+class WorkerPool(object): 
+    _shutdown = False 
+    def __init__(self, maxthreads=None): 
+        self.maxthreads = maxthreads
+        self._numthreads = 0 
+        self._ready = []
+
+    def dispatch(self, func, *args, **kwargs): 
+        if self._shutdown: 
+            raise IOError("WorkerPool is already shutting down") 
+        task = (func, args, kwargs) 
+        try: 
+            thread = self._ready.pop() 
+        except IndexError: # pop from empty list
+            thread = self._newthread() 
+        thread.handle(task) 
+
+    def __del__(self): 
+        self.shutdown() 
+
+    def shutdown(self, timeout=1.0): 
+        if not self._shutdown: 
+            self._shutdown = True
+            now = time.time() 
+            while self._numthreads: 
+                try: 
+                    thread = self._ready.pop() 
+                except IndexError: 
+                    if now + timeout < time.time(): 
+                        raise IOError("Timeout: could not shut down WorkerPool") 
+                    time.sleep(0.1) 
+                else: 
+                    self._numthreads -= 1 
+
+    def _newthread(self): 
+        if self.maxthreads: 
+            if self._numthreads >= self.maxthreads: 
+                raise IOError("cannot create more threads, "
+                              "maxthreads=%d" % (self.maxthreads,))
+        thread = WorkerThread(self) 
+        self._numthreads += 1 
+        thread.start() 
+        return thread 
+
 class NamedThreadPool: 
     def __init__(self, **kw): 
         self._namedthreads = {}
@@ -74,18 +147,15 @@
     num_worker_threads = 2
     RemoteError = RemoteError
 
-    def __init__(self, io, startcount=2):
+    def __init__(self, io, startcount=2, maxthreads=None):
+        self._execpool = WorkerPool() 
         self.running = True 
         self.io = io
-        self._execqueue = Queue.Queue()
         self._outgoing = Queue.Queue()
         self.channelfactory = ChannelFactory(self, startcount)
         self._exitlock = threading.Lock()
         self.pool = NamedThreadPool(receiver = self.thread_receiver, 
                                     sender = self.thread_sender) 
-        for x in range(self.num_worker_threads):
-            self.pool.start('executor', self.thread_executor)
-            self.trace("started executor thread") 
         if not _gateways:
             atexit.register(cleanup_atexit)
         _gateways.append(self)
@@ -99,9 +169,7 @@
 
     def _stopexec(self):
         #self.pool.prunestopped()
-        for x in range(self.num_worker_threads): 
-            self.trace("putting None to execqueue") 
-            self._execqueue.put(None)
+        self._execpool.shutdown() 
 
     def exit(self):
         # note that threads may still be scheduled to start
@@ -178,35 +246,30 @@
         finally:
             self.trace('leaving %r' % threading.currentThread())
 
-    def thread_executor(self):
+    def thread_executor(self, channel, source): 
         """ worker thread to execute source objects from the execution queue. """
         try:
-            while 1:
-                task = self._execqueue.get()
-                if task is None:
-                    self.trace("executor found none, leaving ...") 
-                    break
-                channel, source = task
-                try:
-                    loc = { 'channel' : channel }
-                    self.trace("execution starts:", repr(source)[:50])
-                    try:
-                        co = compile(source+'\n', '', 'exec', 4096)
-                        exec co in loc
-                    finally:
-                        self.trace("execution finished:", repr(source)[:50])
-                except (KeyboardInterrupt, SystemExit):
-                    raise
-                except:
-                    excinfo = sys.exc_info()
-                    l = traceback.format_exception(*excinfo)
-                    errortext = "".join(l)
-                    channel.close(errortext)
-                    self.trace(errortext)
-                else:
-                    channel.close()
-        finally:
-            self.trace('leaving %r' % threading.currentThread())
+            loc = { 'channel' : channel }
+            self.trace("execution starts:", repr(source)[:50])
+            try:
+                co = compile(source+'\n', '', 'exec', 4096)
+                exec co in loc
+            finally:
+                self.trace("execution finished:", repr(source)[:50])
+        except (KeyboardInterrupt, SystemExit):
+            raise
+        except:
+            excinfo = sys.exc_info()
+            l = traceback.format_exception(*excinfo)
+            errortext = "".join(l)
+            channel.close(errortext)
+            self.trace(errortext)
+        else:
+            channel.close()
+
+    def _scheduleexec(self, channel, source): 
+        self.trace("dispatching exec")
+        self._execpool.dispatch(self.thread_executor, channel, source) 
 
     def _dispatchcallback(self, callback, data): 
         # XXX this should run in a separate thread because

Modified: py/branch/py-collect/execnet/message.py
==============================================================================
--- py/branch/py-collect/execnet/message.py	(original)
+++ py/branch/py-collect/execnet/message.py	Thu Apr 14 12:33:11 2005
@@ -90,8 +90,7 @@
     class CHANNEL_OPEN(Message):
         def received(self, gateway):
             channel = gateway.channelfactory.new(self.channelid)
-            #gateway._scheduleexec((channel, self.data))
-            gateway._execqueue.put((channel, self.data))
+            gateway._scheduleexec(channel, self.data)
 
     class CHANNEL_NEW(Message):
         def received(self, gateway):

Deleted: /py/branch/py-collect/execnet/testing/sshtesting.py
==============================================================================
--- /py/branch/py-collect/execnet/testing/sshtesting.py	Thu Apr 14 12:33:11 2005
+++ (empty file)
@@ -1,16 +0,0 @@
-"""
-A test file that doesn't run by default to test SshGateway.
-"""
-
-import py
-
-#REMOTE_HOST = 'codespeak.net'
-#REMOTE_HOSTNAME = 'thoth.codespeak.net'
-
-def test_sshgateway():
-    REMOTE_HOST = 'localhost'    # you need to have a local ssh-daemon running!
-    REMOTE_HOSTNAME = py.std.socket.gethostname() # the remote's socket.gethostname()
-    gw = py.execnet.SshGateway(REMOTE_HOST)
-    c = gw.remote_exec('import socket; channel.send(socket.gethostname())')
-    msg = c.receive()
-    assert msg == REMOTE_HOSTNAME

Modified: py/branch/py-collect/execnet/testing/test_gateway.py
==============================================================================
--- py/branch/py-collect/execnet/testing/test_gateway.py	(original)
+++ py/branch/py-collect/execnet/testing/test_gateway.py	Thu Apr 14 12:33:11 2005
@@ -1,6 +1,7 @@
 import os, sys
 import py
 from py.__impl__.execnet import gateway
+from py.__impl__.conftest import option 
 mypath = py.magic.autopath()
 
 from StringIO import StringIO
@@ -59,7 +60,7 @@
 
 class BasicRemoteExecution:
     def test_correct_setup(self):
-        for x in 'sender', 'receiver', 'executor': 
+        for x in 'sender', 'receiver': # , 'executor': 
             assert self.gw.pool.getstarted(x) 
 
     def test_remote_exec_waitclose(self):
@@ -209,3 +210,10 @@
 class TestSocketGateway(SocketGatewaySetup, BasicRemoteExecution):
     disabled = sys.platform == "win32"
     pass
+
+class TestSshGateway(BasicRemoteExecution):
+    def setup_class(cls): 
+        if option.sshtarget is None: 
+            py.test.skip("no known ssh target, use -S to set one")
+        cls.gw = py.execnet.SshGateway(option.sshtarget) 
+

Added: py/branch/py-collect/execnet/testing/test_threadpool.py
==============================================================================
--- (empty file)
+++ py/branch/py-collect/execnet/testing/test_threadpool.py	Thu Apr 14 12:33:11 2005
@@ -0,0 +1,33 @@
+
+from py.__impl__.execnet.gateway import WorkerPool 
+import py
+
+def test_some(): 
+    pool = WorkerPool() 
+    l = []
+    try: 
+        pool.dispatch(l.append, 1) 
+        pool.dispatch(l.append, 2) 
+        pool.dispatch(l.append, 3) 
+        pool.dispatch(l.append, 4) 
+    finally: 
+        pool.shutdown() 
+    assert len(pool._ready) == pool._numthreads 
+    assert len(l) == 4
+
+def test_maxthreads(): 
+    pool = WorkerPool(maxthreads=1) 
+    def f(): 
+        py.std.time.sleep(0.5) 
+    try: 
+        pool.dispatch(f) 
+        py.test.raises(IOError, pool.dispatch, f)
+    finally: 
+        pool.shutdown() 
+
+def test_shutdown_timeout(): 
+    pool = WorkerPool() 
+    def f(): 
+        py.std.time.sleep(1.5) 
+    pool.dispatch(f) 
+    py.test.raises(IOError, pool.shutdown, 0.2) 



More information about the pytest-commit mailing list