diff options
author | Jim Kitchen <jim22k@users.noreply.github.com> | 2023-04-25 18:57:12 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-25 19:57:12 -0400 |
commit | 8ed20630d5f09f1ced9a797e122e8c3383b08fe7 (patch) | |
tree | 4564f454f2dc3496006706adbc80de37582f33f8 | |
parent | 8982556f31c74ac58540f01ec21ebcd1507b9a6f (diff) | |
download | networkx-8ed20630d5f09f1ced9a797e122e8c3383b08fe7.tar.gz |
Allow multiple graphs for `@nx._dispatch` (#6628)
* Allow multiple graphs for nx._dispatch
A new `graphs` keyword is added. For the case of two graphs named
`G` and `H` as the first to arguments in `foo`, the new spelling is
@nx._dispatch(graphs="G,H")
def foo(G, H, other_arg, **kwargs):
...
* Use better default "G" for graphs kwarg
* fix
-rw-r--r-- | networkx/algorithms/cluster.py | 6 | ||||
-rw-r--r-- | networkx/algorithms/operators/binary.py | 7 | ||||
-rw-r--r-- | networkx/algorithms/operators/tests/test_binary.py | 19 | ||||
-rw-r--r-- | networkx/classes/backends.py | 112 |
4 files changed, 103 insertions, 41 deletions
diff --git a/networkx/algorithms/cluster.py b/networkx/algorithms/cluster.py index ec65e97f..620c8ce8 100644 --- a/networkx/algorithms/cluster.py +++ b/networkx/algorithms/cluster.py @@ -16,7 +16,7 @@ __all__ = [ ] -@nx._dispatch("triangles") +@nx._dispatch(graphs="G") @not_implemented_for("directed") def triangles(G, nodes=None): """Compute the number of triangles. @@ -395,7 +395,7 @@ def clustering(G, nodes=None, weight=None): return clusterc -@nx._dispatch("transitivity") +@nx._dispatch(name="transitivity") def transitivity(G): r"""Compute graph transitivity, the fraction of all possible triangles present in G. @@ -512,7 +512,7 @@ def square_clustering(G, nodes=None): return clustering -@nx._dispatch("generalized_degree") +@nx._dispatch(name="generalized_degree", graphs="G") @not_implemented_for("directed") def generalized_degree(G, nodes=None): r"""Compute the generalized degree for nodes. diff --git a/networkx/algorithms/operators/binary.py b/networkx/algorithms/operators/binary.py index 09f59d13..80e509c7 100644 --- a/networkx/algorithms/operators/binary.py +++ b/networkx/algorithms/operators/binary.py @@ -14,6 +14,7 @@ __all__ = [ ] +@nx._dispatch(graphs="G,H") def union(G, H, rename=()): """Combine graphs G and H. The names of nodes must be unique. @@ -69,6 +70,7 @@ def union(G, H, rename=()): return nx.union_all([G, H], rename) +@nx._dispatch(graphs="G,H") def disjoint_union(G, H): """Combine graphs G and H. The nodes are assumed to be unique (disjoint). @@ -122,6 +124,7 @@ def disjoint_union(G, H): return nx.disjoint_union_all([G, H]) +@nx._dispatch(graphs="G,H") def intersection(G, H): """Returns a new graph that contains only the nodes and the edges that exist in both G and H. @@ -166,6 +169,7 @@ def intersection(G, H): return nx.intersection_all([G, H]) +@nx._dispatch(graphs="G,H") def difference(G, H): """Returns a new graph that contains the edges that exist in G but not in H. @@ -220,6 +224,7 @@ def difference(G, H): return R +@nx._dispatch(graphs="G,H") def symmetric_difference(G, H): """Returns new graph with edges that exist in either G or H but not both. @@ -282,6 +287,7 @@ def symmetric_difference(G, H): return R +@nx._dispatch(graphs="G,H") def compose(G, H): """Compose graph G with H by combining nodes and edges into a single graph. @@ -358,6 +364,7 @@ def compose(G, H): return nx.compose_all([G, H]) +@nx._dispatch(graphs="G,H") def full_join(G, H, rename=(None, None)): """Returns the full join of graphs G and H. diff --git a/networkx/algorithms/operators/tests/test_binary.py b/networkx/algorithms/operators/tests/test_binary.py index b4e64f8e..d358265e 100644 --- a/networkx/algorithms/operators/tests/test_binary.py +++ b/networkx/algorithms/operators/tests/test_binary.py @@ -1,6 +1,9 @@ +import os + import pytest import networkx as nx +from networkx.classes.tests import dispatch_interface from networkx.utils import edges_equal @@ -39,6 +42,22 @@ def test_intersection(): assert set(I.nodes()) == {1, 2, 3, 4} assert sorted(I.edges()) == [(2, 3)] + ################## + # Tests for @nx._dispatch mechanism with multiple graph arguments + # nx.intersection is called as if it were a re-implementation + # from another package. + ################### + G2 = dispatch_interface.convert(G) + H2 = dispatch_interface.convert(H) + I2 = nx.intersection(G2, H2) + assert set(I2.nodes()) == {1, 2, 3, 4} + assert sorted(I2.edges()) == [(2, 3)] + if os.environ.get("NETWORKX_GRAPH_CONVERT", None) != "nx-loopback": + with pytest.raises(TypeError): + nx.intersection(G2, H) + with pytest.raises(TypeError): + nx.intersection(G, H2) + def test_intersection_node_sets_different(): G = nx.Graph() diff --git a/networkx/classes/backends.py b/networkx/classes/backends.py index 761d0a4b..06af4231 100644 --- a/networkx/classes/backends.py +++ b/networkx/classes/backends.py @@ -107,44 +107,75 @@ def _register_algo(name, wrapped_func): wrapped_func.dispatchname = name -def _dispatch(func=None, *, name=None): +def _dispatch(func=None, *, name=None, graphs="G"): """Dispatches to a backend algorithm when the first argument is a backend graph-like object. + + The algorithm name is assumed to be the name of the wrapped function unless + `name` is provided. This is useful to avoid name conflicts, as all + dispatched algorithms live in a single namespace. + + If more than one graph is required for the algorithm, provide a comma-separated + string of variable names as `graphs`. These must be the same order and name + as the variables passed to the algorithm. Dispatching does not support graphs + which are not the first argument(s) to an algorithm. """ # Allow any of the following decorator forms: # - @_dispatch # - @_dispatch() - # - @_dispatch("override_name") # - @_dispatch(name="override_name") + # - @_dispatch(graphs="G,H") + # - @_dispatch(name="override_name", graphs="G,H") if func is None: - if name is None: + if name is None and graphs == "G": return _dispatch - return functools.partial(_dispatch, name=name) + return functools.partial(_dispatch, name=name, graphs=graphs) if isinstance(func, str): - return functools.partial(_dispatch, name=func) + raise TypeError("'name' and 'graphs' must be passed by keyword") from None # If name not provided, use the name of the function if name is None: name = func.__name__ + graph_list = [g.strip() for g in graphs.split(",")] + if len(graph_list) <= 0: + raise KeyError("'graphs' must contain at least one variable name") from None + @functools.wraps(func) def wrapper(*args, **kwds): - if args: - graph = args[0] - else: - try: - graph = kwds["G"] - except KeyError: - raise TypeError(f"{name}() missing positional argument: 'G'") from None - if hasattr(graph, "__networkx_plugin__") and plugins: - plugin_name = graph.__networkx_plugin__ - if plugin_name in plugins: - backend = plugins[plugin_name].load() - if hasattr(backend, name): - return getattr(backend, name).__call__(*args, **kwds) - else: - raise NetworkXNotImplemented( - f"'{name}' not implemented by {plugin_name}" - ) + # Select overlap of args and graph_list + graphs_resolved = dict(zip(graph_list, args)) + # Check for duplicates from kwds + dups = set(kwds) & set(graphs_resolved) + if dups: + raise KeyError(f"{name}() got multiple values for {dups}") from None + # Add items from kwds + for gname in graph_list[len(graphs_resolved) :]: + if gname not in kwds: + raise TypeError( + f"{name}() missing required graph argument: {gname}" + ) from None + graphs_resolved[gname] = kwds[gname] + # Check if any graph comes from a plugin + if any(hasattr(g, "__networkx_plugin__") for g in graphs_resolved.values()): + # Find common plugin name + plugin_names = { + getattr(g, "__networkx_plugin__", "networkx") + for g in graphs_resolved.values() + } + if len(plugin_names) != 1: + raise TypeError( + f"{name}() graphs must all be from the same plugin, found {plugin_names}" + ) from None + plugin_name = plugin_names.pop() + if plugin_name not in plugins: + raise ImportError(f"Unable to load plugin: {plugin_name}") from None + backend = plugins[plugin_name].load() + if hasattr(backend, name): + return getattr(backend, name).__call__(*args, **kwds) + else: + raise NetworkXNotImplemented( + f"'{name}' not implemented by {plugin_name}" + ) return func(*args, **kwds) # Keep a handle to the original function to use when testing @@ -155,20 +186,24 @@ def _dispatch(func=None, *, name=None): return wrapper -def test_override_dispatch(func=None, *, name=None): - """Auto-converts the first argument into the backend equivalent, +def test_override_dispatch(func=None, *, name=None, graphs="G"): + """Auto-converts graph arguments into the backend equivalent, causing the dispatching mechanism to trigger for every decorated algorithm.""" if func is None: - if name is None: - return test_override_dispatch - return functools.partial(test_override_dispatch, name=name) + if name is None and graphs == "G": + return _dispatch + return functools.partial(_dispatch, name=name, graphs=graphs) if isinstance(func, str): - return functools.partial(test_override_dispatch, name=func) + raise TypeError("'name' and 'graphs' must be passed by keyword") from None # If name not provided, use the name of the function if name is None: name = func.__name__ + graph_list = [g.strip() for g in graphs.split(",")] + if len(graph_list) <= 0: + raise ValueError("'graphs' must contain at least one variable name") from None + sig = inspect.signature(func) @functools.wraps(func) @@ -182,14 +217,12 @@ def test_override_dispatch(func=None, *, name=None): pytest.xfail(f"'{name}' not implemented by {plugin_name}") bound = sig.bind(*args, **kwds) bound.apply_defaults() - if args: - graph, *args = args - else: - try: - graph = kwds.pop("G") - except KeyError: - raise TypeError(f"{name}() missing positional argument: 'G'") from None - # Convert graph into backend graph-like object + # Check that graph names are actually in the signature + if set(graph_list) - set(bound.arguments): + raise KeyError( + f"Invalid graph names: {set(graph_list) - set(bound.arguments)}" + ) + # Convert graphs into backend graph-like object # Include the weight label, if provided to the algorithm weight = None if "weight" in bound.arguments: @@ -200,8 +233,11 @@ def test_override_dispatch(func=None, *, name=None): weight = bound.arguments["data"] elif bound.arguments["data"]: weight = "weight" - graph = backend.convert_from_nx(graph, weight=weight, name=name) - result = getattr(backend, name).__call__(graph, *args, **kwds) + for gname in graph_list: + bound.arguments[gname] = backend.convert_from_nx( + bound.arguments[gname], weight=weight, name=name + ) + result = getattr(backend, name).__call__(**bound.arguments) return backend.convert_to_nx(result, name=name) wrapper._orig_func = func |