summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorge Kraft <george.kraft@calxeda.com>2012-08-22 14:59:45 -0500
committerGeorge Kraft <george.kraft@calxeda.com>2012-08-22 14:59:45 -0500
commit4dcdb1639f8687d79acf5421bd12a912fea2086e (patch)
tree9953bc80cdcae737d6b955dfd37e7ba5a696a239
parent349dcfe06e23cde7cf0836c812422fe91d16edef (diff)
downloadcxmanage-4dcdb1639f8687d79acf5421bd12a912fea2086e.tar.gz
cxmanage: Don't allow infinite-depth hostfile recursion
Hostfiles can contain any entry that the hostname argument could contain... including entries for other hostfiles. So don't allow a single hostfile to be parsed more than once, in case of circular dependencies. Also allow "file=" as well as "hostfile=" for hostfile entries.
-rw-r--r--cxmanage/controller.py24
-rwxr-xr-xscripts/cxmanage97
2 files changed, 70 insertions, 51 deletions
diff --git a/cxmanage/controller.py b/cxmanage/controller.py
index c0c054e..98833a5 100644
--- a/cxmanage/controller.py
+++ b/cxmanage/controller.py
@@ -202,30 +202,6 @@ class Controller:
return len(errors) > 0
- def get_addresses_in_range(self, start, end):
- """ Return a list of addresses in the given IP range """
- try:
- # Convert startaddr to int
- start_bytes = map(int, start.split("."))
- start_i = ((start_bytes[0] << 24) | (start_bytes[1] << 16)
- | (start_bytes[2] << 8) | (start_bytes[3]))
-
- # Convert endaddr to int
- end_bytes = map(int, end.split("."))
- end_i = ((end_bytes[0] << 24) | (end_bytes[1] << 16)
- | (end_bytes[2] << 8) | (end_bytes[3]))
-
- # Get ip addresses in range
- addresses = []
- for i in range(start_i, end_i + 1):
- address_bytes = [(i >> (24 - 8 * x)) & 0xff for x in range(4)]
- addresses.append("%i.%i.%i.%i" % tuple(address_bytes))
-
- return addresses
-
- except IndexError:
- raise ValueError("Invalid arguments to get_targets_in_range")
-
######################### Execution methods #########################
def power(self, mode):
diff --git a/scripts/cxmanage b/scripts/cxmanage
index b4fac67..019e3c6 100755
--- a/scripts/cxmanage
+++ b/scripts/cxmanage
@@ -300,7 +300,13 @@ def set_tftp(controller, args):
def add_targets(controller, args):
- hosts = parse_hosts(controller, args.hostname.split(','))
+ """add targets to controller"""
+ # Get a list of hosts
+ hosts = []
+ for entry in args.hostname.split(','):
+ hosts.extend(parse_host_entry(entry))
+
+ # Add hosts to controller
if args.all_nodes:
if controller.add_fabrics(hosts, args.user, args.password):
print "ERROR: Failed to get IP addresses. Aborting.\n"
@@ -309,32 +315,69 @@ def add_targets(controller, args):
for host in hosts:
controller.add_target(host, args.user, args.password)
-def parse_hosts(controller, hosts):
- """add targets to controller addresses"""
- results = []
- for entry in hosts:
- # Check if it's a hostfile
- if entry.startswith('hostfile='):
- try:
- hostfile_entries = []
- for line in open(entry[9:]):
- elements = line.partition('#')[0].split()
- for element in elements:
- hostfile_entries.extend(element.split(','))
- hosts.extend(parse_hosts(controller, hostfile_entries))
- except IOError:
- print 'ERROR: %s is not a valid hostfile entry' % entry
- sys.exit(1)
- else:
- # Not a hostfile, is it an IP range?
- try:
- start, end = entry.split('-')
- hosts.extend(controller.get_addresses_in_range(start, end))
- except ValueError:
- # Not a hostfile or IP range, add it as a regular host
- results.append(entry)
-
- return results
+
+def parse_host_entry(entry, hostfiles=set()):
+ """parse a host entry"""
+ try:
+ return parse_hostfile_entry(entry, hostfiles)
+ except ValueError:
+ try:
+ return parse_ip_range_entry(entry)
+ except ValueError:
+ return [entry]
+
+
+def parse_hostfile_entry(entry, hostfiles=set()):
+ """parse a hostfile entry, returning a list of hosts"""
+ if entry.startswith('file='):
+ filename = entry[5:]
+ elif entry.startswith('hostfile='):
+ filename = entry[9:]
+ else:
+ raise ValueError('%s is not a hostfile entry' % entry)
+
+ if filename in hostfiles:
+ return []
+ hostfiles.add(filename)
+
+ entries = []
+ try:
+ for line in open(filename):
+ for element in line.partition('#')[0].split():
+ for hostfile_entry in element.split(','):
+ entries.extend(parse_host_entry(hostfile_entry, hostfiles))
+ except IOError:
+ print 'ERROR: %s is not a valid hostfile entry' % entry
+ sys.exit(1)
+
+ return entries
+
+
+def parse_ip_range_entry(entry):
+ """ Get a list of ip addresses in a given range"""
+ try:
+ start, end = entry.split('-')
+
+ # Convert start address to int
+ start_bytes = map(int, start.split('.'))
+ start_i = ((start_bytes[0] << 24) | (start_bytes[1] << 16)
+ | (start_bytes[2] << 8) | (start_bytes[3]))
+
+ # Convert end address to int
+ end_bytes = map(int, end.split('.'))
+ end_i = ((end_bytes[0] << 24) | (end_bytes[1] << 16)
+ | (end_bytes[2] << 8) | (end_bytes[3]))
+
+ # Get ip addresses in range
+ addresses = []
+ for i in range(start_i, end_i + 1):
+ address_bytes = [(i >> (24 - 8 * x)) & 0xff for x in range(4)]
+ addresses.append('%i.%i.%i.%i' % tuple(address_bytes))
+
+ except (ValueError, IndexError):
+ raise ValueError('%s is not an IP range' % entry)
+
+ return addresses
def power_command(controller, args):