summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJim Kitchen <jim22k@users.noreply.github.com>2023-04-25 18:57:12 -0500
committerGitHub <noreply@github.com>2023-04-25 19:57:12 -0400
commit8ed20630d5f09f1ced9a797e122e8c3383b08fe7 (patch)
tree4564f454f2dc3496006706adbc80de37582f33f8
parent8982556f31c74ac58540f01ec21ebcd1507b9a6f (diff)
downloadnetworkx-8ed20630d5f09f1ced9a797e122e8c3383b08fe7.tar.gz
Allow multiple graphs for `@nx._dispatch` (#6628)
* Allow multiple graphs for nx._dispatch A new `graphs` keyword is added. For the case of two graphs named `G` and `H` as the first to arguments in `foo`, the new spelling is @nx._dispatch(graphs="G,H") def foo(G, H, other_arg, **kwargs): ... * Use better default "G" for graphs kwarg * fix
-rw-r--r--networkx/algorithms/cluster.py6
-rw-r--r--networkx/algorithms/operators/binary.py7
-rw-r--r--networkx/algorithms/operators/tests/test_binary.py19
-rw-r--r--networkx/classes/backends.py112
4 files changed, 103 insertions, 41 deletions
diff --git a/networkx/algorithms/cluster.py b/networkx/algorithms/cluster.py
index ec65e97f..620c8ce8 100644
--- a/networkx/algorithms/cluster.py
+++ b/networkx/algorithms/cluster.py
@@ -16,7 +16,7 @@ __all__ = [
]
-@nx._dispatch("triangles")
+@nx._dispatch(graphs="G")
@not_implemented_for("directed")
def triangles(G, nodes=None):
"""Compute the number of triangles.
@@ -395,7 +395,7 @@ def clustering(G, nodes=None, weight=None):
return clusterc
-@nx._dispatch("transitivity")
+@nx._dispatch(name="transitivity")
def transitivity(G):
r"""Compute graph transitivity, the fraction of all possible triangles
present in G.
@@ -512,7 +512,7 @@ def square_clustering(G, nodes=None):
return clustering
-@nx._dispatch("generalized_degree")
+@nx._dispatch(name="generalized_degree", graphs="G")
@not_implemented_for("directed")
def generalized_degree(G, nodes=None):
r"""Compute the generalized degree for nodes.
diff --git a/networkx/algorithms/operators/binary.py b/networkx/algorithms/operators/binary.py
index 09f59d13..80e509c7 100644
--- a/networkx/algorithms/operators/binary.py
+++ b/networkx/algorithms/operators/binary.py
@@ -14,6 +14,7 @@ __all__ = [
]
+@nx._dispatch(graphs="G,H")
def union(G, H, rename=()):
"""Combine graphs G and H. The names of nodes must be unique.
@@ -69,6 +70,7 @@ def union(G, H, rename=()):
return nx.union_all([G, H], rename)
+@nx._dispatch(graphs="G,H")
def disjoint_union(G, H):
"""Combine graphs G and H. The nodes are assumed to be unique (disjoint).
@@ -122,6 +124,7 @@ def disjoint_union(G, H):
return nx.disjoint_union_all([G, H])
+@nx._dispatch(graphs="G,H")
def intersection(G, H):
"""Returns a new graph that contains only the nodes and the edges that exist in
both G and H.
@@ -166,6 +169,7 @@ def intersection(G, H):
return nx.intersection_all([G, H])
+@nx._dispatch(graphs="G,H")
def difference(G, H):
"""Returns a new graph that contains the edges that exist in G but not in H.
@@ -220,6 +224,7 @@ def difference(G, H):
return R
+@nx._dispatch(graphs="G,H")
def symmetric_difference(G, H):
"""Returns new graph with edges that exist in either G or H but not both.
@@ -282,6 +287,7 @@ def symmetric_difference(G, H):
return R
+@nx._dispatch(graphs="G,H")
def compose(G, H):
"""Compose graph G with H by combining nodes and edges into a single graph.
@@ -358,6 +364,7 @@ def compose(G, H):
return nx.compose_all([G, H])
+@nx._dispatch(graphs="G,H")
def full_join(G, H, rename=(None, None)):
"""Returns the full join of graphs G and H.
diff --git a/networkx/algorithms/operators/tests/test_binary.py b/networkx/algorithms/operators/tests/test_binary.py
index b4e64f8e..d358265e 100644
--- a/networkx/algorithms/operators/tests/test_binary.py
+++ b/networkx/algorithms/operators/tests/test_binary.py
@@ -1,6 +1,9 @@
+import os
+
import pytest
import networkx as nx
+from networkx.classes.tests import dispatch_interface
from networkx.utils import edges_equal
@@ -39,6 +42,22 @@ def test_intersection():
assert set(I.nodes()) == {1, 2, 3, 4}
assert sorted(I.edges()) == [(2, 3)]
+ ##################
+ # Tests for @nx._dispatch mechanism with multiple graph arguments
+ # nx.intersection is called as if it were a re-implementation
+ # from another package.
+ ###################
+ G2 = dispatch_interface.convert(G)
+ H2 = dispatch_interface.convert(H)
+ I2 = nx.intersection(G2, H2)
+ assert set(I2.nodes()) == {1, 2, 3, 4}
+ assert sorted(I2.edges()) == [(2, 3)]
+ if os.environ.get("NETWORKX_GRAPH_CONVERT", None) != "nx-loopback":
+ with pytest.raises(TypeError):
+ nx.intersection(G2, H)
+ with pytest.raises(TypeError):
+ nx.intersection(G, H2)
+
def test_intersection_node_sets_different():
G = nx.Graph()
diff --git a/networkx/classes/backends.py b/networkx/classes/backends.py
index 761d0a4b..06af4231 100644
--- a/networkx/classes/backends.py
+++ b/networkx/classes/backends.py
@@ -107,44 +107,75 @@ def _register_algo(name, wrapped_func):
wrapped_func.dispatchname = name
-def _dispatch(func=None, *, name=None):
+def _dispatch(func=None, *, name=None, graphs="G"):
"""Dispatches to a backend algorithm
when the first argument is a backend graph-like object.
+
+ The algorithm name is assumed to be the name of the wrapped function unless
+ `name` is provided. This is useful to avoid name conflicts, as all
+ dispatched algorithms live in a single namespace.
+
+ If more than one graph is required for the algorithm, provide a comma-separated
+ string of variable names as `graphs`. These must be the same order and name
+ as the variables passed to the algorithm. Dispatching does not support graphs
+ which are not the first argument(s) to an algorithm.
"""
# Allow any of the following decorator forms:
# - @_dispatch
# - @_dispatch()
- # - @_dispatch("override_name")
# - @_dispatch(name="override_name")
+ # - @_dispatch(graphs="G,H")
+ # - @_dispatch(name="override_name", graphs="G,H")
if func is None:
- if name is None:
+ if name is None and graphs == "G":
return _dispatch
- return functools.partial(_dispatch, name=name)
+ return functools.partial(_dispatch, name=name, graphs=graphs)
if isinstance(func, str):
- return functools.partial(_dispatch, name=func)
+ raise TypeError("'name' and 'graphs' must be passed by keyword") from None
# If name not provided, use the name of the function
if name is None:
name = func.__name__
+ graph_list = [g.strip() for g in graphs.split(",")]
+ if len(graph_list) <= 0:
+ raise KeyError("'graphs' must contain at least one variable name") from None
+
@functools.wraps(func)
def wrapper(*args, **kwds):
- if args:
- graph = args[0]
- else:
- try:
- graph = kwds["G"]
- except KeyError:
- raise TypeError(f"{name}() missing positional argument: 'G'") from None
- if hasattr(graph, "__networkx_plugin__") and plugins:
- plugin_name = graph.__networkx_plugin__
- if plugin_name in plugins:
- backend = plugins[plugin_name].load()
- if hasattr(backend, name):
- return getattr(backend, name).__call__(*args, **kwds)
- else:
- raise NetworkXNotImplemented(
- f"'{name}' not implemented by {plugin_name}"
- )
+ # Select overlap of args and graph_list
+ graphs_resolved = dict(zip(graph_list, args))
+ # Check for duplicates from kwds
+ dups = set(kwds) & set(graphs_resolved)
+ if dups:
+ raise KeyError(f"{name}() got multiple values for {dups}") from None
+ # Add items from kwds
+ for gname in graph_list[len(graphs_resolved) :]:
+ if gname not in kwds:
+ raise TypeError(
+ f"{name}() missing required graph argument: {gname}"
+ ) from None
+ graphs_resolved[gname] = kwds[gname]
+ # Check if any graph comes from a plugin
+ if any(hasattr(g, "__networkx_plugin__") for g in graphs_resolved.values()):
+ # Find common plugin name
+ plugin_names = {
+ getattr(g, "__networkx_plugin__", "networkx")
+ for g in graphs_resolved.values()
+ }
+ if len(plugin_names) != 1:
+ raise TypeError(
+ f"{name}() graphs must all be from the same plugin, found {plugin_names}"
+ ) from None
+ plugin_name = plugin_names.pop()
+ if plugin_name not in plugins:
+ raise ImportError(f"Unable to load plugin: {plugin_name}") from None
+ backend = plugins[plugin_name].load()
+ if hasattr(backend, name):
+ return getattr(backend, name).__call__(*args, **kwds)
+ else:
+ raise NetworkXNotImplemented(
+ f"'{name}' not implemented by {plugin_name}"
+ )
return func(*args, **kwds)
# Keep a handle to the original function to use when testing
@@ -155,20 +186,24 @@ def _dispatch(func=None, *, name=None):
return wrapper
-def test_override_dispatch(func=None, *, name=None):
- """Auto-converts the first argument into the backend equivalent,
+def test_override_dispatch(func=None, *, name=None, graphs="G"):
+ """Auto-converts graph arguments into the backend equivalent,
causing the dispatching mechanism to trigger for every
decorated algorithm."""
if func is None:
- if name is None:
- return test_override_dispatch
- return functools.partial(test_override_dispatch, name=name)
+ if name is None and graphs == "G":
+ return _dispatch
+ return functools.partial(_dispatch, name=name, graphs=graphs)
if isinstance(func, str):
- return functools.partial(test_override_dispatch, name=func)
+ raise TypeError("'name' and 'graphs' must be passed by keyword") from None
# If name not provided, use the name of the function
if name is None:
name = func.__name__
+ graph_list = [g.strip() for g in graphs.split(",")]
+ if len(graph_list) <= 0:
+ raise ValueError("'graphs' must contain at least one variable name") from None
+
sig = inspect.signature(func)
@functools.wraps(func)
@@ -182,14 +217,12 @@ def test_override_dispatch(func=None, *, name=None):
pytest.xfail(f"'{name}' not implemented by {plugin_name}")
bound = sig.bind(*args, **kwds)
bound.apply_defaults()
- if args:
- graph, *args = args
- else:
- try:
- graph = kwds.pop("G")
- except KeyError:
- raise TypeError(f"{name}() missing positional argument: 'G'") from None
- # Convert graph into backend graph-like object
+ # Check that graph names are actually in the signature
+ if set(graph_list) - set(bound.arguments):
+ raise KeyError(
+ f"Invalid graph names: {set(graph_list) - set(bound.arguments)}"
+ )
+ # Convert graphs into backend graph-like object
# Include the weight label, if provided to the algorithm
weight = None
if "weight" in bound.arguments:
@@ -200,8 +233,11 @@ def test_override_dispatch(func=None, *, name=None):
weight = bound.arguments["data"]
elif bound.arguments["data"]:
weight = "weight"
- graph = backend.convert_from_nx(graph, weight=weight, name=name)
- result = getattr(backend, name).__call__(graph, *args, **kwds)
+ for gname in graph_list:
+ bound.arguments[gname] = backend.convert_from_nx(
+ bound.arguments[gname], weight=weight, name=name
+ )
+ result = getattr(backend, name).__call__(**bound.arguments)
return backend.convert_to_nx(result, name=name)
wrapper._orig_func = func