summaryrefslogtreecommitdiff
path: root/networkx/classes/backends.py
blob: 135fa76c704459d6c939c6a161f7fc12c2dc8dae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""
Code to support various backends in a plugin dispatch architecture.

Create a Dispatcher
-------------------

To be a valid plugin, a package must register an entry_point
of `networkx.plugins` with a key pointing to the handler.

For example::

    entry_points={'networkx.plugins': 'sparse = networkx_plugin_sparse'}

The plugin must create a Graph-like object which contains an attribute
``__networkx_plugin__`` with a value of the entry point name.

Continuing the example above::

    class WrappedSparse:
        __networkx_plugin__ = "sparse"
        ...

When a dispatchable NetworkX algorithm encounters a Graph-like object
with a ``__networkx_plugin__`` attribute, it will look for the associated
dispatch object in the entry_points, load it, and dispatch the work to it.


Testing
-------
To assist in validating the backend algorithm implementations, if an
environment variable ``NETWORKX_GRAPH_CONVERT`` is set to a registered
plugin keys, the dispatch machinery will automatically convert regular
networkx Graphs and DiGraphs to the backend equivalent by calling
``<backend dispatcher>.convert_from_nx(G, weight=weight, name=name)``.

The converted object is then passed to the backend implementation of
the algorithm. The result is then passed to
``<backend dispatcher>.convert_to_nx(result, name=name)`` to convert back
to a form expected by the NetworkX tests.

By defining ``convert_from_nx`` and ``convert_to_nx`` methods and setting
the environment variable, NetworkX will automatically route tests on
dispatchable algorithms to the backend, allowing the full networkx test
suite to be run against the backend implementation.

Example pytest invocation::

    NETWORKX_GRAPH_CONVERT=sparse pytest --pyargs networkx

Dispatchable algorithms which are not implemented by the backend
will cause a ``pytest.xfail()``, giving some indication that not all
tests are working, while avoiding causing an explicit failure.

A special ``on_start_tests(items)`` function may be defined by the backend.
It will be called with the list of NetworkX tests discovered. Each item
is a test object that can be marked as xfail if the backend does not support
the test using `item.add_marker(pytest.mark.xfail(reason=...))`.
"""
import functools
import inspect
import os
import sys
from importlib.metadata import entry_points

from ..exception import NetworkXNotImplemented

__all__ = ["_dispatch", "_mark_tests"]


class PluginInfo:
    """Lazily loaded entry_points plugin information"""

    def __init__(self):
        self._items = None

    def __bool__(self):
        return len(self.items) > 0

    @property
    def items(self):
        if self._items is None:
            if sys.version_info < (3, 10):
                self._items = entry_points()["networkx.plugins"]
            else:
                self._items = entry_points(group="networkx.plugins")
        return self._items

    def __contains__(self, name):
        if sys.version_info < (3, 10):
            return len([ep for ep in self.items if ep.name == name]) > 0
        return name in self.items.names

    def __getitem__(self, name):
        if sys.version_info < (3, 10):
            return [ep for ep in self.items if ep.name == name][0]
        return self.items[name]


plugins = PluginInfo()
_registered_algorithms = {}


def _register_algo(name, wrapped_func):
    if name in _registered_algorithms:
        raise KeyError(f"Algorithm already exists in dispatch registry: {name}")
    _registered_algorithms[name] = wrapped_func
    wrapped_func.dispatchname = name


def _dispatch(func=None, *, name=None, graphs="G"):
    """Dispatches to a backend algorithm
    when the first argument is a backend graph-like object.

    The algorithm name is assumed to be the name of the wrapped function unless
    `name` is provided. This is useful to avoid name conflicts, as all
    dispatched algorithms live in a single namespace.

    If more than one graph is required for the algorithm, provide a comma-separated
    string of variable names as `graphs`. These must be the same order and name
    as the variables passed to the algorithm. Dispatching does not support graphs
    which are not the first argument(s) to an algorithm.
    """
    # Allow any of the following decorator forms:
    #  - @_dispatch
    #  - @_dispatch()
    #  - @_dispatch(name="override_name")
    #  - @_dispatch(graphs="G,H")
    #  - @_dispatch(name="override_name", graphs="G,H")
    if func is None:
        if name is None and graphs == "G":
            return _dispatch
        return functools.partial(_dispatch, name=name, graphs=graphs)
    if isinstance(func, str):
        raise TypeError("'name' and 'graphs' must be passed by keyword") from None
    # If name not provided, use the name of the function
    if name is None:
        name = func.__name__

    graph_list = [g.strip() for g in graphs.split(",")]
    if len(graph_list) <= 0:
        raise KeyError("'graphs' must contain at least one variable name") from None

    @functools.wraps(func)
    def wrapper(*args, **kwds):
        # Select overlap of args and graph_list
        graphs_resolved = dict(zip(graph_list, args))
        # Check for duplicates from kwds
        dups = set(kwds) & set(graphs_resolved)
        if dups:
            raise KeyError(f"{name}() got multiple values for {dups}") from None
        # Add items from kwds
        for gname in graph_list[len(graphs_resolved) :]:
            if gname not in kwds:
                raise TypeError(
                    f"{name}() missing required graph argument: {gname}"
                ) from None
            graphs_resolved[gname] = kwds[gname]
        # Check if any graph comes from a plugin
        if any(hasattr(g, "__networkx_plugin__") for g in graphs_resolved.values()):
            # Find common plugin name
            plugin_names = {
                getattr(g, "__networkx_plugin__", "networkx")
                for g in graphs_resolved.values()
            }
            if len(plugin_names) != 1:
                raise TypeError(
                    f"{name}() graphs must all be from the same plugin, found {plugin_names}"
                ) from None
            plugin_name = plugin_names.pop()
            if plugin_name not in plugins:
                raise ImportError(f"Unable to load plugin: {plugin_name}") from None
            backend = plugins[plugin_name].load()
            if hasattr(backend, name):
                return getattr(backend, name).__call__(*args, **kwds)
            else:
                raise NetworkXNotImplemented(
                    f"'{name}' not implemented by {plugin_name}"
                )
        return func(*args, **kwds)

    # Keep a handle to the original function to use when testing
    # the dispatch mechanism internally
    wrapper._orig_func = func

    _register_algo(name, wrapper)
    return wrapper


def test_override_dispatch(func=None, *, name=None, graphs="G"):
    """Auto-converts graph arguments into the backend equivalent,
    causing the dispatching mechanism to trigger for every
    decorated algorithm."""
    if func is None:
        if name is None and graphs == "G":
            return _dispatch
        return functools.partial(_dispatch, name=name, graphs=graphs)
    if isinstance(func, str):
        raise TypeError("'name' and 'graphs' must be passed by keyword") from None
    # If name not provided, use the name of the function
    if name is None:
        name = func.__name__

    graph_list = [g.strip() for g in graphs.split(",")]
    if len(graph_list) <= 0:
        raise ValueError("'graphs' must contain at least one variable name") from None

    sig = inspect.signature(func)

    @functools.wraps(func)
    def wrapper(*args, **kwds):
        backend = plugins[plugin_name].load()
        if not hasattr(backend, name):
            if plugin_name == "nx-loopback":
                raise NetworkXNotImplemented(
                    f"'{name}' not found in {backend.__class__.__name__}"
                )
            pytest.xfail(f"'{name}' not implemented by {plugin_name}")
        bound = sig.bind(*args, **kwds)
        bound.apply_defaults()
        # Check that graph names are actually in the signature
        if set(graph_list) - set(bound.arguments):
            raise KeyError(
                f"Invalid graph names: {set(graph_list) - set(bound.arguments)}"
            )
        # Convert graphs into backend graph-like object
        #   Include the weight label, if provided to the algorithm
        weight = None
        if "weight" in bound.arguments:
            weight = bound.arguments["weight"]
        elif "data" in bound.arguments and "default" in bound.arguments:
            # This case exists for several MultiGraph edge algorithms
            if isinstance(bound.arguments["data"], str):
                weight = bound.arguments["data"]
            elif bound.arguments["data"]:
                weight = "weight"
        elif "distance" in bound.arguments:
            # For ego_graph
            weight = bound.arguments["distance"]
        for gname in graph_list:
            bound.arguments[gname] = backend.convert_from_nx(
                bound.arguments[gname], weight=weight, name=name
            )
        result = getattr(backend, name).__call__(**bound.arguments)
        return backend.convert_to_nx(result, name=name)

    wrapper._orig_func = func
    _register_algo(name, wrapper)
    return wrapper


# Check for auto-convert testing
# This allows existing NetworkX tests to be run against a backend
# implementation without any changes to the testing code. The only
# required change is to set an environment variable prior to running
# pytest.
if os.environ.get("NETWORKX_GRAPH_CONVERT"):
    plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
    if not plugins:
        raise Exception("No registered networkx.plugins entry_points")
    if plugin_name not in plugins:
        raise Exception(
            f"No registered networkx.plugins entry_point named {plugin_name}"
        )

    try:
        import pytest
    except ImportError:
        raise ImportError(
            f"Missing pytest, which is required when using NETWORKX_GRAPH_CONVERT"
        )

    # Override `dispatch` for testing
    _dispatch = test_override_dispatch


def _mark_tests(items):
    """Allow backend to mark tests (skip or xfail) if they aren't able to correctly handle them"""
    if os.environ.get("NETWORKX_GRAPH_CONVERT"):
        plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
        backend = plugins[plugin_name].load()
        if hasattr(backend, "on_start_tests"):
            getattr(backend, "on_start_tests")(items)