summaryrefslogtreecommitdiff
path: root/coverage/tomlconfig.py
blob: 8212cfe67e97153553342ab2c29d73782890e098 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
# For details: https://github.com/nedbat/coveragepy/blob/master/NOTICE.txt

"""TOML configuration support for coverage.py"""

import configparser
import os
import re

from coverage.exceptions import CoverageException
from coverage.misc import substitute_variables

# TOML support is an install-time extra option.
try:
    import tomli
except ImportError:         # pragma: not covered
    tomli = None


class TomlDecodeError(Exception):
    """An exception class that exists even when toml isn't installed."""
    pass


class TomlConfigParser:
    """TOML file reading with the interface of HandyConfigParser."""

    # This class has the same interface as config.HandyConfigParser, no
    # need for docstrings.
    # pylint: disable=missing-function-docstring

    def __init__(self, our_file):
        self.our_file = our_file
        self.data = None

    def read(self, filenames):
        # RawConfigParser takes a filename or list of filenames, but we only
        # ever call this with a single filename.
        assert isinstance(filenames, (bytes, str, os.PathLike))
        filename = os.fspath(filenames)

        try:
            with open(filename, encoding='utf-8') as fp:
                toml_text = fp.read()
        except OSError:
            return []
        if tomli is not None:
            toml_text = substitute_variables(toml_text, os.environ)
            try:
                self.data = tomli.loads(toml_text)
            except tomli.TOMLDecodeError as err:
                raise TomlDecodeError(str(err)) from err
            return [filename]
        else:
            has_toml = re.search(r"^\[tool\.coverage\.", toml_text, flags=re.MULTILINE)
            if self.our_file or has_toml:
                # Looks like they meant to read TOML, but we can't read it.
                msg = "Can't read {!r} without TOML support. Install with [toml] extra"
                raise CoverageException(msg.format(filename))
            return []

    def _get_section(self, section):
        """Get a section from the data.

        Arguments:
            section (str): A section name, which can be dotted.

        Returns:
            name (str): the actual name of the section that was found, if any,
                or None.
            data (str): the dict of data in the section, or None if not found.

        """
        prefixes = ["tool.coverage."]
        if self.our_file:
            prefixes.append("")
        for prefix in prefixes:
            real_section = prefix + section
            parts = real_section.split(".")
            try:
                data = self.data[parts[0]]
                for part in parts[1:]:
                    data = data[part]
            except KeyError:
                continue
            break
        else:
            return None, None
        return real_section, data

    def _get(self, section, option):
        """Like .get, but returns the real section name and the value."""
        name, data = self._get_section(section)
        if data is None:
            raise configparser.NoSectionError(section)
        try:
            return name, data[option]
        except KeyError as exc:
            raise configparser.NoOptionError(option, name) from exc

    def has_option(self, section, option):
        _, data = self._get_section(section)
        if data is None:
            return False
        return option in data

    def has_section(self, section):
        name, _ = self._get_section(section)
        return name

    def options(self, section):
        _, data = self._get_section(section)
        if data is None:
            raise configparser.NoSectionError(section)
        return list(data.keys())

    def get_section(self, section):
        _, data = self._get_section(section)
        return data

    def get(self, section, option):
        _, value = self._get(section, option)
        return value

    def _check_type(self, section, option, value, type_, type_desc):
        if not isinstance(value, type_):
            raise ValueError(
                'Option {!r} in section {!r} is not {}: {!r}'
                    .format(option, section, type_desc, value)
            )

    def getboolean(self, section, option):
        name, value = self._get(section, option)
        self._check_type(name, option, value, bool, "a boolean")
        return value

    def getlist(self, section, option):
        name, values = self._get(section, option)
        self._check_type(name, option, values, list, "a list")
        return values

    def getregexlist(self, section, option):
        name, values = self._get(section, option)
        self._check_type(name, option, values, list, "a list")
        for value in values:
            value = value.strip()
            try:
                re.compile(value)
            except re.error as e:
                raise CoverageException(
                    f"Invalid [{name}].{option} value {value!r}: {e}"
                ) from e
        return values

    def getint(self, section, option):
        name, value = self._get(section, option)
        self._check_type(name, option, value, int, "an integer")
        return value

    def getfloat(self, section, option):
        name, value = self._get(section, option)
        if isinstance(value, int):
            value = float(value)
        self._check_type(name, option, value, float, "a float")
        return value