summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorge Kraft <george.kraft@calxeda.com>2012-12-04 13:51:12 -0600
committerGeorge Kraft <george.kraft@calxeda.com>2012-12-04 13:51:12 -0600
commitc62a9791592d5af23bd590159a93d5222c9ee3af (patch)
treeeb0d67ed2f5547375ee872854d29e9dbb79eda49
parent33f3f06e1cb6ce0ee29c3458e555e10eacb6a81c (diff)
downloadcxmanage-c62a9791592d5af23bd590159a93d5222c9ee3af.tar.gz
CXMAN-150: Replace the Command class with a TaskQueue
TaskQueue uses a more traditional thread pool pattern, which gives us a bit more flexibility in how we use it. This means the fabric num_threads parameter is gone; it takes in an optional TaskQueue object instead. By default, all fabrics use a single, shared default TaskQueue with 48 threads.
-rw-r--r--cxmanage_api/command.py231
-rw-r--r--cxmanage_api/fabric.py39
-rw-r--r--cxmanage_api/tasks.py106
-rw-r--r--cxmanage_test/command_test.py139
-rw-r--r--cxmanage_test/fabric_test.py13
-rw-r--r--cxmanage_test/tasks_test.py71
-rwxr-xr-xrun_tests4
-rwxr-xr-xscripts/cxmanage50
8 files changed, 230 insertions, 423 deletions
diff --git a/cxmanage_api/command.py b/cxmanage_api/command.py
deleted file mode 100644
index 3983803..0000000
--- a/cxmanage_api/command.py
+++ /dev/null
@@ -1,231 +0,0 @@
-# Copyright (c) 2012, Calxeda Inc.
-#
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of Calxeda Inc. nor the names of its contributors
-# may be used to endorse or promote products derived from this software
-# without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
-# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
-# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
-# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
-# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
-# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
-# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
-# THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
-# DAMAGE.
-
-
-from time import sleep
-from threading import Thread, Lock
-from cxmanage_api.cx_exceptions import CommandFailedError
-
-
-class CommandWorker(Thread):
- """A worker thread for a `Command <command.html>`_.
-
- A CommandWorker will obtain nodes from the pool and run the named method on
- them once started. The thread terminates once there are no nodes remaining.
-
- >>> from cxmanage_api.command import CommandWorker
- >>> cw = CommandWorker(command=cmd)
-
- :param command: The command to run.
- :type command: Command
-
- """
-
- def __init__(self, command):
- """Default constructor for the CommandWorker class."""
- super(CommandWorker, self).__init__()
- self.daemon = True
- self.command = command
- self.results = {}
- self.errors = {}
-
- def run(self):
- """Runs the named method, stores results/errors, then terminates.
-
- >>> cw.run()
-
- """
- try:
- while (True):
- key, node = self.command._get_next_node()
- try:
- sleep(self.command._delay)
- method = getattr(node, self.command._name)
- result = method(*self.command._args)
- self.results[key] = result
- except Exception as e:
- self.errors[key] = e
- except StopIteration:
- pass
-
-
-class Command:
- """Commands are containers/managers of multi-threaded CommandWorkers.
-
- .. note::
- * Commands are designed to have an interface similar to a thread,
- but they are not threads. The CommandWorkers are the actual threads.
-
- >>> # Create a list of Nodes ...
- >>> from cxmanage_api.node import Node
- >>> n1 = Node('10.20.1.9')
- >>> n2 = Node('10.20.2.131')
- >>> nodes = [n1, n2]
- >>> #
- >>> # Typical instantiation ...
- >>> # (this example gets the mac addresses which takes no arguments)
- >>> #
- >>> from cxmanage_api.command import Command
- >>> cmd = Command(nodes=nodes, name='get_mac_addresses', args=[])
-
- :param nodes: Nodes to execute commands on.
- :type nodes: list
- :param name: Named command to run.
- :type name: string
- :param args: Arguments to pass on to the command.
- :type args: list
- :param delay: Time to wait before issuing the next command.
- :type delay: integer
- :param max_threads: Maximum number of threads to spawn at a time.
- :type max_threads: integer
-
- """
-
- def __init__(self, nodes, name, args, delay=0, max_threads=1):
- """Default constructor for the Command class."""
- self._lock = Lock()
-
- try:
- self._node_iterator = nodes.iteritems()
- except AttributeError:
- self._node_iterator = iter([(x, x) for x in nodes])
- self._node_count = len(nodes)
-
- self._name = name
- self._args = args
- self._delay = delay
-
- num_threads = min(max_threads, self._node_count)
- self._workers = [CommandWorker(self) for i in range(num_threads)]
-
- def start(self):
- """Starts the command.
-
- >>> cmd.start()
-
- """
- for worker in self._workers:
- worker.start()
-
- def join(self):
- """Waits for the command to finish.
-
- >>> cmd.join()
-
- """
- for worker in self._workers:
- worker.join()
-
- def is_alive(self):
- """Tests to see if the command is alive.
-
- >>> # Command is still in progress ... (i.e. has nodes still running it)
- >>> cmd.is_alive()
- True
- >>> # Command has completed. (or failed)
- >>> cmd.is_alive()
- False
-
- :return: Whether or not the command is alive.
- :rtype: boolean
-
- """
- return any([x.is_alive() for x in self._workers])
-
- def get_results(self):
- """Gets the command results.
-
- >>> cmd.get_results()
- {<cxmanage_api.node.Node object at 0x7f8a99940cd0>:
- ['fc:2f:40:3b:ec:40', 'fc:2f:40:3b:ec:41', 'fc:2f:40:3b:ec:42'],
- <cxmanage_api.node.Node object at 0x7f8a99237a90>:
- ['fc:2f:40:91:dc:40', 'fc:2f:40:91:dc:41', 'fc:2f:40:91:dc:42']}
-
- :return: Results of this commands.
- :rtype: dictionary
-
- :raises CommandFailedError: If The command results contains ANY errors.
-
- """
- results, errors = {}, {}
- for worker in self._workers:
- results.update(worker.results)
- errors.update(worker.errors)
- if (errors):
- raise CommandFailedError(results, errors)
- return results
-
- def get_status(self):
- """Gets the status of this command.
-
- >>> status = cmd.get_status()
- >>> status.successes
- 2
- >>> status.errors
- 0
- >>> status.nodes_left
- 0
-
- :return: The commands status.
- :rtype: CommandStatus
-
- """
-
-
- class CommandStatus:
- """Container for a commands status."""
-
- def __init__(self, successes, errors, nodes_left):
- """Default constructor for the CommandStatus class."""
- self.successes = successes
- self.errors = errors
- self.nodes_left = nodes_left
-
- #
- # get_status()
- #
- successes = 0
- errors = 0
- for worker in self._workers:
- successes += len(worker.results)
- errors += len(worker.errors)
- nodes_left = self._node_count - successes - errors
- return CommandStatus(successes, errors, nodes_left)
-
- def _get_next_node(self):
- """Gets the next node to operate on."""
- self._lock.acquire()
- try:
- return self._node_iterator.next()
- finally:
- self._lock.release()
-
-
-# End of file: ./command.py
diff --git a/cxmanage_api/fabric.py b/cxmanage_api/fabric.py
index 88c7747..3209b74 100644
--- a/cxmanage_api/fabric.py
+++ b/cxmanage_api/fabric.py
@@ -28,9 +28,10 @@
# THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
-from cxmanage_api.command import Command
+from cxmanage_api.tasks import DEFAULT_TASK_QUEUE
from cxmanage_api.tftp import InternalTftp
from cxmanage_api.node import Node as NODE
+from cxmanage_api.cx_exceptions import CommandFailedError
class Fabric(object):
@@ -56,25 +57,24 @@ class Fabric(object):
"""
def __init__(self, ip_address, username="admin", password="admin",
- tftp=None, max_threads=48, command_delay=0, verbose=False,
+ tftp=None, task_queue=DEFAULT_TASK_QUEUE, verbose=False,
node=None):
"""Default constructor for the Fabric class."""
- self._tftp = tftp
- self.max_threads = max_threads
- self.command_delay = command_delay
- self.verbose = verbose
- self.node = node
self.ip_address = ip_address
self.username = username
self.password = password
+ self._tftp = tftp
+ self.task_queue = task_queue
+ self.verbose = verbose
+ self.node = node
self._nodes = {}
if (not self.node):
self.node = NODE
- if (not self.tftp):
- self.tftp = InternalTftp()
+ if (not self._tftp):
+ self._tftp = InternalTftp()
def __eq__(self, other):
"""__eq__() override."""
@@ -463,15 +463,24 @@ class Fabric(object):
def _run_command(self, async, name, *args):
"""Start a command on the given nodes."""
- command = Command(self.nodes, name, args, self.command_delay,
- self.max_threads)
- command.start()
+ tasks = {}
+ for node_id, node in self.nodes.iteritems():
+ tasks[node_id] = self.task_queue.put(getattr(node, name), *args)
if async:
- return command
+ return tasks
else:
- command.join()
- return command.get_results()
+ results = {}
+ errors = {}
+ for node_id, task in tasks.iteritems():
+ task.join()
+ if task.status == "Completed":
+ results[node_id] = task.result
+ else:
+ errors[node_id] = task.errors
+ if errors:
+ raise CommandFailedError(results, errors)
+ return results
# End of file: ./fabric.py
diff --git a/cxmanage_api/tasks.py b/cxmanage_api/tasks.py
new file mode 100644
index 0000000..50fdfa0
--- /dev/null
+++ b/cxmanage_api/tasks.py
@@ -0,0 +1,106 @@
+# Copyright (c) 2012, Calxeda Inc.
+#
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of Calxeda Inc. nor the names of its contributors
+# may be used to endorse or promote products derived from this software
+# without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
+# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
+# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
+# THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
+# DAMAGE.
+
+from Queue import Queue
+from threading import Thread
+from time import sleep
+
+class Task(object):
+ """ A task object representing some unit of work to be done. """
+ def __init__(self, method, *args):
+ self._method = method
+ self._args = args
+
+ self.status = "Queued"
+
+ def join(self):
+ """ Wait for this task to finish """
+ while self.is_alive():
+ pass # TODO: don't busy wait here
+
+ def is_alive(self):
+ """ Return true if this task hasn't been finished """
+ return not self.status in ["Completed", "Failed"]
+
+class TaskQueue(object):
+ """ A task queue, consisting of a queue and a number of workers. """
+
+ def __init__(self, threads=48, delay=0):
+ self._queue = Queue()
+ self._workers = [TaskWorker(task_queue=self, delay=delay)
+ for x in xrange(threads)]
+
+ def put(self, method, *args):
+ """
+ Add a task to the task queue.
+
+ method: a method to call
+ args: args to pass to the method
+
+ returns: a Task object that will be executed by a worker at a later
+ time.
+ """
+ task = Task(method, *args)
+ self._queue.put(task)
+ return task
+
+ def get(self):
+ """
+ Get a task from the task queue. Mainly used by workers.
+
+ returns: a Task object that hasn't been executed yet.
+ """
+ return self._queue.get()
+
+class TaskWorker(Thread):
+ """ A TaskWorker that executes tasks from a TaskQueue. """
+ def __init__(self, task_queue, delay=0):
+ super(TaskWorker, self).__init__()
+ self.daemon = True
+
+ self._task_queue = task_queue
+ self._delay = delay
+
+ self.start()
+
+ def run(self):
+ """ Repeatedly get tasks from the TaskQueue and execute them. """
+ while True:
+ sleep(self._delay)
+ task = self._task_queue.get()
+ task.status = "In Progress"
+ try:
+ task.result = task._method(*task._args)
+ task.status = "Completed"
+ except Exception as e:
+ task.error = e
+ task.status = "Failed"
+
+DEFAULT_TASK_QUEUE = TaskQueue()
diff --git a/cxmanage_test/command_test.py b/cxmanage_test/command_test.py
deleted file mode 100644
index ff57052..0000000
--- a/cxmanage_test/command_test.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Copyright (c) 2012, Calxeda Inc.
-#
-# All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of Calxeda Inc. nor the names of its contributors
-# may be used to endorse or promote products derived from this software
-# without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
-# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
-# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
-# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
-# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
-# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
-# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
-# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
-# THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
-# DAMAGE.
-
-import unittest
-import random
-import time
-
-from cxmanage_api.command import Command, CommandWorker, CommandFailedError
-
-NUM_NODES = 128
-ADDRESSES = ["192.168.100.%i" % x for x in range(1, NUM_NODES+1)]
-
-class CommandTest(unittest.TestCase):
- def setUp(self):
- addresses = ADDRESSES[:]
- random.shuffle(addresses)
-
- self.good_targets = [DummyTarget(x) for x in addresses[:NUM_NODES/2]]
- self.bad_targets = [DummyTarget(x, True)
- for x in addresses[NUM_NODES/2:]]
- self.targets = self.good_targets + self.bad_targets
-
- def test_worker(self):
- """ Test the command worker thread """
- command = DummyCommand(self.targets)
- worker = CommandWorker(command)
- worker.start()
- worker.join()
-
- self.assertEqual(len(worker.results), len(self.good_targets))
- self.assertEqual(len(worker.errors), len(self.bad_targets))
-
- for target in self.good_targets:
- self.assertEqual(target.executed, [("action", ("a", "b", "c"))])
- self.assertEqual(worker.results[target], "action_result")
- for target in self.bad_targets:
- self.assertEqual(target.executed, [("action", ("a", "b", "c"))])
- self.assertEqual(str(worker.errors[target]), "action_error")
-
- def test_command(self):
- """ Test the command spawner """
- command = Command(self.targets, "action", ("a", "b", "c"),
- max_threads=32)
- command.start()
- command.join()
-
- try:
- command.get_results()
- self.fail()
- except CommandFailedError as e:
- results = e.results
- errors = e.errors
-
- for target in self.good_targets:
- self.assertEqual(target.executed, [("action", ("a", "b", "c"))])
- self.assertEqual(results[target], "action_result")
- for target in self.bad_targets:
- self.assertEqual(target.executed, [("action", ("a", "b", "c"))])
- self.assertEqual(str(errors[target]), "action_error")
-
- def test_command_delay(self):
- """ Test the command delay argument """
- delay = 1
- expected_duration = 4
- max_threads = NUM_NODES / expected_duration
-
- command = Command(self.targets, "action", ("a", "b", "c"),
- delay=delay, max_threads=max_threads)
-
- start_time = time.time()
- command.start()
- command.join()
- end_time = time.time()
-
- self.assertGreaterEqual(end_time - start_time, expected_duration)
-
- def test_command_get_status(self):
- """ Test the get_status method """
- command = Command(self.targets, "action", ("a", "b", "c"),
- max_threads=32)
- command.start()
- command.join()
-
- status = command.get_status()
- self.assertEqual(status.successes, NUM_NODES/2)
- self.assertEqual(status.errors, NUM_NODES/2)
- self.assertEqual(status.nodes_left, 0)
-
-class DummyTarget:
- def __init__(self, address, fail=False):
- self.address = address
- self.fail = fail
- self.executed = []
-
- def action(self, *args):
- self.executed.append(("action", args))
- if self.fail:
- raise ValueError("action_error")
- return "action_result"
-
-class DummyCommand:
- def __init__(self, targets):
- self.results = []
- self.errors = []
-
- self._iterator = iter((x, x) for x in targets)
- self._name = "action"
- self._args = ("a", "b", "c")
- self._delay = 0
-
- def _get_next_node(self):
- return self._iterator.next()
diff --git a/cxmanage_test/fabric_test.py b/cxmanage_test/fabric_test.py
index 787bdb7..53f83d3 100644
--- a/cxmanage_test/fabric_test.py
+++ b/cxmanage_test/fabric_test.py
@@ -45,8 +45,7 @@ class FabricTest(unittest.TestCase):
""" Test the various Fabric commands """
def setUp(self):
# Set up the controller and add targets
- self.fabric = Fabric("192.168.100.1", max_threads=32,
- node=DummyNode)
+ self.fabric = Fabric("192.168.100.1", node=DummyNode)
self.nodes = [DummyNode(x) for x in ADDRESSES]
self.fabric._nodes = dict((i, self.nodes[i])
for i in xrange(NUM_NODES))
@@ -65,16 +64,6 @@ class FabricTest(unittest.TestCase):
for node in self.nodes:
self.assertTrue(node.tftp is tftp)
- def test_command_delay(self):
- """Test that we delay for at least command_delay"""
- delay = random.randint(1, 5)
- self.fabric.command_delay = delay
- self.fabric._nodes = {0: self.fabric.nodes[0]}
- start = time.time()
- self.fabric.info_basic()
- finish = time.time()
- self.assertLess(delay, finish - start)
-
def test_get_mac_addresses(self):
""" Test get_mac_addresses command """
self.fabric.get_mac_addresses()
diff --git a/cxmanage_test/tasks_test.py b/cxmanage_test/tasks_test.py
new file mode 100644
index 0000000..a936b06
--- /dev/null
+++ b/cxmanage_test/tasks_test.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2012, Calxeda Inc.
+#
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of Calxeda Inc. nor the names of its contributors
+# may be used to endorse or promote products derived from this software
+# without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
+# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
+# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
+# THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
+# DAMAGE.
+
+import unittest
+import time
+
+from cxmanage_api.tasks import TaskQueue
+
+class TaskTest(unittest.TestCase):
+ def test_task_queue(self):
+ """ Test the task queue """
+ task_queue = TaskQueue()
+ counters = [Counter() for x in xrange(128)]
+ tasks = [task_queue.put(counters[i].add, i) for i in xrange(128)]
+
+ for task in tasks:
+ task.join()
+
+ for i in xrange(128):
+ self.assertEqual(counters[i].value, i)
+
+ def test_sequential_delay(self):
+ """ Test that a single thread delays between tasks """
+ task_queue = TaskQueue(threads=1, delay=0.25)
+ counters = [Counter() for x in xrange(8)]
+
+ start = time.time()
+
+ tasks = [task_queue.put(x.add, 1) for x in counters]
+ for task in tasks:
+ task.join()
+
+ finish = time.time()
+
+ self.assertGreaterEqual(finish - start, 2.0)
+
+class Counter(object):
+ """ Simple counter object for testing purposes """
+ def __init__(self):
+ self.value = 0
+
+ def add(self, value):
+ """ Increment this counter's value by some amount """
+ self.value += value
diff --git a/run_tests b/run_tests
index 4140dda..f0d7831 100755
--- a/run_tests
+++ b/run_tests
@@ -34,8 +34,8 @@
import unittest
from cxmanage_test import tftp_test, image_test, node_test, fabric_test, \
- command_test
-test_modules = [tftp_test, image_test, node_test, fabric_test, command_test]
+ tasks_test
+test_modules = [tftp_test, image_test, node_test, fabric_test, tasks_test]
def main():
""" Load and run tests """
diff --git a/scripts/cxmanage b/scripts/cxmanage
index a8307c3..6f27256 100755
--- a/scripts/cxmanage
+++ b/scripts/cxmanage
@@ -45,7 +45,7 @@ import time
from cxmanage_api.tftp import InternalTftp, ExternalTftp
from cxmanage_api.node import Node
-from cxmanage_api.command import Command, CommandFailedError
+from cxmanage_api.tasks import TaskQueue
from cxmanage_api.image import Image
from cxmanage_api.firmware_package import FirmwarePackage
from cxmanage_api.ubootenv import UbootEnv
@@ -972,39 +972,40 @@ def ipmitool_command(args):
def _run_command(args, nodes, name, *method_args):
- command = Command(nodes, name, method_args, args.command_delay,
- args.threads)
- command.start()
+ task_queue = TaskQueue(threads=args.threads, delay=args.command_delay)
+ tasks = {}
+ for node in nodes:
+ tasks[node] = task_queue.put(getattr(node, name), *method_args)
results = {}
errors = {}
try:
counter = 0
- while command.is_alive():
+ while any(x.is_alive() for x in tasks.values()):
if not args.quiet:
- _print_command_status(command, counter)
+ _print_command_status(tasks, counter)
counter += 1
time.sleep(0.25)
- try:
- results = command.get_results()
- except CommandFailedError as e:
- results = e.results
- errors = e.errors
+ for node, task in tasks.iteritems():
+ if task.status == "Completed":
+ results[node] = task.result
+ else:
+ errors[node] = task.error
except KeyboardInterrupt:
args.retry = 0
- try:
- results = command.get_results()
- except CommandFailedError as e:
- results = e.results
- errors = e.errors
- for node in nodes:
- if not (node.ip_address in results or node.ip_address in errors):
- errors[nodes.ip_address] = "Aborted by keyboard interrupt"
+
+ for node, task in tasks.iteritems():
+ if task.status == "Completed":
+ results[node] = task.result
+ elif task.status == "Failed":
+ errors[node] = task.error
+ else:
+ errors[node] = KeyboardInterrupt("Aborted by keyboard interrupt")
if not args.quiet:
- _print_command_status(command, counter)
+ _print_command_status(tasks, counter)
print "\n"
# Handle errors
@@ -1049,13 +1050,14 @@ def _print_errors(nodes, errors):
print
-def _print_command_status(command, counter):
+def _print_command_status(tasks, counter):
""" Print the status of a command """
message = "\r%i successes | %i errors | %i nodes left | %s"
- status = command.get_status()
+ successes = len([x for x in tasks.values() if x.status == "Completed"])
+ errors = len([x for x in tasks.values() if x.status == "Failed"])
+ nodes_left = len(tasks) - successes - errors
dots = "".join(["." for x in range(counter % 4)]).ljust(3)
- sys.stdout.write(message % (status.successes, status.errors,
- status.nodes_left, dots))
+ sys.stdout.write(message % (successes, errors, nodes_left, dots))
sys.stdout.flush()