diff options
Diffstat (limited to 'taskflow/types/graph.py')
-rw-r--r-- | taskflow/types/graph.py | 121 |
1 files changed, 17 insertions, 104 deletions
diff --git a/taskflow/types/graph.py b/taskflow/types/graph.py index 553e690..ad518e9 100644 --- a/taskflow/types/graph.py +++ b/taskflow/types/graph.py @@ -21,8 +21,6 @@ import networkx as nx from networkx.drawing import nx_pydot import six -from taskflow.utils import misc - def _common_format(g, edge_notation): lines = [] @@ -31,13 +29,13 @@ def _common_format(g, edge_notation): lines.append("Frozen: %s" % nx.is_frozen(g)) lines.append("Density: %0.3f" % nx.density(g)) lines.append("Nodes: %s" % g.number_of_nodes()) - for n, n_data in g.nodes_iter(data=True): + for n, n_data in g.nodes(data=True): if n_data: lines.append(" - %s (%s)" % (n, n_data)) else: lines.append(" - %s" % n) lines.append("Edges: %s" % g.number_of_edges()) - for (u, v, e_data) in g.edges_iter(data=True): + for (u, v, e_data) in g.edges(data=True): if e_data: lines.append(" %s %s %s (%s)" % (u, edge_notation, v, e_data)) else: @@ -48,11 +46,9 @@ def _common_format(g, edge_notation): class Graph(nx.Graph): """A graph subclass with useful utility functions.""" - def __init__(self, data=None, name=''): - if misc.nx_version() == '1': - super(Graph, self).__init__(name=name, data=data) - else: - super(Graph, self).__init__(name=name, incoming_graph_data=data) + def __init__(self, incoming_graph_data=None, name=''): + super(Graph, self).__init__(incoming_graph_data=incoming_graph_data, + name=name) self.frozen = False def freeze(self): @@ -69,45 +65,14 @@ class Graph(nx.Graph): """Pretty formats your graph into a string.""" return os.linesep.join(_common_format(self, "<->")) - def nodes_iter(self, data=False): - """Returns an iterable object over the nodes. - - Type of iterable returned object depends on which version - of networkx is used. When networkx < 2.0 is used , method - returns an iterator, but if networkx > 2.0 is used, it returns - NodeView of the Graph which is also iterable. - """ - if misc.nx_version() == '1': - return super(Graph, self).nodes_iter(data=data) - return super(Graph, self).nodes(data=data) - - def edges_iter(self, nbunch=None, data=False, default=None): - """Returns an iterable object over the edges. - - Type of iterable returned object depends on which version - of networkx is used. When networkx < 2.0 is used , method - returns an iterator, but if networkx > 2.0 is used, it returns - EdgeView of the Graph which is also iterable. - """ - if misc.nx_version() == '1': - return super(Graph, self).edges_iter(nbunch=nbunch, data=data, - default=default) - return super(Graph, self).edges(nbunch=nbunch, data=data, - default=default) - def add_edge(self, u, v, attr_dict=None, **attr): """Add an edge between u and v.""" - if misc.nx_version() == '1': - return super(Graph, self).add_edge(u, v, attr_dict=attr_dict, - **attr) if attr_dict is not None: return super(Graph, self).add_edge(u, v, **attr_dict) return super(Graph, self).add_edge(u, v, **attr) def add_node(self, n, attr_dict=None, **attr): """Add a single node n and update node attributes.""" - if misc.nx_version() == '1': - return super(Graph, self).add_node(n, attr_dict=attr_dict, **attr) if attr_dict is not None: return super(Graph, self).add_node(n, **attr_dict) return super(Graph, self).add_node(n, **attr) @@ -125,11 +90,9 @@ class Graph(nx.Graph): class DiGraph(nx.DiGraph): """A directed graph subclass with useful utility functions.""" - def __init__(self, data=None, name=''): - if misc.nx_version() == '1': - super(DiGraph, self).__init__(name=name, data=data) - else: - super(DiGraph, self).__init__(name=name, incoming_graph_data=data) + def __init__(self, incoming_graph_data=None, name=''): + super(DiGraph, self).__init__(incoming_graph_data=incoming_graph_data, + name=name) self.frozen = False def freeze(self): @@ -183,13 +146,13 @@ class DiGraph(nx.DiGraph): def no_successors_iter(self): """Returns an iterator for all nodes with no successors.""" - for n in self.nodes_iter(): + for n in self.nodes: if not len(list(self.successors(n))): yield n def no_predecessors_iter(self): """Returns an iterator for all nodes with no predecessors.""" - for n in self.nodes_iter(): + for n in self.nodes: if not len(list(self.predecessors(n))): yield n @@ -203,72 +166,28 @@ class DiGraph(nx.DiGraph): over more than once (this prevents infinite iteration). """ visited = set([n]) - queue = collections.deque(self.predecessors_iter(n)) + queue = collections.deque(self.predecessors(n)) while queue: pred = queue.popleft() if pred not in visited: yield pred visited.add(pred) - for pred_pred in self.predecessors_iter(pred): + for pred_pred in self.predecessors(pred): if pred_pred not in visited: queue.append(pred_pred) def add_edge(self, u, v, attr_dict=None, **attr): """Add an edge between u and v.""" - if misc.nx_version() == '1': - return super(DiGraph, self).add_edge(u, v, attr_dict=attr_dict, - **attr) if attr_dict is not None: return super(DiGraph, self).add_edge(u, v, **attr_dict) return super(DiGraph, self).add_edge(u, v, **attr) def add_node(self, n, attr_dict=None, **attr): """Add a single node n and update node attributes.""" - if misc.nx_version() == '1': - return super(DiGraph, self).add_node(n, attr_dict=attr_dict, - **attr) if attr_dict is not None: return super(DiGraph, self).add_node(n, **attr_dict) return super(DiGraph, self).add_node(n, **attr) - def successors_iter(self, n): - """Returns an iterator over successor nodes of n.""" - if misc.nx_version() == '1': - return super(DiGraph, self).successors_iter(n) - return super(DiGraph, self).successors(n) - - def predecessors_iter(self, n): - """Return an iterator over predecessor nodes of n.""" - if misc.nx_version() == '1': - return super(DiGraph, self).predecessors_iter(n) - return super(DiGraph, self).predecessors(n) - - def nodes_iter(self, data=False): - """Returns an iterable object over the nodes. - - Type of iterable returned object depends on which version - of networkx is used. When networkx < 2.0 is used , method - returns an iterator, but if networkx > 2.0 is used, it returns - NodeView of the Graph which is also iterable. - """ - if misc.nx_version() == '1': - return super(DiGraph, self).nodes_iter(data=data) - return super(DiGraph, self).nodes(data=data) - - def edges_iter(self, nbunch=None, data=False, default=None): - """Returns an iterable object over the edges. - - Type of iterable returned object depends on which version - of networkx is used. When networkx < 2.0 is used , method - returns an iterator, but if networkx > 2.0 is used, it returns - EdgeView of the Graph which is also iterable. - """ - if misc.nx_version() == '1': - return super(DiGraph, self).edges_iter(nbunch=nbunch, data=data, - default=default) - return super(DiGraph, self).edges(nbunch=nbunch, data=data, - default=default) - def fresh_copy(self): """Return a fresh copy graph with the same data structure. @@ -287,11 +206,8 @@ class OrderedDiGraph(DiGraph): order). """ node_dict_factory = collections.OrderedDict - if misc.nx_version() == '1': - adjlist_dict_factory = collections.OrderedDict - else: - adjlist_outer_dict_factory = collections.OrderedDict - adjlist_inner_dict_factory = collections.OrderedDict + adjlist_outer_dict_factory = collections.OrderedDict + adjlist_inner_dict_factory = collections.OrderedDict edge_attr_dict_factory = collections.OrderedDict def fresh_copy(self): @@ -312,11 +228,8 @@ class OrderedGraph(Graph): order). """ node_dict_factory = collections.OrderedDict - if misc.nx_version() == '1': - adjlist_dict_factory = collections.OrderedDict - else: - adjlist_outer_dict_factory = collections.OrderedDict - adjlist_inner_dict_factory = collections.OrderedDict + adjlist_outer_dict_factory = collections.OrderedDict + adjlist_inner_dict_factory = collections.OrderedDict edge_attr_dict_factory = collections.OrderedDict def fresh_copy(self): @@ -342,7 +255,7 @@ def merge_graphs(graph, *graphs, **kwargs): raise ValueError("Overlap detection callback expected to be callable") elif overlap_detector is None: overlap_detector = (lambda to_graph, from_graph: - len(to_graph.subgraph(from_graph.nodes_iter()))) + len(to_graph.subgraph(from_graph.nodes))) for g in graphs: # This should ensure that the nodes to be merged do not already exist # in the graph that is to be merged into. This could be problematic if |