summaryrefslogtreecommitdiff
path: root/taskflow/types/graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'taskflow/types/graph.py')
-rw-r--r--taskflow/types/graph.py121
1 files changed, 17 insertions, 104 deletions
diff --git a/taskflow/types/graph.py b/taskflow/types/graph.py
index 553e690..ad518e9 100644
--- a/taskflow/types/graph.py
+++ b/taskflow/types/graph.py
@@ -21,8 +21,6 @@ import networkx as nx
from networkx.drawing import nx_pydot
import six
-from taskflow.utils import misc
-
def _common_format(g, edge_notation):
lines = []
@@ -31,13 +29,13 @@ def _common_format(g, edge_notation):
lines.append("Frozen: %s" % nx.is_frozen(g))
lines.append("Density: %0.3f" % nx.density(g))
lines.append("Nodes: %s" % g.number_of_nodes())
- for n, n_data in g.nodes_iter(data=True):
+ for n, n_data in g.nodes(data=True):
if n_data:
lines.append(" - %s (%s)" % (n, n_data))
else:
lines.append(" - %s" % n)
lines.append("Edges: %s" % g.number_of_edges())
- for (u, v, e_data) in g.edges_iter(data=True):
+ for (u, v, e_data) in g.edges(data=True):
if e_data:
lines.append(" %s %s %s (%s)" % (u, edge_notation, v, e_data))
else:
@@ -48,11 +46,9 @@ def _common_format(g, edge_notation):
class Graph(nx.Graph):
"""A graph subclass with useful utility functions."""
- def __init__(self, data=None, name=''):
- if misc.nx_version() == '1':
- super(Graph, self).__init__(name=name, data=data)
- else:
- super(Graph, self).__init__(name=name, incoming_graph_data=data)
+ def __init__(self, incoming_graph_data=None, name=''):
+ super(Graph, self).__init__(incoming_graph_data=incoming_graph_data,
+ name=name)
self.frozen = False
def freeze(self):
@@ -69,45 +65,14 @@ class Graph(nx.Graph):
"""Pretty formats your graph into a string."""
return os.linesep.join(_common_format(self, "<->"))
- def nodes_iter(self, data=False):
- """Returns an iterable object over the nodes.
-
- Type of iterable returned object depends on which version
- of networkx is used. When networkx < 2.0 is used , method
- returns an iterator, but if networkx > 2.0 is used, it returns
- NodeView of the Graph which is also iterable.
- """
- if misc.nx_version() == '1':
- return super(Graph, self).nodes_iter(data=data)
- return super(Graph, self).nodes(data=data)
-
- def edges_iter(self, nbunch=None, data=False, default=None):
- """Returns an iterable object over the edges.
-
- Type of iterable returned object depends on which version
- of networkx is used. When networkx < 2.0 is used , method
- returns an iterator, but if networkx > 2.0 is used, it returns
- EdgeView of the Graph which is also iterable.
- """
- if misc.nx_version() == '1':
- return super(Graph, self).edges_iter(nbunch=nbunch, data=data,
- default=default)
- return super(Graph, self).edges(nbunch=nbunch, data=data,
- default=default)
-
def add_edge(self, u, v, attr_dict=None, **attr):
"""Add an edge between u and v."""
- if misc.nx_version() == '1':
- return super(Graph, self).add_edge(u, v, attr_dict=attr_dict,
- **attr)
if attr_dict is not None:
return super(Graph, self).add_edge(u, v, **attr_dict)
return super(Graph, self).add_edge(u, v, **attr)
def add_node(self, n, attr_dict=None, **attr):
"""Add a single node n and update node attributes."""
- if misc.nx_version() == '1':
- return super(Graph, self).add_node(n, attr_dict=attr_dict, **attr)
if attr_dict is not None:
return super(Graph, self).add_node(n, **attr_dict)
return super(Graph, self).add_node(n, **attr)
@@ -125,11 +90,9 @@ class Graph(nx.Graph):
class DiGraph(nx.DiGraph):
"""A directed graph subclass with useful utility functions."""
- def __init__(self, data=None, name=''):
- if misc.nx_version() == '1':
- super(DiGraph, self).__init__(name=name, data=data)
- else:
- super(DiGraph, self).__init__(name=name, incoming_graph_data=data)
+ def __init__(self, incoming_graph_data=None, name=''):
+ super(DiGraph, self).__init__(incoming_graph_data=incoming_graph_data,
+ name=name)
self.frozen = False
def freeze(self):
@@ -183,13 +146,13 @@ class DiGraph(nx.DiGraph):
def no_successors_iter(self):
"""Returns an iterator for all nodes with no successors."""
- for n in self.nodes_iter():
+ for n in self.nodes:
if not len(list(self.successors(n))):
yield n
def no_predecessors_iter(self):
"""Returns an iterator for all nodes with no predecessors."""
- for n in self.nodes_iter():
+ for n in self.nodes:
if not len(list(self.predecessors(n))):
yield n
@@ -203,72 +166,28 @@ class DiGraph(nx.DiGraph):
over more than once (this prevents infinite iteration).
"""
visited = set([n])
- queue = collections.deque(self.predecessors_iter(n))
+ queue = collections.deque(self.predecessors(n))
while queue:
pred = queue.popleft()
if pred not in visited:
yield pred
visited.add(pred)
- for pred_pred in self.predecessors_iter(pred):
+ for pred_pred in self.predecessors(pred):
if pred_pred not in visited:
queue.append(pred_pred)
def add_edge(self, u, v, attr_dict=None, **attr):
"""Add an edge between u and v."""
- if misc.nx_version() == '1':
- return super(DiGraph, self).add_edge(u, v, attr_dict=attr_dict,
- **attr)
if attr_dict is not None:
return super(DiGraph, self).add_edge(u, v, **attr_dict)
return super(DiGraph, self).add_edge(u, v, **attr)
def add_node(self, n, attr_dict=None, **attr):
"""Add a single node n and update node attributes."""
- if misc.nx_version() == '1':
- return super(DiGraph, self).add_node(n, attr_dict=attr_dict,
- **attr)
if attr_dict is not None:
return super(DiGraph, self).add_node(n, **attr_dict)
return super(DiGraph, self).add_node(n, **attr)
- def successors_iter(self, n):
- """Returns an iterator over successor nodes of n."""
- if misc.nx_version() == '1':
- return super(DiGraph, self).successors_iter(n)
- return super(DiGraph, self).successors(n)
-
- def predecessors_iter(self, n):
- """Return an iterator over predecessor nodes of n."""
- if misc.nx_version() == '1':
- return super(DiGraph, self).predecessors_iter(n)
- return super(DiGraph, self).predecessors(n)
-
- def nodes_iter(self, data=False):
- """Returns an iterable object over the nodes.
-
- Type of iterable returned object depends on which version
- of networkx is used. When networkx < 2.0 is used , method
- returns an iterator, but if networkx > 2.0 is used, it returns
- NodeView of the Graph which is also iterable.
- """
- if misc.nx_version() == '1':
- return super(DiGraph, self).nodes_iter(data=data)
- return super(DiGraph, self).nodes(data=data)
-
- def edges_iter(self, nbunch=None, data=False, default=None):
- """Returns an iterable object over the edges.
-
- Type of iterable returned object depends on which version
- of networkx is used. When networkx < 2.0 is used , method
- returns an iterator, but if networkx > 2.0 is used, it returns
- EdgeView of the Graph which is also iterable.
- """
- if misc.nx_version() == '1':
- return super(DiGraph, self).edges_iter(nbunch=nbunch, data=data,
- default=default)
- return super(DiGraph, self).edges(nbunch=nbunch, data=data,
- default=default)
-
def fresh_copy(self):
"""Return a fresh copy graph with the same data structure.
@@ -287,11 +206,8 @@ class OrderedDiGraph(DiGraph):
order).
"""
node_dict_factory = collections.OrderedDict
- if misc.nx_version() == '1':
- adjlist_dict_factory = collections.OrderedDict
- else:
- adjlist_outer_dict_factory = collections.OrderedDict
- adjlist_inner_dict_factory = collections.OrderedDict
+ adjlist_outer_dict_factory = collections.OrderedDict
+ adjlist_inner_dict_factory = collections.OrderedDict
edge_attr_dict_factory = collections.OrderedDict
def fresh_copy(self):
@@ -312,11 +228,8 @@ class OrderedGraph(Graph):
order).
"""
node_dict_factory = collections.OrderedDict
- if misc.nx_version() == '1':
- adjlist_dict_factory = collections.OrderedDict
- else:
- adjlist_outer_dict_factory = collections.OrderedDict
- adjlist_inner_dict_factory = collections.OrderedDict
+ adjlist_outer_dict_factory = collections.OrderedDict
+ adjlist_inner_dict_factory = collections.OrderedDict
edge_attr_dict_factory = collections.OrderedDict
def fresh_copy(self):
@@ -342,7 +255,7 @@ def merge_graphs(graph, *graphs, **kwargs):
raise ValueError("Overlap detection callback expected to be callable")
elif overlap_detector is None:
overlap_detector = (lambda to_graph, from_graph:
- len(to_graph.subgraph(from_graph.nodes_iter())))
+ len(to_graph.subgraph(from_graph.nodes)))
for g in graphs:
# This should ensure that the nodes to be merged do not already exist
# in the graph that is to be merged into. This could be problematic if