summaryrefslogtreecommitdiff
path: root/sphinx/ext/autodoc/preserve_defaults.py
blob: d451d0973ceb3c0fb3e31ab48fe0bb8858a0ba6d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
    sphinx.ext.autodoc.preserve_defaults
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    Preserve the default argument values of function signatures in source code
    and keep them not evaluated for readability.

    :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS.
    :license: BSD, see LICENSE for details.
"""

import ast
import inspect
import sys
from typing import Any, Dict, List, Optional

from sphinx.application import Sphinx
from sphinx.locale import __
from sphinx.pycode.ast import parse as ast_parse
from sphinx.pycode.ast import unparse as ast_unparse
from sphinx.util import logging

logger = logging.getLogger(__name__)


class DefaultValue:
    def __init__(self, name: str) -> None:
        self.name = name

    def __repr__(self) -> str:
        return self.name


def get_function_def(obj: Any) -> ast.FunctionDef:
    """Get FunctionDef object from living object.
    This tries to parse original code for living object and returns
    AST node for given *obj*.
    """
    try:
        source = inspect.getsource(obj)
        if source.startswith((' ', r'\t')):
            # subject is placed inside class or block.  To read its docstring,
            # this adds if-block before the declaration.
            module = ast_parse('if True:\n' + source)
            return module.body[0].body[0]  # type: ignore
        else:
            module = ast_parse(source)
            return module.body[0]  # type: ignore
    except (OSError, TypeError):  # failed to load source code
        return None


def get_default_value(lines: List[str], position: ast.AST) -> Optional[str]:
    try:
        if sys.version_info < (3, 8):  # only for py38+
            return None
        elif position.lineno == position.end_lineno:
            line = lines[position.lineno - 1]
            return line[position.col_offset:position.end_col_offset]
        else:
            # multiline value is not supported now
            return None
    except (AttributeError, IndexError):
        return None


def update_defvalue(app: Sphinx, obj: Any, bound_method: bool) -> None:
    """Update defvalue info of *obj* using type_comments."""
    if not app.config.autodoc_preserve_defaults:
        return

    try:
        lines = inspect.getsource(obj).splitlines()
        if lines[0].startswith((' ', r'\t')):
            lines.insert(0, '')  # insert a dummy line to follow what get_function_def() does.
    except OSError:
        lines = []

    try:
        function = get_function_def(obj)
        if function.args.defaults or function.args.kw_defaults:
            sig = inspect.signature(obj)
            defaults = list(function.args.defaults)
            kw_defaults = list(function.args.kw_defaults)
            parameters = list(sig.parameters.values())
            for i, param in enumerate(parameters):
                if param.default is not param.empty:
                    if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
                        default = defaults.pop(0)
                        value = get_default_value(lines, default)
                        if value is None:
                            value = ast_unparse(default)  # type: ignore
                        parameters[i] = param.replace(default=DefaultValue(value))
                    else:
                        default = kw_defaults.pop(0)
                        value = get_default_value(lines, default)
                        if value is None:
                            value = ast_unparse(default)  # type: ignore
                        parameters[i] = param.replace(default=DefaultValue(value))
            sig = sig.replace(parameters=parameters)
            obj.__signature__ = sig
    except (AttributeError, TypeError):
        # failed to update signature (ex. built-in or extension types)
        pass
    except NotImplementedError as exc:  # failed to ast.unparse()
        logger.warning(__("Failed to parse a default argument value for %r: %s"), obj, exc)


def setup(app: Sphinx) -> Dict[str, Any]:
    app.add_config_value('autodoc_preserve_defaults', False, True)
    app.connect('autodoc-before-process-signature', update_defvalue)

    return {
        'version': '1.0',
        'parallel_read_safe': True
    }