diff options
author | Laurent Peuch <cortex@worlddomination.be> | 2020-03-20 19:39:33 +0100 |
---|---|---|
committer | Laurent Peuch <cortex@worlddomination.be> | 2020-03-20 19:39:33 +0100 |
commit | 5c3c1a5dd8ccea45cce331e07c5ca39a63b51660 (patch) | |
tree | c910a2c64206e460a0f5d0514d8d3b9e8d41c9bf | |
parent | 2f92ba46d9801839063d940dfcf1f0d46c576b9d (diff) | |
download | logilab-common-5c3c1a5dd8ccea45cce331e07c5ca39a63b51660.tar.gz |
[types] clean type annotations generation from pyannotation
32 files changed, 1089 insertions, 643 deletions
diff --git a/__pkginfo__.py b/__pkginfo__.py index f1bffe5..6ad6cb6 100644 --- a/__pkginfo__.py +++ b/__pkginfo__.py @@ -45,6 +45,8 @@ include_dirs = [join('test', 'data')] install_requires = [ 'setuptools', + 'mypy-extensions', + 'typing_extensions', ] tests_require = [ 'pytz', diff --git a/logilab/common/__init__.py b/logilab/common/__init__.py index bf35711..0d7f183 100644 --- a/logilab/common/__init__.py +++ b/logilab/common/__init__.py @@ -29,13 +29,16 @@ __docformat__ = "restructuredtext en" import sys import types import pkg_resources +from typing import List, Sequence __version__ = pkg_resources.get_distribution('logilab-common').version # deprecated, but keep compatibility with pylint < 1.4.4 __pkginfo__ = types.ModuleType('__pkginfo__') __pkginfo__.__package__ = __name__ -__pkginfo__.version = __version__ +# mypy output: Module has no attribute "version" +# logilab's magic +__pkginfo__.version = __version__ # type: ignore sys.modules['logilab.common.__pkginfo__'] = __pkginfo__ STD_BLACKLIST = ('CVS', '.svn', '.hg', '.git', '.tox', 'debian', 'dist', 'build') @@ -49,7 +52,7 @@ USE_MX_DATETIME = True class attrdict(dict): """A dictionary for which keys are also accessible as attributes.""" - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> str: try: return self[attr] except KeyError: @@ -170,7 +173,7 @@ def make_domains(lists): # private stuff ################################################################ -def _handle_blacklist(blacklist, dirnames, filenames): +def _handle_blacklist(blacklist: Sequence[str], dirnames: List[str], filenames: List[str]) -> None: """remove files/directories in the black list dirnames/filenames are usually from os.walk diff --git a/logilab/common/cache.py b/logilab/common/cache.py index 11ed137..c47f481 100644 --- a/logilab/common/cache.py +++ b/logilab/common/cache.py @@ -27,9 +27,14 @@ __docformat__ = "restructuredtext en" from threading import Lock from logilab.common.decorators import locked +from typing import Union, TypeVar, List _marker = object() + +_KeyType = TypeVar("_KeyType") + + class Cache(dict): """A dictionary like cache. @@ -38,23 +43,23 @@ class Cache(dict): len(self.data) <= self.size """ - def __init__(self, size=100): + def __init__(self, size: int = 100) -> None: """ Warning : Cache.__init__() != dict.__init__(). Constructor does not take any arguments beside size. """ assert size >= 0, 'cache size must be >= 0 (0 meaning no caching)' self.size = size - self._usage = [] + self._usage: List = [] self._lock = Lock() super(Cache, self).__init__() - def _acquire(self): + def _acquire(self) -> None: self._lock.acquire() - def _release(self): + def _release(self) -> None: self._lock.release() - def _update_usage(self, key): + def _update_usage(self, key: _KeyType) -> None: if not self._usage: self._usage.append(key) elif self._usage[-1] != key: @@ -71,20 +76,20 @@ class Cache(dict): else: pass # key is already the most recently used key - def __getitem__(self, key): + def __getitem__(self, key: _KeyType): value = super(Cache, self).__getitem__(key) self._update_usage(key) return value __getitem__ = locked(_acquire, _release)(__getitem__) - def __setitem__(self, key, item): + def __setitem__(self, key: _KeyType, item): # Just make sure that size > 0 before inserting a new item in the cache if self.size > 0: super(Cache, self).__setitem__(key, item) self._update_usage(key) __setitem__ = locked(_acquire, _release)(__setitem__) - def __delitem__(self, key): + def __delitem__(self, key: _KeyType): super(Cache, self).__delitem__(key) self._usage.remove(key) __delitem__ = locked(_acquire, _release)(__delitem__) @@ -94,7 +99,7 @@ class Cache(dict): self._usage = [] clear = locked(_acquire, _release)(clear) - def pop(self, key, default=_marker): + def pop(self, key: _KeyType, default=_marker): if key in self: self._usage.remove(key) #if default is _marker: diff --git a/logilab/common/changelog.py b/logilab/common/changelog.py index 6eb8432..c128eb7 100644 --- a/logilab/common/changelog.py +++ b/logilab/common/changelog.py @@ -49,6 +49,8 @@ __docformat__ = "restructuredtext en" import sys from stat import S_IWRITE import codecs +from typing import List, Any, Optional, Tuple +from _io import StringIO BULLET = '*' SUBBULLET = '-' @@ -76,7 +78,7 @@ class Version(tuple): return tuple.__new__(cls, parsed) @classmethod - def parse(cls, versionstr): + def parse(cls, versionstr: str) -> List[int]: versionstr = versionstr.strip(' :') try: return [int(i) for i in versionstr.split('.')] @@ -84,7 +86,7 @@ class Version(tuple): raise ValueError("invalid literal for version '%s' (%s)" % (versionstr, ex)) - def __str__(self): + def __str__(self) -> str: return '.'.join([str(i) for i in self]) @@ -96,20 +98,21 @@ class ChangeLogEntry(object): """ version_class = Version - def __init__(self, date=None, version=None, **kwargs): + def __init__(self, date: Optional[str] = None, version: Optional[str] = None, **kwargs: Any) -> None: self.__dict__.update(kwargs) + self.version: Optional[Version] if version: self.version = self.version_class(version) else: self.version = None self.date = date - self.messages = [] + self.messages: List[Tuple[List[str], List[List[str]]]] = [] - def add_message(self, msg): + def add_message(self, msg: str) -> None: """add a new message""" self.messages.append(([msg], [])) - def complete_latest_message(self, msg_suite): + def complete_latest_message(self, msg_suite: str) -> None: """complete the latest added message """ if not self.messages: @@ -120,7 +123,7 @@ class ChangeLogEntry(object): else: # message self.messages[-1][0].append(msg_suite) - def add_sub_message(self, sub_msg, key=None): + def add_sub_message(self, sub_msg: str, key: Optional[Any] = None) -> None: if not self.messages: raise ValueError('unable to complete last message as ' 'there is no previous message)') @@ -130,7 +133,7 @@ class ChangeLogEntry(object): raise NotImplementedError('sub message to specific key ' 'are not implemented yet') - def write(self, stream=sys.stdout): + def write(self, stream: StringIO = sys.stdout) -> None: """write the entry to file """ stream.write(u'%s -- %s\n' % (self.date or '', self.version or '')) for msg, sub_msgs in self.messages: @@ -152,19 +155,19 @@ class ChangeLog(object): entry_class = ChangeLogEntry - def __init__(self, changelog_file, title=u''): + def __init__(self, changelog_file: str, title: str = u'') -> None: self.file = changelog_file assert isinstance(title, type(u'')), 'title must be a unicode object' self.title = title self.additional_content = u'' - self.entries = [] + self.entries: List[ChangeLogEntry] = [] self.load() def __repr__(self): return '<ChangeLog %s at %s (%s entries)>' % (self.file, id(self), len(self.entries)) - def add_entry(self, entry): + def add_entry(self, entry: ChangeLogEntry) -> None: """add a new entry to the change log""" self.entries.append(entry) @@ -191,26 +194,31 @@ class ChangeLog(object): entry = self.get_entry(create=create) entry.add_message(msg) - def load(self): + def load(self) -> None: """ read a logilab's ChangeLog from file """ try: stream = codecs.open(self.file, encoding='utf-8') except IOError: return - last = None + + last: Optional[ChangeLogEntry] = None expect_sub = False + for line in stream: sline = line.strip() words = sline.split() + # if new entry if len(words) == 1 and words[0] == '--': expect_sub = False last = self.entry_class() + self.add_entry(last) # if old entry elif len(words) == 3 and words[1] == '--': expect_sub = False last = self.entry_class(words[0], words[2]) + self.add_entry(last) # if title elif sline and last is None: @@ -218,19 +226,23 @@ class ChangeLog(object): # if new entry elif sline and sline[0] == BULLET: expect_sub = False + + assert last is not None last.add_message(sline[1:].strip()) # if new sub_entry elif expect_sub and sline and sline[0] == SUBBULLET: + assert last is not None last.add_sub_message(sline[1:].strip()) # if new line for current entry - elif sline and last.messages: + elif sline and (last and last.messages): last.complete_latest_message(line) else: expect_sub = True self.additional_content += line + stream.close() - def format_title(self): + def format_title(self) -> str: return u'%s\n\n' % self.title.strip() def save(self): @@ -240,7 +252,7 @@ class ChangeLog(object): ensure_fs_mode(self.file, S_IWRITE) self.write(codecs.open(self.file, 'w', encoding='utf-8')) - def write(self, stream=sys.stdout): + def write(self, stream: StringIO = sys.stdout) -> None: """write changelog to stream""" stream.write(self.format_title()) for entry in self.entries: diff --git a/logilab/common/compat.py b/logilab/common/compat.py index eee0a61..4ca540b 100644 --- a/logilab/common/compat.py +++ b/logilab/common/compat.py @@ -33,6 +33,7 @@ import os import sys import types from warnings import warn +from typing import Union # not used here, but imported to preserve API import builtins @@ -41,7 +42,7 @@ def str_to_bytes(string): return str.encode(string) # we have to ignore the encoding in py3k to be able to write a string into a # TextIOWrapper or like object (which expect an unicode string) -def str_encode(string, encoding): +def str_encode(string: Union[int, str], encoding: str) -> str: return str(string) # See also http://bugs.python.org/issue11776 diff --git a/logilab/common/configuration.py b/logilab/common/configuration.py index 8489de5..61c2e97 100644 --- a/logilab/common/configuration.py +++ b/logilab/common/configuration.py @@ -119,25 +119,31 @@ import os import sys import re from os.path import exists, expanduser +from optparse import OptionGroup from copy import copy +from _io import StringIO, TextIOWrapper +from mypy_extensions import NoReturn +from typing import Any, Optional, Union, Dict, List, Tuple, Iterator, Callable, Sequence from warnings import warn import configparser as cp +from logilab.common.types import OptionParser, Option, attrdict from logilab.common.compat import str_encode as _encode from logilab.common.deprecation import deprecated from logilab.common.textutils import normalize_text, unquote from logilab.common import optik_ext + OptionError = optik_ext.OptionError -REQUIRED = [] +REQUIRED: List = [] class UnsupportedAction(Exception): """raised by set_option when it doesn't know what to do for an action""" -def _get_encoding(encoding, stream): +def _get_encoding(encoding: Optional[str], stream: Union[StringIO, TextIOWrapper]) -> str: encoding = encoding or getattr(stream, 'encoding', None) if not encoding: import locale @@ -145,12 +151,14 @@ def _get_encoding(encoding, stream): return encoding +_ValueType = Union[List[str], Tuple[str, ...], str] + # validation functions ######################################################## # validators will return the validated value or raise optparse.OptionValueError # XXX add to documentation -def choice_validator(optdict, name, value): +def choice_validator(optdict: Dict[str, Any], name: str, value: str) -> str: """validate and return a converted value for option of type 'choice' """ if not value in optdict['choices']: @@ -158,7 +166,8 @@ def choice_validator(optdict, name, value): raise optik_ext.OptionValueError(msg % (name, value, optdict['choices'])) return value -def multiple_choice_validator(optdict, name, value): + +def multiple_choice_validator(optdict: Dict[str, Any], name: str, value: _ValueType) -> _ValueType: """validate and return a converted value for option of type 'choice' """ choices = optdict['choices'] @@ -169,17 +178,17 @@ def multiple_choice_validator(optdict, name, value): raise optik_ext.OptionValueError(msg % (name, value, choices)) return values -def csv_validator(optdict, name, value): +def csv_validator(optdict: Dict[str, Any], name: str, value: _ValueType) -> _ValueType: """validate and return a converted value for option of type 'csv' """ return optik_ext.check_csv(None, name, value) -def yn_validator(optdict, name, value): +def yn_validator(optdict: Dict[str, Any], name: str, value: Union[bool, str]) -> bool: """validate and return a converted value for option of type 'yn' """ return optik_ext.check_yn(None, name, value) -def named_validator(optdict, name, value): +def named_validator(optdict: Dict[str, Any], name: str, value: Union[Dict[str, str], str]) -> Dict[str, str]: """validate and return a converted value for option of type 'named' """ return optik_ext.check_named(None, name, value) @@ -204,31 +213,32 @@ def time_validator(optdict, name, value): """validate and return a time object for option of type 'time'""" return optik_ext.check_time(None, name, value) -def bytes_validator(optdict, name, value): +def bytes_validator(optdict: Dict[str, str], name: str, value: Union[int, str]) -> int: """validate and return an integer for option of type 'bytes'""" return optik_ext.check_bytes(None, name, value) -VALIDATORS = {'string': unquote, - 'int': int, - 'float': float, - 'file': file_validator, - 'font': unquote, - 'color': color_validator, - 'regexp': re.compile, - 'csv': csv_validator, - 'yn': yn_validator, - 'bool': yn_validator, - 'named': named_validator, - 'password': password_validator, - 'date': date_validator, - 'time': time_validator, - 'bytes': bytes_validator, - 'choice': choice_validator, - 'multiple_choice': multiple_choice_validator, - } - -def _call_validator(opttype, optdict, option, value): +VALIDATORS: Dict[str, Callable] = { + 'string': unquote, + 'int': int, + 'float': float, + 'file': file_validator, + 'font': unquote, + 'color': color_validator, + 'regexp': re.compile, + 'csv': csv_validator, + 'yn': yn_validator, + 'bool': yn_validator, + 'named': named_validator, + 'password': password_validator, + 'date': date_validator, + 'time': time_validator, + 'bytes': bytes_validator, + 'choice': choice_validator, + 'multiple_choice': multiple_choice_validator, +} + +def _call_validator(opttype: str, optdict: Dict[str, Any], option: str, value: Union[List[str], int, str]) -> Union[List[str], int, str]: if opttype not in VALIDATORS: raise Exception('Unsupported type "%s"' % opttype) try: @@ -274,10 +284,11 @@ def _make_input_function(opttype): print('bad value: %s' % msg) return input_validator -INPUT_FUNCTIONS = { + +INPUT_FUNCTIONS: Dict[str, Callable] = { 'string': input_string, 'password': input_password, - } +} for opttype in VALIDATORS.keys(): INPUT_FUNCTIONS.setdefault(opttype, _make_input_function(opttype)) @@ -306,7 +317,7 @@ def expand_default(self, option): return option.help.replace(self.default_tag, str(value)) -def _validate(value, optdict, name=''): +def _validate(value: Union[List[str], int, str], optdict: Dict[str, Any], name: str = '') -> Union[List[str], int, str]: """return a validated value for an option according to its type optional argument name is only used for error message formatting @@ -343,7 +354,7 @@ def format_time(value): return '%sh' % nbhour return '%sd' % nbday -def format_bytes(value): +def format_bytes(value: int) -> str: if not value: return '0' if value != int(value): @@ -358,7 +369,7 @@ def format_bytes(value): value = next return '%s%s' % (value, unit) -def format_option_value(optdict, value): +def format_option_value(optdict: Dict[str, Any], value: Any) -> Union[None, int, str]: """return the user input's value from a 'compiled' value""" if isinstance(value, (list, tuple)): value = ','.join(value) @@ -377,7 +388,7 @@ def format_option_value(optdict, value): value = format_bytes(value) return value -def ini_format_section(stream, section, options, encoding=None, doc=None): +def ini_format_section(stream: Union[StringIO, TextIOWrapper], section: str, options: Any, encoding: str = None, doc: Optional[Any] = None) -> None: """format an options section using the INI format""" encoding = _get_encoding(encoding, stream) if doc: @@ -385,7 +396,7 @@ def ini_format_section(stream, section, options, encoding=None, doc=None): print('[%s]' % section, file=stream) ini_format(stream, options, encoding) -def ini_format(stream, options, encoding): +def ini_format(stream: Union[StringIO, TextIOWrapper], options: Any, encoding: str) -> None: """format options using the INI format""" for optname, optdict, value in options: value = format_option_value(optdict, value) @@ -428,34 +439,37 @@ def rest_format_section(stream, section, options, encoding=None, doc=None): # Options Manager ############################################################## + class OptionsManagerMixIn(object): """MixIn to handle a configuration from both a configuration file and command line options """ - def __init__(self, usage, config_file=None, version=None, quiet=0): + def __init__(self, usage: Optional[str], config_file: Optional[Any] = None, version: Optional[Any] = None, quiet: int = 0) -> None: self.config_file = config_file self.reset_parsers(usage, version=version) # list of registered options providers - self.options_providers = [] + self.options_providers: List[ConfigurationMixIn] = [] # dictionary associating option name to checker - self._all_options = {} - self._short_options = {} - self._nocallback_options = {} - self._mygroups = dict() + self._all_options: Dict[str, ConfigurationMixIn] = {} + self._short_options: Dict[str, str] = {} + self._nocallback_options: Dict[ConfigurationMixIn, str] = {} + self._mygroups: Dict[str, optik_ext.OptionGroup] = {} # verbosity self.quiet = quiet self._maxlevel = 0 - def reset_parsers(self, usage='', version=None): + def reset_parsers(self, usage: Optional[str] = '', version: Optional[Any] = None) -> None: # configuration file parser self.cfgfile_parser = cp.ConfigParser() # command line parser self.cmdline_parser = optik_ext.OptionParser(usage=usage, version=version) - self.cmdline_parser.options_manager = self + # mypy: "OptionParser" has no attribute "options_manager" + # dynamic attribute? + self.cmdline_parser.options_manager = self # type: ignore self._optik_option_attrs = set(self.cmdline_parser.option_class.ATTRS) - def register_options_provider(self, provider, own_group=True): + def register_options_provider(self, provider: 'ConfigurationMixIn', own_group: bool = True) -> None: """register an options provider""" assert provider.priority <= 0, "provider's priority can't be >= 0" for i in range(len(self.options_providers)): @@ -464,8 +478,12 @@ class OptionsManagerMixIn(object): break else: self.options_providers.append(provider) - non_group_spec_options = [option for option in provider.options - if 'group' not in option[1]] + + # mypy: Need type annotation for 'option' + # you can't type variable of a list comprehension, right? + non_group_spec_options: List = [option for option in provider.options # type: ignore + if 'group' not in option[1]] # type: ignore + groups = getattr(provider, 'option_groups', ()) if own_group and non_group_spec_options: self.add_option_group(provider.name.upper(), provider.__doc__, @@ -475,11 +493,14 @@ class OptionsManagerMixIn(object): self.add_optik_option(provider, self.cmdline_parser, opt, optdict) for gname, gdoc in groups: gname = gname.upper() - goptions = [option for option in provider.options - if option[1].get('group', '').upper() == gname] + + # mypy: Need type annotation for 'option' + # you can't type variable of a list comprehension, right? + goptions: List = [option for option in provider.options # type: ignore + if option[1].get('group', '').upper() == gname] # type: ignore self.add_option_group(gname, gdoc, goptions, provider) - def add_option_group(self, group_name, doc, options, provider): + def add_option_group(self, group_name: str, doc: Optional[str], options: Union[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, Dict[str, str]]]], provider: 'ConfigurationMixIn') -> None: """add an option group including the listed options """ assert options @@ -490,7 +511,9 @@ class OptionsManagerMixIn(object): group = optik_ext.OptionGroup(self.cmdline_parser, title=group_name.capitalize()) self.cmdline_parser.add_option_group(group) - group.level = provider.level + # mypy: "OptionGroup" has no attribute "level" + # dynamic attribute + group.level = provider.level # type: ignore self._mygroups[group_name] = group # add section to the config file if group_name != "DEFAULT": @@ -499,7 +522,7 @@ class OptionsManagerMixIn(object): for opt, optdict in options: self.add_optik_option(provider, group, opt, optdict) - def add_optik_option(self, provider, optikcontainer, opt, optdict): + def add_optik_option(self, provider: 'ConfigurationMixIn', optikcontainer: Union[OptionParser, OptionGroup], opt: str, optdict: Dict[str, Any]) -> None: if 'inputlevel' in optdict: warn('[0.50] "inputlevel" in option dictionary for %s is deprecated,' ' use "level"' % opt, DeprecationWarning) @@ -509,12 +532,11 @@ class OptionsManagerMixIn(object): self._all_options[opt] = provider self._maxlevel = max(self._maxlevel, option.level or 0) - def optik_option(self, provider, opt, optdict): + def optik_option(self, provider: 'ConfigurationMixIn', opt: str, optdict: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any]]: """get our personal option definition and return a suitable form for use with optik/optparse """ optdict = copy(optdict) - others = {} if 'action' in optdict: self._nocallback_options[provider] = opt else: @@ -539,7 +561,7 @@ class OptionsManagerMixIn(object): optdict.pop(key) return args, optdict - def cb_set_provider_option(self, option, opt, value, parser): + def cb_set_provider_option(self, option: 'Option', opt: str, value: Union[List[str], int, str], parser: 'OptionParser') -> None: """optik callback for option setting""" if opt.startswith('--'): # remove -- on long option @@ -552,16 +574,17 @@ class OptionsManagerMixIn(object): value = 1 self.global_set_option(opt, value) - def global_set_option(self, opt, value): + def global_set_option(self, opt: str, value: Union[List[str], int, str]) -> None: """set option on the correct option provider""" self._all_options[opt].set_option(opt, value) - def generate_config(self, stream=None, skipsections=(), encoding=None): + def generate_config(self, stream: Union[StringIO, TextIOWrapper] = None, skipsections: Tuple[()] = (), encoding: Optional[Any] = None) -> None: """write a configuration file according to the current configuration into the given stream or stdout """ - options_by_section = {} + options_by_section: Dict[Any, List] = {} sections = [] + for provider in self.options_providers: for section, options in provider.options_by_section(): if section is None: @@ -586,7 +609,7 @@ class OptionsManagerMixIn(object): encoding) printed = True - def generate_manpage(self, pkginfo, section=1, stream=None): + def generate_manpage(self, pkginfo: attrdict, section: int = 1, stream: StringIO = None) -> None: """write a man page for the current configuration into the given stream or stdout """ @@ -600,17 +623,17 @@ class OptionsManagerMixIn(object): # initialization methods ################################################## - def load_provider_defaults(self): + def load_provider_defaults(self) -> None: """initialize configuration using default values""" for provider in self.options_providers: provider.load_defaults() - def load_file_configuration(self, config_file=None): + def load_file_configuration(self, config_file: str = None) -> None: """load the configuration from file""" self.read_config_file(config_file) self.load_config_file() - def read_config_file(self, config_file=None): + def read_config_file(self, config_file: str = None) -> None: """read the configuration file but do not load it (i.e. dispatching values to each options provider) """ @@ -637,9 +660,11 @@ class OptionsManagerMixIn(object): parser = self.cfgfile_parser parser.read([config_file]) # normalize sections'title - for sect, values in list(parser._sections.items()): + # mypy: "ConfigParser" has no attribute "_sections" + # dynamic attribute? + for sect, values in list(parser._sections.items()): # type: ignore if not sect.isupper() and values: - parser._sections[sect.upper()] = values + parser._sections[sect.upper()] = values # type: ignore elif not self.quiet: msg = 'No config file found, using default configuration' print(msg, file=sys.stderr) @@ -663,7 +688,7 @@ class OptionsManagerMixIn(object): if stream is not None: self.generate_config(stream) - def load_config_file(self): + def load_config_file(self) -> None: """dispatch values previously read from a configuration file to each options provider) """ @@ -676,7 +701,7 @@ class OptionsManagerMixIn(object): # TODO handle here undeclared options appearing in the config file continue - def load_configuration(self, **kwargs): + def load_configuration(self, **kwargs: Any) -> None: """override configuration according to given parameters """ for opt, opt_value in kwargs.items(): @@ -684,12 +709,13 @@ class OptionsManagerMixIn(object): provider = self._all_options[opt] provider.set_option(opt, opt_value) - def load_command_line_configuration(self, args=None): + def load_command_line_configuration(self, args: List[str] = None) -> List[str]: """override configuration according to command line parameters return additional arguments """ self._monkeypatch_expand_default() + try: if args is None: args = sys.argv[1:] @@ -710,32 +736,41 @@ class OptionsManagerMixIn(object): # help methods ############################################################ - def add_help_section(self, title, description, level=0): + def add_help_section(self, title: str, description: str, level: int = 0) -> None: """add a dummy option section for help purpose """ group = optik_ext.OptionGroup(self.cmdline_parser, title=title.capitalize(), description=description) - group.level = level + # mypy: "OptionGroup" has no attribute "level" + # it does, it is set in the optik_ext module + group.level = level # type: ignore self._maxlevel = max(self._maxlevel, level) self.cmdline_parser.add_option_group(group) - def _monkeypatch_expand_default(self): + def _monkeypatch_expand_default(self) -> None: # monkey patch optik_ext to deal with our default values try: self.__expand_default_backup = optik_ext.HelpFormatter.expand_default - optik_ext.HelpFormatter.expand_default = expand_default + # mypy: Cannot assign to a method + # it's dirty but you can + optik_ext.HelpFormatter.expand_default = expand_default # type: ignore except AttributeError: # python < 2.4: nothing to be done pass - def _unmonkeypatch_expand_default(self): + def _unmonkeypatch_expand_default(self) -> None: # remove monkey patch if hasattr(optik_ext.HelpFormatter, 'expand_default'): + # mypy: Cannot assign to a method + # it's dirty but you can + # unpatch optik_ext to avoid side effects - optik_ext.HelpFormatter.expand_default = self.__expand_default_backup + optik_ext.HelpFormatter.expand_default = self.__expand_default_backup # type: ignore - def help(self, level=0): + def help(self, level: int = 0) -> str: """return the usage string for available options """ - self.cmdline_parser.formatter.output_level = level + # mypy: "HelpFormatter" has no attribute "output_level" + # set in optik_ext + self.cmdline_parser.formatter.output_level = level # type: ignore self._monkeypatch_expand_default() try: return self.cmdline_parser.format_help() @@ -751,12 +786,12 @@ class Method(object): self.method = methname self._inst = None - def bind(self, instance): + def bind(self, instance: 'Configuration') -> None: """bind the method to its instance""" if self._inst is None: self._inst = instance - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Dict[str, str]: assert self._inst, 'unbound method' return getattr(self._inst, self.method)(*args, **kwargs) @@ -768,23 +803,23 @@ class OptionsProviderMixIn(object): # those attributes should be overridden priority = -1 name = 'default' - options = () + options: Tuple = () level = 0 - def __init__(self): + def __init__(self) -> None: self.config = optik_ext.Values() - for option in self.options: + for option_tuple in self.options: try: - option, optdict = option + option, optdict = option_tuple except ValueError: - raise Exception('Bad option: %r' % option) + raise Exception('Bad option: %s' % str(option_tuple)) if isinstance(optdict.get('default'), Method): optdict['default'].bind(self) elif isinstance(optdict.get('callback'), Method): optdict['callback'].bind(self) self.load_defaults() - def load_defaults(self): + def load_defaults(self) -> None: """initialize the provider using default values""" for opt, optdict in self.options: action = optdict.get('action') @@ -883,8 +918,10 @@ class OptionsProviderMixIn(object): for option in self.options: if option[0] == opt: return option[1] + # mypy: Argument 2 to "OptionError" has incompatible type "str"; expected "Option" + # seems to be working? raise OptionError('no such option %s in section %r' - % (opt, self.name), opt) + % (opt, self.name), opt) # type: ignore def all_options(self): @@ -900,17 +937,19 @@ class OptionsProviderMixIn(object): for option, optiondict, value in options: yield section, option, optiondict - def options_by_section(self): + def options_by_section(self) -> Iterator[Any]: """return an iterator on options grouped by section (section, [list of (optname, optdict, optvalue)]) """ - sections = {} + sections: Dict[str, List[Tuple[str, Dict[str, Any], Any]]] = {} for optname, optdict in self.options: sections.setdefault(optdict.get('group'), []).append( (optname, optdict, self.option_value(optname))) if None in sections: - yield None, sections.pop(None) + # mypy: No overload variant of "pop" of "MutableMapping" matches argument type "None" + # it actually works + yield None, sections.pop(None) # type: ignore for section, options in sorted(sections.items()): yield section.upper(), options @@ -926,14 +965,14 @@ class ConfigurationMixIn(OptionsManagerMixIn, OptionsProviderMixIn): """basic mixin for simple configurations which don't need the manager / providers model """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: if not args: kwargs.setdefault('usage', '') kwargs.setdefault('quiet', 1) OptionsManagerMixIn.__init__(self, *args, **kwargs) OptionsProviderMixIn.__init__(self) if not getattr(self, 'option_groups', None): - self.option_groups = [] + self.option_groups: List[Tuple[Any, str]] = [] for option, optdict in self.options: try: gdef = (optdict['group'].upper(), '') diff --git a/logilab/common/date.py b/logilab/common/date.py index cdf2317..2d2ed22 100644 --- a/logilab/common/date.py +++ b/logilab/common/date.py @@ -27,6 +27,8 @@ from locale import getlocale, LC_TIME from datetime import date, time, datetime, timedelta from time import strptime as time_strptime from calendar import monthrange, timegm +from typing import Union, List, Any, Iterator, Optional, Generator + try: from mx.DateTime import RelativeDateTime, Date, DateTimeType @@ -90,13 +92,13 @@ FRENCH_MOBILE_HOLIDAYS = { # XXX this implementation cries for multimethod dispatching -def get_step(dateobj, nbdays=1): +def get_step(dateobj: Union[date, datetime], nbdays: int = 1) -> timedelta: # assume date is either a python datetime or a mx.DateTime object if isinstance(dateobj, date): return ONEDAY * nbdays return nbdays # mx.DateTime is ok with integers -def datefactory(year, month, day, sampledate): +def datefactory(year: int, month: int, day: int, sampledate: Union[date, datetime]) -> Union[date, datetime]: # assume date is either a python datetime or a mx.DateTime object if isinstance(sampledate, datetime): return datetime(year, month, day) @@ -104,20 +106,23 @@ def datefactory(year, month, day, sampledate): return date(year, month, day) return Date(year, month, day) -def weekday(dateobj): +def weekday(dateobj: Union[date, datetime]) -> int: # assume date is either a python datetime or a mx.DateTime object if isinstance(dateobj, date): return dateobj.weekday() return dateobj.day_of_week -def str2date(datestr, sampledate): +def str2date(datestr: str, sampledate: Union[date, datetime]) -> Union[date, datetime]: # NOTE: datetime.strptime is not an option until we drop py2.4 compat year, month, day = [int(chunk) for chunk in datestr.split('-')] return datefactory(year, month, day, sampledate) -def days_between(start, end): +def days_between(start: Union[date, datetime], end: Union[date, datetime]) -> int: if isinstance(start, date): - delta = end - start + # mypy: No overload variant of "__sub__" of "datetime" matches argument type "date" + # we ensure that end is a date + assert isinstance(end, date) + delta = end - start # type: ignore # datetime.timedelta.days is always an integer (floored) if delta.seconds: return delta.days + 1 @@ -125,7 +130,7 @@ def days_between(start, end): else: return int(math.ceil((end - start).days)) -def get_national_holidays(begin, end): +def get_national_holidays(begin: Union[date, datetime], end: Union[date, datetime]) -> Union[List[date], List[datetime]]: """return french national days off between begin and end""" begin = datefactory(begin.year, begin.month, begin.day, begin) end = datefactory(end.year, end.month, end.day, end) @@ -138,7 +143,7 @@ def get_national_holidays(begin, end): holidays.append(date) return [day for day in holidays if begin <= day < end] -def add_days_worked(start, days): +def add_days_worked(start: date, days: int) -> date: """adds date but try to only take days worked into account""" step = get_step(start) weeks, plus = divmod(days, 5) @@ -151,7 +156,7 @@ def add_days_worked(start, days): end += (2 * step) return end -def nb_open_days(start, end): +def nb_open_days(start: Union[date, datetime], end: Union[date, datetime]) -> int: assert start <= end step = get_step(start) days = days_between(start, end) @@ -168,7 +173,8 @@ def nb_open_days(start, end): return 0 return open_days -def date_range(begin, end, incday=None, incmonth=None): + +def date_range(begin: date, end: date, incday: Optional[Any] = None, incmonth: Optional[bool] = None) -> Generator[date, Any, None]: """yields each date between begin and end :param begin: the start date @@ -193,13 +199,13 @@ def date_range(begin, end, incday=None, incmonth=None): else: incr = get_step(begin, incday or 1) while begin < end: - yield begin - begin += incr + yield begin + begin += incr # makes py datetime usable ##################################################### -ONEDAY = timedelta(days=1) -ONEWEEK = timedelta(days=7) +ONEDAY: timedelta = timedelta(days=1) +ONEWEEK: timedelta = timedelta(days=7) try: strptime = datetime.strptime @@ -211,7 +217,7 @@ except AttributeError: # py < 2.5 def strptime_time(value, format='%H:%M'): return time(*time_strptime(value, format)[3:6]) -def todate(somedate): +def todate(somedate: date) -> date: """return a date from a date (leaving unchanged) or a datetime""" if isinstance(somedate, datetime): return date(somedate.year, somedate.month, somedate.day) @@ -234,10 +240,10 @@ def todatetime(somedate): assert isinstance(somedate, (date, DateTimeType)), repr(somedate) return datetime(somedate.year, somedate.month, somedate.day) -def datetime2ticks(somedate): +def datetime2ticks(somedate: Union[date, datetime]) -> int: return timegm(somedate.timetuple()) * 1000 + int(getattr(somedate, 'microsecond', 0) / 1000) -def ticks2datetime(ticks): +def ticks2datetime(ticks: int) -> datetime: miliseconds, microseconds = divmod(ticks, 1000) try: return datetime.fromtimestamp(miliseconds) @@ -250,7 +256,7 @@ def ticks2datetime(ticks): except (ValueError, OverflowError): raise -def days_in_month(somedate): +def days_in_month(somedate: date) -> int: return monthrange(somedate.year, somedate.month)[1] def days_in_year(somedate): @@ -266,7 +272,7 @@ def previous_month(somedate, nbmonth=1): nbmonth -= 1 return somedate -def next_month(somedate, nbmonth=1): +def next_month(somedate: date, nbmonth: int = 1) -> date: while nbmonth: somedate = last_day(somedate) + ONEDAY nbmonth -= 1 @@ -275,10 +281,10 @@ def next_month(somedate, nbmonth=1): def first_day(somedate): return date(somedate.year, somedate.month, 1) -def last_day(somedate): +def last_day(somedate: date) -> date: return date(somedate.year, somedate.month, days_in_month(somedate)) -def ustrftime(somedate, fmt='%Y-%m-%d'): +def ustrftime(somedate: datetime, fmt: str = '%Y-%m-%d') -> str: """like strftime, but returns a unicode string instead of an encoded string which may be problematic with localized date. """ @@ -309,10 +315,11 @@ def ustrftime(somedate, fmt='%Y-%m-%d'): fmt = re.sub('%([YmdHMS])', r'%(\1)02d', fmt) return unicode(fmt) % fields -def utcdatetime(dt): +def utcdatetime(dt: datetime) -> datetime: if dt.tzinfo is None: return dt - return (dt.replace(tzinfo=None) - dt.utcoffset()) + # mypy: No overload variant of "__sub__" of "datetime" matches argument type "None" + return (dt.replace(tzinfo=None) - dt.utcoffset()) # type: ignore def utctime(dt): if dt.tzinfo is None: diff --git a/logilab/common/debugger.py b/logilab/common/debugger.py index 909169c..2df84ad 100644 --- a/logilab/common/debugger.py +++ b/logilab/common/debugger.py @@ -34,7 +34,10 @@ __docformat__ = "restructuredtext en" try: import readline except ImportError: - readline = None + # mypy: Incompatible types in assignment (expression has type "None", + # mypy: variable has type Module)) + # conditional import + readline = None # type: ignore import os import os.path as osp import sys diff --git a/logilab/common/decorators.py b/logilab/common/decorators.py index 8dd2e7f..27ed7ee 100644 --- a/logilab/common/decorators.py +++ b/logilab/common/decorators.py @@ -25,6 +25,8 @@ import sys import types from time import process_time, time from inspect import isgeneratorfunction +from mypy_extensions import NoReturn +from typing import Any, Optional, Callable, Union from inspect import getfullargspec @@ -33,12 +35,13 @@ from logilab.common.compat import method_type # XXX rewrite so we can use the decorator syntax when keyarg has to be specified class cached_decorator(object): - def __init__(self, cacheattr=None, keyarg=None): + def __init__(self, cacheattr: Optional[str] = None, keyarg: Optional[int] = None) -> None: self.cacheattr = cacheattr self.keyarg = keyarg - def __call__(self, callableobj=None): + def __call__(self, callableobj: Optional[Callable] = None) -> Callable: assert not isgeneratorfunction(callableobj), \ 'cannot cache generator function: %s' % callableobj + assert callableobj is not None if len(getfullargspec(callableobj).args) == 1 or self.keyarg == 0: cache = _SingleValueCache(callableobj, self.cacheattr) elif self.keyarg: @@ -48,7 +51,7 @@ class cached_decorator(object): return cache.closure() class _SingleValueCache(object): - def __init__(self, callableobj, cacheattr=None): + def __init__(self, callableobj: Callable, cacheattr: Optional[str] = None) -> None: self.callable = callableobj if cacheattr is None: self.cacheattr = '_%s_cache_' % callableobj.__name__ @@ -64,10 +67,12 @@ class _SingleValueCache(object): setattr(self, __me.cacheattr, value) return value - def closure(self): + def closure(self) -> Callable: def wrapped(*args, **kwargs): return self.__call__(*args, **kwargs) - wrapped.cache_obj = self + # mypy: "Callable[[VarArg(Any), KwArg(Any)], Any]" has no attribute "cache_obj" + # dynamic attribute for magic + wrapped.cache_obj = self # type: ignore try: wrapped.__doc__ = self.callable.__doc__ wrapped.__name__ = self.callable.__name__ @@ -97,7 +102,7 @@ class _MultiValuesCache(_SingleValueCache): return _cache[args] class _MultiValuesKeyArgCache(_MultiValuesCache): - def __init__(self, callableobj, keyarg, cacheattr=None): + def __init__(self, callableobj: Callable, keyarg: int, cacheattr: Optional[str] = None) -> None: super(_MultiValuesKeyArgCache, self).__init__(callableobj, cacheattr) self.keyarg = keyarg @@ -111,7 +116,7 @@ class _MultiValuesKeyArgCache(_MultiValuesCache): return _cache[key] -def cached(callableobj=None, keyarg=None, **kwargs): +def cached(callableobj: Optional[Callable] = None, keyarg: Optional[int] = None, **kwargs: Any) -> Union[Callable, cached_decorator]: """Simple decorator to cache result of method call.""" kwargs['keyarg'] = keyarg decorator = cached_decorator(**kwargs) @@ -145,8 +150,10 @@ class cachedproperty(object): wrapped) self.wrapped = wrapped + # mypy: Signature of "__doc__" incompatible with supertype "object" + # but this works? @property - def __doc__(self): + def __doc__(self) -> str: # type: ignore doc = getattr(self.wrapped, '__doc__', None) return ('<wrapped by the cachedproperty decorator>%s' % ('\n%s' % doc if doc else '')) @@ -241,7 +248,7 @@ def locked(acquire, release): having called acquire(self) et will call release(self) afterwards. """ def decorator(f): - def wrapper(self, *args, **kwargs): + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: acquire(self) try: return f(self, *args, **kwargs) @@ -251,7 +258,7 @@ def locked(acquire, release): return decorator -def monkeypatch(klass, methodname=None): +def monkeypatch(klass: type, methodname: Optional[str] = None) -> Callable: """Decorator extending class with the decorated callable. This is basically a syntactic sugar vs class assignment. diff --git a/logilab/common/fileutils.py b/logilab/common/fileutils.py index 93439d3..102cd7c 100644 --- a/logilab/common/fileutils.py +++ b/logilab/common/fileutils.py @@ -36,13 +36,15 @@ from os.path import isabs, isdir, islink, split, exists, normpath, join from os.path import abspath from os import sep, mkdir, remove, listdir, stat, chmod, walk from stat import ST_MODE, S_IWRITE +from typing import Optional, List, Tuple +from _io import TextIOWrapper from logilab.common import STD_BLACKLIST as BASE_BLACKLIST, IGNORED_EXTENSIONS from logilab.common.shellutils import find from logilab.common.deprecation import deprecated from logilab.common.compat import FileIO -def first_level_directory(path): +def first_level_directory(path: str) -> str: """Return the first level directory of a path. >>> first_level_directory('home/syt/work') @@ -73,7 +75,7 @@ def abspath_listdir(path): return [join(path, filename) for filename in listdir(path)] -def is_binary(filename): +def is_binary(filename: str) -> int: """Return true if filename may be a binary file, according to it's extension. @@ -86,12 +88,14 @@ def is_binary(filename): isn't beginning by text/) """ try: - return not mimetypes.guess_type(filename)[0].startswith('text') + # mypy: Item "None" of "Optional[str]" has no attribute "startswith" + # it's handle by the exception + return not mimetypes.guess_type(filename)[0].startswith('text') # type: ignore except AttributeError: return 1 -def write_open_mode(filename): +def write_open_mode(filename: str) -> str: """Return the write mode that should used to open file. :type filename: str @@ -143,7 +147,7 @@ class ProtectedFile(FileIO): - on close()/del(), write/append the StringIO content to the file and do the chmod only once """ - def __init__(self, filepath, mode): + def __init__(self, filepath: str, mode: str) -> None: self.original_mode = stat(filepath)[ST_MODE] self.mode_changed = False if mode in ('w', 'a', 'wb', 'ab'): @@ -152,19 +156,19 @@ class ProtectedFile(FileIO): self.mode_changed = True FileIO.__init__(self, filepath, mode) - def _restore_mode(self): + def _restore_mode(self) -> None: """restores the original mode if needed""" if self.mode_changed: chmod(self.name, self.original_mode) # Don't re-chmod in case of several restore self.mode_changed = False - def close(self): + def close(self) -> None: """restore mode before closing""" self._restore_mode() FileIO.close(self) - def __del__(self): + def __del__(self) -> None: if not self.closed: self.close() @@ -265,7 +269,7 @@ def norm_open(path): return open(path, 'U') norm_open = deprecated("use \"open(path, 'U')\"")(norm_open) -def lines(path, comments=None): +def lines(path: str, comments: Optional[str] = None) -> List[str]: """Return a list of non empty lines in the file located at `path`. :type path: str @@ -287,7 +291,7 @@ def lines(path, comments=None): return stream_lines(stream, comments) -def stream_lines(stream, comments=None): +def stream_lines(stream: TextIOWrapper, comments: Optional[str] = None) -> List[str]: """Return a list of non empty lines in the given `stream`. :type stream: object implementing 'xreadlines' or 'readlines' @@ -317,9 +321,9 @@ def stream_lines(stream, comments=None): return result -def export(from_dir, to_dir, - blacklist=BASE_BLACKLIST, ignore_ext=IGNORED_EXTENSIONS, - verbose=0): +def export(from_dir: str, to_dir: str, + blacklist: Tuple[str, str, str, str, str, str, str, str] = BASE_BLACKLIST, ignore_ext: Tuple[str, str, str, str, str, str] = IGNORED_EXTENSIONS, + verbose: int = 0) -> None: """Make a mirror of `from_dir` in `to_dir`, omitting directories and files listed in the black list or ending with one of the given extensions. diff --git a/logilab/common/graph.py b/logilab/common/graph.py index cef1c98..fffa172 100644 --- a/logilab/common/graph.py +++ b/logilab/common/graph.py @@ -30,6 +30,7 @@ import sys import tempfile import codecs import errno +from typing import Dict, List, Tuple, Union, Any, Optional, Set, TypeVar, Iterable def escape(value): """Make <value> usable in a dot file.""" @@ -176,7 +177,12 @@ class GraphGenerator: class UnorderableGraph(Exception): pass -def ordered_nodes(graph): + +V = TypeVar("V") +_Graph = Dict[V, List[V]] + + +def ordered_nodes(graph: _Graph) -> Tuple[V, ...]: """takes a dependency graph dict as arguments and return an ordered tuple of nodes starting with nodes without dependencies and up to the outermost node. @@ -185,84 +191,115 @@ def ordered_nodes(graph): Also the given graph dict will be emptied. """ # check graph consistency - cycles = get_cycles(graph) + cycles: List[List[V]] = get_cycles(graph) + if cycles: - cycles = '\n'.join([' -> '.join(cycle) for cycle in cycles]) - raise UnorderableGraph('cycles in graph: %s' % cycles) + bad_cycles = '\n'.join([' -> '.join(map(str, cycle)) for cycle in cycles]) + raise UnorderableGraph('cycles in graph: %s' % bad_cycles) + vertices = set(graph) to_vertices = set() + for edges in graph.values(): to_vertices |= set(edges) + missing_vertices = to_vertices - vertices if missing_vertices: raise UnorderableGraph('missing vertices: %s' % ', '.join(missing_vertices)) + # order vertices order = [] order_set = set() old_len = None + while graph: if old_len == len(graph): raise UnorderableGraph('unknown problem with %s' % graph) + old_len = len(graph) deps_ok = [] + for node, node_deps in graph.items(): for dep in node_deps: if dep not in order_set: break else: deps_ok.append(node) + order.append(deps_ok) order_set |= set(deps_ok) + for node in deps_ok: del graph[node] + result = [] + for grp in reversed(order): result.extend(sorted(grp)) + return tuple(result) -def get_cycles(graph_dict, vertices=None): +def get_cycles(graph_dict: _Graph, + vertices: Optional[Iterable] = None) -> List[List]: '''given a dictionary representing an ordered graph (i.e. key are vertices and values is a list of destination vertices representing edges), return a list of detected cycles ''' if not graph_dict: - return () - result = [] + return [] + + result: List[List] = [] if vertices is None: vertices = graph_dict.keys() + for vertice in vertices: _get_cycles(graph_dict, [], set(), result, vertice) + return result -def _get_cycles(graph_dict, path, visited, result, vertice): + +def _get_cycles(graph_dict: _Graph, + path: List, + visited: Set, + result: List[List], + vertice: V) -> None: """recursive function doing the real work for get_cycles""" if vertice in path: cycle = [vertice] + for node in path[::-1]: if node == vertice: break + cycle.insert(0, node) + # make a canonical representation start_from = min(cycle) index = cycle.index(start_from) cycle = cycle[index:] + cycle[0:index] + # append it to result if not already in - if not cycle in result: + if cycle not in result: result.append(cycle) return + path.append(vertice) + try: for node in graph_dict[vertice]: # don't check already visited nodes again if node not in visited: _get_cycles(graph_dict, path, visited, result, node) visited.add(node) + except KeyError: pass + path.pop() -def has_path(graph_dict, fromnode, tonode, path=None): + +def has_path(graph_dict: Dict[str, List[str]], fromnode: str, tonode: str, path: Optional[List[str]] = None) -> Optional[List[str]]: """generic function taking a simple graph definition as a dictionary, with node has key associated to a list of nodes directly reachable from it. diff --git a/logilab/common/interface.py b/logilab/common/interface.py index ab0e529..8248a27 100644 --- a/logilab/common/interface.py +++ b/logilab/common/interface.py @@ -29,11 +29,11 @@ __docformat__ = "restructuredtext en" class Interface(object): """Base class for interfaces.""" @classmethod - def is_implemented_by(cls, instance): + def is_implemented_by(cls, instance: type) -> bool: return implements(instance, cls) -def implements(obj, interface): +def implements(obj: type, interface: type) -> bool: """Return true if the give object (maybe an instance or class) implements the interface. """ @@ -46,7 +46,7 @@ def implements(obj, interface): return False -def extend(klass, interface, _recurs=False): +def extend(klass: type, interface: type, _recurs: bool = False) -> None: """Add interface to klass'__implements__ if not already implemented in. If klass is subclassed, ensure subclasses __implements__ it as well. @@ -55,14 +55,14 @@ def extend(klass, interface, _recurs=False): """ if not implements(klass, interface): try: - kimplements = klass.__implements__ + kimplements = klass.__implements__ # type: ignore kimplementsklass = type(kimplements) kimplements = list(kimplements) except AttributeError: kimplementsklass = tuple kimplements = [] kimplements.append(interface) - klass.__implements__ = kimplementsklass(kimplements) + klass.__implements__ = kimplementsklass(kimplements) #type: ignore for subklass in klass.__subclasses__(): extend(subklass, interface, _recurs=True) elif _recurs: diff --git a/logilab/common/modutils.py b/logilab/common/modutils.py index 34419bd..76c4ac4 100644 --- a/logilab/common/modutils.py +++ b/logilab/common/modutils.py @@ -35,13 +35,18 @@ import os from os.path import (splitext, join, abspath, isdir, dirname, exists, basename, expanduser, normcase, realpath) from imp import find_module, load_module, C_BUILTIN, PY_COMPILED, PKG_DIRECTORY -from distutils.sysconfig import get_config_var, get_python_lib, get_python_version +from distutils.sysconfig import get_config_var, get_python_lib from distutils.errors import DistutilsPlatformError +from typing import Dict, List, Optional, Any, Tuple, Union, Sequence +from types import ModuleType +from _frozen_importlib_external import FileFinder try: import zipimport except ImportError: - zipimport = None + # mypy: Incompatible types in assignment (expression has type "None", variable has type Module) + # conditional import + zipimport = None # type: ignore ZIPFILE = object() @@ -86,8 +91,8 @@ class LazyObject(object): def _getobj(self): if self._imported is None: - self._imported = getattr(load_module_from_name(self.module), - self.obj) + self._imported = getattr(load_module_from_name(self.module), + self.obj) return self._imported def __getattribute__(self, attr): @@ -100,7 +105,7 @@ class LazyObject(object): return self._getobj()(*args, **kwargs) -def load_module_from_name(dotted_name, path=None, use_sys=True): +def load_module_from_name(dotted_name: str, path: Optional[Any] = None, use_sys: int = True) -> ModuleType: """Load a Python module from its name. :type dotted_name: str @@ -122,10 +127,13 @@ def load_module_from_name(dotted_name, path=None, use_sys=True): :rtype: module :return: the loaded module """ - return load_module_from_modpath(dotted_name.split('.'), path, use_sys) + module = load_module_from_modpath(dotted_name.split('.'), path, use_sys) + if module is None: + raise ImportError("module %s doesn't exist" % dotted_name) + return module -def load_module_from_modpath(parts, path=None, use_sys=True): +def load_module_from_modpath(parts: List[str], path: Optional[Any] = None, use_sys: int = True) -> Optional[ModuleType]: """Load a python module from its splitted name. :type parts: list(str) or tuple(str) @@ -208,7 +216,7 @@ def load_module_from_file(filepath, path=None, use_sys=True, extrapath=None): return load_module_from_modpath(modpath, path, use_sys) -def _check_init(path, mod_path): +def _check_init(path: str, mod_path: List[str]) -> bool: """check there are some __init__.py all along the way""" modpath = [] for part in mod_path: @@ -219,13 +227,13 @@ def _check_init(path, mod_path): return True -def _canonicalize_path(path): +def _canonicalize_path(path: str) -> str: return realpath(expanduser(path)) @deprecated('you should avoid using modpath_from_file()') -def modpath_from_file(filename, extrapath=None): +def modpath_from_file(filename: str, extrapath: Optional[Dict[str, str]] = None) -> List[str]: """DEPRECATED: doens't play well with symlinks and sys.meta_path Given a file path return the corresponding splitted module's name @@ -269,7 +277,7 @@ def modpath_from_file(filename, extrapath=None): filename, ', \n'.join(sys.path))) -def file_from_modpath(modpath, path=None, context_file=None): +def file_from_modpath(modpath: List[str], path: Optional[Any] = None, context_file: Optional[str] = None) -> Optional[str]: """given a mod path (i.e. splitted module / package name), return the corresponding file, giving priority to source file over precompiled file if it exists @@ -299,6 +307,7 @@ def file_from_modpath(modpath, path=None, context_file=None): the path to the module's file or None if it's an integrated builtin module such as 'sys' """ + context: Optional[str] if context_file is not None: context = dirname(context_file) else: @@ -316,7 +325,7 @@ def file_from_modpath(modpath, path=None, context_file=None): -def get_module_part(dotted_name, context_file=None): +def get_module_part(dotted_name: str, context_file: Optional[str] = None) -> str: """given a dotted name return the module part of the name : >>> get_module_part('logilab.common.modutils.get_module_part') @@ -354,7 +363,7 @@ def get_module_part(dotted_name, context_file=None): raise ImportError(dotted_name) return parts[0] # don't use += or insert, we want a new list to be created ! - path = None + path: Optional[List] = None starti = 0 if parts[0] == '': assert context_file is not None, \ @@ -363,6 +372,7 @@ def get_module_part(dotted_name, context_file=None): starti = 1 while parts[starti] == '': # for all further dots: change context starti += 1 + assert context_file is not None context_file = dirname(context_file) for i in range(starti, len(parts)): try: @@ -375,7 +385,7 @@ def get_module_part(dotted_name, context_file=None): return dotted_name -def get_modules(package, src_directory, blacklist=STD_BLACKLIST): +def get_modules(package: str, src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST) -> List[str]: """given a package directory return a list of all available python modules in the package and its subpackages @@ -415,7 +425,7 @@ def get_modules(package, src_directory, blacklist=STD_BLACKLIST): -def get_module_files(src_directory, blacklist=STD_BLACKLIST): +def get_module_files(src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST) -> List[str]: """given a package directory return a list of all available python module's files in the package and its subpackages @@ -447,7 +457,7 @@ def get_module_files(src_directory, blacklist=STD_BLACKLIST): return files -def get_source_file(filename, include_no_ext=False): +def get_source_file(filename: str, include_no_ext: bool = False) -> str: """given a python module's file name return the matching source file name (the filename will be returned identically if it's a already an absolute path to a python source file...) @@ -505,7 +515,7 @@ def is_python_source(filename): return splitext(filename)[1][1:] in PY_SOURCE_EXTS -def is_standard_module(modname, std_path=(STD_LIB_DIR,)): +def is_standard_module(modname: str, std_path: Union[List[str], Tuple[str]] = (STD_LIB_DIR,)) -> bool: """try to guess if a module is a standard python module (by default, see `std_path` parameter's description) @@ -547,7 +557,7 @@ def is_standard_module(modname, std_path=(STD_LIB_DIR,)): -def is_relative(modname, from_file): +def is_relative(modname: str, from_file: str) -> bool: """return true if the given module name is relative to the given file name @@ -575,7 +585,7 @@ def is_relative(modname, from_file): # internal only functions ##################################################### -def _file_from_modpath(modpath, path=None, context=None): +def _file_from_modpath(modpath: List[str], path: Optional[Any] = None, context: Optional[str] = None) -> Optional[str]: """given a mod path (i.e. splitted module / package name), return the corresponding file @@ -592,6 +602,7 @@ def _file_from_modpath(modpath, path=None, context=None): mtype, mp_filename = _module_file(modpath, path) if mtype == PY_COMPILED: try: + assert mp_filename is not None return get_source_file(mp_filename) except NoSourceFile: return mp_filename @@ -599,10 +610,11 @@ def _file_from_modpath(modpath, path=None, context=None): # integrated builtin module return None elif mtype == PKG_DIRECTORY: + assert mp_filename is not None mp_filename = _has_init(mp_filename) return mp_filename -def _search_zip(modpath, pic): +def _search_zip(modpath: List[str], pic: Dict[str, Optional[FileFinder]]) -> Tuple[object, str, str]: for filepath, importer in pic.items(): if importer is not None: if importer.find_module(modpath[0]): @@ -615,15 +627,19 @@ def _search_zip(modpath, pic): try: import pkg_resources except ImportError: - pkg_resources = None + # mypy: Incompatible types in assignment (expression has type "None", variable has type Module) + # conditional import + pkg_resources = None # type: ignore -def _is_namespace(modname): +def _is_namespace(modname: str) -> bool: + # mypy: Module has no attribute "_namespace_packages"; maybe "fixup_namespace_packages"?" + # but is still has? or is it a failure from python3 port? return (pkg_resources is not None - and modname in pkg_resources._namespace_packages) + and modname in pkg_resources._namespace_packages) # type: ignore -def _module_file(modpath, path=None): +def _module_file(modpath: List[str], path: Optional[List[str]] = None) -> Tuple[Union[int, object], Optional[str]]: """get a module type / file path :type modpath: list or tuple @@ -643,7 +659,7 @@ def _module_file(modpath, path=None): # egg support compat try: pic = sys.path_importer_cache - _path = (path is None and sys.path or path) + _path = path if path is not None else sys.path for __path in _path: if not __path in pic: try: @@ -660,10 +676,14 @@ def _module_file(modpath, path=None): module = sys.modules[modpath.pop(0)] # use list() to protect against _NamespacePath instance we get with python 3, which # find_module later doesn't like - path = list(module.__path__) + # mypy: Module has no attribute "__path__" + # I guess it does thanks to logilab's magic? + path = list(module.__path__) # type: ignore if not modpath: return C_BUILTIN, None + imported = [] + while modpath: modname = modpath[0] # take care to changes in find_module implementation wrt builtin modules @@ -719,7 +739,7 @@ def _module_file(modpath, path=None): path = [mp_filename] return mtype, mp_filename -def _is_python_file(filename): +def _is_python_file(filename: str) -> bool: """return true if the given filename should be considered as a python file .pyc and .pyo are ignored @@ -730,12 +750,14 @@ def _is_python_file(filename): return False -def _has_init(directory): +def _has_init(directory: str) -> Optional[str]: """if the given directory has a valid __init__ file, return its path, else return None """ mod_or_pack = join(directory, '__init__') + for ext in PY_SOURCE_EXTS + ('pyc', 'pyo'): if exists(mod_or_pack + '.' + ext): return mod_or_pack + '.' + ext + return None diff --git a/logilab/common/optik_ext.py b/logilab/common/optik_ext.py index 07365a7..11e2155 100644 --- a/logilab/common/optik_ext.py +++ b/logilab/common/optik_ext.py @@ -55,6 +55,11 @@ import sys import time from copy import copy from os.path import exists +from logilab.common import attrdict + +from typing import Any, Union, List, Optional, Tuple, Dict +from optparse import Values, IndentedHelpFormatter, OptionGroup +from _io import StringIO # python >= 2.3 from optparse import OptionParser as BaseParser, Option as BaseOption, \ @@ -83,7 +88,7 @@ def check_regexp(option, opt, value): raise OptionValueError( "option %s: invalid regexp value: %r" % (opt, value)) -def check_csv(option, opt, value): +def check_csv(option: Optional['Option'], opt: str, value: Union[List[str], Tuple[str, ...], str]) -> Union[List[str], Tuple[str, ...]]: """check a csv value by trying to split it return the list of separated values """ @@ -95,7 +100,7 @@ def check_csv(option, opt, value): raise OptionValueError( "option %s: invalid csv value: %r" % (opt, value)) -def check_yn(option, opt, value): +def check_yn(option: Optional['Option'], opt: str, value: Union[bool, str]) -> bool: """check a yn value return true for yes and false for no """ @@ -108,18 +113,21 @@ def check_yn(option, opt, value): msg = "option %s: invalid yn value %r, should be in (y, yes, n, no)" raise OptionValueError(msg % (opt, value)) -def check_named(option, opt, value): +def check_named(option: Optional[Any], opt: str, value: Union[Dict[str, str], str]) -> Dict[str, str]: """check a named value return a dictionary containing (name, value) associations """ if isinstance(value, dict): return value - values = [] + values: List[Tuple[str, str]] = [] for value in check_csv(option, opt, value): + # mypy: Argument 1 to "append" of "list" has incompatible type "List[str]"; + # mypy: expected "Tuple[str, str]" + # we know that the split will give a 2 items list if value.find('=') != -1: - values.append(value.split('=', 1)) + values.append(value.split('=', 1)) # type: ignore elif value.find(':') != -1: - values.append(value.split(':', 1)) + values.append(value.split(':', 1)) # type: ignore if values: return dict(values) msg = "option %s: invalid named value %r, should be <NAME>=<VALUE> or \ @@ -173,10 +181,12 @@ def check_time(option, opt, value): return value return apply_units(value, TIME_UNITS) -def check_bytes(option, opt, value): +def check_bytes(option: Optional['Option'], opt: str, value: Any) -> int: if hasattr(value, '__int__'): return value - return apply_units(value, BYTE_UNITS, final=int) + # mypy: Incompatible return value type (got "Union[float, int]", expected "int") + # we force "int" using "final=int" + return apply_units(value, BYTE_UNITS, final=int) # type: ignore class Option(BaseOption): @@ -201,50 +211,62 @@ class Option(BaseOption): TYPES += ('date',) TYPE_CHECKER['date'] = check_date - def __init__(self, *opts, **attrs): + def __init__(self, *opts: str, **attrs: Any) -> None: BaseOption.__init__(self, *opts, **attrs) - if hasattr(self, "hide") and self.hide: + # mypy: "Option" has no attribute "hide" + # we test that in the if + if hasattr(self, "hide") and self.hide: # type: ignore self.help = SUPPRESS_HELP - def _check_choice(self): + def _check_choice(self) -> None: """FIXME: need to override this due to optik misdesign""" if self.type in ("choice", "multiple_choice"): - if self.choices is None: + # mypy: "Option" has no attribute "choices" + # we know that option of this type has this attribute + if self.choices is None: # type: ignore raise OptionError( "must supply a list of choices for type 'choice'", self) - elif not isinstance(self.choices, (tuple, list)): + elif not isinstance(self.choices, (tuple, list)): # type: ignore raise OptionError( "choices must be a list of strings ('%s' supplied)" - % str(type(self.choices)).split("'")[1], self) - elif self.choices is not None: + % str(type(self.choices)).split("'")[1], self) # type: ignore + elif self.choices is not None: # type: ignore raise OptionError( "must not supply choices for type %r" % self.type, self) - BaseOption.CHECK_METHODS[2] = _check_choice - + # mypy: Unsupported target for indexed assignment + # black magic? + BaseOption.CHECK_METHODS[2] = _check_choice # type: ignore - def process(self, opt, value, values, parser): + def process(self, opt: str, value: str, values: Values, parser: BaseParser) -> int: # First, convert the value(s) to the right type. Howl if any # value(s) are bogus. value = self.convert_value(opt, value) if self.type == 'named': + assert self.dest is not None existant = getattr(values, self.dest) if existant: existant.update(value) value = existant - # And then take whatever action is expected of us. + # And then take whatever action is expected of us. # This is a separate method to make life easier for # subclasses to add new actions. + # mypy: Argument 2 to "take_action" of "Option" has incompatible type "Optional[str]"; + # mypy: expected "str" + # is it ok? return self.take_action( - self.action, self.dest, opt, value, values, parser) + self.action, self.dest, opt, value, values, parser) # type: ignore class OptionParser(BaseParser): """override optik.OptionParser to use our Option class """ - def __init__(self, option_class=Option, *args, **kwargs): - BaseParser.__init__(self, option_class=Option, *args, **kwargs) + def __init__(self, option_class: type = Option, *args: Any, **kwargs: Any) -> None: + # mypy: Argument "option_class" to "__init__" of "OptionParser" has incompatible type + # mypy: "type"; expected "Option" + # mypy is doing really weird things with *args/**kwargs and looks buggy + BaseParser.__init__(self, option_class=option_class, *args, **kwargs) # type: ignore - def format_option_help(self, formatter=None): + def format_option_help(self, formatter: Optional[HelpFormatter] = None) -> str: if formatter is None: formatter = self.formatter outputlevel = getattr(formatter, 'output_level', 0) @@ -256,7 +278,9 @@ class OptionParser(BaseParser): result.append(OptionContainer.format_option_help(self, formatter)) result.append("\n") for group in self.option_groups: - if group.level <= outputlevel and ( + # mypy: "OptionParser" has no attribute "level" + # but it has one no? + if group.level <= outputlevel and ( # type: ignore group.description or level_options(group, outputlevel)): result.append(group.format_help(formatter)) result.append("\n") @@ -265,12 +289,16 @@ class OptionParser(BaseParser): return "".join(result[:-1]) -OptionGroup.level = 0 +# mypy error: error: "Type[OptionGroup]" has no attribute "level" +# monkeypatching +OptionGroup.level = 0 # type: ignore -def level_options(group, outputlevel): +def level_options(group: BaseParser, outputlevel: int) -> List[BaseOption]: + # mypy: "Option" has no attribute "help" + # but it does return [option for option in group.option_list if (getattr(option, 'level', 0) or 0) <= outputlevel - and not option.help is SUPPRESS_HELP] + and not option.help is SUPPRESS_HELP] # type: ignore def format_option_help(self, formatter): result = [] @@ -278,33 +306,42 @@ def format_option_help(self, formatter): for option in level_options(self, outputlevel): result.append(formatter.format_option(option)) return "".join(result) -OptionContainer.format_option_help = format_option_help +# mypy error: Cannot assign to a method +# but we still do it because magic +OptionContainer.format_option_help = format_option_help # type: ignore class ManHelpFormatter(HelpFormatter): """Format help using man pages ROFF format""" def __init__ (self, - indent_increment=0, - max_help_position=24, - width=79, - short_first=0): + indent_increment: int = 0, + max_help_position: int = 24, + width: int = 79, + short_first: int = 0) -> None: HelpFormatter.__init__ ( self, indent_increment, max_help_position, width, short_first) - def format_heading(self, heading): + def format_heading(self, heading: str) -> str: return '.SH %s\n' % heading.upper() def format_description(self, description): return description - def format_option(self, option): + def format_option(self, option: BaseParser) -> str: try: - optstring = option.option_strings + # mypy: "Option" has no attribute "option_strings" + # we handle if it doesn't + optstring = option.option_strings # type: ignore except AttributeError: optstring = self.format_option_strings(option) - if option.help: - help_text = self.expand_default(option) + # mypy: "OptionParser" has no attribute "help" + # it does + if option.help: # type: ignore + # mypy: Argument 1 to "expand_default" of "HelpFormatter" has incompatible type + # mypy: "OptionParser"; expected "Option" + # it still works? + help_text = self.expand_default(option) # type: ignore help = ' '.join([l.strip() for l in help_text.splitlines()]) else: help = '' @@ -312,13 +349,9 @@ class ManHelpFormatter(HelpFormatter): %s ''' % (optstring, help) - def format_head(self, optparser, pkginfo, section=1): + def format_head(self, optparser: OptionParser, pkginfo: attrdict, section: int = 1) -> str: long_desc = "" - try: - pgm = optparser._get_prog_name() - except AttributeError: - # py >= 2.4.X (dunno which X exactly, at least 2) - pgm = optparser.get_prog_name() + pgm = optparser.get_prog_name() short_desc = self.format_short_description(pgm, pkginfo.description) if hasattr(pkginfo, "long_desc"): long_desc = self.format_long_description(pgm, pkginfo.long_desc) @@ -326,17 +359,17 @@ class ManHelpFormatter(HelpFormatter): short_desc, self.format_synopsis(pgm), long_desc) - def format_title(self, pgm, section): + def format_title(self, pgm: str, section: int) -> str: date = '-'.join([str(num) for num in time.localtime()[:3]]) return '.TH %s %s "%s" %s' % (pgm, section, date, pgm) - def format_short_description(self, pgm, short_desc): + def format_short_description(self, pgm: str, short_desc: str) -> str: return '''.SH NAME .B %s \- %s ''' % (pgm, short_desc.strip()) - def format_synopsis(self, pgm): + def format_synopsis(self, pgm: str) -> str: return '''.SH SYNOPSIS .B %s [ @@ -357,7 +390,7 @@ class ManHelpFormatter(HelpFormatter): %s ''' % (pgm, long_desc.strip()) - def format_tail(self, pkginfo): + def format_tail(self, pkginfo: attrdict) -> str: tail = '''.SH SEE ALSO /usr/share/doc/pythonX.Y-%s/ @@ -378,10 +411,12 @@ Please report bugs on the project\'s mailing list: return tail -def generate_manpage(optparser, pkginfo, section=1, stream=sys.stdout, level=0): +def generate_manpage(optparser: OptionParser, pkginfo: attrdict, section: int = 1, stream: StringIO = sys.stdout, level: int = 0) -> None: """generate a man page from an optik parser""" formatter = ManHelpFormatter() - formatter.output_level = level + # mypy: "ManHelpFormatter" has no attribute "output_level" + # dynamic attribute? + formatter.output_level = level # type: ignore formatter.parser = optparser print(formatter.format_head(optparser, pkginfo, section), file=stream) print(optparser.format_option_help(formatter), file=stream) diff --git a/logilab/common/proc.py b/logilab/common/proc.py index c27356c..30e9494 100644 --- a/logilab/common/proc.py +++ b/logilab/common/proc.py @@ -133,14 +133,9 @@ class ProcInfoLoader: pass -try: - class ResourceError(BaseException): - """Error raise when resource limit is reached""" - limit = "Unknown Resource Limit" -except NameError: - class ResourceError(Exception): - """Error raise when resource limit is reached""" - limit = "Unknown Resource Limit" +class ResourceError(Exception): + """Error raise when resource limit is reached""" + limit = "Unknown Resource Limit" class XCPUError(ResourceError): diff --git a/logilab/common/pytest.py b/logilab/common/pytest.py index 5c62816..6819c01 100644 --- a/logilab/common/pytest.py +++ b/logilab/common/pytest.py @@ -116,13 +116,20 @@ FILE_RESTART = ".pytest.restart" import os, sys, re import os.path as osp from time import process_time, time +from re import Match import warnings import types import inspect import traceback -from inspect import isgeneratorfunction, isclass +from inspect import isgeneratorfunction, isclass, FrameInfo from random import shuffle from itertools import dropwhile +# mypy error: Module 'unittest.runner' has no attribute '_WritelnDecorator' +# but it does +from unittest.runner import _WritelnDecorator # type: ignore +from unittest.suite import TestSuite + +from typing import Callable, Any, Optional, List, Tuple, Generator, Union, Dict from logilab.common.deprecation import deprecated from logilab.common.fileutils import abspath_listdir @@ -141,7 +148,8 @@ if not getattr(unittest_legacy, "__package__", None): except ImportError: sys.exit("You have to install python-unittest2 to use this module") else: - import unittest.suite as unittest_suite + # mypy: Name 'unittest_suite' already defined (possibly by an import)) + import unittest.suite as unittest_suite # type: ignore try: import django @@ -153,12 +161,12 @@ except ImportError: CONF_FILE = 'pytestconf.py' TESTFILE_RE = re.compile("^((unit)?test.*|smoketest)\.py$") -def this_is_a_testfile(filename): +def this_is_a_testfile(filename: str) -> Optional[Match]: """returns True if `filename` seems to be a test file""" return TESTFILE_RE.match(osp.basename(filename)) TESTDIR_RE = re.compile("^(unit)?tests?$") -def this_is_a_testdir(dirpath): +def this_is_a_testdir(dirpath: str) -> Optional[Match]: """returns True if `filename` seems to be a test directory""" return TESTDIR_RE.match(osp.basename(dirpath)) @@ -842,10 +850,10 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner): self.batchmode = batchmode self.options = options - def _this_is_skipped(self, testedname): + def _this_is_skipped(self, testedname: str) -> bool: return any([(pat in testedname) for pat in self.skipped_patterns]) - def _runcondition(self, test, skipgenerator=True): + def _runcondition(self, test: Callable, skipgenerator: bool = True) -> bool: if isinstance(test, testlib.InnerTest): testname = test.name else: @@ -876,7 +884,7 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner): return self.does_match_tags(test) - def does_match_tags(self, test): + def does_match_tags(self, test: Callable) -> bool: if self.options is not None: tags_pattern = getattr(self.options, 'tags_pattern', None) if tags_pattern is not None: @@ -886,7 +894,7 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner): return tags.match(tags_pattern) return True # no pattern - def _makeResult(self): + def _makeResult(self) -> 'SkipAwareTestResult': return SkipAwareTestResult(self.stream, self.descriptions, self.verbosity, self.exitfirst, self.pdbmode, self.cvg, self.colorize) @@ -935,14 +943,14 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner): class SkipAwareTestResult(unittest._TextTestResult): - def __init__(self, stream, descriptions, verbosity, - exitfirst=False, pdbmode=False, cvg=None, colorize=False): + def __init__(self, stream: _WritelnDecorator, descriptions: bool, verbosity: int, + exitfirst: bool = False, pdbmode: bool = False, cvg: Optional[Any] = None, colorize: bool = False) -> None: super(SkipAwareTestResult, self).__init__(stream, descriptions, verbosity) - self.skipped = [] - self.debuggers = [] - self.fail_descrs = [] - self.error_descrs = [] + self.skipped: List[Tuple[Any, Any]] = [] + self.debuggers: List = [] + self.fail_descrs: List = [] + self.error_descrs: List = [] self.exitfirst = exitfirst self.pdbmode = pdbmode self.cvg = cvg @@ -950,15 +958,15 @@ class SkipAwareTestResult(unittest._TextTestResult): self.pdbclass = Debugger self.verbose = verbosity > 1 - def descrs_for(self, flavour): + def descrs_for(self, flavour: str) -> List[Tuple[int, str]]: return getattr(self, '%s_descrs' % flavour.lower()) - def _create_pdb(self, test_descr, flavour): + def _create_pdb(self, test_descr: str, flavour: str) -> None: self.descrs_for(flavour).append( (len(self.debuggers), test_descr) ) if self.pdbmode: self.debuggers.append(self.pdbclass(sys.exc_info()[2])) - def _iter_valid_frames(self, frames): + def _iter_valid_frames(self, frames: List[FrameInfo]) -> Generator[FrameInfo, Any, None]: """only consider non-testlib frames when formatting traceback""" lgc_testlib = osp.abspath(__file__) std_testlib = osp.abspath(unittest.__file__) @@ -1030,11 +1038,11 @@ class SkipAwareTestResult(unittest._TextTestResult): elif self.dots: self.stream.write('S') - def printErrors(self): + def printErrors(self) -> None: super(SkipAwareTestResult, self).printErrors() self.printSkippedList() - def printSkippedList(self): + def printSkippedList(self) -> None: # format (test, err) compatible with unittest2 for test, err in self.skipped: descr = self.getDescription(test) @@ -1055,9 +1063,11 @@ class SkipAwareTestResult(unittest._TextTestResult): from .decorators import monkeypatch orig_call = testlib.TestCase.__call__ @monkeypatch(testlib.TestCase, '__call__') -def call(self, result=None, runcondition=None, options=None): +def call(self: Any, result: SkipAwareTestResult = None, runcondition: Optional[Callable] = None, options: Optional[Any] = None) -> None: orig_call(self, result=result, runcondition=runcondition, options=options) - if hasattr(options, "exitfirst") and options.exitfirst: + # mypy: Item "None" of "Optional[Any]" has no attribute "exitfirst" + # we check it first in the if + if hasattr(options, "exitfirst") and options.exitfirst: # type: ignore # add this test to restart file try: restartfile = open(FILE_RESTART, 'a') @@ -1103,18 +1113,18 @@ class NonStrictTestLoader(unittest.TestLoader): 'python test_foo.py test_bar' will run FooTC.test_bar1 and BarTC.test_bar2 """ - def __init__(self): + def __init__(self) -> None: self.skipped_patterns = () # some magic here to accept empty list by extending # and to provide callable capability - def loadTestsFromNames(self, names, module=None): + def loadTestsFromNames(self, names: List[str], module: type = None) -> TestSuite: suites = [] for name in names: suites.extend(self.loadTestsFromName(name, module)) return self.suiteClass(suites) - def _collect_tests(self, module): + def _collect_tests(self, module: type) -> Dict[str, Tuple[type, List[str]]]: tests = {} for obj in vars(module).values(): if isclass(obj) and issubclass(obj, unittest.TestCase): @@ -1182,10 +1192,12 @@ class NonStrictTestLoader(unittest.TestLoader): if pattern in methodname] return collected - def _this_is_skipped(self, testedname): - return any([(pat in testedname) for pat in self.skipped_patterns]) + def _this_is_skipped(self, testedname: str) -> bool: + # mypy: Need type annotation for 'pat' + # doc doesn't say how to that in list comprehension + return any([(pat in testedname) for pat in self.skipped_patterns]) # type: ignore - def getTestCaseNames(self, testCaseClass): + def getTestCaseNames(self, testCaseClass: type) -> List[str]: """Return a sorted sequence of method names found within testCaseClass """ is_skipped = self._this_is_skipped @@ -1202,13 +1214,13 @@ class NonStrictTestLoader(unittest.TestLoader): # It is used to monkeypatch the original implementation to support # extra runcondition and options arguments (see in testlib.py) -def _ts_run(self, result, runcondition=None, options=None): +def _ts_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult: self._wrapped_run(result, runcondition=runcondition, options=options) self._tearDownPreviousClass(None, result) self._handleModuleTearDown(result) return result -def _ts_wrapped_run(self, result, debug=False, runcondition=None, options=None): +def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult: for test in self: if result.shouldStop: break @@ -1247,7 +1259,7 @@ if sys.version_info >= (2, 7): # The function below implements a modified version of the # TestSuite.run method that is provided with python 2.7, in # unittest/suite.py - def _ts_run(self, result, debug=False, runcondition=None, options=None): + def _ts_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult: topLevel = False if getattr(result, '_testRunEntered', False) is False: result._testRunEntered = topLevel = True diff --git a/logilab/common/registry.py b/logilab/common/registry.py index 6f60ef9..d9ae11b 100644 --- a/logilab/common/registry.py +++ b/logilab/common/registry.py @@ -78,6 +78,7 @@ Exceptions from __future__ import print_function + __docformat__ = "restructuredtext en" import sys @@ -87,9 +88,14 @@ import weakref import traceback as tb from os import listdir, stat from os.path import join, isdir, exists -from typing import Dict, Type, Optional, Union +from typing import Dict, Type, Optional, Union, Sequence from logging import getLogger from warnings import warn +from typing import List, Tuple, Any, Iterable, Callable +from types import ModuleType +from typing_extensions import TypedDict +from _frozen_importlib import ModuleSpec +from _frozen_importlib_external import SourceFileLoader from logilab.common.modutils import modpath_from_file from logilab.common.logging_ext import set_log_methods @@ -99,7 +105,7 @@ from logilab.common.deprecation import deprecated # selector base classes and operations ######################################## -def objectify_predicate(selector_func): +def objectify_predicate(selector_func: Callable) -> Any: """Most of the time, a simple score function is enough to build a selector. The :func:`objectify_predicate` decorator turn it into a proper selector class:: @@ -119,7 +125,7 @@ def objectify_predicate(selector_func): _PREDICATES: Dict[int, Type] = {} -def wrap_predicates(decorator): +def wrap_predicates(decorator: Callable) -> None: for predicate in _PREDICATES.values(): if not '_decorators' in predicate.__dict__: predicate._decorators = set() @@ -158,31 +164,36 @@ class Predicate(object, metaclass=PredicateMetaClass): # backward compatibility return self.__class__.__name__ - def search_selector(self, selector): + def search_selector(self, selector: 'Predicate') -> Optional['Predicate']: """search for the given selector, selector instance or tuple of selectors in the selectors tree. Return None if not found. """ if self is selector: return self if (isinstance(selector, type) or isinstance(selector, tuple)) and \ - isinstance(self, selector): + isinstance(self, selector): return self return None def __str__(self): return self.__class__.__name__ - def __and__(self, other): + def __and__(self, other: 'Predicate') -> 'AndPredicate': return AndPredicate(self, other) - def __rand__(self, other): + + def __rand__(self, other: 'Predicate') -> 'AndPredicate': return AndPredicate(other, self) - def __iand__(self, other): + + def __iand__(self, other: 'Predicate') -> 'AndPredicate': return AndPredicate(self, other) - def __or__(self, other): + + def __or__(self, other: 'Predicate') -> 'OrPredicate': return OrPredicate(self, other) - def __ror__(self, other): + + def __ror__(self, other: 'Predicate'): return OrPredicate(other, self) - def __ior__(self, other): + + def __ior__(self, other: 'Predicate') -> 'OrPredicate': return OrPredicate(self, other) def __invert__(self): @@ -201,7 +212,7 @@ class Predicate(object, metaclass=PredicateMetaClass): class MultiPredicate(Predicate): """base class for compound selector classes""" - def __init__(self, *selectors): + def __init__(self, *selectors: Any) -> None: self.selectors = self.merge_selectors(selectors) def __str__(self): @@ -209,7 +220,7 @@ class MultiPredicate(Predicate): ','.join(str(s) for s in self.selectors)) @classmethod - def merge_selectors(cls, selectors): + def merge_selectors(cls, selectors: Sequence[Predicate]) -> List[Predicate]: """deal with selector instanciation when necessary and merge multi-selectors if possible: @@ -231,7 +242,7 @@ class MultiPredicate(Predicate): merged_selectors.append(selector) return merged_selectors - def search_selector(self, selector): + def search_selector(self, selector: Predicate) -> Optional[Predicate]: """search for the given selector or selector instance (or tuple of selectors) in the selectors tree. Return None if not found """ @@ -247,7 +258,7 @@ class MultiPredicate(Predicate): class AndPredicate(MultiPredicate): """and-chained selectors""" - def __call__(self, cls, *args, **kwargs): + def __call__(self, cls: Optional[Any], *args: Any, **kwargs: Any) -> int: score = 0 for selector in self.selectors: partscore = selector(cls, *args, **kwargs) @@ -259,7 +270,7 @@ class AndPredicate(MultiPredicate): class OrPredicate(MultiPredicate): """or-chained selectors""" - def __call__(self, cls, *args, **kwargs): + def __call__(self, cls: Optional[Any], *args: Any, **kwargs: Any) -> int: for selector in self.selectors: partscore = selector(cls, *args, **kwargs) if partscore: @@ -288,7 +299,7 @@ class yes(Predicate): # pylint: disable=C0103 Take care, `yes(0)` could be named 'no'... """ - def __init__(self, score=0.5): + def __init__(self, score: float = 0.5) -> None: self.score = score def __call__(self, *args, **kwargs): @@ -338,7 +349,7 @@ class SelectAmbiguity(RegistryException): """ -def _modname_from_path(path, extrapath=None): +def _modname_from_path(path: str, extrapath: Optional[Any] = None) -> str: modpath = modpath_from_file(path, extrapath) # omit '__init__' from package's name to avoid loading that module # once for each name when it is imported by some other object @@ -356,21 +367,26 @@ def _modname_from_path(path, extrapath=None): return '.'.join(modpath) -def _toload_info(path, extrapath, _toload=None): +def _toload_info(path: List[str], extrapath: Optional[Any], _toload: Optional[Tuple[Dict[str, str], List]] = None) -> Tuple[Dict[str, str], List[Tuple[str, str]]]: """Return a dictionary of <modname>: <modpath> and an ordered list of (file, module name) to load """ if _toload is None: assert isinstance(path, list) _toload = {}, [] + for fileordir in path: if isdir(fileordir) and exists(join(fileordir, '__init__.py')): subfiles = [join(fileordir, fname) for fname in listdir(fileordir)] + _toload_info(subfiles, extrapath, _toload) + elif fileordir[-3:] == '.py': modname = _modname_from_path(fileordir, extrapath) _toload[0][modname] = fileordir + _toload[1].append((fileordir, modname)) + return _toload @@ -404,7 +420,7 @@ class RegistrableObject(object): __abstract__ = True # see doc snipppets below (in Registry class) @classproperty - def __registries__(cls): + def __registries__(cls) -> Union[Tuple[str], Tuple]: if cls.__registry__ is None: return () return (cls.__registry__,) @@ -432,10 +448,17 @@ class RegistrableInstance(RegistrableObject): obj.__module__ = module return obj - def __init__(self, __module__=None): + def __init__(self, __module__: Optional[str] = None) -> None: super(RegistrableInstance, self).__init__() +SelectBestReport = TypedDict("SelectBestReport", {"all_objects": List, "end_score": int, + "winners": List, + "winner": Optional[Any], "self": 'Registry', + "args": List, "kwargs": Dict, + "registry": 'Registry'}) + + class Registry(dict): """The registry store a set of implementations associated to identifier: @@ -469,10 +492,10 @@ class Registry(dict): .. automethod:: possible_objects .. automethod:: object_by_id """ - def __init__(self, debugmode): + def __init__(self, debugmode: bool) -> None: super(Registry, self).__init__() self.debugmode = debugmode - self._select_listeners = [] + self._select_listeners: List[Callable[[SelectBestReport], None]] = [] def __getitem__(self, name): """return the registry (list of implementation objects) associated to @@ -486,16 +509,16 @@ class Registry(dict): raise exc @classmethod - def objid(cls, obj): + def objid(cls, obj: Any) -> str: """returns a unique identifier for an object stored in the registry""" return '%s.%s' % (obj.__module__, cls.objname(obj)) @classmethod - def objname(cls, obj): + def objname(cls, obj: Any) -> str: """returns a readable name for an object stored in the registry""" return getattr(obj, '__name__', id(obj)) - def initialization_completed(self): + def initialization_completed(self) -> None: """call method __registered__() on registered objects when the callback is defined""" for objects in self.values(): @@ -506,7 +529,7 @@ class Registry(dict): if self.debugmode: wrap_predicates(_lltrace) - def register(self, obj, oid=None, clear=False): + def register(self, obj: Any, oid: Optional[Any] = None, clear: bool = False) -> None: """base method to add an object in the registry""" assert not '__abstract__' in obj.__dict__, obj assert obj.__select__, obj @@ -632,15 +655,16 @@ class Registry(dict): once the selection is done and will recieve a dict of the following form:: {"all_objects": [], "end_score": 0, "winners": [], "winner": None or winner, - "self": self, "args": args, "kwargs": kwargs, } + "self": self, "args": args, "kwargs": kwargs, "registry": self} """ if self._select_listeners: - select_best_report = { + select_best_report: SelectBestReport = { "registry": self, "all_objects": [], "end_score": 0, "winners": [], + "winner": None, "self": self, "args": args, "kwargs": kwargs, @@ -692,7 +716,7 @@ class Registry(dict): info = warning = error = critical = exception = debug = lambda msg, *a, **kw: None -def obj_registries(cls, registryname=None): +def obj_registries(cls: Any, registryname: Optional[Any] = None) -> Tuple[str]: """return a tuple of registry names (see __registries__)""" if registryname: return (registryname,) @@ -817,18 +841,18 @@ class RegistryStore(dict): key will be the class used when there is no specific class for a name. """ - def __init__(self, debugmode=False): + def __init__(self, debugmode: bool = False) -> None: super(RegistryStore, self).__init__() self.debugmode = debugmode - def reset(self): + def reset(self) -> None: """clear all registries managed by this store""" # don't use self.clear, we want to keep existing subdictionaries for subdict in self.values(): subdict.clear() - self._lastmodifs = {} + self._lastmodifs: Dict[str, int] = {} - def __getitem__(self, name): + def __getitem__(self, name: str) -> Registry: """return the registry (dictionary of class objects) associated to this name """ @@ -842,9 +866,9 @@ class RegistryStore(dict): # methods for explicit (un)registration ################################### # default class, when no specific class set - REGISTRY_FACTORY = {None: Registry} + REGISTRY_FACTORY: Dict[Union[None, str], type] = {None: Registry} - def registry_class(self, regid): + def registry_class(self, regid: str) -> type: """return existing registry named regid or use factory to create one and return it""" try: @@ -852,14 +876,16 @@ class RegistryStore(dict): except KeyError: return self.REGISTRY_FACTORY[None] - def setdefault(self, regid): + # mypy: Signature of "setdefault" incompatible with supertype "MutableMapping"" + # I don't know how to overload signatures of method in mypy + def setdefault(self, regid: str) -> Registry: # type: ignore try: return self[regid] except RegistryNotFound: self[regid] = self.registry_class(regid)(self.debugmode) return self[regid] - def register_all(self, objects, modname, butclasses=()): + def register_all(self, objects: Iterable, modname: str, butclasses: Sequence = ()) -> None: """register registrable objects into `objects`. Registrable objects are properly configured subclasses of @@ -886,7 +912,7 @@ class RegistryStore(dict): else: self.register(obj) - def register(self, obj, registryname=None, oid=None, clear=False): + def register(self, obj: Any, registryname: Optional[Any] = None, oid: Optional[Any] = None, clear: bool = False) -> None: """register `obj` implementation into `registryname` or `obj.__registries__` if not specified, with identifier `oid` or `obj.__regid__` if not specified. @@ -927,7 +953,7 @@ class RegistryStore(dict): # initialization methods ################################################### - def init_registration(self, path, extrapath=None): + def init_registration(self, path: List[str], extrapath: Optional[Any] = None) -> List[Tuple[str, str]]: """reset registry and walk down path to return list of (path, name) file modules to be loaded""" # XXX make this private by renaming it to _init_registration ? @@ -937,11 +963,11 @@ class RegistryStore(dict): # XXX is _loadedmods still necessary ? It seems like it's useful # to avoid loading same module twice, especially with the # _load_ancestors_then_object logic but this needs to be checked - self._loadedmods = {} + self._loadedmods: Dict[str, Dict[str, type]] = {} return filemods @deprecated('use register_modnames() instead') - def register_objects(self, path, extrapath=None): + def register_objects(self, path: List[str], extrapath: Optional[Any] = None) -> None: """register all objects found walking down <path>""" # load views from each directory in the instance's path # XXX inline init_registration ? @@ -950,14 +976,18 @@ class RegistryStore(dict): self.load_file(filepath, modname) self.initialization_completed() - def register_modnames(self, modnames): + def register_modnames(self, modnames: List[str]) -> None: """register all objects found in <modnames>""" self.reset() self._loadedmods = {} self._toloadmods = {} toload = [] for modname in modnames: - filepath = pkgutil.find_loader(modname).get_filename() + loader = pkgutil.find_loader(modname) + assert loader is not None + # mypy: "Loader" has no attribute "get_filename" + # the selected class has one + filepath = loader.get_filename() # type: ignore if filepath[-4:] in ('.pyc', '.pyo'): # The source file *must* exists filepath = filepath[:-1] @@ -967,12 +997,12 @@ class RegistryStore(dict): self.load_file(filepath, modname) self.initialization_completed() - def initialization_completed(self): + def initialization_completed(self) -> None: """call initialization_completed() on all known registries""" for reg in self.values(): reg.initialization_completed() - def _mdate(self, filepath): + def _mdate(self, filepath: str) -> Optional[int]: """ return the modification date of a file path """ try: return stat(filepath)[-2] @@ -1004,7 +1034,7 @@ class RegistryStore(dict): return True return False - def load_file(self, filepath, modname): + def load_file(self, filepath: str, modname: str) -> None: """ load registrable objects (if any) from a python file """ if modname in self._loadedmods: return @@ -1025,7 +1055,7 @@ class RegistryStore(dict): module = __import__(modname, fromlist=modname.split('.')[:-1]) self.load_module(module) - def load_module(self, module): + def load_module(self, module: ModuleType) -> None: """Automatically handle module objects registration. Instances are registered as soon as they are hashable and have the @@ -1046,11 +1076,13 @@ class RegistryStore(dict): """ self.info('loading %s from %s', module.__name__, module.__file__) if hasattr(module, 'registration_callback'): - module.registration_callback(self) + # mypy: Module has no attribute "registration_callback" + # we check that before + module.registration_callback(self) # type: ignore else: self.register_all(vars(module).values(), module.__name__) - def _load_ancestors_then_object(self, modname, objectcls, butclasses=()): + def _load_ancestors_then_object(self, modname: str, objectcls: type, butclasses: Sequence[Any] = ()) -> None: """handle class registration according to rules defined in :meth:`load_module` """ @@ -1092,7 +1124,7 @@ class RegistryStore(dict): self.register(objectcls) @classmethod - def is_registrable(cls, obj): + def is_registrable(cls, obj: Any) -> bool: """ensure `obj` should be registered as arbitrary stuff may be registered, do a lot of check and warn about @@ -1107,25 +1139,33 @@ class RegistryStore(dict): return False elif issubclass(obj, RegistrableInstance): return False + elif not isinstance(obj, RegistrableInstance): return False + if not obj.__regid__: return False # no regid + registries = obj.__registries__ if not registries: return False # no registries + selector = obj.__select__ if not selector: return False # no selector + if obj.__dict__.get('__abstract__', False): return False + # then detect potential problems that should be warned if not isinstance(registries, (tuple, list)): cls.warning('%s has __registries__ which is not a list or tuple', obj) return False + if not callable(selector): cls.warning('%s has not callable __select__', obj) return False + return True # these are overridden by set_log_methods below diff --git a/logilab/common/shellutils.py b/logilab/common/shellutils.py index d03ae16..2764723 100644 --- a/logilab/common/shellutils.py +++ b/logilab/common/shellutils.py @@ -37,6 +37,8 @@ import random import subprocess import warnings from os.path import exists, isdir, islink, basename, join +from _io import StringIO +from typing import Any, Callable, Optional, List, Union, Iterator, Tuple from logilab.common import STD_BLACKLIST, _handle_blacklist from logilab.common.compat import str_to_bytes @@ -130,7 +132,7 @@ def cp(source, destination): """ mv(source, destination, _action=shutil.copy) -def find(directory, exts, exclude=False, blacklist=STD_BLACKLIST): +def find(directory: str, exts: Union[Tuple[str, ...], str], exclude: bool = False, blacklist: Tuple[str, ...] = STD_BLACKLIST) -> List[str]: """Recursively find files ending with the given extensions from the directory. :type directory: str @@ -158,13 +160,13 @@ def find(directory, exts, exclude=False, blacklist=STD_BLACKLIST): if isinstance(exts, str): exts = (exts,) if exclude: - def match(filename, exts): + def match(filename: str, exts: Tuple[str, ...]) -> bool: for ext in exts: if filename.endswith(ext): return False return True else: - def match(filename, exts): + def match(filename: str, exts: Tuple[str, ...]) -> bool: for ext in exts: if filename.endswith(ext): return True @@ -180,7 +182,7 @@ def find(directory, exts, exclude=False, blacklist=STD_BLACKLIST): return files -def globfind(directory, pattern, blacklist=STD_BLACKLIST): +def globfind(directory: str, pattern: str, blacklist: Tuple[str, str, str, str, str, str, str, str] = STD_BLACKLIST) -> Iterator[str]: """Recursively finds files matching glob `pattern` under `directory`. This is an alternative to `logilab.common.shellutils.find`. @@ -236,7 +238,7 @@ class Execute: class ProgressBar(object): """A simple text progression bar.""" - def __init__(self, nbops, size=20, stream=sys.stdout, title=''): + def __init__(self, nbops: int, size: int = 20, stream: StringIO = sys.stdout, title: str = '') -> None: if title: self._fstr = '\r%s [%%-%ss]' % (title, int(size)) else: @@ -262,7 +264,7 @@ class ProgressBar(object): text = property(_get_text, _set_text, _del_text) - def update(self, offset=1, exact=False): + def update(self, offset: int = 1, exact: bool = False) -> None: """Move FORWARD to new cursor position (cursor will never go backward). :offset: fraction of ``size`` @@ -283,7 +285,7 @@ class ProgressBar(object): self._progress = progress self.refresh() - def refresh(self): + def refresh(self) -> None: """Refresh the progression bar display.""" self._stream.write(self._fstr % ('=' * min(self._progress, self._size)) ) if self._last_text_write_size or self._current_text: @@ -338,7 +340,7 @@ class progress(object): class RawInput(object): - def __init__(self, input_function=None, printer=None, **kwargs): + def __init__(self, input_function: Optional[Callable] = None, printer: Optional[Callable] = None, **kwargs: Any) -> None: if 'input' in kwargs: input_function = kwargs.pop('input') warnings.warn( @@ -349,7 +351,7 @@ class RawInput(object): self._input = input_function or input self._print = printer - def ask(self, question, options, default): + def ask(self, question: str, options: Tuple[str, ...], default: str) -> str: assert default in options choices = [] for option in options: @@ -383,7 +385,7 @@ class RawInput(object): tries -= 1 raise Exception('unable to get a sensible answer') - def confirm(self, question, default_is_yes=True): + def confirm(self, question: str, default_is_yes: bool = True) -> bool: default = default_is_yes and 'y' or 'n' answer = self.ask(question, ('y', 'n'), default) return answer == 'y' diff --git a/logilab/common/table.py b/logilab/common/table.py index 1f1101c..e7b9195 100644 --- a/logilab/common/table.py +++ b/logilab/common/table.py @@ -18,6 +18,10 @@ """Table management module.""" from __future__ import print_function +from types import CodeType +from typing import Any, List, Optional, Tuple, Union, Dict, Iterator +from _io import StringIO +from mypy_extensions import NoReturn __docformat__ = "restructuredtext en" @@ -30,51 +34,64 @@ class Table(object): forall(self.data, lambda x: len(x) <= len(self.col_names)) """ - def __init__(self, default_value=0, col_names=None, row_names=None): - self.col_names = [] - self.row_names = [] - self.data = [] - self.default_value = default_value + def __init__(self, default_value: int = 0, col_names: Optional[List[str]] = None, row_names: Optional[Any] = None) -> None: + self.col_names: List = [] + self.row_names: List = [] + self.data: List = [] + self.default_value: int = default_value if col_names: self.create_columns(col_names) if row_names: self.create_rows(row_names) - def _next_row_name(self): + def _next_row_name(self) -> str: return 'row%s' % (len(self.row_names)+1) - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.data) - def __eq__(self, other): + # def __eq__(self, other: Union[List[List[int]], List[Tuple[str, str, str, float]]]) -> bool: + def __eq__(self, other: object) -> bool: + def is_iterable(variable: Any) -> bool: + try: + iter(variable) + except TypeError: + return False + else: + return True + if other is None: return False + elif is_iterable(other): + # mypy: No overload variant of "list" matches argument type "object" + # checked before + return list(self) == list(other) # type: ignore else: - return list(self) == list(other) + return False __hash__ = object.__hash__ def __ne__(self, other): return not self == other - def __len__(self): + def __len__(self) -> int: return len(self.row_names) ## Rows / Columns creation ################################################# - def create_rows(self, row_names): + def create_rows(self, row_names: List[str]) -> None: """Appends row_names to the list of existing rows """ self.row_names.extend(row_names) for row_name in row_names: self.data.append([self.default_value]*len(self.col_names)) - def create_columns(self, col_names): + def create_columns(self, col_names: List[str]) -> None: """Appends col_names to the list of existing columns """ for col_name in col_names: self.create_column(col_name) - def create_row(self, row_name=None): + def create_row(self, row_name: str = None) -> None: """Creates a rowname to the row_names list """ row_name = row_name or self._next_row_name() @@ -82,7 +99,7 @@ class Table(object): self.data.append([self.default_value]*len(self.col_names)) - def create_column(self, col_name): + def create_column(self, col_name: str) -> None: """Creates a colname to the col_names list """ self.col_names.append(col_name) @@ -90,7 +107,7 @@ class Table(object): row.append(self.default_value) ## Sort by column ########################################################## - def sort_by_column_id(self, col_id, method = 'asc'): + def sort_by_column_id(self, col_id: str, method: str = 'asc') -> None: """Sorts the table (in-place) according to data stored in col_id """ try: @@ -100,7 +117,7 @@ class Table(object): raise KeyError("Col (%s) not found in table" % (col_id)) - def sort_by_column_index(self, col_index, method = 'asc'): + def sort_by_column_index(self, col_index: int, method: str = 'asc') -> None: """Sorts the table 'in-place' according to data stored in col_index method should be in ('asc', 'desc') @@ -119,29 +136,33 @@ class Table(object): self.data.append(row) self.row_names.append(row_name) - def groupby(self, colname, *others): + def groupby(self, colname: str, *others: str) -> Union[Dict[str, Dict[str, 'Table']], + Dict[str, 'Table']]: """builds indexes of data :returns: nested dictionaries pointing to actual rows """ - groups = {} + groups: Dict = {} colnames = (colname,) + others col_indexes = [self.col_names.index(col_id) for col_id in colnames] for row in self.data: ptr = groups for col_index in col_indexes[:-1]: ptr = ptr.setdefault(row[col_index], {}) - ptr = ptr.setdefault(row[col_indexes[-1]], - Table(default_value=self.default_value, - col_names=self.col_names)) - ptr.append_row(tuple(row)) + table = ptr.setdefault(row[col_indexes[-1]], + Table(default_value=self.default_value, + col_names=self.col_names)) + table.append_row(tuple(row)) return groups - def select(self, colname, value): + def select(self, colname: str, value: str) -> 'Table': grouped = self.groupby(colname) try: - return grouped[value] + # mypy: Incompatible return value type (got "Union[Dict[str, Table], Table]", + # mypy: expected "Table") + # I guess we are sure we'll get a Table here? + return grouped[value] # type: ignore except KeyError: - return [] + return Table() def remove(self, colname, value): col_index = self.col_names.index(colname) @@ -151,13 +172,13 @@ class Table(object): ## The 'setter' part ####################################################### - def set_cell(self, row_index, col_index, data): + def set_cell(self, row_index: int, col_index: int, data: int) -> None: """sets value of cell 'row_indew', 'col_index' to data """ self.data[row_index][col_index] = data - def set_cell_by_ids(self, row_id, col_id, data): + def set_cell_by_ids(self, row_id: str, col_id: str, data: Union[int, str]) -> None: """sets value of cell mapped by row_id and col_id to data Raises a KeyError if row_id or col_id are not found in the table """ @@ -173,7 +194,7 @@ class Table(object): raise KeyError("Column (%s) not found in table" % (col_id)) - def set_row(self, row_index, row_data): + def set_row(self, row_index: int, row_data: Union[List[float], List[int], List[str]]) -> None: """sets the 'row_index' row pre:: @@ -183,7 +204,7 @@ class Table(object): self.data[row_index] = row_data - def set_row_by_id(self, row_id, row_data): + def set_row_by_id(self, row_id: str, row_data: List[str]) -> None: """sets the 'row_id' column pre:: @@ -199,7 +220,7 @@ class Table(object): raise KeyError('Row (%s) not found in table' % (row_id)) - def append_row(self, row_data, row_name=None): + def append_row(self, row_data: Union[List[Union[float, str]], List[int]], row_name: Optional[str] = None) -> int: """Appends a row to the table pre:: @@ -211,7 +232,7 @@ class Table(object): self.data.append(row_data) return len(self.data) - 1 - def insert_row(self, index, row_data, row_name=None): + def insert_row(self, index: int, row_data: List[str], row_name: str = None) -> None: """Appends row_data before 'index' in the table. To make 'insert' behave like 'list.insert', inserting in an out of range index will insert row_data to the end of the list @@ -225,7 +246,7 @@ class Table(object): self.data.insert(index, row_data) - def delete_row(self, index): + def delete_row(self, index: int) -> List[str]: """Deletes the 'index' row in the table, and returns it. Raises an IndexError if index is out of range """ @@ -233,7 +254,7 @@ class Table(object): return self.data.pop(index) - def delete_row_by_id(self, row_id): + def delete_row_by_id(self, row_id: str) -> None: """Deletes the 'row_id' row in the table. Raises a KeyError if row_id was not found. """ @@ -244,7 +265,7 @@ class Table(object): raise KeyError('Row (%s) not found in table' % (row_id)) - def set_column(self, col_index, col_data): + def set_column(self, col_index: int, col_data: Union[List[int], range]) -> None: """sets the 'col_index' column pre:: @@ -256,7 +277,7 @@ class Table(object): self.data[row_index][col_index] = cell_data - def set_column_by_id(self, col_id, col_data): + def set_column_by_id(self, col_id: str, col_data: Union[List[int], range]) -> None: """sets the 'col_id' column pre:: @@ -272,7 +293,7 @@ class Table(object): raise KeyError('Column (%s) not found in table' % (col_id)) - def append_column(self, col_data, col_name): + def append_column(self, col_data: range, col_name: str) -> None: """Appends the 'col_index' column pre:: @@ -284,7 +305,7 @@ class Table(object): self.data[row_index].append(cell_data) - def insert_column(self, index, col_data, col_name): + def insert_column(self, index: int, col_data: range, col_name: str) -> None: """Appends col_data before 'index' in the table. To make 'insert' behave like 'list.insert', inserting in an out of range index will insert col_data to the end of the list @@ -298,7 +319,7 @@ class Table(object): self.data[row_index].insert(index, cell_data) - def delete_column(self, index): + def delete_column(self, index: int) -> List[int]: """Deletes the 'index' column in the table, and returns it. Raises an IndexError if index is out of range """ @@ -306,7 +327,7 @@ class Table(object): return [row.pop(index) for row in self.data] - def delete_column_by_id(self, col_id): + def delete_column_by_id(self, col_id: str) -> None: """Deletes the 'col_id' col in the table. Raises a KeyError if col_id was not found. """ @@ -319,53 +340,68 @@ class Table(object): ## The 'getter' part ####################################################### - def get_shape(self): + def get_shape(self) -> Tuple[int, int]: """Returns a tuple which represents the table's shape """ return len(self.row_names), len(self.col_names) shape = property(get_shape) - def __getitem__(self, indices): + def __getitem__(self, indices: Union[Tuple[Union[int, slice, str], Union[int, str]], int, slice]) -> Any: """provided for convenience""" - rows, multirows = None, False - cols, multicols = None, False + multirows: bool = False + multicols: bool = False + + rows: slice + cols: slice + + rows_indice: Union[int, slice, str] + cols_indice: Union[int, str, None] = None + if isinstance(indices, tuple): - rows = indices[0] + rows_indice = indices[0] if len(indices) > 1: - cols = indices[1] + cols_indice = indices[1] else: - rows = indices + rows_indice = indices + # define row slice - if isinstance(rows, str): + if isinstance(rows_indice, str): try: - rows = self.row_names.index(rows) + rows_indice = self.row_names.index(rows_indice) except ValueError: - raise KeyError("Row (%s) not found in table" % (rows)) - if isinstance(rows, int): - rows = slice(rows, rows+1) + raise KeyError("Row (%s) not found in table" % (rows_indice)) + + if isinstance(rows_indice, int): + rows = slice(rows_indice, rows_indice + 1) multirows = False else: rows = slice(None) multirows = True + # define col slice - if isinstance(cols, str): + if isinstance(cols_indice, str): try: - cols = self.col_names.index(cols) + cols_indice = self.col_names.index(cols_indice) except ValueError: - raise KeyError("Column (%s) not found in table" % (cols)) - if isinstance(cols, int): - cols = slice(cols, cols+1) + raise KeyError("Column (%s) not found in table" % (cols_indice)) + + if isinstance(cols_indice, int): + cols = slice(cols_indice, cols_indice + 1) multicols = False else: cols = slice(None) multicols = True + # get sub-table tab = Table() tab.default_value = self.default_value + tab.create_rows(self.row_names[rows]) tab.create_columns(self.col_names[cols]) + for idx, row in enumerate(self.data[rows]): tab.set_row(idx, row[cols]) + if multirows : if multicols: return tab @@ -409,7 +445,7 @@ class Table(object): raise KeyError("Column (%s) not found in table" % (col_id)) return self.get_column(col_index, distinct) - def get_columns(self): + def get_columns(self) -> List[List[int]]: """Returns all the columns in the table """ return [self[:, index] for index in range(len(self.col_names))] @@ -421,14 +457,14 @@ class Table(object): col = list(set(col)) return col - def apply_stylesheet(self, stylesheet): + def apply_stylesheet(self, stylesheet: 'TableStyleSheet') -> None: """Applies the stylesheet to this table """ for instruction in stylesheet.instructions: eval(instruction) - def transpose(self): + def transpose(self) -> 'Table': """Keeps the self object intact, and returns the transposed (rotated) table. """ @@ -440,7 +476,7 @@ class Table(object): return transposed - def pprint(self): + def pprint(self) -> str: """returns a string representing the table in a pretty printed 'text' format. """ @@ -482,7 +518,7 @@ class Table(object): return '\n'.join(lines) - def __repr__(self): + def __repr__(self) -> str: return repr(self.data) def as_text(self): @@ -499,7 +535,7 @@ class TableStyle: """Defines a table's style """ - def __init__(self, table): + def __init__(self, table: Table) -> None: self._table = table self.size = dict([(col_name, '1*') for col_name in table.col_names]) @@ -516,12 +552,12 @@ class TableStyle: self.units['__row_column__'] = '' # XXX FIXME : params order should be reversed for all set() methods - def set_size(self, value, col_id): + def set_size(self, value: str, col_id: str) -> None: """sets the size of the specified col_id to value """ self.size[col_id] = value - def set_size_by_index(self, value, col_index): + def set_size_by_index(self, value: str, col_index: int) -> None: """Allows to set the size according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! @@ -534,13 +570,13 @@ class TableStyle: self.size[col_id] = value - def set_alignment(self, value, col_id): + def set_alignment(self, value: str, col_id: str) -> None: """sets the alignment of the specified col_id to value """ self.alignment[col_id] = value - def set_alignment_by_index(self, value, col_index): + def set_alignment_by_index(self, value: str, col_index: int) -> None: """Allows to set the alignment according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! @@ -553,13 +589,13 @@ class TableStyle: self.alignment[col_id] = value - def set_unit(self, value, col_id): + def set_unit(self, value: str, col_id: str) -> None: """sets the unit of the specified col_id to value """ self.units[col_id] = value - def set_unit_by_index(self, value, col_index): + def set_unit_by_index(self, value: str, col_index: int) -> None: """Allows to set the unit according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! @@ -574,13 +610,13 @@ class TableStyle: self.units[col_id] = value - def get_size(self, col_id): + def get_size(self, col_id: str) -> str: """Returns the size of the specified col_id """ return self.size[col_id] - def get_size_by_index(self, col_index): + def get_size_by_index(self, col_index: int) -> str: """Allows to get the size according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! @@ -593,13 +629,13 @@ class TableStyle: return self.size[col_id] - def get_alignment(self, col_id): + def get_alignment(self, col_id: str) -> str: """Returns the alignment of the specified col_id """ return self.alignment[col_id] - def get_alignment_by_index(self, col_index): + def get_alignment_by_index(self, col_index: int) -> str: """Allors to get the alignment according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! @@ -612,13 +648,13 @@ class TableStyle: return self.alignment[col_id] - def get_unit(self, col_id): + def get_unit(self, col_id: str) -> str: """Returns the unit of the specified col_id """ return self.units[col_id] - def get_unit_by_index(self, col_index): + def get_unit_by_index(self, col_index: int) -> str: """Allors to get the unit according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! @@ -649,28 +685,30 @@ class TableStyleSheet: 2_5 = sqrt(2_3**2 + 2_4**2) """ - def __init__(self, rules = None): + def __init__(self, rules: Optional[List[str]] = None) -> None: rules = rules or [] - self.rules = [] - self.instructions = [] + + self.rules: List[str] = [] + self.instructions: List[CodeType] = [] + for rule in rules: self.add_rule(rule) - def add_rule(self, rule): + def add_rule(self, rule: str) -> None: """Adds a rule to the stylesheet rules """ try: source_code = ['from math import *'] source_code.append(CELL_PROG.sub(r'self.data[\1][\2]', rule)) self.instructions.append(compile('\n'.join(source_code), - 'table.py', 'exec')) + 'table.py', 'exec')) self.rules.append(rule) except SyntaxError: print("Bad Stylesheet Rule : %s [skipped]" % rule) - def add_rowsum_rule(self, dest_cell, row_index, start_col, end_col): + def add_rowsum_rule(self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int) -> None: """Creates and adds a rule to sum over the row at row_index from start_col to end_col. dest_cell is a tuple of two elements (x,y) of the destination cell @@ -686,7 +724,7 @@ class TableStyleSheet: self.add_rule(rule) - def add_rowavg_rule(self, dest_cell, row_index, start_col, end_col): + def add_rowavg_rule(self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int) -> None: """Creates and adds a rule to make the row average (from start_col to end_col) dest_cell is a tuple of two elements (x,y) of the destination cell @@ -703,7 +741,7 @@ class TableStyleSheet: self.add_rule(rule) - def add_colsum_rule(self, dest_cell, col_index, start_row, end_row): + def add_colsum_rule(self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int) -> None: """Creates and adds a rule to sum over the col at col_index from start_row to end_row. dest_cell is a tuple of two elements (x,y) of the destination cell @@ -719,7 +757,7 @@ class TableStyleSheet: self.add_rule(rule) - def add_colavg_rule(self, dest_cell, col_index, start_row, end_row): + def add_colavg_rule(self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int) -> None: """Creates and adds a rule to make the col average (from start_row to end_row) dest_cell is a tuple of two elements (x,y) of the destination cell @@ -741,7 +779,7 @@ class TableCellRenderer: """Defines a simple text renderer """ - def __init__(self, **properties): + def __init__(self, **properties: Any) -> None: """keywords should be properties with an associated boolean as value. For example : renderer = TableCellRenderer(units = True, alignment = False) @@ -752,7 +790,7 @@ class TableCellRenderer: self.properties = properties - def render_cell(self, cell_coord, table, table_style): + def render_cell(self, cell_coord: Tuple[int, int], table: Table, table_style: TableStyle) -> Union[str, int]: """Renders the cell at 'cell_coord' in the table, using table_style """ row_index, col_index = cell_coord @@ -763,14 +801,14 @@ class TableCellRenderer: table_style, col_index + 1) - def render_row_cell(self, row_name, table, table_style): + def render_row_cell(self, row_name: str, table: Table, table_style: TableStyle) -> Union[str, int]: """Renders the cell for 'row_id' row """ cell_value = row_name return self._render_cell_content(cell_value, table_style, 0) - def render_col_cell(self, col_name, table, table_style): + def render_col_cell(self, col_name: str, table: Table, table_style: TableStyle) -> Union[str, int]: """Renders the cell for 'col_id' row """ cell_value = col_name @@ -779,7 +817,7 @@ class TableCellRenderer: - def _render_cell_content(self, content, table_style, col_index): + def _render_cell_content(self, content: Union[str, int], table_style: TableStyle, col_index: int) -> Union[str, int]: """Makes the appropriate rendering for this cell content. Rendering properties will be searched using the *table_style.get_xxx_by_index(col_index)' methods @@ -789,11 +827,12 @@ class TableCellRenderer: return content - def _make_cell_content(self, cell_content, table_style, col_index): + def _make_cell_content(self, cell_content: int, table_style: TableStyle, col_index: int) -> Union[int, str]: """Makes the cell content (adds decoration data, like units for example) """ - final_content = cell_content + final_content: Union[int, str] = cell_content + if 'skip_zero' in self.properties: replacement_char = self.properties['skip_zero'] else: @@ -812,7 +851,7 @@ class TableCellRenderer: return final_content - def _add_unit(self, cell_content, table_style, col_index): + def _add_unit(self, cell_content: int, table_style: TableStyle, col_index: int) -> str: """Adds unit to the cell_content if needed """ unit = table_style.get_unit_by_index(col_index) @@ -824,7 +863,7 @@ class DocbookRenderer(TableCellRenderer): """Defines how to render a cell for a docboook table """ - def define_col_header(self, col_index, table_style): + def define_col_header(self, col_index: int, table_style: TableStyle) -> str: """Computes the colspec element according to the style """ size = table_style.get_size_by_index(col_index) @@ -832,7 +871,7 @@ class DocbookRenderer(TableCellRenderer): (col_index, size) - def _render_cell_content(self, cell_content, table_style, col_index): + def _render_cell_content(self, cell_content: Union[int, str], table_style: TableStyle, col_index: int) -> str: """Makes the appropriate rendering for this cell content. Rendering properties will be searched using the table_style.get_xxx_by_index(col_index)' methods. @@ -847,17 +886,20 @@ class DocbookRenderer(TableCellRenderer): # KeyError <=> Default alignment return "<entry>%s</entry>\n" % cell_content + # XXX really? + return "" + class TableWriter: """A class to write tables """ - def __init__(self, stream, table, style, **properties): + def __init__(self, stream: StringIO, table: Table, style: Optional[Any], **properties: Any) -> None: self._stream = stream self.style = style or TableStyle(table) self._table = table self.properties = properties - self.renderer = None + self.renderer: Optional[DocbookRenderer] = None def set_style(self, style): @@ -866,7 +908,7 @@ class TableWriter: self.style = style - def set_renderer(self, renderer): + def set_renderer(self, renderer: DocbookRenderer) -> None: """sets the way to render cell """ self.renderer = renderer @@ -878,7 +920,7 @@ class TableWriter: self.properties.update(properties) - def write_table(self, title = ""): + def write_table(self, title: str = "") -> None: """Writes the table """ raise NotImplementedError("write_table must be implemented !") @@ -889,9 +931,11 @@ class DocbookTableWriter(TableWriter): """Defines an implementation of TableWriter to write a table in Docbook """ - def _write_headers(self): + def _write_headers(self) -> None: """Writes col headers """ + assert self.renderer is not None + # Define col_headers (colstpec elements) for col_index in range(len(self._table.col_names)+1): self._stream.write(self.renderer.define_col_header(col_index, @@ -908,9 +952,11 @@ class DocbookTableWriter(TableWriter): self._stream.write("</row>\n</thead>\n") - def _write_body(self): + def _write_body(self) -> None: """Writes the table body """ + assert self.renderer is not None + self._stream.write('<tbody>\n') for row_index, row in enumerate(self._table.data): @@ -931,7 +977,7 @@ class DocbookTableWriter(TableWriter): self._stream.write('</tbody>\n') - def write_table(self, title = ""): + def write_table(self, title: str = "") -> None: """Writes the table """ self._stream.write('<table>\n<title>%s></title>\n'%(title)) diff --git a/logilab/common/tasksqueue.py b/logilab/common/tasksqueue.py index 7d561ca..4e3434e 100644 --- a/logilab/common/tasksqueue.py +++ b/logilab/common/tasksqueue.py @@ -19,8 +19,9 @@ __docformat__ = "restructuredtext en" -from bisect import insort_left +from typing import Iterator, List +from bisect import insort_left import queue LOW = 0 @@ -35,16 +36,15 @@ PRIORITY = { REVERSE_PRIORITY = dict((values, key) for key, values in PRIORITY.items()) - class PrioritizedTasksQueue(queue.Queue): - def _init(self, maxsize): + def _init(self, maxsize: int) -> None: """Initialize the queue representation""" self.maxsize = maxsize # ordered list of task, from the lowest to the highest priority - self.queue = [] + self.queue: List['Task'] = [] # type: ignore - def _put(self, item): + def _put(self, item: 'Task') -> None: """Put a new item in the queue""" for i, task in enumerate(self.queue): # equivalent task @@ -60,14 +60,14 @@ class PrioritizedTasksQueue(queue.Queue): return insort_left(self.queue, item) - def _get(self): + def _get(self) -> 'Task': """Get an item from the queue""" return self.queue.pop() - def __iter__(self): + def __iter__(self) -> Iterator['Task']: return iter(self.queue) - def remove(self, tid): + def remove(self, tid: str) -> None: """remove a specific task from the queue""" # XXX acquire lock for i, task in enumerate(self): @@ -76,26 +76,23 @@ class PrioritizedTasksQueue(queue.Queue): return raise ValueError('not task of id %s in queue' % tid) -class Task(object): - def __init__(self, tid, priority=LOW): +class Task: + def __init__(self, tid: str, priority: int = LOW) -> None: # task id self.id = tid # task priority self.priority = priority - def __repr__(self): + def __repr__(self) -> str: return '<Task %s @%#x>' % (self.id, id(self)) - def __cmp__(self, other): - return cmp(self.priority, other.priority) - - def __lt__(self, other): + def __lt__(self, other: 'Task') -> bool: return self.priority < other.priority - def __eq__(self, other): - return self.id == other.id + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and self.id == other.id __hash__ = object.__hash__ - def merge(self, other): + def merge(self, other: 'Task') -> None: pass diff --git a/logilab/common/testlib.py b/logilab/common/testlib.py index dae1ff5..8348900 100644 --- a/logilab/common/testlib.py +++ b/logilab/common/testlib.py @@ -55,6 +55,8 @@ import warnings from shutil import rmtree from operator import itemgetter from inspect import isgeneratorfunction +from typing import Any, Iterator, Union, Optional, Callable, Dict, List, Tuple +from mypy_extensions import NoReturn import builtins import configparser @@ -69,7 +71,9 @@ if not getattr(unittest_legacy, "__package__", None): except ImportError: raise ImportError("You have to install python-unittest2 to use %s" % __name__) else: - import unittest as unittest + # mypy: Name 'unittest' already defined (possibly by an import) + # compat + import unittest as unittest # type: ignore from unittest import SkipTest from functools import wraps @@ -91,11 +95,11 @@ __unittest = 1 @deprecated('with_tempdir is deprecated, use tempfile.TemporaryDirectory.') -def with_tempdir(callable): +def with_tempdir(callable: Callable) -> Callable: """A decorator ensuring no temporary file left when the function return Work only for temporary file created with the tempfile module""" if isgeneratorfunction(callable): - def proxy(*args, **kwargs): + def proxy(*args: Any, **kwargs: Any) -> Iterator[Union[Iterator, Iterator[str]]]: old_tmpdir = tempfile.gettempdir() new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-") tempfile.tempdir = new_tmpdir @@ -109,20 +113,21 @@ def with_tempdir(callable): tempfile.tempdir = old_tmpdir return proxy - @wraps(callable) - def proxy(*args, **kargs): + else: + @wraps(callable) + def proxy(*args: Any, **kargs: Any) -> Any: - old_tmpdir = tempfile.gettempdir() - new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-") - tempfile.tempdir = new_tmpdir - try: - return callable(*args, **kargs) - finally: + old_tmpdir = tempfile.gettempdir() + new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-") + tempfile.tempdir = new_tmpdir try: - rmtree(new_tmpdir, ignore_errors=True) + return callable(*args, **kargs) finally: - tempfile.tempdir = old_tmpdir - return proxy + try: + rmtree(new_tmpdir, ignore_errors=True) + finally: + tempfile.tempdir = old_tmpdir + return proxy def in_tempdir(callable): """A decorator moving the enclosed function inside the tempfile.tempfdir @@ -204,7 +209,7 @@ def start_interactive_mode(result): # coverage pausing tools ##################################################### @contextmanager -def replace_trace(trace=None): +def replace_trace(trace: Optional[Callable] = None) -> Iterator: """A context manager that temporary replaces the trace function""" oldtrace = sys.gettrace() sys.settrace(trace) @@ -222,16 +227,20 @@ def replace_trace(trace=None): pause_trace = replace_trace -def nocoverage(func): +def nocoverage(func: Callable) -> Callable: """Function decorator that pauses tracing functions""" if hasattr(func, 'uncovered'): return func - func.uncovered = True + # mypy: "Callable[..., Any]" has no attribute "uncovered" + # dynamic attribute for magic + func.uncovered = True # type: ignore - def not_covered(*args, **kwargs): + def not_covered(*args: Any, **kwargs: Any) -> Any: with pause_trace(): return func(*args, **kwargs) - not_covered.uncovered = True + # mypy: "Callable[[VarArg(Any), KwArg(Any)], NoReturn]" has no attribute "uncovered" + # dynamic attribute for magic + not_covered.uncovered = True # type: ignore return not_covered @@ -264,10 +273,10 @@ class InnerTestSkipped(SkipTest): """raised when a test is skipped""" pass -def parse_generative_args(params): +def parse_generative_args(params: Tuple[int, ...]) -> Tuple[Union[List[bool], List[int]], Dict]: args = [] varargs = () - kwargs = {} + kwargs: Dict = {} flags = 0 # 2 <=> starargs, 4 <=> kwargs for param in params: if isinstance(param, starargs): @@ -298,22 +307,28 @@ class InnerTest(tuple): class Tags(set): """A set of tag able validate an expression""" - def __init__(self, *tags, **kwargs): + def __init__(self, *tags: str, **kwargs: Any) -> None: self.inherit = kwargs.pop('inherit', True) if kwargs: raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys()) if len(tags) == 1 and not isinstance(tags[0], str): tags = tags[0] - super(Tags, self).__init__(tags, **kwargs) + super(Tags, self).__init__(tags) - def __getitem__(self, key): + def __getitem__(self, key: str) -> bool: return key in self - def match(self, exp): - return eval(exp, {}, self) + def match(self, exp: str) -> bool: + # mypy: Argument 3 to "eval" has incompatible type "Tags"; + # mypy: expected "Optional[Mapping[str, Any]]" + # I'm really not sure here? + return eval(exp, {}, self) # type: ignore - def __or__(self, other): + # mypy: Argument 1 of "__or__" is incompatible with supertype "AbstractSet"; + # mypy: supertype defines the argument type as "AbstractSet[_T]" + # not sure how to fix this one + def __or__(self, other: 'Tags') -> 'Tags': # type: ignore return Tags(*super(Tags, self).__or__(other)) @@ -331,7 +346,7 @@ class TestCase(unittest.TestCase): maxDiff = None tags = Tags() - def __init__(self, methodName='runTest'): + def __init__(self, methodName: str = 'runTest') -> None: super(TestCase, self).__init__(methodName) self.__exc_info = sys.exc_info self.__testMethodName = self._testMethodName @@ -340,7 +355,7 @@ class TestCase(unittest.TestCase): @classproperty @cached - def datadir(cls): # pylint: disable=E0213 + def datadir(cls) -> str: # pylint: disable=E0213 """helper attribute holding the standard test's data directory NOTE: this is a logilab's standard @@ -351,7 +366,7 @@ class TestCase(unittest.TestCase): # instantiated for each test run) @classmethod - def datapath(cls, *fname): + def datapath(cls, *fname: str) -> str: """joins the object's datadir and `fname`""" return osp.join(cls.datadir, *fname) @@ -363,7 +378,7 @@ class TestCase(unittest.TestCase): self._current_test_descr = descr # override default's unittest.py feature - def shortDescription(self): + def shortDescription(self) -> Optional[Any]: """override default unittest shortDescription to handle correctly generative tests """ @@ -371,7 +386,7 @@ class TestCase(unittest.TestCase): return self._current_test_descr return super(TestCase, self).shortDescription() - def quiet_run(self, result, func, *args, **kwargs): + def quiet_run(self, result: Any, func: Callable, *args: Any, **kwargs: Any) -> bool: try: func(*args, **kwargs) except (KeyboardInterrupt, SystemExit): @@ -389,7 +404,7 @@ class TestCase(unittest.TestCase): return False return True - def _get_test_method(self): + def _get_test_method(self) -> Callable: """return the test method""" return getattr(self, self._testMethodName) @@ -446,7 +461,7 @@ class TestCase(unittest.TestCase): # result.cvg.stop() result.stopTest(self) - def _proceed_generative(self, result, testfunc, runcondition=None): + def _proceed_generative(self, result: Any, testfunc: Callable, runcondition: Callable = None) -> bool: # cancel startTest()'s increment result.testsRun -= 1 success = True @@ -485,7 +500,7 @@ class TestCase(unittest.TestCase): success = False return success - def _proceed(self, result, testfunc, args=(), kwargs=None): + def _proceed(self, result: Any, testfunc: Callable, args: Union[List[bool], List[int], Tuple[()]] = (), kwargs: Optional[Dict] = None) -> int: """proceed the actual test returns 0 on success, 1 on failure, 2 on error @@ -512,7 +527,7 @@ class TestCase(unittest.TestCase): return 2 return 0 - def innerSkip(self, msg=None): + def innerSkip(self, msg: str = None) -> NoReturn: """mark a generative test as skipped for the <msg> reason""" msg = msg or 'test was skipped' raise InnerTestSkipped(msg) @@ -549,7 +564,9 @@ class DocTestFinder(doctest.DocTestFinder): globs, source_lines) -class DocTest(TestCase, metaclass=class_deprecated): +# mypy error: Invalid metaclass 'class_deprecated' +# but it works? +class DocTest(TestCase, metaclass=class_deprecated): # type: ignore """trigger module doctest I don't know how to make unittest.main consider the DocTestSuite instance without this hack @@ -610,7 +627,9 @@ class MockConnection: pass -def mock_object(**params): +# mypy error: Name 'Mock' is not defined +# dynamic class created by this class +def mock_object(**params: Any) -> 'Mock': # type: ignore """creates an object using params to set attributes >>> option = mock_object(verbose=False, index=range(5)) >>> option.verbose @@ -621,7 +640,7 @@ def mock_object(**params): return type('Mock', (), params)() -def create_files(paths, chroot): +def create_files(paths: List[str], chroot: str) -> None: """Creates directories and files found in <path>. :param paths: list of relative paths to files or directories @@ -662,19 +681,21 @@ class AttrObject: # XXX cf mock_object def __init__(self, **kwargs): self.__dict__.update(kwargs) -def tag(*args, **kwargs): +def tag(*args: str, **kwargs: Any) -> Callable: """descriptor adding tag to a function""" - def desc(func): + def desc(func: Callable) -> Callable: assert not hasattr(func, 'tags') - func.tags = Tags(*args, **kwargs) + # mypy: "Callable[..., Any]" has no attribute "tags" + # dynamic magic attribute + func.tags = Tags(*args, **kwargs) # type: ignore return func return desc -def require_version(version): +def require_version(version: str) -> Callable: """ Compare version of python interpreter to the given one. Skip the test if older. """ - def check_require_version(f): + def check_require_version(f: Callable) -> Callable: version_elements = version.split('.') try: compare = tuple([int(v) for v in version_elements]) @@ -690,10 +711,10 @@ def require_version(version): return f return check_require_version -def require_module(module): +def require_module(module: str) -> Callable: """ Check if the given module is loaded. Skip the test if not. """ - def check_require_module(f): + def check_require_module(f: Callable) -> Callable: try: __import__(module) return f diff --git a/logilab/common/textutils.py b/logilab/common/textutils.py index 1a5573d..4b6ea98 100644 --- a/logilab/common/textutils.py +++ b/logilab/common/textutils.py @@ -46,8 +46,10 @@ __docformat__ = "restructuredtext en" import sys import re import os.path as osp +from re import Pattern, Match from warnings import warn from unicodedata import normalize as _uninormalize +from typing import Any, Optional, Tuple, List, Callable, Dict, Union try: from os import linesep except ImportError: @@ -74,7 +76,7 @@ MANUAL_UNICODE_MAP = { u'\u2019': u"'", # SIMPLE QUOTE } -def unormalize(ustring, ignorenonascii=None, substitute=None): +def unormalize(ustring: str, ignorenonascii: Optional[Any] = None, substitute: Optional[str] = None) -> str: """replace diacritical characters with their corresponding ascii characters Convert the unicode string to its long normalized form (unicode character @@ -107,7 +109,7 @@ def unormalize(ustring, ignorenonascii=None, substitute=None): res.append(replacement) return u''.join(res) -def unquote(string): +def unquote(string: str) -> str: """remove optional quotes (simple or double) from the string :type string: str or unicode @@ -128,7 +130,7 @@ def unquote(string): _BLANKLINES_RGX = re.compile('\r?\n\r?\n') _NORM_SPACES_RGX = re.compile('\s+') -def normalize_text(text, line_len=80, indent='', rest=False): +def normalize_text(text: str, line_len: int = 80, indent: str = '', rest: bool = False) -> str: """normalize a text to display it with a maximum line size and optionally arbitrary indentation. Line jumps are normalized but blank lines are kept. The indentation string may be used to insert a @@ -159,7 +161,7 @@ def normalize_text(text, line_len=80, indent='', rest=False): return ('%s%s%s' % (linesep, indent, linesep)).join(result) -def normalize_paragraph(text, line_len=80, indent=''): +def normalize_paragraph(text: str, line_len: int = 80, indent: str = '') -> str: """normalize a text to display it with a maximum line size and optionally arbitrary indentation. Line jumps are normalized. The indentation string may be used top insert a comment mark for @@ -188,7 +190,7 @@ def normalize_paragraph(text, line_len=80, indent=''): lines.append(indent + aline) return linesep.join(lines) -def normalize_rest_paragraph(text, line_len=80, indent=''): +def normalize_rest_paragraph(text: str, line_len: int = 80, indent: str = '') -> str: """normalize a ReST text to display it with a maximum line size and optionally arbitrary indentation. Line jumps are normalized. The indentation string may be used top insert a comment mark for @@ -229,7 +231,7 @@ def normalize_rest_paragraph(text, line_len=80, indent=''): return linesep.join(lines) -def splittext(text, line_len): +def splittext(text: str, line_len: int) -> Tuple[str, str]: """split the given text on space according to the given max line size return a 2-uple: @@ -248,7 +250,7 @@ def splittext(text, line_len): return text[:pos], text[pos+1:].strip() -def splitstrip(string, sep=','): +def splitstrip(string: str, sep: str = ',') -> List[str]: """return a list of stripped string by splitting the string given as argument on `sep` (',' by default). Empty string are discarded. @@ -337,8 +339,8 @@ TIME_UNITS = { "d": 60 * 60 *24, } -def apply_units(string, units, inter=None, final=float, blank_reg=_BLANK_RE, - value_reg=_VALUE_RE): +def apply_units(string: str, units: Dict[str, int], inter: Union[Callable, None, type] = None, final: type = float, blank_reg: Pattern = _BLANK_RE, + value_reg: Pattern = _VALUE_RE) -> Union[float, int]: """Parse the string applying the units defined in units (e.g.: "1.5m",{'m',60} -> 80). @@ -379,7 +381,7 @@ def apply_units(string, units, inter=None, final=float, blank_reg=_BLANK_RE, _LINE_RGX = re.compile('\r\n|\r+|\n') -def pretty_match(match, string, underline_char='^'): +def pretty_match(match: Match, string: str, underline_char: str = '^') -> str: """return a string with the match location underlined: >>> import re @@ -424,11 +426,14 @@ def pretty_match(match, string, underline_char='^'): result.append(string) result.append(underline) else: - end = string[end_line_pos + len(linesep):] + # mypy: Incompatible types in assignment (expression has type "str", + # mypy: variable has type "int") + # but it's a str :| + end = string[end_line_pos + len(linesep):] # type: ignore string = string[start_line_pos:end_line_pos] result.append(string) result.append(underline) - result.append(end) + result.append(end) # type: ignore # see previous comment return linesep.join(result).rstrip() @@ -458,7 +463,7 @@ ANSI_COLORS = { 'white': "37", } -def _get_ansi_code(color=None, style=None): +def _get_ansi_code(color: Optional[str] = None, style: Optional[str] = None) -> str: """return ansi escape code corresponding to color and style :type color: str or None @@ -491,7 +496,7 @@ def _get_ansi_code(color=None, style=None): return ANSI_PREFIX + ';'.join(ansi_code) + ANSI_END return '' -def colorize_ansi(msg, color=None, style=None): +def colorize_ansi(msg: str, color: Optional[str] = None, style: Optional[str] = None) -> str: """colorize message by wrapping it with ansi escape codes :type msg: str or unicode diff --git a/logilab/common/tree.py b/logilab/common/tree.py index 885eb0f..1fc5a21 100644 --- a/logilab/common/tree.py +++ b/logilab/common/tree.py @@ -25,29 +25,37 @@ __docformat__ = "restructuredtext en" import sys -from logilab.common import flatten +# from logilab.common import flatten from logilab.common.visitor import VisitedMixIn, FilteredIterator, no_filter +from logilab.common.types import Paragraph, Title +from typing import Optional, Any, Union, List, Callable, TypeVar ## Exceptions ################################################################# class NodeNotFound(Exception): """raised when a node has not been found""" -EX_SIBLING_NOT_FOUND = "No such sibling as '%s'" -EX_CHILD_NOT_FOUND = "No such child as '%s'" -EX_NODE_NOT_FOUND = "No such node as '%s'" +EX_SIBLING_NOT_FOUND: str = "No such sibling as '%s'" +EX_CHILD_NOT_FOUND: str = "No such child as '%s'" +EX_NODE_NOT_FOUND: str = "No such node as '%s'" # Base node ################################################################### +# describe node of current class +NodeType = Any + + class Node(object): """a basic tree node, characterized by an id""" - def __init__(self, nid=None) : + def __init__(self, nid: Optional[str] = None) -> None : self.id = nid # navigation - self.parent = None - self.children = [] + # should be something like Optional[type(self)] for subclasses but that's not possible? + self.parent: Optional[NodeType] = None + # should be something like List[type(self)] for subclasses but that's not possible? + self.children: List[NodeType] = [] def __iter__(self): return iter(self.children) @@ -65,31 +73,35 @@ class Node(object): def is_leaf(self): return not self.children - def append(self, child): + def append(self, child: NodeType) -> None: + # should be child: type(self) but that's not possible """add a node to children""" self.children.append(child) child.parent = self - def remove(self, child): + def remove(self, child: NodeType) -> None: + # should be child: type(self) but that's not possible """remove a child node""" self.children.remove(child) child.parent = None - def insert(self, index, child): + def insert(self, index: int, child: NodeType) -> None: + # should be child: type(self) but that's not possible """insert a child node""" self.children.insert(index, child) child.parent = self - def replace(self, old_child, new_child): + def replace(self, old_child: NodeType, new_child: NodeType) -> None: """replace a child node with another""" i = self.children.index(old_child) self.children.pop(i) self.children.insert(i, new_child) new_child.parent = self - def get_sibling(self, nid): + def get_sibling(self, nid: str) -> NodeType: """return the sibling node that has given id""" try: + assert self.parent is not None return self.parent.get_child_by_id(nid) except NodeNotFound : raise NodeNotFound(EX_SIBLING_NOT_FOUND % nid) @@ -121,7 +133,7 @@ class Node(object): return parent.children[index-1] return None - def get_node_by_id(self, nid): + def get_node_by_id(self, nid: str) -> NodeType: """ return node in whole hierarchy that has given id """ @@ -131,7 +143,7 @@ class Node(object): except NodeNotFound : raise NodeNotFound(EX_NODE_NOT_FOUND % nid) - def get_child_by_id(self, nid, recurse=None): + def get_child_by_id(self, nid: str, recurse: Optional[bool] = None) -> NodeType: """ return child of given id """ @@ -147,7 +159,7 @@ class Node(object): return c raise NodeNotFound(EX_CHILD_NOT_FOUND % nid) - def get_child_by_path(self, path): + def get_child_by_path(self, path: List[str]) -> NodeType: """ return child of given path (path is a list of ids) """ @@ -162,7 +174,7 @@ class Node(object): pass raise NodeNotFound(EX_CHILD_NOT_FOUND % path) - def depth(self): + def depth(self) -> int: """ return depth of this node in the tree """ @@ -171,7 +183,7 @@ class Node(object): else : return 0 - def depth_down(self): + def depth_down(self) -> int: """ return depth of the tree from this node """ @@ -179,13 +191,13 @@ class Node(object): return 1 + max([c.depth_down() for c in self.children]) return 1 - def width(self): + def width(self) -> int: """ return the width of the tree from this node """ return len(self.leaves()) - def root(self): + def root(self) -> NodeType: """ return the root node of the tree """ @@ -193,7 +205,7 @@ class Node(object): return self.parent.root() return self - def leaves(self): + def leaves(self) -> List[NodeType]: """ return a list with all the leaves nodes descendant from this node """ @@ -205,7 +217,7 @@ class Node(object): else: return [self] - def flatten(self, _list=None): + def flatten(self, _list: Optional[List[NodeType]] = None) -> List[NodeType]: """ return a list with all the nodes descendant from this node """ @@ -216,7 +228,7 @@ class Node(object): c.flatten(_list) return _list - def lineage(self): + def lineage(self) -> List[NodeType]: """ return list of parents up to root node """ @@ -226,6 +238,7 @@ class Node(object): return lst class VNode(Node, VisitedMixIn): + # we should probably merge this VisitedMixIn here because it's only used here """a visitable node """ pass @@ -298,7 +311,7 @@ class ListNode(VNode, list_class): # construct list from tree #################################################### -def post_order_list(node, filter_func=no_filter): +def post_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> List[Node]: """ create a list with tree nodes for which the <filter> function returned true in a post order fashion @@ -326,7 +339,7 @@ def post_order_list(node, filter_func=no_filter): poped = 1 return l -def pre_order_list(node, filter_func=no_filter): +def pre_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> List[Node]: """ create a list with tree nodes for which the <filter> function returned true in a pre order fashion @@ -358,12 +371,12 @@ def pre_order_list(node, filter_func=no_filter): class PostfixedDepthFirstIterator(FilteredIterator): """a postfixed depth first iterator, designed to be used with visitors """ - def __init__(self, node, filter_func=None): + def __init__(self, node: Node, filter_func: Optional[Any] = None) -> None: FilteredIterator.__init__(self, node, post_order_list, filter_func) class PrefixedDepthFirstIterator(FilteredIterator): """a prefixed depth first iterator, designed to be used with visitors """ - def __init__(self, node, filter_func=None): + def __init__(self, node: Node, filter_func: Optional[Any] = None) -> None: FilteredIterator.__init__(self, node, pre_order_list, filter_func) diff --git a/logilab/common/types.py b/logilab/common/types.py new file mode 100644 index 0000000..b13f3c7 --- /dev/null +++ b/logilab/common/types.py @@ -0,0 +1,44 @@ +# copyright 2019 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of yams. +# +# yams is free software: you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) +# any later version. +# +# yams is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with yams. If not, see <https://www.gnu.org/licenses/>. + +"""Types declarations for types annotations""" + +from typing import TYPE_CHECKING, TypeVar + + +# to avoid circular imports +if TYPE_CHECKING: + from logilab.common.tree import Node + from logilab.common.ureports.html_writer import HTMLWriter + from logilab.common.ureports.text_writer import TextWriter + from logilab.common.ureports.nodes import Paragraph + from logilab.common.ureports.nodes import Title + from logilab.common.table import Table + from logilab.common.optik_ext import OptionParser + from logilab.common.optik_ext import Option + from logilab.common import attrdict +else: + Node = TypeVar("Node") + HTMLWriter = TypeVar("HTMLWriter") + TextWriter = TypeVar("TextWriter") + Table = TypeVar("Table") + OptionParser = TypeVar("OptionParser") + Option = TypeVar("Option") + attrdict = TypeVar("attrdict") + Paragraph = TypeVar("Paragraph") + Title = TypeVar("Title") diff --git a/logilab/common/umessage.py b/logilab/common/umessage.py index e5a7f4e..77a6272 100644 --- a/logilab/common/umessage.py +++ b/logilab/common/umessage.py @@ -24,8 +24,10 @@ from encodings import search_function import sys from email.utils import parseaddr, parsedate from email.header import decode_header +from email.message import Message from datetime import datetime +from typing import Any, Optional, List, Tuple, Union try: from mx.DateTime import DateTime @@ -35,17 +37,20 @@ except ImportError: import logilab.common as lgc -def decode_QP(string): - parts = [] - for decoded, charset in decode_header(string): +def decode_QP(string: str) -> str: + parts: List[str] = [] + for maybe_decoded, charset in decode_header(string): if not charset : charset = 'iso-8859-15' # python 3 sometimes returns str and sometimes bytes. # the 'official' fix is to use the new 'policy' APIs # https://bugs.python.org/issue24797 # let's just handle this bug ourselves for now - if isinstance(decoded, bytes): - decoded = decoded.decode(charset, 'replace') + if isinstance(maybe_decoded, bytes): + decoded = maybe_decoded.decode(charset, 'replace') + else: + decoded = maybe_decoded + assert isinstance(decoded, str) parts.append(decoded) @@ -53,6 +58,7 @@ def decode_QP(string): # decoding was non-RFC compliant wrt to whitespace handling # see http://bugs.python.org/issue1079 return u' '.join(parts) + return u''.join(parts) def message_from_file(fd): @@ -61,7 +67,7 @@ def message_from_file(fd): except email.errors.MessageParseError: return '' -def message_from_string(string): +def message_from_string(string: str) -> Union['UMessage', str]: try: return UMessage(email.message_from_string(string)) except email.errors.MessageParseError: @@ -71,12 +77,12 @@ class UMessage: """Encapsulates an email.Message instance and returns only unicode objects. """ - def __init__(self, message): + def __init__(self, message: Message) -> None: self.message = message # email.Message interface ################################################# - def get(self, header, default=None): + def get(self, header: str, default: Optional[Any] = None) -> Optional[str]: value = self.message.get(header, default) if value: return decode_QP(value) @@ -85,7 +91,7 @@ class UMessage: def __getitem__(self, header): return self.get(header) - def get_all(self, header, default=()): + def get_all(self, header: str, default: Tuple[()] = ()) -> List[str]: return [decode_QP(val) for val in self.message.get_all(header, default) if val is not None] @@ -99,23 +105,33 @@ class UMessage: for part in self.message.walk(): yield UMessage(part) - def get_payload(self, index=None, decode=False): + def get_payload(self, index: Optional[Any] = None, decode: bool = False) -> Union[str, 'UMessage', List['UMessage']]: message = self.message + if index is None: - payload = message.get_payload(index, decode) + # mypy: Argument 1 to "get_payload" of "Message" has incompatible type "None"; expected "int" + # email.message.Message.get_payload has type signature: + # Message.get_payload(self, i=None, decode=False) + # so None seems to be totally acceptable, I don't understand mypy here + payload = message.get_payload(index, decode) # type: ignore + if isinstance(payload, list): return [UMessage(msg) for msg in payload] + if message.get_content_maintype() != 'text': return payload + if isinstance(payload, str): return payload charset = message.get_content_charset() or 'iso-8859-1' if search_function(charset) is None: charset = 'iso-8859-1' + return str(payload or b'', charset, "replace") else: payload = UMessage(message.get_payload(index, decode)) + return payload def get_content_maintype(self): diff --git a/logilab/common/ureports/__init__.py b/logilab/common/ureports/__init__.py index 8ce68a0..9c0f1df 100644 --- a/logilab/common/ureports/__init__.py +++ b/logilab/common/ureports/__init__.py @@ -26,6 +26,9 @@ import sys from logilab.common.compat import StringIO from logilab.common.textutils import linesep +from logilab.common.tree import VNode +from logilab.common.ureports.nodes import Table, List as NodeList +from typing import Any, Optional, Union, List, Generator, Tuple, Callable, TextIO def get_nodes(node, klass): @@ -76,7 +79,7 @@ def build_summary(layout, level=1): class BaseWriter(object): """base class for ureport writers""" - def format(self, layout, stream=None, encoding=None): + def format(self, layout: Any, stream: Optional[Union[StringIO, TextIO]] = None, encoding: Optional[Any] = None) -> None: """format and write the given layout into the stream object unicode policy: unicode strings may be found in the layout; @@ -88,87 +91,121 @@ class BaseWriter(object): if not encoding: encoding = getattr(stream, 'encoding', 'UTF-8') self.encoding = encoding or 'UTF-8' - self.__compute_funcs = [] + self.__compute_funcs: List[Tuple[Callable[[str], Any], Callable[[str], Any]]] = [] self.out = stream self.begin_format(layout) layout.accept(self) self.end_format(layout) - def format_children(self, layout): + def format_children(self, layout: Union['Paragraph', 'Section', 'Title']) -> None: """recurse on the layout children and call their accept method (see the Visitor pattern) """ for child in getattr(layout, 'children', ()): child.accept(self) - def writeln(self, string=u''): + def writeln(self, string: str = u'') -> None: """write a line in the output buffer""" self.write(string + linesep) - def write(self, string): + def write(self, string: str) -> None: """write a string in the output buffer""" try: self.out.write(string) except UnicodeEncodeError: - self.out.write(string.encode(self.encoding)) + # mypy: Argument 1 to "write" of "IO" has incompatible type "bytes"; expected "str" + # probably a python3 port issue? + self.out.write(string.encode(self.encoding)) # type: ignore - def begin_format(self, layout): + def begin_format(self, layout: Any) -> None: """begin to format a layout""" self.section = 0 - def end_format(self, layout): + def end_format(self, layout: Any) -> None: """finished to format a layout""" - def get_table_content(self, table): + def get_table_content(self, table: Table) -> List[List[str]]: """trick to get table content without actually writing it return an aligned list of lists containing table cells values as string """ - result = [[]] - cols = table.cols + result: List[List[str]] = [[]] + # mypy: "Table" has no attribute "cols" + # dynamic attribute + cols = table.cols # type: ignore + for cell in self.compute_content(table): if cols == 0: result.append([]) - cols = table.cols + # mypy: "Table" has no attribute "cols" + # dynamic attribute + cols = table.cols # type: ignore + cols -= 1 result[-1].append(cell) + # fill missing cells while len(result[-1]) < cols: result[-1].append(u'') + return result - def compute_content(self, layout): + def compute_content(self, layout: VNode) -> Generator[str, Any, None]: """trick to compute the formatting of children layout before actually writing it return an iterator on strings (one for each child element) """ # use cells ! - def write(data): + def write(data: str) -> None: try: stream.write(data) except UnicodeEncodeError: - stream.write(data.encode(self.encoding)) - def writeln(data=u''): + # mypy: Argument 1 to "write" of "TextIOWrapper" has incompatible type "bytes"; + # mypy: expected "str" + # error from porting to python3? + stream.write(data.encode(self.encoding)) # type: ignore + + def writeln(data: str = u'') -> None: try: stream.write(data+linesep) except UnicodeEncodeError: - stream.write(data.encode(self.encoding)+linesep) - self.write = write - self.writeln = writeln + # mypy: Unsupported operand types for + ("bytes" and "str") + # error from porting to python3? + stream.write(data.encode(self.encoding)+linesep) # type: ignore + + # mypy: Cannot assign to a method + # this really looks like black dirty magic since self.write is reused elsewhere in the code + # especially since self.write and self.writeln are conditionally + # deleted at the end of this function + self.write = write # type: ignore + self.writeln = writeln # type: ignore + self.__compute_funcs.append((write, writeln)) - for child in layout.children: + + # mypy: Item "Table" of "Union[List[Any], Table, Title]" has no attribute "children" + # dynamic attribute? + for child in layout.children: # type: ignore stream = StringIO() + child.accept(self) + yield stream.getvalue() + self.__compute_funcs.pop() + try: - self.write, self.writeln = self.__compute_funcs[-1] + # mypy: Cannot assign to a method + # even more black dirty magic + self.write, self.writeln = self.__compute_funcs[-1] # type: ignore except IndexError: del self.write del self.writeln - -from logilab.common.ureports.nodes import * +# mypy error: Incompatible import of "Table" (imported name has type +# mypy error: "Type[logilab.common.ureports.nodes.Table]", local name has type +# mypy error: "Type[logilab.common.table.Table]") +# this will be cleaned when the "*" will be removed +from logilab.common.ureports.nodes import * # type: ignore from logilab.common.ureports.text_writer import TextWriter from logilab.common.ureports.html_writer import HTMLWriter diff --git a/logilab/common/ureports/html_writer.py b/logilab/common/ureports/html_writer.py index 989662f..0783075 100644 --- a/logilab/common/ureports/html_writer.py +++ b/logilab/common/ureports/html_writer.py @@ -20,16 +20,19 @@ __docformat__ = "restructuredtext en" from logilab.common.ureports import BaseWriter +from logilab.common.ureports.nodes import (Section, Title, Table, List, + Paragraph, Link, VerbatimText, Text) +from typing import Any class HTMLWriter(BaseWriter): """format layouts as HTML""" - def __init__(self, snippet=None): + def __init__(self, snippet: int = None) -> None: super(HTMLWriter, self).__init__() self.snippet = snippet - def handle_attrs(self, layout): + def handle_attrs(self, layout: Any) -> str: """get an attribute string from layout member attributes""" attrs = u'' klass = getattr(layout, 'klass', None) @@ -40,20 +43,20 @@ class HTMLWriter(BaseWriter): attrs += u' id="%s"' % nid return attrs - def begin_format(self, layout): + def begin_format(self, layout: Any) -> None: """begin to format a layout""" super(HTMLWriter, self).begin_format(layout) if self.snippet is None: self.writeln(u'<html>') self.writeln(u'<body>') - def end_format(self, layout): + def end_format(self, layout: Any) -> None: """finished to format a layout""" if self.snippet is None: self.writeln(u'</body>') self.writeln(u'</html>') - def visit_section(self, layout): + def visit_section(self, layout: Section) -> None: """display a section as html, using div + h[section level]""" self.section += 1 self.writeln(u'<div%s>' % self.handle_attrs(layout)) @@ -61,13 +64,13 @@ class HTMLWriter(BaseWriter): self.writeln(u'</div>') self.section -= 1 - def visit_title(self, layout): + def visit_title(self, layout: Title) -> None: """display a title using <hX>""" self.write(u'<h%s%s>' % (self.section, self.handle_attrs(layout))) self.format_children(layout) self.writeln(u'</h%s>' % self.section) - def visit_table(self, layout): + def visit_table(self, layout: Table) -> None: """display a table as html""" self.writeln(u'<table%s>' % self.handle_attrs(layout)) table_content = self.get_table_content(layout) @@ -91,14 +94,14 @@ class HTMLWriter(BaseWriter): self.writeln(u'</tr>') self.writeln(u'</table>') - def visit_list(self, layout): + def visit_list(self, layout: List) -> None: """display a list as html""" self.writeln(u'<ul%s>' % self.handle_attrs(layout)) for row in list(self.compute_content(layout)): self.writeln(u'<li>%s</li>' % row) self.writeln(u'</ul>') - def visit_paragraph(self, layout): + def visit_paragraph(self, layout: Paragraph) -> None: """display links (using <p>)""" self.write(u'<p>') self.format_children(layout) @@ -110,19 +113,19 @@ class HTMLWriter(BaseWriter): self.format_children(layout) self.write(u'</span>') - def visit_link(self, layout): + def visit_link(self, layout: Link) -> None: """display links (using <a>)""" self.write(u' <a href="%s"%s>%s</a>' % (layout.url, self.handle_attrs(layout), layout.label)) - def visit_verbatimtext(self, layout): + def visit_verbatimtext(self, layout: VerbatimText) -> None: """display verbatim text (using <pre>)""" self.write(u'<pre>') self.write(layout.data.replace(u'&', u'&').replace(u'<', u'<')) self.write(u'</pre>') - def visit_text(self, layout): + def visit_text(self, layout: Text) -> None: """add some text""" data = layout.data if layout.escaped: diff --git a/logilab/common/ureports/nodes.py b/logilab/common/ureports/nodes.py index d1267f5..d086faf 100644 --- a/logilab/common/ureports/nodes.py +++ b/logilab/common/ureports/nodes.py @@ -22,6 +22,14 @@ A micro report is a tree of layout and content objects. __docformat__ = "restructuredtext en" from logilab.common.tree import VNode +from typing import Optional +# from logilab.common.ureports.nodes import List +# from logilab.common.ureports.nodes import Paragraph +# from logilab.common.ureports.nodes import Text +from typing import Any +from typing import List as TypingList +from typing import Tuple +from typing import Union class BaseComponent(VNode): @@ -31,7 +39,7 @@ class BaseComponent(VNode): * id : the component's optional id * klass : the component's optional klass """ - def __init__(self, id=None, klass=None): + def __init__(self, id: Optional[str] = None, klass: Optional[str] = None) -> None: VNode.__init__(self, id) self.klass = klass @@ -43,27 +51,36 @@ class BaseLayout(BaseComponent): * BaseComponent attributes * children : components in this table (i.e. the table's cells) """ - def __init__(self, children=(), **kwargs): + def __init__(self, + children: Union[TypingList['Text'], + Tuple[Union['Paragraph', str], + Union[TypingList, str]], Tuple[str, ...]] = (), + **kwargs: Any) -> None: + super(BaseLayout, self).__init__(**kwargs) + for child in children: if isinstance(child, BaseComponent): self.append(child) else: - self.add_text(child) + # mypy: Argument 1 to "add_text" of "BaseLayout" has incompatible type + # mypy: "Union[str, List[Any]]"; expected "str" + # we check this situation in the if + self.add_text(child) # type: ignore - def append(self, child): + def append(self, child: Any) -> None: """overridden to detect problems easily""" assert child not in self.parents() VNode.append(self, child) - def parents(self): + def parents(self) -> TypingList: """return the ancestor nodes""" assert self.parent is not self if self.parent is None: return [] return [self.parent] + self.parent.parents() - def add_text(self, text): + def add_text(self, text: str) -> None: """shortcut to add text data""" self.children.append(Text(text)) @@ -77,7 +94,7 @@ class Text(BaseComponent): * BaseComponent attributes * data : the text value as an encoded or unicode string """ - def __init__(self, data, escaped=True, **kwargs): + def __init__(self, data: str, escaped: bool = True, **kwargs: Any) -> None: super(Text, self).__init__(**kwargs) # if isinstance(data, unicode): # data = data.encode('ascii') @@ -103,7 +120,7 @@ class Link(BaseComponent): * url : the link's target (REQUIRED) * label : the link's label as a string (use the url by default) """ - def __init__(self, url, label=None, **kwargs): + def __init__(self, url: str, label: str = None, **kwargs: Any) -> None: super(Link, self).__init__(**kwargs) assert url self.url = url @@ -141,7 +158,7 @@ class Section(BaseLayout): a description may also be given to the constructor, it'll be added as a first paragraph """ - def __init__(self, title=None, description=None, **kwargs): + def __init__(self, title: str = None, description: str = None, **kwargs: Any) -> None: super(Section, self).__init__(**kwargs) if description: self.insert(0, Paragraph([Text(description)])) @@ -189,9 +206,9 @@ class Table(BaseLayout): * cheaders : the first col's elements are table's header * title : the table's optional title """ - def __init__(self, cols, title=None, - rheaders=0, cheaders=0, rrheaders=0, rcheaders=0, - **kwargs): + def __init__(self, cols: int, title: Optional[Any] = None, + rheaders: int = 0, cheaders: int = 0, rrheaders: int = 0, rcheaders: int = 0, + **kwargs: Any) -> None: super(Table, self).__init__(**kwargs) assert isinstance(cols, int) self.cols = cols diff --git a/logilab/common/ureports/text_writer.py b/logilab/common/ureports/text_writer.py index ae92774..f75d7c9 100644 --- a/logilab/common/ureports/text_writer.py +++ b/logilab/common/ureports/text_writer.py @@ -18,11 +18,14 @@ """Text formatting drivers for ureports""" from __future__ import print_function +from typing import Any, List, Tuple __docformat__ = "restructuredtext en" from logilab.common.textutils import linesep from logilab.common.ureports import BaseWriter +from logilab.common.ureports.nodes import (Section, Title, Table, List as NodeList, + Paragraph, Link, VerbatimText, Text) TITLE_UNDERLINES = [u'', u'=', u'-', u'`', u'.', u'~', u'^'] @@ -33,12 +36,12 @@ class TextWriter(BaseWriter): """format layouts as text (ReStructured inspiration but not totally handled yet) """ - def begin_format(self, layout): + def begin_format(self, layout: Any) -> None: super(TextWriter, self).begin_format(layout) self.list_level = 0 - self.pending_urls = [] + self.pending_urls: List[Tuple[str, str]] = [] - def visit_section(self, layout): + def visit_section(self, layout: Section) -> None: """display a section as text """ self.section += 1 @@ -52,7 +55,7 @@ class TextWriter(BaseWriter): self.section -= 1 self.writeln() - def visit_title(self, layout): + def visit_title(self, layout: Title) -> None: title = u''.join(list(self.compute_content(layout))) self.writeln(title) try: @@ -60,7 +63,7 @@ class TextWriter(BaseWriter): except IndexError: print("FIXME TITLE TOO DEEP. TURNING TITLE INTO TEXT") - def visit_paragraph(self, layout): + def visit_paragraph(self, layout: 'Paragraph') -> None: """enter a paragraph""" self.format_children(layout) self.writeln() @@ -69,7 +72,7 @@ class TextWriter(BaseWriter): """enter a span""" self.format_children(layout) - def visit_table(self, layout): + def visit_table(self, layout: Table) -> None: """display a table as text""" table_content = self.get_table_content(layout) # get columns width @@ -84,36 +87,40 @@ class TextWriter(BaseWriter): self.default_table(layout, table_content, cols_width) self.writeln() - def default_table(self, layout, table_content, cols_width): + def default_table(self, layout: Table, table_content: List[List[str]], cols_width: List[int]) -> None: """format a table""" cols_width = [size+1 for size in cols_width] + format_strings = u' '.join([u'%%-%ss'] * len(cols_width)) format_strings = format_strings % tuple(cols_width) - format_strings = format_strings.split(' ') + + format_strings_list = format_strings.split(' ') + table_linesep = ( u'\n+' + u'+'.join([u'-'*w for w in cols_width]) + u'+\n') headsep = u'\n+' + u'+'.join([u'='*w for w in cols_width]) + u'+\n' + # FIXME: layout.cheaders self.write(table_linesep) for i in range(len(table_content)): self.write(u'|') line = table_content[i] for j in range(len(line)): - self.write(format_strings[j] % line[j]) + self.write(format_strings_list[j] % line[j]) self.write(u'|') if i == 0 and layout.rheaders: self.write(headsep) else: self.write(table_linesep) - def field_table(self, layout, table_content, cols_width): + def field_table(self, layout: Table, table_content: List[List[str]], cols_width: List[int]) -> None: """special case for field table""" assert layout.cols == 2 format_string = u'%s%%-%ss: %%s' % (linesep, cols_width[0]) for field, value in table_content: self.write(format_string % (field, value)) - def visit_list(self, layout): + def visit_list(self, layout: NodeList) -> None: """display a list layout as text""" bullet = BULLETS[self.list_level % len(BULLETS)] indent = ' ' * self.list_level @@ -123,7 +130,7 @@ class TextWriter(BaseWriter): child.accept(self) self.list_level -= 1 - def visit_link(self, layout): + def visit_link(self, layout: Link) -> None: """add a hyperlink""" if layout.label != layout.url: self.write(u'`%s`_' % layout.label) @@ -131,7 +138,7 @@ class TextWriter(BaseWriter): else: self.write(layout.url) - def visit_verbatimtext(self, layout): + def visit_verbatimtext(self, layout: VerbatimText) -> None: """display a verbatim layout as text (so difficult ;) """ self.writeln(u'::\n') @@ -139,6 +146,6 @@ class TextWriter(BaseWriter): self.writeln(u' ' + line) self.writeln() - def visit_text(self, layout): + def visit_text(self, layout: Text) -> None: """add some text""" self.write(u'%s' % layout.data) diff --git a/logilab/common/visitor.py b/logilab/common/visitor.py index ed2b70f..0698bae 100644 --- a/logilab/common/visitor.py +++ b/logilab/common/visitor.py @@ -21,21 +21,24 @@ """ +from typing import Any, Callable, Optional, Union +from logilab.common.types import Node, HTMLWriter, TextWriter __docformat__ = "restructuredtext en" -def no_filter(_): + +def no_filter(_: Node) -> int: return 1 # Iterators ################################################################### class FilteredIterator(object): - def __init__(self, node, list_func, filter_func=None): + def __init__(self, node: Node, list_func: Callable, filter_func: Optional[Any] = None) -> None: self._next = [(node, 0)] if filter_func is None: filter_func = no_filter self._list = list_func(node, filter_func) - def __next__(self): + def __next__(self) -> Optional[Node]: try: return self._list.pop(0) except : @@ -89,18 +92,20 @@ class VisitedMixIn(object): """ Visited interface allow node visitors to use the node """ - def get_visit_name(self): + def get_visit_name(self) -> str: """ return the visit name for the mixed class. When calling 'accept', the method <'visit_' + name returned by this method> will be called on the visitor """ try: - return self.TYPE.replace('-', '_') + # mypy: "VisitedMixIn" has no attribute "TYPE" + # dynamic attribute + return self.TYPE.replace('-', '_') # type: ignore except: return self.__class__.__name__.lower() - def accept(self, visitor, *args, **kwargs): + def accept(self, visitor: Union[HTMLWriter, TextWriter], *args: Any, **kwargs: Any) -> Optional[Any]: func = getattr(visitor, 'visit_%s' % self.get_visit_name()) return func(self, *args, **kwargs) diff --git a/logilab/common/xmlutils.py b/logilab/common/xmlutils.py index d383b9d..7b12c45 100644 --- a/logilab/common/xmlutils.py +++ b/logilab/common/xmlutils.py @@ -29,11 +29,12 @@ instruction and return a Python dictionary. __docformat__ = "restructuredtext en" import re +from typing import Dict, Optional, Union RE_DOUBLE_QUOTE = re.compile('([\w\-\.]+)="([^"]+)"') RE_SIMPLE_QUOTE = re.compile("([\w\-\.]+)='([^']+)'") -def parse_pi_data(pi_data): +def parse_pi_data(pi_data: str) -> Dict[str, Optional[str]]: """ Utility function that parses the data contained in an XML processing instruction and returns a dictionary of keywords and their @@ -51,10 +52,13 @@ def parse_pi_data(pi_data): """ results = {} for elt in pi_data.split(): - if RE_DOUBLE_QUOTE.match(elt): - kwd, val = RE_DOUBLE_QUOTE.match(elt).groups() - elif RE_SIMPLE_QUOTE.match(elt): - kwd, val = RE_SIMPLE_QUOTE.match(elt).groups() + val: Optional[str] + double_match = RE_DOUBLE_QUOTE.match(elt) + simple_match = RE_SIMPLE_QUOTE.match(elt) + if double_match: + kwd, val = double_match.groups() + elif simple_match: + kwd, val = simple_match.groups() else: kwd, val = elt, None results[kwd] = val @@ -1,5 +1,5 @@ [tox] -envlist=py3 +envlist=py3,mypy [testenv] deps = @@ -14,3 +14,8 @@ deps = -r docs/requirements-doc.txt commands= {envpython} -m sphinx -b html {toxinidir}/docs {toxinidir}/docs/_build/html {posargs} + +[testenv:mypy] +deps = + mypy >= 0.761 +commands = mypy --ignore-missing-imports logilab |