summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Peuch <cortex@worlddomination.be>2020-03-20 19:39:33 +0100
committerLaurent Peuch <cortex@worlddomination.be>2020-03-20 19:39:33 +0100
commit5c3c1a5dd8ccea45cce331e07c5ca39a63b51660 (patch)
treec910a2c64206e460a0f5d0514d8d3b9e8d41c9bf
parent2f92ba46d9801839063d940dfcf1f0d46c576b9d (diff)
downloadlogilab-common-5c3c1a5dd8ccea45cce331e07c5ca39a63b51660.tar.gz
[types] clean type annotations generation from pyannotation
-rw-r--r--__pkginfo__.py2
-rw-r--r--logilab/common/__init__.py9
-rw-r--r--logilab/common/cache.py23
-rw-r--r--logilab/common/changelog.py44
-rw-r--r--logilab/common/compat.py3
-rw-r--r--logilab/common/configuration.py215
-rw-r--r--logilab/common/date.py53
-rw-r--r--logilab/common/debugger.py5
-rw-r--r--logilab/common/decorators.py27
-rw-r--r--logilab/common/fileutils.py30
-rw-r--r--logilab/common/graph.py57
-rw-r--r--logilab/common/interface.py10
-rw-r--r--logilab/common/modutils.py78
-rw-r--r--logilab/common/optik_ext.py133
-rw-r--r--logilab/common/proc.py11
-rw-r--r--logilab/common/pytest.py72
-rw-r--r--logilab/common/registry.py142
-rw-r--r--logilab/common/shellutils.py22
-rw-r--r--logilab/common/table.py246
-rw-r--r--logilab/common/tasksqueue.py33
-rw-r--r--logilab/common/testlib.py113
-rw-r--r--logilab/common/textutils.py33
-rw-r--r--logilab/common/tree.py65
-rw-r--r--logilab/common/types.py44
-rw-r--r--logilab/common/umessage.py38
-rw-r--r--logilab/common/ureports/__init__.py83
-rw-r--r--logilab/common/ureports/html_writer.py27
-rw-r--r--logilab/common/ureports/nodes.py41
-rw-r--r--logilab/common/ureports/text_writer.py35
-rw-r--r--logilab/common/visitor.py17
-rw-r--r--logilab/common/xmlutils.py14
-rw-r--r--tox.ini7
32 files changed, 1089 insertions, 643 deletions
diff --git a/__pkginfo__.py b/__pkginfo__.py
index f1bffe5..6ad6cb6 100644
--- a/__pkginfo__.py
+++ b/__pkginfo__.py
@@ -45,6 +45,8 @@ include_dirs = [join('test', 'data')]
install_requires = [
'setuptools',
+ 'mypy-extensions',
+ 'typing_extensions',
]
tests_require = [
'pytz',
diff --git a/logilab/common/__init__.py b/logilab/common/__init__.py
index bf35711..0d7f183 100644
--- a/logilab/common/__init__.py
+++ b/logilab/common/__init__.py
@@ -29,13 +29,16 @@ __docformat__ = "restructuredtext en"
import sys
import types
import pkg_resources
+from typing import List, Sequence
__version__ = pkg_resources.get_distribution('logilab-common').version
# deprecated, but keep compatibility with pylint < 1.4.4
__pkginfo__ = types.ModuleType('__pkginfo__')
__pkginfo__.__package__ = __name__
-__pkginfo__.version = __version__
+# mypy output: Module has no attribute "version"
+# logilab's magic
+__pkginfo__.version = __version__ # type: ignore
sys.modules['logilab.common.__pkginfo__'] = __pkginfo__
STD_BLACKLIST = ('CVS', '.svn', '.hg', '.git', '.tox', 'debian', 'dist', 'build')
@@ -49,7 +52,7 @@ USE_MX_DATETIME = True
class attrdict(dict):
"""A dictionary for which keys are also accessible as attributes."""
- def __getattr__(self, attr):
+ def __getattr__(self, attr: str) -> str:
try:
return self[attr]
except KeyError:
@@ -170,7 +173,7 @@ def make_domains(lists):
# private stuff ################################################################
-def _handle_blacklist(blacklist, dirnames, filenames):
+def _handle_blacklist(blacklist: Sequence[str], dirnames: List[str], filenames: List[str]) -> None:
"""remove files/directories in the black list
dirnames/filenames are usually from os.walk
diff --git a/logilab/common/cache.py b/logilab/common/cache.py
index 11ed137..c47f481 100644
--- a/logilab/common/cache.py
+++ b/logilab/common/cache.py
@@ -27,9 +27,14 @@ __docformat__ = "restructuredtext en"
from threading import Lock
from logilab.common.decorators import locked
+from typing import Union, TypeVar, List
_marker = object()
+
+_KeyType = TypeVar("_KeyType")
+
+
class Cache(dict):
"""A dictionary like cache.
@@ -38,23 +43,23 @@ class Cache(dict):
len(self.data) <= self.size
"""
- def __init__(self, size=100):
+ def __init__(self, size: int = 100) -> None:
""" Warning : Cache.__init__() != dict.__init__().
Constructor does not take any arguments beside size.
"""
assert size >= 0, 'cache size must be >= 0 (0 meaning no caching)'
self.size = size
- self._usage = []
+ self._usage: List = []
self._lock = Lock()
super(Cache, self).__init__()
- def _acquire(self):
+ def _acquire(self) -> None:
self._lock.acquire()
- def _release(self):
+ def _release(self) -> None:
self._lock.release()
- def _update_usage(self, key):
+ def _update_usage(self, key: _KeyType) -> None:
if not self._usage:
self._usage.append(key)
elif self._usage[-1] != key:
@@ -71,20 +76,20 @@ class Cache(dict):
else:
pass # key is already the most recently used key
- def __getitem__(self, key):
+ def __getitem__(self, key: _KeyType):
value = super(Cache, self).__getitem__(key)
self._update_usage(key)
return value
__getitem__ = locked(_acquire, _release)(__getitem__)
- def __setitem__(self, key, item):
+ def __setitem__(self, key: _KeyType, item):
# Just make sure that size > 0 before inserting a new item in the cache
if self.size > 0:
super(Cache, self).__setitem__(key, item)
self._update_usage(key)
__setitem__ = locked(_acquire, _release)(__setitem__)
- def __delitem__(self, key):
+ def __delitem__(self, key: _KeyType):
super(Cache, self).__delitem__(key)
self._usage.remove(key)
__delitem__ = locked(_acquire, _release)(__delitem__)
@@ -94,7 +99,7 @@ class Cache(dict):
self._usage = []
clear = locked(_acquire, _release)(clear)
- def pop(self, key, default=_marker):
+ def pop(self, key: _KeyType, default=_marker):
if key in self:
self._usage.remove(key)
#if default is _marker:
diff --git a/logilab/common/changelog.py b/logilab/common/changelog.py
index 6eb8432..c128eb7 100644
--- a/logilab/common/changelog.py
+++ b/logilab/common/changelog.py
@@ -49,6 +49,8 @@ __docformat__ = "restructuredtext en"
import sys
from stat import S_IWRITE
import codecs
+from typing import List, Any, Optional, Tuple
+from _io import StringIO
BULLET = '*'
SUBBULLET = '-'
@@ -76,7 +78,7 @@ class Version(tuple):
return tuple.__new__(cls, parsed)
@classmethod
- def parse(cls, versionstr):
+ def parse(cls, versionstr: str) -> List[int]:
versionstr = versionstr.strip(' :')
try:
return [int(i) for i in versionstr.split('.')]
@@ -84,7 +86,7 @@ class Version(tuple):
raise ValueError("invalid literal for version '%s' (%s)" %
(versionstr, ex))
- def __str__(self):
+ def __str__(self) -> str:
return '.'.join([str(i) for i in self])
@@ -96,20 +98,21 @@ class ChangeLogEntry(object):
"""
version_class = Version
- def __init__(self, date=None, version=None, **kwargs):
+ def __init__(self, date: Optional[str] = None, version: Optional[str] = None, **kwargs: Any) -> None:
self.__dict__.update(kwargs)
+ self.version: Optional[Version]
if version:
self.version = self.version_class(version)
else:
self.version = None
self.date = date
- self.messages = []
+ self.messages: List[Tuple[List[str], List[List[str]]]] = []
- def add_message(self, msg):
+ def add_message(self, msg: str) -> None:
"""add a new message"""
self.messages.append(([msg], []))
- def complete_latest_message(self, msg_suite):
+ def complete_latest_message(self, msg_suite: str) -> None:
"""complete the latest added message
"""
if not self.messages:
@@ -120,7 +123,7 @@ class ChangeLogEntry(object):
else: # message
self.messages[-1][0].append(msg_suite)
- def add_sub_message(self, sub_msg, key=None):
+ def add_sub_message(self, sub_msg: str, key: Optional[Any] = None) -> None:
if not self.messages:
raise ValueError('unable to complete last message as '
'there is no previous message)')
@@ -130,7 +133,7 @@ class ChangeLogEntry(object):
raise NotImplementedError('sub message to specific key '
'are not implemented yet')
- def write(self, stream=sys.stdout):
+ def write(self, stream: StringIO = sys.stdout) -> None:
"""write the entry to file """
stream.write(u'%s -- %s\n' % (self.date or '', self.version or ''))
for msg, sub_msgs in self.messages:
@@ -152,19 +155,19 @@ class ChangeLog(object):
entry_class = ChangeLogEntry
- def __init__(self, changelog_file, title=u''):
+ def __init__(self, changelog_file: str, title: str = u'') -> None:
self.file = changelog_file
assert isinstance(title, type(u'')), 'title must be a unicode object'
self.title = title
self.additional_content = u''
- self.entries = []
+ self.entries: List[ChangeLogEntry] = []
self.load()
def __repr__(self):
return '<ChangeLog %s at %s (%s entries)>' % (self.file, id(self),
len(self.entries))
- def add_entry(self, entry):
+ def add_entry(self, entry: ChangeLogEntry) -> None:
"""add a new entry to the change log"""
self.entries.append(entry)
@@ -191,26 +194,31 @@ class ChangeLog(object):
entry = self.get_entry(create=create)
entry.add_message(msg)
- def load(self):
+ def load(self) -> None:
""" read a logilab's ChangeLog from file """
try:
stream = codecs.open(self.file, encoding='utf-8')
except IOError:
return
- last = None
+
+ last: Optional[ChangeLogEntry] = None
expect_sub = False
+
for line in stream:
sline = line.strip()
words = sline.split()
+
# if new entry
if len(words) == 1 and words[0] == '--':
expect_sub = False
last = self.entry_class()
+
self.add_entry(last)
# if old entry
elif len(words) == 3 and words[1] == '--':
expect_sub = False
last = self.entry_class(words[0], words[2])
+
self.add_entry(last)
# if title
elif sline and last is None:
@@ -218,19 +226,23 @@ class ChangeLog(object):
# if new entry
elif sline and sline[0] == BULLET:
expect_sub = False
+
+ assert last is not None
last.add_message(sline[1:].strip())
# if new sub_entry
elif expect_sub and sline and sline[0] == SUBBULLET:
+ assert last is not None
last.add_sub_message(sline[1:].strip())
# if new line for current entry
- elif sline and last.messages:
+ elif sline and (last and last.messages):
last.complete_latest_message(line)
else:
expect_sub = True
self.additional_content += line
+
stream.close()
- def format_title(self):
+ def format_title(self) -> str:
return u'%s\n\n' % self.title.strip()
def save(self):
@@ -240,7 +252,7 @@ class ChangeLog(object):
ensure_fs_mode(self.file, S_IWRITE)
self.write(codecs.open(self.file, 'w', encoding='utf-8'))
- def write(self, stream=sys.stdout):
+ def write(self, stream: StringIO = sys.stdout) -> None:
"""write changelog to stream"""
stream.write(self.format_title())
for entry in self.entries:
diff --git a/logilab/common/compat.py b/logilab/common/compat.py
index eee0a61..4ca540b 100644
--- a/logilab/common/compat.py
+++ b/logilab/common/compat.py
@@ -33,6 +33,7 @@ import os
import sys
import types
from warnings import warn
+from typing import Union
# not used here, but imported to preserve API
import builtins
@@ -41,7 +42,7 @@ def str_to_bytes(string):
return str.encode(string)
# we have to ignore the encoding in py3k to be able to write a string into a
# TextIOWrapper or like object (which expect an unicode string)
-def str_encode(string, encoding):
+def str_encode(string: Union[int, str], encoding: str) -> str:
return str(string)
# See also http://bugs.python.org/issue11776
diff --git a/logilab/common/configuration.py b/logilab/common/configuration.py
index 8489de5..61c2e97 100644
--- a/logilab/common/configuration.py
+++ b/logilab/common/configuration.py
@@ -119,25 +119,31 @@ import os
import sys
import re
from os.path import exists, expanduser
+from optparse import OptionGroup
from copy import copy
+from _io import StringIO, TextIOWrapper
+from mypy_extensions import NoReturn
+from typing import Any, Optional, Union, Dict, List, Tuple, Iterator, Callable, Sequence
from warnings import warn
import configparser as cp
+from logilab.common.types import OptionParser, Option, attrdict
from logilab.common.compat import str_encode as _encode
from logilab.common.deprecation import deprecated
from logilab.common.textutils import normalize_text, unquote
from logilab.common import optik_ext
+
OptionError = optik_ext.OptionError
-REQUIRED = []
+REQUIRED: List = []
class UnsupportedAction(Exception):
"""raised by set_option when it doesn't know what to do for an action"""
-def _get_encoding(encoding, stream):
+def _get_encoding(encoding: Optional[str], stream: Union[StringIO, TextIOWrapper]) -> str:
encoding = encoding or getattr(stream, 'encoding', None)
if not encoding:
import locale
@@ -145,12 +151,14 @@ def _get_encoding(encoding, stream):
return encoding
+_ValueType = Union[List[str], Tuple[str, ...], str]
+
# validation functions ########################################################
# validators will return the validated value or raise optparse.OptionValueError
# XXX add to documentation
-def choice_validator(optdict, name, value):
+def choice_validator(optdict: Dict[str, Any], name: str, value: str) -> str:
"""validate and return a converted value for option of type 'choice'
"""
if not value in optdict['choices']:
@@ -158,7 +166,8 @@ def choice_validator(optdict, name, value):
raise optik_ext.OptionValueError(msg % (name, value, optdict['choices']))
return value
-def multiple_choice_validator(optdict, name, value):
+
+def multiple_choice_validator(optdict: Dict[str, Any], name: str, value: _ValueType) -> _ValueType:
"""validate and return a converted value for option of type 'choice'
"""
choices = optdict['choices']
@@ -169,17 +178,17 @@ def multiple_choice_validator(optdict, name, value):
raise optik_ext.OptionValueError(msg % (name, value, choices))
return values
-def csv_validator(optdict, name, value):
+def csv_validator(optdict: Dict[str, Any], name: str, value: _ValueType) -> _ValueType:
"""validate and return a converted value for option of type 'csv'
"""
return optik_ext.check_csv(None, name, value)
-def yn_validator(optdict, name, value):
+def yn_validator(optdict: Dict[str, Any], name: str, value: Union[bool, str]) -> bool:
"""validate and return a converted value for option of type 'yn'
"""
return optik_ext.check_yn(None, name, value)
-def named_validator(optdict, name, value):
+def named_validator(optdict: Dict[str, Any], name: str, value: Union[Dict[str, str], str]) -> Dict[str, str]:
"""validate and return a converted value for option of type 'named'
"""
return optik_ext.check_named(None, name, value)
@@ -204,31 +213,32 @@ def time_validator(optdict, name, value):
"""validate and return a time object for option of type 'time'"""
return optik_ext.check_time(None, name, value)
-def bytes_validator(optdict, name, value):
+def bytes_validator(optdict: Dict[str, str], name: str, value: Union[int, str]) -> int:
"""validate and return an integer for option of type 'bytes'"""
return optik_ext.check_bytes(None, name, value)
-VALIDATORS = {'string': unquote,
- 'int': int,
- 'float': float,
- 'file': file_validator,
- 'font': unquote,
- 'color': color_validator,
- 'regexp': re.compile,
- 'csv': csv_validator,
- 'yn': yn_validator,
- 'bool': yn_validator,
- 'named': named_validator,
- 'password': password_validator,
- 'date': date_validator,
- 'time': time_validator,
- 'bytes': bytes_validator,
- 'choice': choice_validator,
- 'multiple_choice': multiple_choice_validator,
- }
-
-def _call_validator(opttype, optdict, option, value):
+VALIDATORS: Dict[str, Callable] = {
+ 'string': unquote,
+ 'int': int,
+ 'float': float,
+ 'file': file_validator,
+ 'font': unquote,
+ 'color': color_validator,
+ 'regexp': re.compile,
+ 'csv': csv_validator,
+ 'yn': yn_validator,
+ 'bool': yn_validator,
+ 'named': named_validator,
+ 'password': password_validator,
+ 'date': date_validator,
+ 'time': time_validator,
+ 'bytes': bytes_validator,
+ 'choice': choice_validator,
+ 'multiple_choice': multiple_choice_validator,
+}
+
+def _call_validator(opttype: str, optdict: Dict[str, Any], option: str, value: Union[List[str], int, str]) -> Union[List[str], int, str]:
if opttype not in VALIDATORS:
raise Exception('Unsupported type "%s"' % opttype)
try:
@@ -274,10 +284,11 @@ def _make_input_function(opttype):
print('bad value: %s' % msg)
return input_validator
-INPUT_FUNCTIONS = {
+
+INPUT_FUNCTIONS: Dict[str, Callable] = {
'string': input_string,
'password': input_password,
- }
+}
for opttype in VALIDATORS.keys():
INPUT_FUNCTIONS.setdefault(opttype, _make_input_function(opttype))
@@ -306,7 +317,7 @@ def expand_default(self, option):
return option.help.replace(self.default_tag, str(value))
-def _validate(value, optdict, name=''):
+def _validate(value: Union[List[str], int, str], optdict: Dict[str, Any], name: str = '') -> Union[List[str], int, str]:
"""return a validated value for an option according to its type
optional argument name is only used for error message formatting
@@ -343,7 +354,7 @@ def format_time(value):
return '%sh' % nbhour
return '%sd' % nbday
-def format_bytes(value):
+def format_bytes(value: int) -> str:
if not value:
return '0'
if value != int(value):
@@ -358,7 +369,7 @@ def format_bytes(value):
value = next
return '%s%s' % (value, unit)
-def format_option_value(optdict, value):
+def format_option_value(optdict: Dict[str, Any], value: Any) -> Union[None, int, str]:
"""return the user input's value from a 'compiled' value"""
if isinstance(value, (list, tuple)):
value = ','.join(value)
@@ -377,7 +388,7 @@ def format_option_value(optdict, value):
value = format_bytes(value)
return value
-def ini_format_section(stream, section, options, encoding=None, doc=None):
+def ini_format_section(stream: Union[StringIO, TextIOWrapper], section: str, options: Any, encoding: str = None, doc: Optional[Any] = None) -> None:
"""format an options section using the INI format"""
encoding = _get_encoding(encoding, stream)
if doc:
@@ -385,7 +396,7 @@ def ini_format_section(stream, section, options, encoding=None, doc=None):
print('[%s]' % section, file=stream)
ini_format(stream, options, encoding)
-def ini_format(stream, options, encoding):
+def ini_format(stream: Union[StringIO, TextIOWrapper], options: Any, encoding: str) -> None:
"""format options using the INI format"""
for optname, optdict, value in options:
value = format_option_value(optdict, value)
@@ -428,34 +439,37 @@ def rest_format_section(stream, section, options, encoding=None, doc=None):
# Options Manager ##############################################################
+
class OptionsManagerMixIn(object):
"""MixIn to handle a configuration from both a configuration file and
command line options
"""
- def __init__(self, usage, config_file=None, version=None, quiet=0):
+ def __init__(self, usage: Optional[str], config_file: Optional[Any] = None, version: Optional[Any] = None, quiet: int = 0) -> None:
self.config_file = config_file
self.reset_parsers(usage, version=version)
# list of registered options providers
- self.options_providers = []
+ self.options_providers: List[ConfigurationMixIn] = []
# dictionary associating option name to checker
- self._all_options = {}
- self._short_options = {}
- self._nocallback_options = {}
- self._mygroups = dict()
+ self._all_options: Dict[str, ConfigurationMixIn] = {}
+ self._short_options: Dict[str, str] = {}
+ self._nocallback_options: Dict[ConfigurationMixIn, str] = {}
+ self._mygroups: Dict[str, optik_ext.OptionGroup] = {}
# verbosity
self.quiet = quiet
self._maxlevel = 0
- def reset_parsers(self, usage='', version=None):
+ def reset_parsers(self, usage: Optional[str] = '', version: Optional[Any] = None) -> None:
# configuration file parser
self.cfgfile_parser = cp.ConfigParser()
# command line parser
self.cmdline_parser = optik_ext.OptionParser(usage=usage, version=version)
- self.cmdline_parser.options_manager = self
+ # mypy: "OptionParser" has no attribute "options_manager"
+ # dynamic attribute?
+ self.cmdline_parser.options_manager = self # type: ignore
self._optik_option_attrs = set(self.cmdline_parser.option_class.ATTRS)
- def register_options_provider(self, provider, own_group=True):
+ def register_options_provider(self, provider: 'ConfigurationMixIn', own_group: bool = True) -> None:
"""register an options provider"""
assert provider.priority <= 0, "provider's priority can't be >= 0"
for i in range(len(self.options_providers)):
@@ -464,8 +478,12 @@ class OptionsManagerMixIn(object):
break
else:
self.options_providers.append(provider)
- non_group_spec_options = [option for option in provider.options
- if 'group' not in option[1]]
+
+ # mypy: Need type annotation for 'option'
+ # you can't type variable of a list comprehension, right?
+ non_group_spec_options: List = [option for option in provider.options # type: ignore
+ if 'group' not in option[1]] # type: ignore
+
groups = getattr(provider, 'option_groups', ())
if own_group and non_group_spec_options:
self.add_option_group(provider.name.upper(), provider.__doc__,
@@ -475,11 +493,14 @@ class OptionsManagerMixIn(object):
self.add_optik_option(provider, self.cmdline_parser, opt, optdict)
for gname, gdoc in groups:
gname = gname.upper()
- goptions = [option for option in provider.options
- if option[1].get('group', '').upper() == gname]
+
+ # mypy: Need type annotation for 'option'
+ # you can't type variable of a list comprehension, right?
+ goptions: List = [option for option in provider.options # type: ignore
+ if option[1].get('group', '').upper() == gname] # type: ignore
self.add_option_group(gname, gdoc, goptions, provider)
- def add_option_group(self, group_name, doc, options, provider):
+ def add_option_group(self, group_name: str, doc: Optional[str], options: Union[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, Dict[str, str]]]], provider: 'ConfigurationMixIn') -> None:
"""add an option group including the listed options
"""
assert options
@@ -490,7 +511,9 @@ class OptionsManagerMixIn(object):
group = optik_ext.OptionGroup(self.cmdline_parser,
title=group_name.capitalize())
self.cmdline_parser.add_option_group(group)
- group.level = provider.level
+ # mypy: "OptionGroup" has no attribute "level"
+ # dynamic attribute
+ group.level = provider.level # type: ignore
self._mygroups[group_name] = group
# add section to the config file
if group_name != "DEFAULT":
@@ -499,7 +522,7 @@ class OptionsManagerMixIn(object):
for opt, optdict in options:
self.add_optik_option(provider, group, opt, optdict)
- def add_optik_option(self, provider, optikcontainer, opt, optdict):
+ def add_optik_option(self, provider: 'ConfigurationMixIn', optikcontainer: Union[OptionParser, OptionGroup], opt: str, optdict: Dict[str, Any]) -> None:
if 'inputlevel' in optdict:
warn('[0.50] "inputlevel" in option dictionary for %s is deprecated,'
' use "level"' % opt, DeprecationWarning)
@@ -509,12 +532,11 @@ class OptionsManagerMixIn(object):
self._all_options[opt] = provider
self._maxlevel = max(self._maxlevel, option.level or 0)
- def optik_option(self, provider, opt, optdict):
+ def optik_option(self, provider: 'ConfigurationMixIn', opt: str, optdict: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any]]:
"""get our personal option definition and return a suitable form for
use with optik/optparse
"""
optdict = copy(optdict)
- others = {}
if 'action' in optdict:
self._nocallback_options[provider] = opt
else:
@@ -539,7 +561,7 @@ class OptionsManagerMixIn(object):
optdict.pop(key)
return args, optdict
- def cb_set_provider_option(self, option, opt, value, parser):
+ def cb_set_provider_option(self, option: 'Option', opt: str, value: Union[List[str], int, str], parser: 'OptionParser') -> None:
"""optik callback for option setting"""
if opt.startswith('--'):
# remove -- on long option
@@ -552,16 +574,17 @@ class OptionsManagerMixIn(object):
value = 1
self.global_set_option(opt, value)
- def global_set_option(self, opt, value):
+ def global_set_option(self, opt: str, value: Union[List[str], int, str]) -> None:
"""set option on the correct option provider"""
self._all_options[opt].set_option(opt, value)
- def generate_config(self, stream=None, skipsections=(), encoding=None):
+ def generate_config(self, stream: Union[StringIO, TextIOWrapper] = None, skipsections: Tuple[()] = (), encoding: Optional[Any] = None) -> None:
"""write a configuration file according to the current configuration
into the given stream or stdout
"""
- options_by_section = {}
+ options_by_section: Dict[Any, List] = {}
sections = []
+
for provider in self.options_providers:
for section, options in provider.options_by_section():
if section is None:
@@ -586,7 +609,7 @@ class OptionsManagerMixIn(object):
encoding)
printed = True
- def generate_manpage(self, pkginfo, section=1, stream=None):
+ def generate_manpage(self, pkginfo: attrdict, section: int = 1, stream: StringIO = None) -> None:
"""write a man page for the current configuration into the given
stream or stdout
"""
@@ -600,17 +623,17 @@ class OptionsManagerMixIn(object):
# initialization methods ##################################################
- def load_provider_defaults(self):
+ def load_provider_defaults(self) -> None:
"""initialize configuration using default values"""
for provider in self.options_providers:
provider.load_defaults()
- def load_file_configuration(self, config_file=None):
+ def load_file_configuration(self, config_file: str = None) -> None:
"""load the configuration from file"""
self.read_config_file(config_file)
self.load_config_file()
- def read_config_file(self, config_file=None):
+ def read_config_file(self, config_file: str = None) -> None:
"""read the configuration file but do not load it (i.e. dispatching
values to each options provider)
"""
@@ -637,9 +660,11 @@ class OptionsManagerMixIn(object):
parser = self.cfgfile_parser
parser.read([config_file])
# normalize sections'title
- for sect, values in list(parser._sections.items()):
+ # mypy: "ConfigParser" has no attribute "_sections"
+ # dynamic attribute?
+ for sect, values in list(parser._sections.items()): # type: ignore
if not sect.isupper() and values:
- parser._sections[sect.upper()] = values
+ parser._sections[sect.upper()] = values # type: ignore
elif not self.quiet:
msg = 'No config file found, using default configuration'
print(msg, file=sys.stderr)
@@ -663,7 +688,7 @@ class OptionsManagerMixIn(object):
if stream is not None:
self.generate_config(stream)
- def load_config_file(self):
+ def load_config_file(self) -> None:
"""dispatch values previously read from a configuration file to each
options provider)
"""
@@ -676,7 +701,7 @@ class OptionsManagerMixIn(object):
# TODO handle here undeclared options appearing in the config file
continue
- def load_configuration(self, **kwargs):
+ def load_configuration(self, **kwargs: Any) -> None:
"""override configuration according to given parameters
"""
for opt, opt_value in kwargs.items():
@@ -684,12 +709,13 @@ class OptionsManagerMixIn(object):
provider = self._all_options[opt]
provider.set_option(opt, opt_value)
- def load_command_line_configuration(self, args=None):
+ def load_command_line_configuration(self, args: List[str] = None) -> List[str]:
"""override configuration according to command line parameters
return additional arguments
"""
self._monkeypatch_expand_default()
+
try:
if args is None:
args = sys.argv[1:]
@@ -710,32 +736,41 @@ class OptionsManagerMixIn(object):
# help methods ############################################################
- def add_help_section(self, title, description, level=0):
+ def add_help_section(self, title: str, description: str, level: int = 0) -> None:
"""add a dummy option section for help purpose """
group = optik_ext.OptionGroup(self.cmdline_parser,
title=title.capitalize(),
description=description)
- group.level = level
+ # mypy: "OptionGroup" has no attribute "level"
+ # it does, it is set in the optik_ext module
+ group.level = level # type: ignore
self._maxlevel = max(self._maxlevel, level)
self.cmdline_parser.add_option_group(group)
- def _monkeypatch_expand_default(self):
+ def _monkeypatch_expand_default(self) -> None:
# monkey patch optik_ext to deal with our default values
try:
self.__expand_default_backup = optik_ext.HelpFormatter.expand_default
- optik_ext.HelpFormatter.expand_default = expand_default
+ # mypy: Cannot assign to a method
+ # it's dirty but you can
+ optik_ext.HelpFormatter.expand_default = expand_default # type: ignore
except AttributeError:
# python < 2.4: nothing to be done
pass
- def _unmonkeypatch_expand_default(self):
+ def _unmonkeypatch_expand_default(self) -> None:
# remove monkey patch
if hasattr(optik_ext.HelpFormatter, 'expand_default'):
+ # mypy: Cannot assign to a method
+ # it's dirty but you can
+
# unpatch optik_ext to avoid side effects
- optik_ext.HelpFormatter.expand_default = self.__expand_default_backup
+ optik_ext.HelpFormatter.expand_default = self.__expand_default_backup # type: ignore
- def help(self, level=0):
+ def help(self, level: int = 0) -> str:
"""return the usage string for available options """
- self.cmdline_parser.formatter.output_level = level
+ # mypy: "HelpFormatter" has no attribute "output_level"
+ # set in optik_ext
+ self.cmdline_parser.formatter.output_level = level # type: ignore
self._monkeypatch_expand_default()
try:
return self.cmdline_parser.format_help()
@@ -751,12 +786,12 @@ class Method(object):
self.method = methname
self._inst = None
- def bind(self, instance):
+ def bind(self, instance: 'Configuration') -> None:
"""bind the method to its instance"""
if self._inst is None:
self._inst = instance
- def __call__(self, *args, **kwargs):
+ def __call__(self, *args: Any, **kwargs: Any) -> Dict[str, str]:
assert self._inst, 'unbound method'
return getattr(self._inst, self.method)(*args, **kwargs)
@@ -768,23 +803,23 @@ class OptionsProviderMixIn(object):
# those attributes should be overridden
priority = -1
name = 'default'
- options = ()
+ options: Tuple = ()
level = 0
- def __init__(self):
+ def __init__(self) -> None:
self.config = optik_ext.Values()
- for option in self.options:
+ for option_tuple in self.options:
try:
- option, optdict = option
+ option, optdict = option_tuple
except ValueError:
- raise Exception('Bad option: %r' % option)
+ raise Exception('Bad option: %s' % str(option_tuple))
if isinstance(optdict.get('default'), Method):
optdict['default'].bind(self)
elif isinstance(optdict.get('callback'), Method):
optdict['callback'].bind(self)
self.load_defaults()
- def load_defaults(self):
+ def load_defaults(self) -> None:
"""initialize the provider using default values"""
for opt, optdict in self.options:
action = optdict.get('action')
@@ -883,8 +918,10 @@ class OptionsProviderMixIn(object):
for option in self.options:
if option[0] == opt:
return option[1]
+ # mypy: Argument 2 to "OptionError" has incompatible type "str"; expected "Option"
+ # seems to be working?
raise OptionError('no such option %s in section %r'
- % (opt, self.name), opt)
+ % (opt, self.name), opt) # type: ignore
def all_options(self):
@@ -900,17 +937,19 @@ class OptionsProviderMixIn(object):
for option, optiondict, value in options:
yield section, option, optiondict
- def options_by_section(self):
+ def options_by_section(self) -> Iterator[Any]:
"""return an iterator on options grouped by section
(section, [list of (optname, optdict, optvalue)])
"""
- sections = {}
+ sections: Dict[str, List[Tuple[str, Dict[str, Any], Any]]] = {}
for optname, optdict in self.options:
sections.setdefault(optdict.get('group'), []).append(
(optname, optdict, self.option_value(optname)))
if None in sections:
- yield None, sections.pop(None)
+ # mypy: No overload variant of "pop" of "MutableMapping" matches argument type "None"
+ # it actually works
+ yield None, sections.pop(None) # type: ignore
for section, options in sorted(sections.items()):
yield section.upper(), options
@@ -926,14 +965,14 @@ class ConfigurationMixIn(OptionsManagerMixIn, OptionsProviderMixIn):
"""basic mixin for simple configurations which don't need the
manager / providers model
"""
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
if not args:
kwargs.setdefault('usage', '')
kwargs.setdefault('quiet', 1)
OptionsManagerMixIn.__init__(self, *args, **kwargs)
OptionsProviderMixIn.__init__(self)
if not getattr(self, 'option_groups', None):
- self.option_groups = []
+ self.option_groups: List[Tuple[Any, str]] = []
for option, optdict in self.options:
try:
gdef = (optdict['group'].upper(), '')
diff --git a/logilab/common/date.py b/logilab/common/date.py
index cdf2317..2d2ed22 100644
--- a/logilab/common/date.py
+++ b/logilab/common/date.py
@@ -27,6 +27,8 @@ from locale import getlocale, LC_TIME
from datetime import date, time, datetime, timedelta
from time import strptime as time_strptime
from calendar import monthrange, timegm
+from typing import Union, List, Any, Iterator, Optional, Generator
+
try:
from mx.DateTime import RelativeDateTime, Date, DateTimeType
@@ -90,13 +92,13 @@ FRENCH_MOBILE_HOLIDAYS = {
# XXX this implementation cries for multimethod dispatching
-def get_step(dateobj, nbdays=1):
+def get_step(dateobj: Union[date, datetime], nbdays: int = 1) -> timedelta:
# assume date is either a python datetime or a mx.DateTime object
if isinstance(dateobj, date):
return ONEDAY * nbdays
return nbdays # mx.DateTime is ok with integers
-def datefactory(year, month, day, sampledate):
+def datefactory(year: int, month: int, day: int, sampledate: Union[date, datetime]) -> Union[date, datetime]:
# assume date is either a python datetime or a mx.DateTime object
if isinstance(sampledate, datetime):
return datetime(year, month, day)
@@ -104,20 +106,23 @@ def datefactory(year, month, day, sampledate):
return date(year, month, day)
return Date(year, month, day)
-def weekday(dateobj):
+def weekday(dateobj: Union[date, datetime]) -> int:
# assume date is either a python datetime or a mx.DateTime object
if isinstance(dateobj, date):
return dateobj.weekday()
return dateobj.day_of_week
-def str2date(datestr, sampledate):
+def str2date(datestr: str, sampledate: Union[date, datetime]) -> Union[date, datetime]:
# NOTE: datetime.strptime is not an option until we drop py2.4 compat
year, month, day = [int(chunk) for chunk in datestr.split('-')]
return datefactory(year, month, day, sampledate)
-def days_between(start, end):
+def days_between(start: Union[date, datetime], end: Union[date, datetime]) -> int:
if isinstance(start, date):
- delta = end - start
+ # mypy: No overload variant of "__sub__" of "datetime" matches argument type "date"
+ # we ensure that end is a date
+ assert isinstance(end, date)
+ delta = end - start # type: ignore
# datetime.timedelta.days is always an integer (floored)
if delta.seconds:
return delta.days + 1
@@ -125,7 +130,7 @@ def days_between(start, end):
else:
return int(math.ceil((end - start).days))
-def get_national_holidays(begin, end):
+def get_national_holidays(begin: Union[date, datetime], end: Union[date, datetime]) -> Union[List[date], List[datetime]]:
"""return french national days off between begin and end"""
begin = datefactory(begin.year, begin.month, begin.day, begin)
end = datefactory(end.year, end.month, end.day, end)
@@ -138,7 +143,7 @@ def get_national_holidays(begin, end):
holidays.append(date)
return [day for day in holidays if begin <= day < end]
-def add_days_worked(start, days):
+def add_days_worked(start: date, days: int) -> date:
"""adds date but try to only take days worked into account"""
step = get_step(start)
weeks, plus = divmod(days, 5)
@@ -151,7 +156,7 @@ def add_days_worked(start, days):
end += (2 * step)
return end
-def nb_open_days(start, end):
+def nb_open_days(start: Union[date, datetime], end: Union[date, datetime]) -> int:
assert start <= end
step = get_step(start)
days = days_between(start, end)
@@ -168,7 +173,8 @@ def nb_open_days(start, end):
return 0
return open_days
-def date_range(begin, end, incday=None, incmonth=None):
+
+def date_range(begin: date, end: date, incday: Optional[Any] = None, incmonth: Optional[bool] = None) -> Generator[date, Any, None]:
"""yields each date between begin and end
:param begin: the start date
@@ -193,13 +199,13 @@ def date_range(begin, end, incday=None, incmonth=None):
else:
incr = get_step(begin, incday or 1)
while begin < end:
- yield begin
- begin += incr
+ yield begin
+ begin += incr
# makes py datetime usable #####################################################
-ONEDAY = timedelta(days=1)
-ONEWEEK = timedelta(days=7)
+ONEDAY: timedelta = timedelta(days=1)
+ONEWEEK: timedelta = timedelta(days=7)
try:
strptime = datetime.strptime
@@ -211,7 +217,7 @@ except AttributeError: # py < 2.5
def strptime_time(value, format='%H:%M'):
return time(*time_strptime(value, format)[3:6])
-def todate(somedate):
+def todate(somedate: date) -> date:
"""return a date from a date (leaving unchanged) or a datetime"""
if isinstance(somedate, datetime):
return date(somedate.year, somedate.month, somedate.day)
@@ -234,10 +240,10 @@ def todatetime(somedate):
assert isinstance(somedate, (date, DateTimeType)), repr(somedate)
return datetime(somedate.year, somedate.month, somedate.day)
-def datetime2ticks(somedate):
+def datetime2ticks(somedate: Union[date, datetime]) -> int:
return timegm(somedate.timetuple()) * 1000 + int(getattr(somedate, 'microsecond', 0) / 1000)
-def ticks2datetime(ticks):
+def ticks2datetime(ticks: int) -> datetime:
miliseconds, microseconds = divmod(ticks, 1000)
try:
return datetime.fromtimestamp(miliseconds)
@@ -250,7 +256,7 @@ def ticks2datetime(ticks):
except (ValueError, OverflowError):
raise
-def days_in_month(somedate):
+def days_in_month(somedate: date) -> int:
return monthrange(somedate.year, somedate.month)[1]
def days_in_year(somedate):
@@ -266,7 +272,7 @@ def previous_month(somedate, nbmonth=1):
nbmonth -= 1
return somedate
-def next_month(somedate, nbmonth=1):
+def next_month(somedate: date, nbmonth: int = 1) -> date:
while nbmonth:
somedate = last_day(somedate) + ONEDAY
nbmonth -= 1
@@ -275,10 +281,10 @@ def next_month(somedate, nbmonth=1):
def first_day(somedate):
return date(somedate.year, somedate.month, 1)
-def last_day(somedate):
+def last_day(somedate: date) -> date:
return date(somedate.year, somedate.month, days_in_month(somedate))
-def ustrftime(somedate, fmt='%Y-%m-%d'):
+def ustrftime(somedate: datetime, fmt: str = '%Y-%m-%d') -> str:
"""like strftime, but returns a unicode string instead of an encoded
string which may be problematic with localized date.
"""
@@ -309,10 +315,11 @@ def ustrftime(somedate, fmt='%Y-%m-%d'):
fmt = re.sub('%([YmdHMS])', r'%(\1)02d', fmt)
return unicode(fmt) % fields
-def utcdatetime(dt):
+def utcdatetime(dt: datetime) -> datetime:
if dt.tzinfo is None:
return dt
- return (dt.replace(tzinfo=None) - dt.utcoffset())
+ # mypy: No overload variant of "__sub__" of "datetime" matches argument type "None"
+ return (dt.replace(tzinfo=None) - dt.utcoffset()) # type: ignore
def utctime(dt):
if dt.tzinfo is None:
diff --git a/logilab/common/debugger.py b/logilab/common/debugger.py
index 909169c..2df84ad 100644
--- a/logilab/common/debugger.py
+++ b/logilab/common/debugger.py
@@ -34,7 +34,10 @@ __docformat__ = "restructuredtext en"
try:
import readline
except ImportError:
- readline = None
+ # mypy: Incompatible types in assignment (expression has type "None",
+ # mypy: variable has type Module))
+ # conditional import
+ readline = None # type: ignore
import os
import os.path as osp
import sys
diff --git a/logilab/common/decorators.py b/logilab/common/decorators.py
index 8dd2e7f..27ed7ee 100644
--- a/logilab/common/decorators.py
+++ b/logilab/common/decorators.py
@@ -25,6 +25,8 @@ import sys
import types
from time import process_time, time
from inspect import isgeneratorfunction
+from mypy_extensions import NoReturn
+from typing import Any, Optional, Callable, Union
from inspect import getfullargspec
@@ -33,12 +35,13 @@ from logilab.common.compat import method_type
# XXX rewrite so we can use the decorator syntax when keyarg has to be specified
class cached_decorator(object):
- def __init__(self, cacheattr=None, keyarg=None):
+ def __init__(self, cacheattr: Optional[str] = None, keyarg: Optional[int] = None) -> None:
self.cacheattr = cacheattr
self.keyarg = keyarg
- def __call__(self, callableobj=None):
+ def __call__(self, callableobj: Optional[Callable] = None) -> Callable:
assert not isgeneratorfunction(callableobj), \
'cannot cache generator function: %s' % callableobj
+ assert callableobj is not None
if len(getfullargspec(callableobj).args) == 1 or self.keyarg == 0:
cache = _SingleValueCache(callableobj, self.cacheattr)
elif self.keyarg:
@@ -48,7 +51,7 @@ class cached_decorator(object):
return cache.closure()
class _SingleValueCache(object):
- def __init__(self, callableobj, cacheattr=None):
+ def __init__(self, callableobj: Callable, cacheattr: Optional[str] = None) -> None:
self.callable = callableobj
if cacheattr is None:
self.cacheattr = '_%s_cache_' % callableobj.__name__
@@ -64,10 +67,12 @@ class _SingleValueCache(object):
setattr(self, __me.cacheattr, value)
return value
- def closure(self):
+ def closure(self) -> Callable:
def wrapped(*args, **kwargs):
return self.__call__(*args, **kwargs)
- wrapped.cache_obj = self
+ # mypy: "Callable[[VarArg(Any), KwArg(Any)], Any]" has no attribute "cache_obj"
+ # dynamic attribute for magic
+ wrapped.cache_obj = self # type: ignore
try:
wrapped.__doc__ = self.callable.__doc__
wrapped.__name__ = self.callable.__name__
@@ -97,7 +102,7 @@ class _MultiValuesCache(_SingleValueCache):
return _cache[args]
class _MultiValuesKeyArgCache(_MultiValuesCache):
- def __init__(self, callableobj, keyarg, cacheattr=None):
+ def __init__(self, callableobj: Callable, keyarg: int, cacheattr: Optional[str] = None) -> None:
super(_MultiValuesKeyArgCache, self).__init__(callableobj, cacheattr)
self.keyarg = keyarg
@@ -111,7 +116,7 @@ class _MultiValuesKeyArgCache(_MultiValuesCache):
return _cache[key]
-def cached(callableobj=None, keyarg=None, **kwargs):
+def cached(callableobj: Optional[Callable] = None, keyarg: Optional[int] = None, **kwargs: Any) -> Union[Callable, cached_decorator]:
"""Simple decorator to cache result of method call."""
kwargs['keyarg'] = keyarg
decorator = cached_decorator(**kwargs)
@@ -145,8 +150,10 @@ class cachedproperty(object):
wrapped)
self.wrapped = wrapped
+ # mypy: Signature of "__doc__" incompatible with supertype "object"
+ # but this works?
@property
- def __doc__(self):
+ def __doc__(self) -> str: # type: ignore
doc = getattr(self.wrapped, '__doc__', None)
return ('<wrapped by the cachedproperty decorator>%s'
% ('\n%s' % doc if doc else ''))
@@ -241,7 +248,7 @@ def locked(acquire, release):
having called acquire(self) et will call release(self) afterwards.
"""
def decorator(f):
- def wrapper(self, *args, **kwargs):
+ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
acquire(self)
try:
return f(self, *args, **kwargs)
@@ -251,7 +258,7 @@ def locked(acquire, release):
return decorator
-def monkeypatch(klass, methodname=None):
+def monkeypatch(klass: type, methodname: Optional[str] = None) -> Callable:
"""Decorator extending class with the decorated callable. This is basically
a syntactic sugar vs class assignment.
diff --git a/logilab/common/fileutils.py b/logilab/common/fileutils.py
index 93439d3..102cd7c 100644
--- a/logilab/common/fileutils.py
+++ b/logilab/common/fileutils.py
@@ -36,13 +36,15 @@ from os.path import isabs, isdir, islink, split, exists, normpath, join
from os.path import abspath
from os import sep, mkdir, remove, listdir, stat, chmod, walk
from stat import ST_MODE, S_IWRITE
+from typing import Optional, List, Tuple
+from _io import TextIOWrapper
from logilab.common import STD_BLACKLIST as BASE_BLACKLIST, IGNORED_EXTENSIONS
from logilab.common.shellutils import find
from logilab.common.deprecation import deprecated
from logilab.common.compat import FileIO
-def first_level_directory(path):
+def first_level_directory(path: str) -> str:
"""Return the first level directory of a path.
>>> first_level_directory('home/syt/work')
@@ -73,7 +75,7 @@ def abspath_listdir(path):
return [join(path, filename) for filename in listdir(path)]
-def is_binary(filename):
+def is_binary(filename: str) -> int:
"""Return true if filename may be a binary file, according to it's
extension.
@@ -86,12 +88,14 @@ def is_binary(filename):
isn't beginning by text/)
"""
try:
- return not mimetypes.guess_type(filename)[0].startswith('text')
+ # mypy: Item "None" of "Optional[str]" has no attribute "startswith"
+ # it's handle by the exception
+ return not mimetypes.guess_type(filename)[0].startswith('text') # type: ignore
except AttributeError:
return 1
-def write_open_mode(filename):
+def write_open_mode(filename: str) -> str:
"""Return the write mode that should used to open file.
:type filename: str
@@ -143,7 +147,7 @@ class ProtectedFile(FileIO):
- on close()/del(), write/append the StringIO content to the file and
do the chmod only once
"""
- def __init__(self, filepath, mode):
+ def __init__(self, filepath: str, mode: str) -> None:
self.original_mode = stat(filepath)[ST_MODE]
self.mode_changed = False
if mode in ('w', 'a', 'wb', 'ab'):
@@ -152,19 +156,19 @@ class ProtectedFile(FileIO):
self.mode_changed = True
FileIO.__init__(self, filepath, mode)
- def _restore_mode(self):
+ def _restore_mode(self) -> None:
"""restores the original mode if needed"""
if self.mode_changed:
chmod(self.name, self.original_mode)
# Don't re-chmod in case of several restore
self.mode_changed = False
- def close(self):
+ def close(self) -> None:
"""restore mode before closing"""
self._restore_mode()
FileIO.close(self)
- def __del__(self):
+ def __del__(self) -> None:
if not self.closed:
self.close()
@@ -265,7 +269,7 @@ def norm_open(path):
return open(path, 'U')
norm_open = deprecated("use \"open(path, 'U')\"")(norm_open)
-def lines(path, comments=None):
+def lines(path: str, comments: Optional[str] = None) -> List[str]:
"""Return a list of non empty lines in the file located at `path`.
:type path: str
@@ -287,7 +291,7 @@ def lines(path, comments=None):
return stream_lines(stream, comments)
-def stream_lines(stream, comments=None):
+def stream_lines(stream: TextIOWrapper, comments: Optional[str] = None) -> List[str]:
"""Return a list of non empty lines in the given `stream`.
:type stream: object implementing 'xreadlines' or 'readlines'
@@ -317,9 +321,9 @@ def stream_lines(stream, comments=None):
return result
-def export(from_dir, to_dir,
- blacklist=BASE_BLACKLIST, ignore_ext=IGNORED_EXTENSIONS,
- verbose=0):
+def export(from_dir: str, to_dir: str,
+ blacklist: Tuple[str, str, str, str, str, str, str, str] = BASE_BLACKLIST, ignore_ext: Tuple[str, str, str, str, str, str] = IGNORED_EXTENSIONS,
+ verbose: int = 0) -> None:
"""Make a mirror of `from_dir` in `to_dir`, omitting directories and
files listed in the black list or ending with one of the given
extensions.
diff --git a/logilab/common/graph.py b/logilab/common/graph.py
index cef1c98..fffa172 100644
--- a/logilab/common/graph.py
+++ b/logilab/common/graph.py
@@ -30,6 +30,7 @@ import sys
import tempfile
import codecs
import errno
+from typing import Dict, List, Tuple, Union, Any, Optional, Set, TypeVar, Iterable
def escape(value):
"""Make <value> usable in a dot file."""
@@ -176,7 +177,12 @@ class GraphGenerator:
class UnorderableGraph(Exception):
pass
-def ordered_nodes(graph):
+
+V = TypeVar("V")
+_Graph = Dict[V, List[V]]
+
+
+def ordered_nodes(graph: _Graph) -> Tuple[V, ...]:
"""takes a dependency graph dict as arguments and return an ordered tuple of
nodes starting with nodes without dependencies and up to the outermost node.
@@ -185,84 +191,115 @@ def ordered_nodes(graph):
Also the given graph dict will be emptied.
"""
# check graph consistency
- cycles = get_cycles(graph)
+ cycles: List[List[V]] = get_cycles(graph)
+
if cycles:
- cycles = '\n'.join([' -> '.join(cycle) for cycle in cycles])
- raise UnorderableGraph('cycles in graph: %s' % cycles)
+ bad_cycles = '\n'.join([' -> '.join(map(str, cycle)) for cycle in cycles])
+ raise UnorderableGraph('cycles in graph: %s' % bad_cycles)
+
vertices = set(graph)
to_vertices = set()
+
for edges in graph.values():
to_vertices |= set(edges)
+
missing_vertices = to_vertices - vertices
if missing_vertices:
raise UnorderableGraph('missing vertices: %s' % ', '.join(missing_vertices))
+
# order vertices
order = []
order_set = set()
old_len = None
+
while graph:
if old_len == len(graph):
raise UnorderableGraph('unknown problem with %s' % graph)
+
old_len = len(graph)
deps_ok = []
+
for node, node_deps in graph.items():
for dep in node_deps:
if dep not in order_set:
break
else:
deps_ok.append(node)
+
order.append(deps_ok)
order_set |= set(deps_ok)
+
for node in deps_ok:
del graph[node]
+
result = []
+
for grp in reversed(order):
result.extend(sorted(grp))
+
return tuple(result)
-def get_cycles(graph_dict, vertices=None):
+def get_cycles(graph_dict: _Graph,
+ vertices: Optional[Iterable] = None) -> List[List]:
'''given a dictionary representing an ordered graph (i.e. key are vertices
and values is a list of destination vertices representing edges), return a
list of detected cycles
'''
if not graph_dict:
- return ()
- result = []
+ return []
+
+ result: List[List] = []
if vertices is None:
vertices = graph_dict.keys()
+
for vertice in vertices:
_get_cycles(graph_dict, [], set(), result, vertice)
+
return result
-def _get_cycles(graph_dict, path, visited, result, vertice):
+
+def _get_cycles(graph_dict: _Graph,
+ path: List,
+ visited: Set,
+ result: List[List],
+ vertice: V) -> None:
"""recursive function doing the real work for get_cycles"""
if vertice in path:
cycle = [vertice]
+
for node in path[::-1]:
if node == vertice:
break
+
cycle.insert(0, node)
+
# make a canonical representation
start_from = min(cycle)
index = cycle.index(start_from)
cycle = cycle[index:] + cycle[0:index]
+
# append it to result if not already in
- if not cycle in result:
+ if cycle not in result:
result.append(cycle)
return
+
path.append(vertice)
+
try:
for node in graph_dict[vertice]:
# don't check already visited nodes again
if node not in visited:
_get_cycles(graph_dict, path, visited, result, node)
visited.add(node)
+
except KeyError:
pass
+
path.pop()
-def has_path(graph_dict, fromnode, tonode, path=None):
+
+def has_path(graph_dict: Dict[str, List[str]], fromnode: str, tonode: str, path: Optional[List[str]] = None) -> Optional[List[str]]:
"""generic function taking a simple graph definition as a dictionary, with
node has key associated to a list of nodes directly reachable from it.
diff --git a/logilab/common/interface.py b/logilab/common/interface.py
index ab0e529..8248a27 100644
--- a/logilab/common/interface.py
+++ b/logilab/common/interface.py
@@ -29,11 +29,11 @@ __docformat__ = "restructuredtext en"
class Interface(object):
"""Base class for interfaces."""
@classmethod
- def is_implemented_by(cls, instance):
+ def is_implemented_by(cls, instance: type) -> bool:
return implements(instance, cls)
-def implements(obj, interface):
+def implements(obj: type, interface: type) -> bool:
"""Return true if the give object (maybe an instance or class) implements
the interface.
"""
@@ -46,7 +46,7 @@ def implements(obj, interface):
return False
-def extend(klass, interface, _recurs=False):
+def extend(klass: type, interface: type, _recurs: bool = False) -> None:
"""Add interface to klass'__implements__ if not already implemented in.
If klass is subclassed, ensure subclasses __implements__ it as well.
@@ -55,14 +55,14 @@ def extend(klass, interface, _recurs=False):
"""
if not implements(klass, interface):
try:
- kimplements = klass.__implements__
+ kimplements = klass.__implements__ # type: ignore
kimplementsklass = type(kimplements)
kimplements = list(kimplements)
except AttributeError:
kimplementsklass = tuple
kimplements = []
kimplements.append(interface)
- klass.__implements__ = kimplementsklass(kimplements)
+ klass.__implements__ = kimplementsklass(kimplements) #type: ignore
for subklass in klass.__subclasses__():
extend(subklass, interface, _recurs=True)
elif _recurs:
diff --git a/logilab/common/modutils.py b/logilab/common/modutils.py
index 34419bd..76c4ac4 100644
--- a/logilab/common/modutils.py
+++ b/logilab/common/modutils.py
@@ -35,13 +35,18 @@ import os
from os.path import (splitext, join, abspath, isdir, dirname, exists,
basename, expanduser, normcase, realpath)
from imp import find_module, load_module, C_BUILTIN, PY_COMPILED, PKG_DIRECTORY
-from distutils.sysconfig import get_config_var, get_python_lib, get_python_version
+from distutils.sysconfig import get_config_var, get_python_lib
from distutils.errors import DistutilsPlatformError
+from typing import Dict, List, Optional, Any, Tuple, Union, Sequence
+from types import ModuleType
+from _frozen_importlib_external import FileFinder
try:
import zipimport
except ImportError:
- zipimport = None
+ # mypy: Incompatible types in assignment (expression has type "None", variable has type Module)
+ # conditional import
+ zipimport = None # type: ignore
ZIPFILE = object()
@@ -86,8 +91,8 @@ class LazyObject(object):
def _getobj(self):
if self._imported is None:
- self._imported = getattr(load_module_from_name(self.module),
- self.obj)
+ self._imported = getattr(load_module_from_name(self.module),
+ self.obj)
return self._imported
def __getattribute__(self, attr):
@@ -100,7 +105,7 @@ class LazyObject(object):
return self._getobj()(*args, **kwargs)
-def load_module_from_name(dotted_name, path=None, use_sys=True):
+def load_module_from_name(dotted_name: str, path: Optional[Any] = None, use_sys: int = True) -> ModuleType:
"""Load a Python module from its name.
:type dotted_name: str
@@ -122,10 +127,13 @@ def load_module_from_name(dotted_name, path=None, use_sys=True):
:rtype: module
:return: the loaded module
"""
- return load_module_from_modpath(dotted_name.split('.'), path, use_sys)
+ module = load_module_from_modpath(dotted_name.split('.'), path, use_sys)
+ if module is None:
+ raise ImportError("module %s doesn't exist" % dotted_name)
+ return module
-def load_module_from_modpath(parts, path=None, use_sys=True):
+def load_module_from_modpath(parts: List[str], path: Optional[Any] = None, use_sys: int = True) -> Optional[ModuleType]:
"""Load a python module from its splitted name.
:type parts: list(str) or tuple(str)
@@ -208,7 +216,7 @@ def load_module_from_file(filepath, path=None, use_sys=True, extrapath=None):
return load_module_from_modpath(modpath, path, use_sys)
-def _check_init(path, mod_path):
+def _check_init(path: str, mod_path: List[str]) -> bool:
"""check there are some __init__.py all along the way"""
modpath = []
for part in mod_path:
@@ -219,13 +227,13 @@ def _check_init(path, mod_path):
return True
-def _canonicalize_path(path):
+def _canonicalize_path(path: str) -> str:
return realpath(expanduser(path))
@deprecated('you should avoid using modpath_from_file()')
-def modpath_from_file(filename, extrapath=None):
+def modpath_from_file(filename: str, extrapath: Optional[Dict[str, str]] = None) -> List[str]:
"""DEPRECATED: doens't play well with symlinks and sys.meta_path
Given a file path return the corresponding splitted module's name
@@ -269,7 +277,7 @@ def modpath_from_file(filename, extrapath=None):
filename, ', \n'.join(sys.path)))
-def file_from_modpath(modpath, path=None, context_file=None):
+def file_from_modpath(modpath: List[str], path: Optional[Any] = None, context_file: Optional[str] = None) -> Optional[str]:
"""given a mod path (i.e. splitted module / package name), return the
corresponding file, giving priority to source file over precompiled
file if it exists
@@ -299,6 +307,7 @@ def file_from_modpath(modpath, path=None, context_file=None):
the path to the module's file or None if it's an integrated
builtin module such as 'sys'
"""
+ context: Optional[str]
if context_file is not None:
context = dirname(context_file)
else:
@@ -316,7 +325,7 @@ def file_from_modpath(modpath, path=None, context_file=None):
-def get_module_part(dotted_name, context_file=None):
+def get_module_part(dotted_name: str, context_file: Optional[str] = None) -> str:
"""given a dotted name return the module part of the name :
>>> get_module_part('logilab.common.modutils.get_module_part')
@@ -354,7 +363,7 @@ def get_module_part(dotted_name, context_file=None):
raise ImportError(dotted_name)
return parts[0]
# don't use += or insert, we want a new list to be created !
- path = None
+ path: Optional[List] = None
starti = 0
if parts[0] == '':
assert context_file is not None, \
@@ -363,6 +372,7 @@ def get_module_part(dotted_name, context_file=None):
starti = 1
while parts[starti] == '': # for all further dots: change context
starti += 1
+ assert context_file is not None
context_file = dirname(context_file)
for i in range(starti, len(parts)):
try:
@@ -375,7 +385,7 @@ def get_module_part(dotted_name, context_file=None):
return dotted_name
-def get_modules(package, src_directory, blacklist=STD_BLACKLIST):
+def get_modules(package: str, src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST) -> List[str]:
"""given a package directory return a list of all available python
modules in the package and its subpackages
@@ -415,7 +425,7 @@ def get_modules(package, src_directory, blacklist=STD_BLACKLIST):
-def get_module_files(src_directory, blacklist=STD_BLACKLIST):
+def get_module_files(src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST) -> List[str]:
"""given a package directory return a list of all available python
module's files in the package and its subpackages
@@ -447,7 +457,7 @@ def get_module_files(src_directory, blacklist=STD_BLACKLIST):
return files
-def get_source_file(filename, include_no_ext=False):
+def get_source_file(filename: str, include_no_ext: bool = False) -> str:
"""given a python module's file name return the matching source file
name (the filename will be returned identically if it's a already an
absolute path to a python source file...)
@@ -505,7 +515,7 @@ def is_python_source(filename):
return splitext(filename)[1][1:] in PY_SOURCE_EXTS
-def is_standard_module(modname, std_path=(STD_LIB_DIR,)):
+def is_standard_module(modname: str, std_path: Union[List[str], Tuple[str]] = (STD_LIB_DIR,)) -> bool:
"""try to guess if a module is a standard python module (by default,
see `std_path` parameter's description)
@@ -547,7 +557,7 @@ def is_standard_module(modname, std_path=(STD_LIB_DIR,)):
-def is_relative(modname, from_file):
+def is_relative(modname: str, from_file: str) -> bool:
"""return true if the given module name is relative to the given
file name
@@ -575,7 +585,7 @@ def is_relative(modname, from_file):
# internal only functions #####################################################
-def _file_from_modpath(modpath, path=None, context=None):
+def _file_from_modpath(modpath: List[str], path: Optional[Any] = None, context: Optional[str] = None) -> Optional[str]:
"""given a mod path (i.e. splitted module / package name), return the
corresponding file
@@ -592,6 +602,7 @@ def _file_from_modpath(modpath, path=None, context=None):
mtype, mp_filename = _module_file(modpath, path)
if mtype == PY_COMPILED:
try:
+ assert mp_filename is not None
return get_source_file(mp_filename)
except NoSourceFile:
return mp_filename
@@ -599,10 +610,11 @@ def _file_from_modpath(modpath, path=None, context=None):
# integrated builtin module
return None
elif mtype == PKG_DIRECTORY:
+ assert mp_filename is not None
mp_filename = _has_init(mp_filename)
return mp_filename
-def _search_zip(modpath, pic):
+def _search_zip(modpath: List[str], pic: Dict[str, Optional[FileFinder]]) -> Tuple[object, str, str]:
for filepath, importer in pic.items():
if importer is not None:
if importer.find_module(modpath[0]):
@@ -615,15 +627,19 @@ def _search_zip(modpath, pic):
try:
import pkg_resources
except ImportError:
- pkg_resources = None
+ # mypy: Incompatible types in assignment (expression has type "None", variable has type Module)
+ # conditional import
+ pkg_resources = None # type: ignore
-def _is_namespace(modname):
+def _is_namespace(modname: str) -> bool:
+ # mypy: Module has no attribute "_namespace_packages"; maybe "fixup_namespace_packages"?"
+ # but is still has? or is it a failure from python3 port?
return (pkg_resources is not None
- and modname in pkg_resources._namespace_packages)
+ and modname in pkg_resources._namespace_packages) # type: ignore
-def _module_file(modpath, path=None):
+def _module_file(modpath: List[str], path: Optional[List[str]] = None) -> Tuple[Union[int, object], Optional[str]]:
"""get a module type / file path
:type modpath: list or tuple
@@ -643,7 +659,7 @@ def _module_file(modpath, path=None):
# egg support compat
try:
pic = sys.path_importer_cache
- _path = (path is None and sys.path or path)
+ _path = path if path is not None else sys.path
for __path in _path:
if not __path in pic:
try:
@@ -660,10 +676,14 @@ def _module_file(modpath, path=None):
module = sys.modules[modpath.pop(0)]
# use list() to protect against _NamespacePath instance we get with python 3, which
# find_module later doesn't like
- path = list(module.__path__)
+ # mypy: Module has no attribute "__path__"
+ # I guess it does thanks to logilab's magic?
+ path = list(module.__path__) # type: ignore
if not modpath:
return C_BUILTIN, None
+
imported = []
+
while modpath:
modname = modpath[0]
# take care to changes in find_module implementation wrt builtin modules
@@ -719,7 +739,7 @@ def _module_file(modpath, path=None):
path = [mp_filename]
return mtype, mp_filename
-def _is_python_file(filename):
+def _is_python_file(filename: str) -> bool:
"""return true if the given filename should be considered as a python file
.pyc and .pyo are ignored
@@ -730,12 +750,14 @@ def _is_python_file(filename):
return False
-def _has_init(directory):
+def _has_init(directory: str) -> Optional[str]:
"""if the given directory has a valid __init__ file, return its path,
else return None
"""
mod_or_pack = join(directory, '__init__')
+
for ext in PY_SOURCE_EXTS + ('pyc', 'pyo'):
if exists(mod_or_pack + '.' + ext):
return mod_or_pack + '.' + ext
+
return None
diff --git a/logilab/common/optik_ext.py b/logilab/common/optik_ext.py
index 07365a7..11e2155 100644
--- a/logilab/common/optik_ext.py
+++ b/logilab/common/optik_ext.py
@@ -55,6 +55,11 @@ import sys
import time
from copy import copy
from os.path import exists
+from logilab.common import attrdict
+
+from typing import Any, Union, List, Optional, Tuple, Dict
+from optparse import Values, IndentedHelpFormatter, OptionGroup
+from _io import StringIO
# python >= 2.3
from optparse import OptionParser as BaseParser, Option as BaseOption, \
@@ -83,7 +88,7 @@ def check_regexp(option, opt, value):
raise OptionValueError(
"option %s: invalid regexp value: %r" % (opt, value))
-def check_csv(option, opt, value):
+def check_csv(option: Optional['Option'], opt: str, value: Union[List[str], Tuple[str, ...], str]) -> Union[List[str], Tuple[str, ...]]:
"""check a csv value by trying to split it
return the list of separated values
"""
@@ -95,7 +100,7 @@ def check_csv(option, opt, value):
raise OptionValueError(
"option %s: invalid csv value: %r" % (opt, value))
-def check_yn(option, opt, value):
+def check_yn(option: Optional['Option'], opt: str, value: Union[bool, str]) -> bool:
"""check a yn value
return true for yes and false for no
"""
@@ -108,18 +113,21 @@ def check_yn(option, opt, value):
msg = "option %s: invalid yn value %r, should be in (y, yes, n, no)"
raise OptionValueError(msg % (opt, value))
-def check_named(option, opt, value):
+def check_named(option: Optional[Any], opt: str, value: Union[Dict[str, str], str]) -> Dict[str, str]:
"""check a named value
return a dictionary containing (name, value) associations
"""
if isinstance(value, dict):
return value
- values = []
+ values: List[Tuple[str, str]] = []
for value in check_csv(option, opt, value):
+ # mypy: Argument 1 to "append" of "list" has incompatible type "List[str]";
+ # mypy: expected "Tuple[str, str]"
+ # we know that the split will give a 2 items list
if value.find('=') != -1:
- values.append(value.split('=', 1))
+ values.append(value.split('=', 1)) # type: ignore
elif value.find(':') != -1:
- values.append(value.split(':', 1))
+ values.append(value.split(':', 1)) # type: ignore
if values:
return dict(values)
msg = "option %s: invalid named value %r, should be <NAME>=<VALUE> or \
@@ -173,10 +181,12 @@ def check_time(option, opt, value):
return value
return apply_units(value, TIME_UNITS)
-def check_bytes(option, opt, value):
+def check_bytes(option: Optional['Option'], opt: str, value: Any) -> int:
if hasattr(value, '__int__'):
return value
- return apply_units(value, BYTE_UNITS, final=int)
+ # mypy: Incompatible return value type (got "Union[float, int]", expected "int")
+ # we force "int" using "final=int"
+ return apply_units(value, BYTE_UNITS, final=int) # type: ignore
class Option(BaseOption):
@@ -201,50 +211,62 @@ class Option(BaseOption):
TYPES += ('date',)
TYPE_CHECKER['date'] = check_date
- def __init__(self, *opts, **attrs):
+ def __init__(self, *opts: str, **attrs: Any) -> None:
BaseOption.__init__(self, *opts, **attrs)
- if hasattr(self, "hide") and self.hide:
+ # mypy: "Option" has no attribute "hide"
+ # we test that in the if
+ if hasattr(self, "hide") and self.hide: # type: ignore
self.help = SUPPRESS_HELP
- def _check_choice(self):
+ def _check_choice(self) -> None:
"""FIXME: need to override this due to optik misdesign"""
if self.type in ("choice", "multiple_choice"):
- if self.choices is None:
+ # mypy: "Option" has no attribute "choices"
+ # we know that option of this type has this attribute
+ if self.choices is None: # type: ignore
raise OptionError(
"must supply a list of choices for type 'choice'", self)
- elif not isinstance(self.choices, (tuple, list)):
+ elif not isinstance(self.choices, (tuple, list)): # type: ignore
raise OptionError(
"choices must be a list of strings ('%s' supplied)"
- % str(type(self.choices)).split("'")[1], self)
- elif self.choices is not None:
+ % str(type(self.choices)).split("'")[1], self) # type: ignore
+ elif self.choices is not None: # type: ignore
raise OptionError(
"must not supply choices for type %r" % self.type, self)
- BaseOption.CHECK_METHODS[2] = _check_choice
-
+ # mypy: Unsupported target for indexed assignment
+ # black magic?
+ BaseOption.CHECK_METHODS[2] = _check_choice # type: ignore
- def process(self, opt, value, values, parser):
+ def process(self, opt: str, value: str, values: Values, parser: BaseParser) -> int:
# First, convert the value(s) to the right type. Howl if any
# value(s) are bogus.
value = self.convert_value(opt, value)
if self.type == 'named':
+ assert self.dest is not None
existant = getattr(values, self.dest)
if existant:
existant.update(value)
value = existant
- # And then take whatever action is expected of us.
+ # And then take whatever action is expected of us.
# This is a separate method to make life easier for
# subclasses to add new actions.
+ # mypy: Argument 2 to "take_action" of "Option" has incompatible type "Optional[str]";
+ # mypy: expected "str"
+ # is it ok?
return self.take_action(
- self.action, self.dest, opt, value, values, parser)
+ self.action, self.dest, opt, value, values, parser) # type: ignore
class OptionParser(BaseParser):
"""override optik.OptionParser to use our Option class
"""
- def __init__(self, option_class=Option, *args, **kwargs):
- BaseParser.__init__(self, option_class=Option, *args, **kwargs)
+ def __init__(self, option_class: type = Option, *args: Any, **kwargs: Any) -> None:
+ # mypy: Argument "option_class" to "__init__" of "OptionParser" has incompatible type
+ # mypy: "type"; expected "Option"
+ # mypy is doing really weird things with *args/**kwargs and looks buggy
+ BaseParser.__init__(self, option_class=option_class, *args, **kwargs) # type: ignore
- def format_option_help(self, formatter=None):
+ def format_option_help(self, formatter: Optional[HelpFormatter] = None) -> str:
if formatter is None:
formatter = self.formatter
outputlevel = getattr(formatter, 'output_level', 0)
@@ -256,7 +278,9 @@ class OptionParser(BaseParser):
result.append(OptionContainer.format_option_help(self, formatter))
result.append("\n")
for group in self.option_groups:
- if group.level <= outputlevel and (
+ # mypy: "OptionParser" has no attribute "level"
+ # but it has one no?
+ if group.level <= outputlevel and ( # type: ignore
group.description or level_options(group, outputlevel)):
result.append(group.format_help(formatter))
result.append("\n")
@@ -265,12 +289,16 @@ class OptionParser(BaseParser):
return "".join(result[:-1])
-OptionGroup.level = 0
+# mypy error: error: "Type[OptionGroup]" has no attribute "level"
+# monkeypatching
+OptionGroup.level = 0 # type: ignore
-def level_options(group, outputlevel):
+def level_options(group: BaseParser, outputlevel: int) -> List[BaseOption]:
+ # mypy: "Option" has no attribute "help"
+ # but it does
return [option for option in group.option_list
if (getattr(option, 'level', 0) or 0) <= outputlevel
- and not option.help is SUPPRESS_HELP]
+ and not option.help is SUPPRESS_HELP] # type: ignore
def format_option_help(self, formatter):
result = []
@@ -278,33 +306,42 @@ def format_option_help(self, formatter):
for option in level_options(self, outputlevel):
result.append(formatter.format_option(option))
return "".join(result)
-OptionContainer.format_option_help = format_option_help
+# mypy error: Cannot assign to a method
+# but we still do it because magic
+OptionContainer.format_option_help = format_option_help # type: ignore
class ManHelpFormatter(HelpFormatter):
"""Format help using man pages ROFF format"""
def __init__ (self,
- indent_increment=0,
- max_help_position=24,
- width=79,
- short_first=0):
+ indent_increment: int = 0,
+ max_help_position: int = 24,
+ width: int = 79,
+ short_first: int = 0) -> None:
HelpFormatter.__init__ (
self, indent_increment, max_help_position, width, short_first)
- def format_heading(self, heading):
+ def format_heading(self, heading: str) -> str:
return '.SH %s\n' % heading.upper()
def format_description(self, description):
return description
- def format_option(self, option):
+ def format_option(self, option: BaseParser) -> str:
try:
- optstring = option.option_strings
+ # mypy: "Option" has no attribute "option_strings"
+ # we handle if it doesn't
+ optstring = option.option_strings # type: ignore
except AttributeError:
optstring = self.format_option_strings(option)
- if option.help:
- help_text = self.expand_default(option)
+ # mypy: "OptionParser" has no attribute "help"
+ # it does
+ if option.help: # type: ignore
+ # mypy: Argument 1 to "expand_default" of "HelpFormatter" has incompatible type
+ # mypy: "OptionParser"; expected "Option"
+ # it still works?
+ help_text = self.expand_default(option) # type: ignore
help = ' '.join([l.strip() for l in help_text.splitlines()])
else:
help = ''
@@ -312,13 +349,9 @@ class ManHelpFormatter(HelpFormatter):
%s
''' % (optstring, help)
- def format_head(self, optparser, pkginfo, section=1):
+ def format_head(self, optparser: OptionParser, pkginfo: attrdict, section: int = 1) -> str:
long_desc = ""
- try:
- pgm = optparser._get_prog_name()
- except AttributeError:
- # py >= 2.4.X (dunno which X exactly, at least 2)
- pgm = optparser.get_prog_name()
+ pgm = optparser.get_prog_name()
short_desc = self.format_short_description(pgm, pkginfo.description)
if hasattr(pkginfo, "long_desc"):
long_desc = self.format_long_description(pgm, pkginfo.long_desc)
@@ -326,17 +359,17 @@ class ManHelpFormatter(HelpFormatter):
short_desc, self.format_synopsis(pgm),
long_desc)
- def format_title(self, pgm, section):
+ def format_title(self, pgm: str, section: int) -> str:
date = '-'.join([str(num) for num in time.localtime()[:3]])
return '.TH %s %s "%s" %s' % (pgm, section, date, pgm)
- def format_short_description(self, pgm, short_desc):
+ def format_short_description(self, pgm: str, short_desc: str) -> str:
return '''.SH NAME
.B %s
\- %s
''' % (pgm, short_desc.strip())
- def format_synopsis(self, pgm):
+ def format_synopsis(self, pgm: str) -> str:
return '''.SH SYNOPSIS
.B %s
[
@@ -357,7 +390,7 @@ class ManHelpFormatter(HelpFormatter):
%s
''' % (pgm, long_desc.strip())
- def format_tail(self, pkginfo):
+ def format_tail(self, pkginfo: attrdict) -> str:
tail = '''.SH SEE ALSO
/usr/share/doc/pythonX.Y-%s/
@@ -378,10 +411,12 @@ Please report bugs on the project\'s mailing list:
return tail
-def generate_manpage(optparser, pkginfo, section=1, stream=sys.stdout, level=0):
+def generate_manpage(optparser: OptionParser, pkginfo: attrdict, section: int = 1, stream: StringIO = sys.stdout, level: int = 0) -> None:
"""generate a man page from an optik parser"""
formatter = ManHelpFormatter()
- formatter.output_level = level
+ # mypy: "ManHelpFormatter" has no attribute "output_level"
+ # dynamic attribute?
+ formatter.output_level = level # type: ignore
formatter.parser = optparser
print(formatter.format_head(optparser, pkginfo, section), file=stream)
print(optparser.format_option_help(formatter), file=stream)
diff --git a/logilab/common/proc.py b/logilab/common/proc.py
index c27356c..30e9494 100644
--- a/logilab/common/proc.py
+++ b/logilab/common/proc.py
@@ -133,14 +133,9 @@ class ProcInfoLoader:
pass
-try:
- class ResourceError(BaseException):
- """Error raise when resource limit is reached"""
- limit = "Unknown Resource Limit"
-except NameError:
- class ResourceError(Exception):
- """Error raise when resource limit is reached"""
- limit = "Unknown Resource Limit"
+class ResourceError(Exception):
+ """Error raise when resource limit is reached"""
+ limit = "Unknown Resource Limit"
class XCPUError(ResourceError):
diff --git a/logilab/common/pytest.py b/logilab/common/pytest.py
index 5c62816..6819c01 100644
--- a/logilab/common/pytest.py
+++ b/logilab/common/pytest.py
@@ -116,13 +116,20 @@ FILE_RESTART = ".pytest.restart"
import os, sys, re
import os.path as osp
from time import process_time, time
+from re import Match
import warnings
import types
import inspect
import traceback
-from inspect import isgeneratorfunction, isclass
+from inspect import isgeneratorfunction, isclass, FrameInfo
from random import shuffle
from itertools import dropwhile
+# mypy error: Module 'unittest.runner' has no attribute '_WritelnDecorator'
+# but it does
+from unittest.runner import _WritelnDecorator # type: ignore
+from unittest.suite import TestSuite
+
+from typing import Callable, Any, Optional, List, Tuple, Generator, Union, Dict
from logilab.common.deprecation import deprecated
from logilab.common.fileutils import abspath_listdir
@@ -141,7 +148,8 @@ if not getattr(unittest_legacy, "__package__", None):
except ImportError:
sys.exit("You have to install python-unittest2 to use this module")
else:
- import unittest.suite as unittest_suite
+ # mypy: Name 'unittest_suite' already defined (possibly by an import))
+ import unittest.suite as unittest_suite # type: ignore
try:
import django
@@ -153,12 +161,12 @@ except ImportError:
CONF_FILE = 'pytestconf.py'
TESTFILE_RE = re.compile("^((unit)?test.*|smoketest)\.py$")
-def this_is_a_testfile(filename):
+def this_is_a_testfile(filename: str) -> Optional[Match]:
"""returns True if `filename` seems to be a test file"""
return TESTFILE_RE.match(osp.basename(filename))
TESTDIR_RE = re.compile("^(unit)?tests?$")
-def this_is_a_testdir(dirpath):
+def this_is_a_testdir(dirpath: str) -> Optional[Match]:
"""returns True if `filename` seems to be a test directory"""
return TESTDIR_RE.match(osp.basename(dirpath))
@@ -842,10 +850,10 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner):
self.batchmode = batchmode
self.options = options
- def _this_is_skipped(self, testedname):
+ def _this_is_skipped(self, testedname: str) -> bool:
return any([(pat in testedname) for pat in self.skipped_patterns])
- def _runcondition(self, test, skipgenerator=True):
+ def _runcondition(self, test: Callable, skipgenerator: bool = True) -> bool:
if isinstance(test, testlib.InnerTest):
testname = test.name
else:
@@ -876,7 +884,7 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner):
return self.does_match_tags(test)
- def does_match_tags(self, test):
+ def does_match_tags(self, test: Callable) -> bool:
if self.options is not None:
tags_pattern = getattr(self.options, 'tags_pattern', None)
if tags_pattern is not None:
@@ -886,7 +894,7 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner):
return tags.match(tags_pattern)
return True # no pattern
- def _makeResult(self):
+ def _makeResult(self) -> 'SkipAwareTestResult':
return SkipAwareTestResult(self.stream, self.descriptions,
self.verbosity, self.exitfirst,
self.pdbmode, self.cvg, self.colorize)
@@ -935,14 +943,14 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner):
class SkipAwareTestResult(unittest._TextTestResult):
- def __init__(self, stream, descriptions, verbosity,
- exitfirst=False, pdbmode=False, cvg=None, colorize=False):
+ def __init__(self, stream: _WritelnDecorator, descriptions: bool, verbosity: int,
+ exitfirst: bool = False, pdbmode: bool = False, cvg: Optional[Any] = None, colorize: bool = False) -> None:
super(SkipAwareTestResult, self).__init__(stream,
descriptions, verbosity)
- self.skipped = []
- self.debuggers = []
- self.fail_descrs = []
- self.error_descrs = []
+ self.skipped: List[Tuple[Any, Any]] = []
+ self.debuggers: List = []
+ self.fail_descrs: List = []
+ self.error_descrs: List = []
self.exitfirst = exitfirst
self.pdbmode = pdbmode
self.cvg = cvg
@@ -950,15 +958,15 @@ class SkipAwareTestResult(unittest._TextTestResult):
self.pdbclass = Debugger
self.verbose = verbosity > 1
- def descrs_for(self, flavour):
+ def descrs_for(self, flavour: str) -> List[Tuple[int, str]]:
return getattr(self, '%s_descrs' % flavour.lower())
- def _create_pdb(self, test_descr, flavour):
+ def _create_pdb(self, test_descr: str, flavour: str) -> None:
self.descrs_for(flavour).append( (len(self.debuggers), test_descr) )
if self.pdbmode:
self.debuggers.append(self.pdbclass(sys.exc_info()[2]))
- def _iter_valid_frames(self, frames):
+ def _iter_valid_frames(self, frames: List[FrameInfo]) -> Generator[FrameInfo, Any, None]:
"""only consider non-testlib frames when formatting traceback"""
lgc_testlib = osp.abspath(__file__)
std_testlib = osp.abspath(unittest.__file__)
@@ -1030,11 +1038,11 @@ class SkipAwareTestResult(unittest._TextTestResult):
elif self.dots:
self.stream.write('S')
- def printErrors(self):
+ def printErrors(self) -> None:
super(SkipAwareTestResult, self).printErrors()
self.printSkippedList()
- def printSkippedList(self):
+ def printSkippedList(self) -> None:
# format (test, err) compatible with unittest2
for test, err in self.skipped:
descr = self.getDescription(test)
@@ -1055,9 +1063,11 @@ class SkipAwareTestResult(unittest._TextTestResult):
from .decorators import monkeypatch
orig_call = testlib.TestCase.__call__
@monkeypatch(testlib.TestCase, '__call__')
-def call(self, result=None, runcondition=None, options=None):
+def call(self: Any, result: SkipAwareTestResult = None, runcondition: Optional[Callable] = None, options: Optional[Any] = None) -> None:
orig_call(self, result=result, runcondition=runcondition, options=options)
- if hasattr(options, "exitfirst") and options.exitfirst:
+ # mypy: Item "None" of "Optional[Any]" has no attribute "exitfirst"
+ # we check it first in the if
+ if hasattr(options, "exitfirst") and options.exitfirst: # type: ignore
# add this test to restart file
try:
restartfile = open(FILE_RESTART, 'a')
@@ -1103,18 +1113,18 @@ class NonStrictTestLoader(unittest.TestLoader):
'python test_foo.py test_bar' will run FooTC.test_bar1 and BarTC.test_bar2
"""
- def __init__(self):
+ def __init__(self) -> None:
self.skipped_patterns = ()
# some magic here to accept empty list by extending
# and to provide callable capability
- def loadTestsFromNames(self, names, module=None):
+ def loadTestsFromNames(self, names: List[str], module: type = None) -> TestSuite:
suites = []
for name in names:
suites.extend(self.loadTestsFromName(name, module))
return self.suiteClass(suites)
- def _collect_tests(self, module):
+ def _collect_tests(self, module: type) -> Dict[str, Tuple[type, List[str]]]:
tests = {}
for obj in vars(module).values():
if isclass(obj) and issubclass(obj, unittest.TestCase):
@@ -1182,10 +1192,12 @@ class NonStrictTestLoader(unittest.TestLoader):
if pattern in methodname]
return collected
- def _this_is_skipped(self, testedname):
- return any([(pat in testedname) for pat in self.skipped_patterns])
+ def _this_is_skipped(self, testedname: str) -> bool:
+ # mypy: Need type annotation for 'pat'
+ # doc doesn't say how to that in list comprehension
+ return any([(pat in testedname) for pat in self.skipped_patterns]) # type: ignore
- def getTestCaseNames(self, testCaseClass):
+ def getTestCaseNames(self, testCaseClass: type) -> List[str]:
"""Return a sorted sequence of method names found within testCaseClass
"""
is_skipped = self._this_is_skipped
@@ -1202,13 +1214,13 @@ class NonStrictTestLoader(unittest.TestLoader):
# It is used to monkeypatch the original implementation to support
# extra runcondition and options arguments (see in testlib.py)
-def _ts_run(self, result, runcondition=None, options=None):
+def _ts_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult:
self._wrapped_run(result, runcondition=runcondition, options=options)
self._tearDownPreviousClass(None, result)
self._handleModuleTearDown(result)
return result
-def _ts_wrapped_run(self, result, debug=False, runcondition=None, options=None):
+def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult:
for test in self:
if result.shouldStop:
break
@@ -1247,7 +1259,7 @@ if sys.version_info >= (2, 7):
# The function below implements a modified version of the
# TestSuite.run method that is provided with python 2.7, in
# unittest/suite.py
- def _ts_run(self, result, debug=False, runcondition=None, options=None):
+ def _ts_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult:
topLevel = False
if getattr(result, '_testRunEntered', False) is False:
result._testRunEntered = topLevel = True
diff --git a/logilab/common/registry.py b/logilab/common/registry.py
index 6f60ef9..d9ae11b 100644
--- a/logilab/common/registry.py
+++ b/logilab/common/registry.py
@@ -78,6 +78,7 @@ Exceptions
from __future__ import print_function
+
__docformat__ = "restructuredtext en"
import sys
@@ -87,9 +88,14 @@ import weakref
import traceback as tb
from os import listdir, stat
from os.path import join, isdir, exists
-from typing import Dict, Type, Optional, Union
+from typing import Dict, Type, Optional, Union, Sequence
from logging import getLogger
from warnings import warn
+from typing import List, Tuple, Any, Iterable, Callable
+from types import ModuleType
+from typing_extensions import TypedDict
+from _frozen_importlib import ModuleSpec
+from _frozen_importlib_external import SourceFileLoader
from logilab.common.modutils import modpath_from_file
from logilab.common.logging_ext import set_log_methods
@@ -99,7 +105,7 @@ from logilab.common.deprecation import deprecated
# selector base classes and operations ########################################
-def objectify_predicate(selector_func):
+def objectify_predicate(selector_func: Callable) -> Any:
"""Most of the time, a simple score function is enough to build a selector.
The :func:`objectify_predicate` decorator turn it into a proper selector
class::
@@ -119,7 +125,7 @@ def objectify_predicate(selector_func):
_PREDICATES: Dict[int, Type] = {}
-def wrap_predicates(decorator):
+def wrap_predicates(decorator: Callable) -> None:
for predicate in _PREDICATES.values():
if not '_decorators' in predicate.__dict__:
predicate._decorators = set()
@@ -158,31 +164,36 @@ class Predicate(object, metaclass=PredicateMetaClass):
# backward compatibility
return self.__class__.__name__
- def search_selector(self, selector):
+ def search_selector(self, selector: 'Predicate') -> Optional['Predicate']:
"""search for the given selector, selector instance or tuple of
selectors in the selectors tree. Return None if not found.
"""
if self is selector:
return self
if (isinstance(selector, type) or isinstance(selector, tuple)) and \
- isinstance(self, selector):
+ isinstance(self, selector):
return self
return None
def __str__(self):
return self.__class__.__name__
- def __and__(self, other):
+ def __and__(self, other: 'Predicate') -> 'AndPredicate':
return AndPredicate(self, other)
- def __rand__(self, other):
+
+ def __rand__(self, other: 'Predicate') -> 'AndPredicate':
return AndPredicate(other, self)
- def __iand__(self, other):
+
+ def __iand__(self, other: 'Predicate') -> 'AndPredicate':
return AndPredicate(self, other)
- def __or__(self, other):
+
+ def __or__(self, other: 'Predicate') -> 'OrPredicate':
return OrPredicate(self, other)
- def __ror__(self, other):
+
+ def __ror__(self, other: 'Predicate'):
return OrPredicate(other, self)
- def __ior__(self, other):
+
+ def __ior__(self, other: 'Predicate') -> 'OrPredicate':
return OrPredicate(self, other)
def __invert__(self):
@@ -201,7 +212,7 @@ class Predicate(object, metaclass=PredicateMetaClass):
class MultiPredicate(Predicate):
"""base class for compound selector classes"""
- def __init__(self, *selectors):
+ def __init__(self, *selectors: Any) -> None:
self.selectors = self.merge_selectors(selectors)
def __str__(self):
@@ -209,7 +220,7 @@ class MultiPredicate(Predicate):
','.join(str(s) for s in self.selectors))
@classmethod
- def merge_selectors(cls, selectors):
+ def merge_selectors(cls, selectors: Sequence[Predicate]) -> List[Predicate]:
"""deal with selector instanciation when necessary and merge
multi-selectors if possible:
@@ -231,7 +242,7 @@ class MultiPredicate(Predicate):
merged_selectors.append(selector)
return merged_selectors
- def search_selector(self, selector):
+ def search_selector(self, selector: Predicate) -> Optional[Predicate]:
"""search for the given selector or selector instance (or tuple of
selectors) in the selectors tree. Return None if not found
"""
@@ -247,7 +258,7 @@ class MultiPredicate(Predicate):
class AndPredicate(MultiPredicate):
"""and-chained selectors"""
- def __call__(self, cls, *args, **kwargs):
+ def __call__(self, cls: Optional[Any], *args: Any, **kwargs: Any) -> int:
score = 0
for selector in self.selectors:
partscore = selector(cls, *args, **kwargs)
@@ -259,7 +270,7 @@ class AndPredicate(MultiPredicate):
class OrPredicate(MultiPredicate):
"""or-chained selectors"""
- def __call__(self, cls, *args, **kwargs):
+ def __call__(self, cls: Optional[Any], *args: Any, **kwargs: Any) -> int:
for selector in self.selectors:
partscore = selector(cls, *args, **kwargs)
if partscore:
@@ -288,7 +299,7 @@ class yes(Predicate): # pylint: disable=C0103
Take care, `yes(0)` could be named 'no'...
"""
- def __init__(self, score=0.5):
+ def __init__(self, score: float = 0.5) -> None:
self.score = score
def __call__(self, *args, **kwargs):
@@ -338,7 +349,7 @@ class SelectAmbiguity(RegistryException):
"""
-def _modname_from_path(path, extrapath=None):
+def _modname_from_path(path: str, extrapath: Optional[Any] = None) -> str:
modpath = modpath_from_file(path, extrapath)
# omit '__init__' from package's name to avoid loading that module
# once for each name when it is imported by some other object
@@ -356,21 +367,26 @@ def _modname_from_path(path, extrapath=None):
return '.'.join(modpath)
-def _toload_info(path, extrapath, _toload=None):
+def _toload_info(path: List[str], extrapath: Optional[Any], _toload: Optional[Tuple[Dict[str, str], List]] = None) -> Tuple[Dict[str, str], List[Tuple[str, str]]]:
"""Return a dictionary of <modname>: <modpath> and an ordered list of
(file, module name) to load
"""
if _toload is None:
assert isinstance(path, list)
_toload = {}, []
+
for fileordir in path:
if isdir(fileordir) and exists(join(fileordir, '__init__.py')):
subfiles = [join(fileordir, fname) for fname in listdir(fileordir)]
+
_toload_info(subfiles, extrapath, _toload)
+
elif fileordir[-3:] == '.py':
modname = _modname_from_path(fileordir, extrapath)
_toload[0][modname] = fileordir
+
_toload[1].append((fileordir, modname))
+
return _toload
@@ -404,7 +420,7 @@ class RegistrableObject(object):
__abstract__ = True # see doc snipppets below (in Registry class)
@classproperty
- def __registries__(cls):
+ def __registries__(cls) -> Union[Tuple[str], Tuple]:
if cls.__registry__ is None:
return ()
return (cls.__registry__,)
@@ -432,10 +448,17 @@ class RegistrableInstance(RegistrableObject):
obj.__module__ = module
return obj
- def __init__(self, __module__=None):
+ def __init__(self, __module__: Optional[str] = None) -> None:
super(RegistrableInstance, self).__init__()
+SelectBestReport = TypedDict("SelectBestReport", {"all_objects": List, "end_score": int,
+ "winners": List,
+ "winner": Optional[Any], "self": 'Registry',
+ "args": List, "kwargs": Dict,
+ "registry": 'Registry'})
+
+
class Registry(dict):
"""The registry store a set of implementations associated to identifier:
@@ -469,10 +492,10 @@ class Registry(dict):
.. automethod:: possible_objects
.. automethod:: object_by_id
"""
- def __init__(self, debugmode):
+ def __init__(self, debugmode: bool) -> None:
super(Registry, self).__init__()
self.debugmode = debugmode
- self._select_listeners = []
+ self._select_listeners: List[Callable[[SelectBestReport], None]] = []
def __getitem__(self, name):
"""return the registry (list of implementation objects) associated to
@@ -486,16 +509,16 @@ class Registry(dict):
raise exc
@classmethod
- def objid(cls, obj):
+ def objid(cls, obj: Any) -> str:
"""returns a unique identifier for an object stored in the registry"""
return '%s.%s' % (obj.__module__, cls.objname(obj))
@classmethod
- def objname(cls, obj):
+ def objname(cls, obj: Any) -> str:
"""returns a readable name for an object stored in the registry"""
return getattr(obj, '__name__', id(obj))
- def initialization_completed(self):
+ def initialization_completed(self) -> None:
"""call method __registered__() on registered objects when the callback
is defined"""
for objects in self.values():
@@ -506,7 +529,7 @@ class Registry(dict):
if self.debugmode:
wrap_predicates(_lltrace)
- def register(self, obj, oid=None, clear=False):
+ def register(self, obj: Any, oid: Optional[Any] = None, clear: bool = False) -> None:
"""base method to add an object in the registry"""
assert not '__abstract__' in obj.__dict__, obj
assert obj.__select__, obj
@@ -632,15 +655,16 @@ class Registry(dict):
once the selection is done and will recieve a dict of the following form::
{"all_objects": [], "end_score": 0, "winners": [], "winner": None or winner,
- "self": self, "args": args, "kwargs": kwargs, }
+ "self": self, "args": args, "kwargs": kwargs, "registry": self}
"""
if self._select_listeners:
- select_best_report = {
+ select_best_report: SelectBestReport = {
"registry": self,
"all_objects": [],
"end_score": 0,
"winners": [],
+ "winner": None,
"self": self,
"args": args,
"kwargs": kwargs,
@@ -692,7 +716,7 @@ class Registry(dict):
info = warning = error = critical = exception = debug = lambda msg, *a, **kw: None
-def obj_registries(cls, registryname=None):
+def obj_registries(cls: Any, registryname: Optional[Any] = None) -> Tuple[str]:
"""return a tuple of registry names (see __registries__)"""
if registryname:
return (registryname,)
@@ -817,18 +841,18 @@ class RegistryStore(dict):
key will be the class used when there is no specific class for a name.
"""
- def __init__(self, debugmode=False):
+ def __init__(self, debugmode: bool = False) -> None:
super(RegistryStore, self).__init__()
self.debugmode = debugmode
- def reset(self):
+ def reset(self) -> None:
"""clear all registries managed by this store"""
# don't use self.clear, we want to keep existing subdictionaries
for subdict in self.values():
subdict.clear()
- self._lastmodifs = {}
+ self._lastmodifs: Dict[str, int] = {}
- def __getitem__(self, name):
+ def __getitem__(self, name: str) -> Registry:
"""return the registry (dictionary of class objects) associated to
this name
"""
@@ -842,9 +866,9 @@ class RegistryStore(dict):
# methods for explicit (un)registration ###################################
# default class, when no specific class set
- REGISTRY_FACTORY = {None: Registry}
+ REGISTRY_FACTORY: Dict[Union[None, str], type] = {None: Registry}
- def registry_class(self, regid):
+ def registry_class(self, regid: str) -> type:
"""return existing registry named regid or use factory to create one and
return it"""
try:
@@ -852,14 +876,16 @@ class RegistryStore(dict):
except KeyError:
return self.REGISTRY_FACTORY[None]
- def setdefault(self, regid):
+ # mypy: Signature of "setdefault" incompatible with supertype "MutableMapping""
+ # I don't know how to overload signatures of method in mypy
+ def setdefault(self, regid: str) -> Registry: # type: ignore
try:
return self[regid]
except RegistryNotFound:
self[regid] = self.registry_class(regid)(self.debugmode)
return self[regid]
- def register_all(self, objects, modname, butclasses=()):
+ def register_all(self, objects: Iterable, modname: str, butclasses: Sequence = ()) -> None:
"""register registrable objects into `objects`.
Registrable objects are properly configured subclasses of
@@ -886,7 +912,7 @@ class RegistryStore(dict):
else:
self.register(obj)
- def register(self, obj, registryname=None, oid=None, clear=False):
+ def register(self, obj: Any, registryname: Optional[Any] = None, oid: Optional[Any] = None, clear: bool = False) -> None:
"""register `obj` implementation into `registryname` or
`obj.__registries__` if not specified, with identifier `oid` or
`obj.__regid__` if not specified.
@@ -927,7 +953,7 @@ class RegistryStore(dict):
# initialization methods ###################################################
- def init_registration(self, path, extrapath=None):
+ def init_registration(self, path: List[str], extrapath: Optional[Any] = None) -> List[Tuple[str, str]]:
"""reset registry and walk down path to return list of (path, name)
file modules to be loaded"""
# XXX make this private by renaming it to _init_registration ?
@@ -937,11 +963,11 @@ class RegistryStore(dict):
# XXX is _loadedmods still necessary ? It seems like it's useful
# to avoid loading same module twice, especially with the
# _load_ancestors_then_object logic but this needs to be checked
- self._loadedmods = {}
+ self._loadedmods: Dict[str, Dict[str, type]] = {}
return filemods
@deprecated('use register_modnames() instead')
- def register_objects(self, path, extrapath=None):
+ def register_objects(self, path: List[str], extrapath: Optional[Any] = None) -> None:
"""register all objects found walking down <path>"""
# load views from each directory in the instance's path
# XXX inline init_registration ?
@@ -950,14 +976,18 @@ class RegistryStore(dict):
self.load_file(filepath, modname)
self.initialization_completed()
- def register_modnames(self, modnames):
+ def register_modnames(self, modnames: List[str]) -> None:
"""register all objects found in <modnames>"""
self.reset()
self._loadedmods = {}
self._toloadmods = {}
toload = []
for modname in modnames:
- filepath = pkgutil.find_loader(modname).get_filename()
+ loader = pkgutil.find_loader(modname)
+ assert loader is not None
+ # mypy: "Loader" has no attribute "get_filename"
+ # the selected class has one
+ filepath = loader.get_filename() # type: ignore
if filepath[-4:] in ('.pyc', '.pyo'):
# The source file *must* exists
filepath = filepath[:-1]
@@ -967,12 +997,12 @@ class RegistryStore(dict):
self.load_file(filepath, modname)
self.initialization_completed()
- def initialization_completed(self):
+ def initialization_completed(self) -> None:
"""call initialization_completed() on all known registries"""
for reg in self.values():
reg.initialization_completed()
- def _mdate(self, filepath):
+ def _mdate(self, filepath: str) -> Optional[int]:
""" return the modification date of a file path """
try:
return stat(filepath)[-2]
@@ -1004,7 +1034,7 @@ class RegistryStore(dict):
return True
return False
- def load_file(self, filepath, modname):
+ def load_file(self, filepath: str, modname: str) -> None:
""" load registrable objects (if any) from a python file """
if modname in self._loadedmods:
return
@@ -1025,7 +1055,7 @@ class RegistryStore(dict):
module = __import__(modname, fromlist=modname.split('.')[:-1])
self.load_module(module)
- def load_module(self, module):
+ def load_module(self, module: ModuleType) -> None:
"""Automatically handle module objects registration.
Instances are registered as soon as they are hashable and have the
@@ -1046,11 +1076,13 @@ class RegistryStore(dict):
"""
self.info('loading %s from %s', module.__name__, module.__file__)
if hasattr(module, 'registration_callback'):
- module.registration_callback(self)
+ # mypy: Module has no attribute "registration_callback"
+ # we check that before
+ module.registration_callback(self) # type: ignore
else:
self.register_all(vars(module).values(), module.__name__)
- def _load_ancestors_then_object(self, modname, objectcls, butclasses=()):
+ def _load_ancestors_then_object(self, modname: str, objectcls: type, butclasses: Sequence[Any] = ()) -> None:
"""handle class registration according to rules defined in
:meth:`load_module`
"""
@@ -1092,7 +1124,7 @@ class RegistryStore(dict):
self.register(objectcls)
@classmethod
- def is_registrable(cls, obj):
+ def is_registrable(cls, obj: Any) -> bool:
"""ensure `obj` should be registered
as arbitrary stuff may be registered, do a lot of check and warn about
@@ -1107,25 +1139,33 @@ class RegistryStore(dict):
return False
elif issubclass(obj, RegistrableInstance):
return False
+
elif not isinstance(obj, RegistrableInstance):
return False
+
if not obj.__regid__:
return False # no regid
+
registries = obj.__registries__
if not registries:
return False # no registries
+
selector = obj.__select__
if not selector:
return False # no selector
+
if obj.__dict__.get('__abstract__', False):
return False
+
# then detect potential problems that should be warned
if not isinstance(registries, (tuple, list)):
cls.warning('%s has __registries__ which is not a list or tuple', obj)
return False
+
if not callable(selector):
cls.warning('%s has not callable __select__', obj)
return False
+
return True
# these are overridden by set_log_methods below
diff --git a/logilab/common/shellutils.py b/logilab/common/shellutils.py
index d03ae16..2764723 100644
--- a/logilab/common/shellutils.py
+++ b/logilab/common/shellutils.py
@@ -37,6 +37,8 @@ import random
import subprocess
import warnings
from os.path import exists, isdir, islink, basename, join
+from _io import StringIO
+from typing import Any, Callable, Optional, List, Union, Iterator, Tuple
from logilab.common import STD_BLACKLIST, _handle_blacklist
from logilab.common.compat import str_to_bytes
@@ -130,7 +132,7 @@ def cp(source, destination):
"""
mv(source, destination, _action=shutil.copy)
-def find(directory, exts, exclude=False, blacklist=STD_BLACKLIST):
+def find(directory: str, exts: Union[Tuple[str, ...], str], exclude: bool = False, blacklist: Tuple[str, ...] = STD_BLACKLIST) -> List[str]:
"""Recursively find files ending with the given extensions from the directory.
:type directory: str
@@ -158,13 +160,13 @@ def find(directory, exts, exclude=False, blacklist=STD_BLACKLIST):
if isinstance(exts, str):
exts = (exts,)
if exclude:
- def match(filename, exts):
+ def match(filename: str, exts: Tuple[str, ...]) -> bool:
for ext in exts:
if filename.endswith(ext):
return False
return True
else:
- def match(filename, exts):
+ def match(filename: str, exts: Tuple[str, ...]) -> bool:
for ext in exts:
if filename.endswith(ext):
return True
@@ -180,7 +182,7 @@ def find(directory, exts, exclude=False, blacklist=STD_BLACKLIST):
return files
-def globfind(directory, pattern, blacklist=STD_BLACKLIST):
+def globfind(directory: str, pattern: str, blacklist: Tuple[str, str, str, str, str, str, str, str] = STD_BLACKLIST) -> Iterator[str]:
"""Recursively finds files matching glob `pattern` under `directory`.
This is an alternative to `logilab.common.shellutils.find`.
@@ -236,7 +238,7 @@ class Execute:
class ProgressBar(object):
"""A simple text progression bar."""
- def __init__(self, nbops, size=20, stream=sys.stdout, title=''):
+ def __init__(self, nbops: int, size: int = 20, stream: StringIO = sys.stdout, title: str = '') -> None:
if title:
self._fstr = '\r%s [%%-%ss]' % (title, int(size))
else:
@@ -262,7 +264,7 @@ class ProgressBar(object):
text = property(_get_text, _set_text, _del_text)
- def update(self, offset=1, exact=False):
+ def update(self, offset: int = 1, exact: bool = False) -> None:
"""Move FORWARD to new cursor position (cursor will never go backward).
:offset: fraction of ``size``
@@ -283,7 +285,7 @@ class ProgressBar(object):
self._progress = progress
self.refresh()
- def refresh(self):
+ def refresh(self) -> None:
"""Refresh the progression bar display."""
self._stream.write(self._fstr % ('=' * min(self._progress, self._size)) )
if self._last_text_write_size or self._current_text:
@@ -338,7 +340,7 @@ class progress(object):
class RawInput(object):
- def __init__(self, input_function=None, printer=None, **kwargs):
+ def __init__(self, input_function: Optional[Callable] = None, printer: Optional[Callable] = None, **kwargs: Any) -> None:
if 'input' in kwargs:
input_function = kwargs.pop('input')
warnings.warn(
@@ -349,7 +351,7 @@ class RawInput(object):
self._input = input_function or input
self._print = printer
- def ask(self, question, options, default):
+ def ask(self, question: str, options: Tuple[str, ...], default: str) -> str:
assert default in options
choices = []
for option in options:
@@ -383,7 +385,7 @@ class RawInput(object):
tries -= 1
raise Exception('unable to get a sensible answer')
- def confirm(self, question, default_is_yes=True):
+ def confirm(self, question: str, default_is_yes: bool = True) -> bool:
default = default_is_yes and 'y' or 'n'
answer = self.ask(question, ('y', 'n'), default)
return answer == 'y'
diff --git a/logilab/common/table.py b/logilab/common/table.py
index 1f1101c..e7b9195 100644
--- a/logilab/common/table.py
+++ b/logilab/common/table.py
@@ -18,6 +18,10 @@
"""Table management module."""
from __future__ import print_function
+from types import CodeType
+from typing import Any, List, Optional, Tuple, Union, Dict, Iterator
+from _io import StringIO
+from mypy_extensions import NoReturn
__docformat__ = "restructuredtext en"
@@ -30,51 +34,64 @@ class Table(object):
forall(self.data, lambda x: len(x) <= len(self.col_names))
"""
- def __init__(self, default_value=0, col_names=None, row_names=None):
- self.col_names = []
- self.row_names = []
- self.data = []
- self.default_value = default_value
+ def __init__(self, default_value: int = 0, col_names: Optional[List[str]] = None, row_names: Optional[Any] = None) -> None:
+ self.col_names: List = []
+ self.row_names: List = []
+ self.data: List = []
+ self.default_value: int = default_value
if col_names:
self.create_columns(col_names)
if row_names:
self.create_rows(row_names)
- def _next_row_name(self):
+ def _next_row_name(self) -> str:
return 'row%s' % (len(self.row_names)+1)
- def __iter__(self):
+ def __iter__(self) -> Iterator:
return iter(self.data)
- def __eq__(self, other):
+ # def __eq__(self, other: Union[List[List[int]], List[Tuple[str, str, str, float]]]) -> bool:
+ def __eq__(self, other: object) -> bool:
+ def is_iterable(variable: Any) -> bool:
+ try:
+ iter(variable)
+ except TypeError:
+ return False
+ else:
+ return True
+
if other is None:
return False
+ elif is_iterable(other):
+ # mypy: No overload variant of "list" matches argument type "object"
+ # checked before
+ return list(self) == list(other) # type: ignore
else:
- return list(self) == list(other)
+ return False
__hash__ = object.__hash__
def __ne__(self, other):
return not self == other
- def __len__(self):
+ def __len__(self) -> int:
return len(self.row_names)
## Rows / Columns creation #################################################
- def create_rows(self, row_names):
+ def create_rows(self, row_names: List[str]) -> None:
"""Appends row_names to the list of existing rows
"""
self.row_names.extend(row_names)
for row_name in row_names:
self.data.append([self.default_value]*len(self.col_names))
- def create_columns(self, col_names):
+ def create_columns(self, col_names: List[str]) -> None:
"""Appends col_names to the list of existing columns
"""
for col_name in col_names:
self.create_column(col_name)
- def create_row(self, row_name=None):
+ def create_row(self, row_name: str = None) -> None:
"""Creates a rowname to the row_names list
"""
row_name = row_name or self._next_row_name()
@@ -82,7 +99,7 @@ class Table(object):
self.data.append([self.default_value]*len(self.col_names))
- def create_column(self, col_name):
+ def create_column(self, col_name: str) -> None:
"""Creates a colname to the col_names list
"""
self.col_names.append(col_name)
@@ -90,7 +107,7 @@ class Table(object):
row.append(self.default_value)
## Sort by column ##########################################################
- def sort_by_column_id(self, col_id, method = 'asc'):
+ def sort_by_column_id(self, col_id: str, method: str = 'asc') -> None:
"""Sorts the table (in-place) according to data stored in col_id
"""
try:
@@ -100,7 +117,7 @@ class Table(object):
raise KeyError("Col (%s) not found in table" % (col_id))
- def sort_by_column_index(self, col_index, method = 'asc'):
+ def sort_by_column_index(self, col_index: int, method: str = 'asc') -> None:
"""Sorts the table 'in-place' according to data stored in col_index
method should be in ('asc', 'desc')
@@ -119,29 +136,33 @@ class Table(object):
self.data.append(row)
self.row_names.append(row_name)
- def groupby(self, colname, *others):
+ def groupby(self, colname: str, *others: str) -> Union[Dict[str, Dict[str, 'Table']],
+ Dict[str, 'Table']]:
"""builds indexes of data
:returns: nested dictionaries pointing to actual rows
"""
- groups = {}
+ groups: Dict = {}
colnames = (colname,) + others
col_indexes = [self.col_names.index(col_id) for col_id in colnames]
for row in self.data:
ptr = groups
for col_index in col_indexes[:-1]:
ptr = ptr.setdefault(row[col_index], {})
- ptr = ptr.setdefault(row[col_indexes[-1]],
- Table(default_value=self.default_value,
- col_names=self.col_names))
- ptr.append_row(tuple(row))
+ table = ptr.setdefault(row[col_indexes[-1]],
+ Table(default_value=self.default_value,
+ col_names=self.col_names))
+ table.append_row(tuple(row))
return groups
- def select(self, colname, value):
+ def select(self, colname: str, value: str) -> 'Table':
grouped = self.groupby(colname)
try:
- return grouped[value]
+ # mypy: Incompatible return value type (got "Union[Dict[str, Table], Table]",
+ # mypy: expected "Table")
+ # I guess we are sure we'll get a Table here?
+ return grouped[value] # type: ignore
except KeyError:
- return []
+ return Table()
def remove(self, colname, value):
col_index = self.col_names.index(colname)
@@ -151,13 +172,13 @@ class Table(object):
## The 'setter' part #######################################################
- def set_cell(self, row_index, col_index, data):
+ def set_cell(self, row_index: int, col_index: int, data: int) -> None:
"""sets value of cell 'row_indew', 'col_index' to data
"""
self.data[row_index][col_index] = data
- def set_cell_by_ids(self, row_id, col_id, data):
+ def set_cell_by_ids(self, row_id: str, col_id: str, data: Union[int, str]) -> None:
"""sets value of cell mapped by row_id and col_id to data
Raises a KeyError if row_id or col_id are not found in the table
"""
@@ -173,7 +194,7 @@ class Table(object):
raise KeyError("Column (%s) not found in table" % (col_id))
- def set_row(self, row_index, row_data):
+ def set_row(self, row_index: int, row_data: Union[List[float], List[int], List[str]]) -> None:
"""sets the 'row_index' row
pre::
@@ -183,7 +204,7 @@ class Table(object):
self.data[row_index] = row_data
- def set_row_by_id(self, row_id, row_data):
+ def set_row_by_id(self, row_id: str, row_data: List[str]) -> None:
"""sets the 'row_id' column
pre::
@@ -199,7 +220,7 @@ class Table(object):
raise KeyError('Row (%s) not found in table' % (row_id))
- def append_row(self, row_data, row_name=None):
+ def append_row(self, row_data: Union[List[Union[float, str]], List[int]], row_name: Optional[str] = None) -> int:
"""Appends a row to the table
pre::
@@ -211,7 +232,7 @@ class Table(object):
self.data.append(row_data)
return len(self.data) - 1
- def insert_row(self, index, row_data, row_name=None):
+ def insert_row(self, index: int, row_data: List[str], row_name: str = None) -> None:
"""Appends row_data before 'index' in the table. To make 'insert'
behave like 'list.insert', inserting in an out of range index will
insert row_data to the end of the list
@@ -225,7 +246,7 @@ class Table(object):
self.data.insert(index, row_data)
- def delete_row(self, index):
+ def delete_row(self, index: int) -> List[str]:
"""Deletes the 'index' row in the table, and returns it.
Raises an IndexError if index is out of range
"""
@@ -233,7 +254,7 @@ class Table(object):
return self.data.pop(index)
- def delete_row_by_id(self, row_id):
+ def delete_row_by_id(self, row_id: str) -> None:
"""Deletes the 'row_id' row in the table.
Raises a KeyError if row_id was not found.
"""
@@ -244,7 +265,7 @@ class Table(object):
raise KeyError('Row (%s) not found in table' % (row_id))
- def set_column(self, col_index, col_data):
+ def set_column(self, col_index: int, col_data: Union[List[int], range]) -> None:
"""sets the 'col_index' column
pre::
@@ -256,7 +277,7 @@ class Table(object):
self.data[row_index][col_index] = cell_data
- def set_column_by_id(self, col_id, col_data):
+ def set_column_by_id(self, col_id: str, col_data: Union[List[int], range]) -> None:
"""sets the 'col_id' column
pre::
@@ -272,7 +293,7 @@ class Table(object):
raise KeyError('Column (%s) not found in table' % (col_id))
- def append_column(self, col_data, col_name):
+ def append_column(self, col_data: range, col_name: str) -> None:
"""Appends the 'col_index' column
pre::
@@ -284,7 +305,7 @@ class Table(object):
self.data[row_index].append(cell_data)
- def insert_column(self, index, col_data, col_name):
+ def insert_column(self, index: int, col_data: range, col_name: str) -> None:
"""Appends col_data before 'index' in the table. To make 'insert'
behave like 'list.insert', inserting in an out of range index will
insert col_data to the end of the list
@@ -298,7 +319,7 @@ class Table(object):
self.data[row_index].insert(index, cell_data)
- def delete_column(self, index):
+ def delete_column(self, index: int) -> List[int]:
"""Deletes the 'index' column in the table, and returns it.
Raises an IndexError if index is out of range
"""
@@ -306,7 +327,7 @@ class Table(object):
return [row.pop(index) for row in self.data]
- def delete_column_by_id(self, col_id):
+ def delete_column_by_id(self, col_id: str) -> None:
"""Deletes the 'col_id' col in the table.
Raises a KeyError if col_id was not found.
"""
@@ -319,53 +340,68 @@ class Table(object):
## The 'getter' part #######################################################
- def get_shape(self):
+ def get_shape(self) -> Tuple[int, int]:
"""Returns a tuple which represents the table's shape
"""
return len(self.row_names), len(self.col_names)
shape = property(get_shape)
- def __getitem__(self, indices):
+ def __getitem__(self, indices: Union[Tuple[Union[int, slice, str], Union[int, str]], int, slice]) -> Any:
"""provided for convenience"""
- rows, multirows = None, False
- cols, multicols = None, False
+ multirows: bool = False
+ multicols: bool = False
+
+ rows: slice
+ cols: slice
+
+ rows_indice: Union[int, slice, str]
+ cols_indice: Union[int, str, None] = None
+
if isinstance(indices, tuple):
- rows = indices[0]
+ rows_indice = indices[0]
if len(indices) > 1:
- cols = indices[1]
+ cols_indice = indices[1]
else:
- rows = indices
+ rows_indice = indices
+
# define row slice
- if isinstance(rows, str):
+ if isinstance(rows_indice, str):
try:
- rows = self.row_names.index(rows)
+ rows_indice = self.row_names.index(rows_indice)
except ValueError:
- raise KeyError("Row (%s) not found in table" % (rows))
- if isinstance(rows, int):
- rows = slice(rows, rows+1)
+ raise KeyError("Row (%s) not found in table" % (rows_indice))
+
+ if isinstance(rows_indice, int):
+ rows = slice(rows_indice, rows_indice + 1)
multirows = False
else:
rows = slice(None)
multirows = True
+
# define col slice
- if isinstance(cols, str):
+ if isinstance(cols_indice, str):
try:
- cols = self.col_names.index(cols)
+ cols_indice = self.col_names.index(cols_indice)
except ValueError:
- raise KeyError("Column (%s) not found in table" % (cols))
- if isinstance(cols, int):
- cols = slice(cols, cols+1)
+ raise KeyError("Column (%s) not found in table" % (cols_indice))
+
+ if isinstance(cols_indice, int):
+ cols = slice(cols_indice, cols_indice + 1)
multicols = False
else:
cols = slice(None)
multicols = True
+
# get sub-table
tab = Table()
tab.default_value = self.default_value
+
tab.create_rows(self.row_names[rows])
tab.create_columns(self.col_names[cols])
+
for idx, row in enumerate(self.data[rows]):
tab.set_row(idx, row[cols])
+
if multirows :
if multicols:
return tab
@@ -409,7 +445,7 @@ class Table(object):
raise KeyError("Column (%s) not found in table" % (col_id))
return self.get_column(col_index, distinct)
- def get_columns(self):
+ def get_columns(self) -> List[List[int]]:
"""Returns all the columns in the table
"""
return [self[:, index] for index in range(len(self.col_names))]
@@ -421,14 +457,14 @@ class Table(object):
col = list(set(col))
return col
- def apply_stylesheet(self, stylesheet):
+ def apply_stylesheet(self, stylesheet: 'TableStyleSheet') -> None:
"""Applies the stylesheet to this table
"""
for instruction in stylesheet.instructions:
eval(instruction)
- def transpose(self):
+ def transpose(self) -> 'Table':
"""Keeps the self object intact, and returns the transposed (rotated)
table.
"""
@@ -440,7 +476,7 @@ class Table(object):
return transposed
- def pprint(self):
+ def pprint(self) -> str:
"""returns a string representing the table in a pretty
printed 'text' format.
"""
@@ -482,7 +518,7 @@ class Table(object):
return '\n'.join(lines)
- def __repr__(self):
+ def __repr__(self) -> str:
return repr(self.data)
def as_text(self):
@@ -499,7 +535,7 @@ class TableStyle:
"""Defines a table's style
"""
- def __init__(self, table):
+ def __init__(self, table: Table) -> None:
self._table = table
self.size = dict([(col_name, '1*') for col_name in table.col_names])
@@ -516,12 +552,12 @@ class TableStyle:
self.units['__row_column__'] = ''
# XXX FIXME : params order should be reversed for all set() methods
- def set_size(self, value, col_id):
+ def set_size(self, value: str, col_id: str) -> None:
"""sets the size of the specified col_id to value
"""
self.size[col_id] = value
- def set_size_by_index(self, value, col_index):
+ def set_size_by_index(self, value: str, col_index: int) -> None:
"""Allows to set the size according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
@@ -534,13 +570,13 @@ class TableStyle:
self.size[col_id] = value
- def set_alignment(self, value, col_id):
+ def set_alignment(self, value: str, col_id: str) -> None:
"""sets the alignment of the specified col_id to value
"""
self.alignment[col_id] = value
- def set_alignment_by_index(self, value, col_index):
+ def set_alignment_by_index(self, value: str, col_index: int) -> None:
"""Allows to set the alignment according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
@@ -553,13 +589,13 @@ class TableStyle:
self.alignment[col_id] = value
- def set_unit(self, value, col_id):
+ def set_unit(self, value: str, col_id: str) -> None:
"""sets the unit of the specified col_id to value
"""
self.units[col_id] = value
- def set_unit_by_index(self, value, col_index):
+ def set_unit_by_index(self, value: str, col_index: int) -> None:
"""Allows to set the unit according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
@@ -574,13 +610,13 @@ class TableStyle:
self.units[col_id] = value
- def get_size(self, col_id):
+ def get_size(self, col_id: str) -> str:
"""Returns the size of the specified col_id
"""
return self.size[col_id]
- def get_size_by_index(self, col_index):
+ def get_size_by_index(self, col_index: int) -> str:
"""Allows to get the size according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
@@ -593,13 +629,13 @@ class TableStyle:
return self.size[col_id]
- def get_alignment(self, col_id):
+ def get_alignment(self, col_id: str) -> str:
"""Returns the alignment of the specified col_id
"""
return self.alignment[col_id]
- def get_alignment_by_index(self, col_index):
+ def get_alignment_by_index(self, col_index: int) -> str:
"""Allors to get the alignment according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
@@ -612,13 +648,13 @@ class TableStyle:
return self.alignment[col_id]
- def get_unit(self, col_id):
+ def get_unit(self, col_id: str) -> str:
"""Returns the unit of the specified col_id
"""
return self.units[col_id]
- def get_unit_by_index(self, col_index):
+ def get_unit_by_index(self, col_index: int) -> str:
"""Allors to get the unit according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
@@ -649,28 +685,30 @@ class TableStyleSheet:
2_5 = sqrt(2_3**2 + 2_4**2)
"""
- def __init__(self, rules = None):
+ def __init__(self, rules: Optional[List[str]] = None) -> None:
rules = rules or []
- self.rules = []
- self.instructions = []
+
+ self.rules: List[str] = []
+ self.instructions: List[CodeType] = []
+
for rule in rules:
self.add_rule(rule)
- def add_rule(self, rule):
+ def add_rule(self, rule: str) -> None:
"""Adds a rule to the stylesheet rules
"""
try:
source_code = ['from math import *']
source_code.append(CELL_PROG.sub(r'self.data[\1][\2]', rule))
self.instructions.append(compile('\n'.join(source_code),
- 'table.py', 'exec'))
+ 'table.py', 'exec'))
self.rules.append(rule)
except SyntaxError:
print("Bad Stylesheet Rule : %s [skipped]" % rule)
- def add_rowsum_rule(self, dest_cell, row_index, start_col, end_col):
+ def add_rowsum_rule(self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int) -> None:
"""Creates and adds a rule to sum over the row at row_index from
start_col to end_col.
dest_cell is a tuple of two elements (x,y) of the destination cell
@@ -686,7 +724,7 @@ class TableStyleSheet:
self.add_rule(rule)
- def add_rowavg_rule(self, dest_cell, row_index, start_col, end_col):
+ def add_rowavg_rule(self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int) -> None:
"""Creates and adds a rule to make the row average (from start_col
to end_col)
dest_cell is a tuple of two elements (x,y) of the destination cell
@@ -703,7 +741,7 @@ class TableStyleSheet:
self.add_rule(rule)
- def add_colsum_rule(self, dest_cell, col_index, start_row, end_row):
+ def add_colsum_rule(self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int) -> None:
"""Creates and adds a rule to sum over the col at col_index from
start_row to end_row.
dest_cell is a tuple of two elements (x,y) of the destination cell
@@ -719,7 +757,7 @@ class TableStyleSheet:
self.add_rule(rule)
- def add_colavg_rule(self, dest_cell, col_index, start_row, end_row):
+ def add_colavg_rule(self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int) -> None:
"""Creates and adds a rule to make the col average (from start_row
to end_row)
dest_cell is a tuple of two elements (x,y) of the destination cell
@@ -741,7 +779,7 @@ class TableCellRenderer:
"""Defines a simple text renderer
"""
- def __init__(self, **properties):
+ def __init__(self, **properties: Any) -> None:
"""keywords should be properties with an associated boolean as value.
For example :
renderer = TableCellRenderer(units = True, alignment = False)
@@ -752,7 +790,7 @@ class TableCellRenderer:
self.properties = properties
- def render_cell(self, cell_coord, table, table_style):
+ def render_cell(self, cell_coord: Tuple[int, int], table: Table, table_style: TableStyle) -> Union[str, int]:
"""Renders the cell at 'cell_coord' in the table, using table_style
"""
row_index, col_index = cell_coord
@@ -763,14 +801,14 @@ class TableCellRenderer:
table_style, col_index + 1)
- def render_row_cell(self, row_name, table, table_style):
+ def render_row_cell(self, row_name: str, table: Table, table_style: TableStyle) -> Union[str, int]:
"""Renders the cell for 'row_id' row
"""
cell_value = row_name
return self._render_cell_content(cell_value, table_style, 0)
- def render_col_cell(self, col_name, table, table_style):
+ def render_col_cell(self, col_name: str, table: Table, table_style: TableStyle) -> Union[str, int]:
"""Renders the cell for 'col_id' row
"""
cell_value = col_name
@@ -779,7 +817,7 @@ class TableCellRenderer:
- def _render_cell_content(self, content, table_style, col_index):
+ def _render_cell_content(self, content: Union[str, int], table_style: TableStyle, col_index: int) -> Union[str, int]:
"""Makes the appropriate rendering for this cell content.
Rendering properties will be searched using the
*table_style.get_xxx_by_index(col_index)' methods
@@ -789,11 +827,12 @@ class TableCellRenderer:
return content
- def _make_cell_content(self, cell_content, table_style, col_index):
+ def _make_cell_content(self, cell_content: int, table_style: TableStyle, col_index: int) -> Union[int, str]:
"""Makes the cell content (adds decoration data, like units for
example)
"""
- final_content = cell_content
+ final_content: Union[int, str] = cell_content
+
if 'skip_zero' in self.properties:
replacement_char = self.properties['skip_zero']
else:
@@ -812,7 +851,7 @@ class TableCellRenderer:
return final_content
- def _add_unit(self, cell_content, table_style, col_index):
+ def _add_unit(self, cell_content: int, table_style: TableStyle, col_index: int) -> str:
"""Adds unit to the cell_content if needed
"""
unit = table_style.get_unit_by_index(col_index)
@@ -824,7 +863,7 @@ class DocbookRenderer(TableCellRenderer):
"""Defines how to render a cell for a docboook table
"""
- def define_col_header(self, col_index, table_style):
+ def define_col_header(self, col_index: int, table_style: TableStyle) -> str:
"""Computes the colspec element according to the style
"""
size = table_style.get_size_by_index(col_index)
@@ -832,7 +871,7 @@ class DocbookRenderer(TableCellRenderer):
(col_index, size)
- def _render_cell_content(self, cell_content, table_style, col_index):
+ def _render_cell_content(self, cell_content: Union[int, str], table_style: TableStyle, col_index: int) -> str:
"""Makes the appropriate rendering for this cell content.
Rendering properties will be searched using the
table_style.get_xxx_by_index(col_index)' methods.
@@ -847,17 +886,20 @@ class DocbookRenderer(TableCellRenderer):
# KeyError <=> Default alignment
return "<entry>%s</entry>\n" % cell_content
+ # XXX really?
+ return ""
+
class TableWriter:
"""A class to write tables
"""
- def __init__(self, stream, table, style, **properties):
+ def __init__(self, stream: StringIO, table: Table, style: Optional[Any], **properties: Any) -> None:
self._stream = stream
self.style = style or TableStyle(table)
self._table = table
self.properties = properties
- self.renderer = None
+ self.renderer: Optional[DocbookRenderer] = None
def set_style(self, style):
@@ -866,7 +908,7 @@ class TableWriter:
self.style = style
- def set_renderer(self, renderer):
+ def set_renderer(self, renderer: DocbookRenderer) -> None:
"""sets the way to render cell
"""
self.renderer = renderer
@@ -878,7 +920,7 @@ class TableWriter:
self.properties.update(properties)
- def write_table(self, title = ""):
+ def write_table(self, title: str = "") -> None:
"""Writes the table
"""
raise NotImplementedError("write_table must be implemented !")
@@ -889,9 +931,11 @@ class DocbookTableWriter(TableWriter):
"""Defines an implementation of TableWriter to write a table in Docbook
"""
- def _write_headers(self):
+ def _write_headers(self) -> None:
"""Writes col headers
"""
+ assert self.renderer is not None
+
# Define col_headers (colstpec elements)
for col_index in range(len(self._table.col_names)+1):
self._stream.write(self.renderer.define_col_header(col_index,
@@ -908,9 +952,11 @@ class DocbookTableWriter(TableWriter):
self._stream.write("</row>\n</thead>\n")
- def _write_body(self):
+ def _write_body(self) -> None:
"""Writes the table body
"""
+ assert self.renderer is not None
+
self._stream.write('<tbody>\n')
for row_index, row in enumerate(self._table.data):
@@ -931,7 +977,7 @@ class DocbookTableWriter(TableWriter):
self._stream.write('</tbody>\n')
- def write_table(self, title = ""):
+ def write_table(self, title: str = "") -> None:
"""Writes the table
"""
self._stream.write('<table>\n<title>%s></title>\n'%(title))
diff --git a/logilab/common/tasksqueue.py b/logilab/common/tasksqueue.py
index 7d561ca..4e3434e 100644
--- a/logilab/common/tasksqueue.py
+++ b/logilab/common/tasksqueue.py
@@ -19,8 +19,9 @@
__docformat__ = "restructuredtext en"
-from bisect import insort_left
+from typing import Iterator, List
+from bisect import insort_left
import queue
LOW = 0
@@ -35,16 +36,15 @@ PRIORITY = {
REVERSE_PRIORITY = dict((values, key) for key, values in PRIORITY.items())
-
class PrioritizedTasksQueue(queue.Queue):
- def _init(self, maxsize):
+ def _init(self, maxsize: int) -> None:
"""Initialize the queue representation"""
self.maxsize = maxsize
# ordered list of task, from the lowest to the highest priority
- self.queue = []
+ self.queue: List['Task'] = [] # type: ignore
- def _put(self, item):
+ def _put(self, item: 'Task') -> None:
"""Put a new item in the queue"""
for i, task in enumerate(self.queue):
# equivalent task
@@ -60,14 +60,14 @@ class PrioritizedTasksQueue(queue.Queue):
return
insort_left(self.queue, item)
- def _get(self):
+ def _get(self) -> 'Task':
"""Get an item from the queue"""
return self.queue.pop()
- def __iter__(self):
+ def __iter__(self) -> Iterator['Task']:
return iter(self.queue)
- def remove(self, tid):
+ def remove(self, tid: str) -> None:
"""remove a specific task from the queue"""
# XXX acquire lock
for i, task in enumerate(self):
@@ -76,26 +76,23 @@ class PrioritizedTasksQueue(queue.Queue):
return
raise ValueError('not task of id %s in queue' % tid)
-class Task(object):
- def __init__(self, tid, priority=LOW):
+class Task:
+ def __init__(self, tid: str, priority: int = LOW) -> None:
# task id
self.id = tid
# task priority
self.priority = priority
- def __repr__(self):
+ def __repr__(self) -> str:
return '<Task %s @%#x>' % (self.id, id(self))
- def __cmp__(self, other):
- return cmp(self.priority, other.priority)
-
- def __lt__(self, other):
+ def __lt__(self, other: 'Task') -> bool:
return self.priority < other.priority
- def __eq__(self, other):
- return self.id == other.id
+ def __eq__(self, other: object) -> bool:
+ return isinstance(other, type(self)) and self.id == other.id
__hash__ = object.__hash__
- def merge(self, other):
+ def merge(self, other: 'Task') -> None:
pass
diff --git a/logilab/common/testlib.py b/logilab/common/testlib.py
index dae1ff5..8348900 100644
--- a/logilab/common/testlib.py
+++ b/logilab/common/testlib.py
@@ -55,6 +55,8 @@ import warnings
from shutil import rmtree
from operator import itemgetter
from inspect import isgeneratorfunction
+from typing import Any, Iterator, Union, Optional, Callable, Dict, List, Tuple
+from mypy_extensions import NoReturn
import builtins
import configparser
@@ -69,7 +71,9 @@ if not getattr(unittest_legacy, "__package__", None):
except ImportError:
raise ImportError("You have to install python-unittest2 to use %s" % __name__)
else:
- import unittest as unittest
+ # mypy: Name 'unittest' already defined (possibly by an import)
+ # compat
+ import unittest as unittest # type: ignore
from unittest import SkipTest
from functools import wraps
@@ -91,11 +95,11 @@ __unittest = 1
@deprecated('with_tempdir is deprecated, use tempfile.TemporaryDirectory.')
-def with_tempdir(callable):
+def with_tempdir(callable: Callable) -> Callable:
"""A decorator ensuring no temporary file left when the function return
Work only for temporary file created with the tempfile module"""
if isgeneratorfunction(callable):
- def proxy(*args, **kwargs):
+ def proxy(*args: Any, **kwargs: Any) -> Iterator[Union[Iterator, Iterator[str]]]:
old_tmpdir = tempfile.gettempdir()
new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-")
tempfile.tempdir = new_tmpdir
@@ -109,20 +113,21 @@ def with_tempdir(callable):
tempfile.tempdir = old_tmpdir
return proxy
- @wraps(callable)
- def proxy(*args, **kargs):
+ else:
+ @wraps(callable)
+ def proxy(*args: Any, **kargs: Any) -> Any:
- old_tmpdir = tempfile.gettempdir()
- new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-")
- tempfile.tempdir = new_tmpdir
- try:
- return callable(*args, **kargs)
- finally:
+ old_tmpdir = tempfile.gettempdir()
+ new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-")
+ tempfile.tempdir = new_tmpdir
try:
- rmtree(new_tmpdir, ignore_errors=True)
+ return callable(*args, **kargs)
finally:
- tempfile.tempdir = old_tmpdir
- return proxy
+ try:
+ rmtree(new_tmpdir, ignore_errors=True)
+ finally:
+ tempfile.tempdir = old_tmpdir
+ return proxy
def in_tempdir(callable):
"""A decorator moving the enclosed function inside the tempfile.tempfdir
@@ -204,7 +209,7 @@ def start_interactive_mode(result):
# coverage pausing tools #####################################################
@contextmanager
-def replace_trace(trace=None):
+def replace_trace(trace: Optional[Callable] = None) -> Iterator:
"""A context manager that temporary replaces the trace function"""
oldtrace = sys.gettrace()
sys.settrace(trace)
@@ -222,16 +227,20 @@ def replace_trace(trace=None):
pause_trace = replace_trace
-def nocoverage(func):
+def nocoverage(func: Callable) -> Callable:
"""Function decorator that pauses tracing functions"""
if hasattr(func, 'uncovered'):
return func
- func.uncovered = True
+ # mypy: "Callable[..., Any]" has no attribute "uncovered"
+ # dynamic attribute for magic
+ func.uncovered = True # type: ignore
- def not_covered(*args, **kwargs):
+ def not_covered(*args: Any, **kwargs: Any) -> Any:
with pause_trace():
return func(*args, **kwargs)
- not_covered.uncovered = True
+ # mypy: "Callable[[VarArg(Any), KwArg(Any)], NoReturn]" has no attribute "uncovered"
+ # dynamic attribute for magic
+ not_covered.uncovered = True # type: ignore
return not_covered
@@ -264,10 +273,10 @@ class InnerTestSkipped(SkipTest):
"""raised when a test is skipped"""
pass
-def parse_generative_args(params):
+def parse_generative_args(params: Tuple[int, ...]) -> Tuple[Union[List[bool], List[int]], Dict]:
args = []
varargs = ()
- kwargs = {}
+ kwargs: Dict = {}
flags = 0 # 2 <=> starargs, 4 <=> kwargs
for param in params:
if isinstance(param, starargs):
@@ -298,22 +307,28 @@ class InnerTest(tuple):
class Tags(set):
"""A set of tag able validate an expression"""
- def __init__(self, *tags, **kwargs):
+ def __init__(self, *tags: str, **kwargs: Any) -> None:
self.inherit = kwargs.pop('inherit', True)
if kwargs:
raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys())
if len(tags) == 1 and not isinstance(tags[0], str):
tags = tags[0]
- super(Tags, self).__init__(tags, **kwargs)
+ super(Tags, self).__init__(tags)
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> bool:
return key in self
- def match(self, exp):
- return eval(exp, {}, self)
+ def match(self, exp: str) -> bool:
+ # mypy: Argument 3 to "eval" has incompatible type "Tags";
+ # mypy: expected "Optional[Mapping[str, Any]]"
+ # I'm really not sure here?
+ return eval(exp, {}, self) # type: ignore
- def __or__(self, other):
+ # mypy: Argument 1 of "__or__" is incompatible with supertype "AbstractSet";
+ # mypy: supertype defines the argument type as "AbstractSet[_T]"
+ # not sure how to fix this one
+ def __or__(self, other: 'Tags') -> 'Tags': # type: ignore
return Tags(*super(Tags, self).__or__(other))
@@ -331,7 +346,7 @@ class TestCase(unittest.TestCase):
maxDiff = None
tags = Tags()
- def __init__(self, methodName='runTest'):
+ def __init__(self, methodName: str = 'runTest') -> None:
super(TestCase, self).__init__(methodName)
self.__exc_info = sys.exc_info
self.__testMethodName = self._testMethodName
@@ -340,7 +355,7 @@ class TestCase(unittest.TestCase):
@classproperty
@cached
- def datadir(cls): # pylint: disable=E0213
+ def datadir(cls) -> str: # pylint: disable=E0213
"""helper attribute holding the standard test's data directory
NOTE: this is a logilab's standard
@@ -351,7 +366,7 @@ class TestCase(unittest.TestCase):
# instantiated for each test run)
@classmethod
- def datapath(cls, *fname):
+ def datapath(cls, *fname: str) -> str:
"""joins the object's datadir and `fname`"""
return osp.join(cls.datadir, *fname)
@@ -363,7 +378,7 @@ class TestCase(unittest.TestCase):
self._current_test_descr = descr
# override default's unittest.py feature
- def shortDescription(self):
+ def shortDescription(self) -> Optional[Any]:
"""override default unittest shortDescription to handle correctly
generative tests
"""
@@ -371,7 +386,7 @@ class TestCase(unittest.TestCase):
return self._current_test_descr
return super(TestCase, self).shortDescription()
- def quiet_run(self, result, func, *args, **kwargs):
+ def quiet_run(self, result: Any, func: Callable, *args: Any, **kwargs: Any) -> bool:
try:
func(*args, **kwargs)
except (KeyboardInterrupt, SystemExit):
@@ -389,7 +404,7 @@ class TestCase(unittest.TestCase):
return False
return True
- def _get_test_method(self):
+ def _get_test_method(self) -> Callable:
"""return the test method"""
return getattr(self, self._testMethodName)
@@ -446,7 +461,7 @@ class TestCase(unittest.TestCase):
# result.cvg.stop()
result.stopTest(self)
- def _proceed_generative(self, result, testfunc, runcondition=None):
+ def _proceed_generative(self, result: Any, testfunc: Callable, runcondition: Callable = None) -> bool:
# cancel startTest()'s increment
result.testsRun -= 1
success = True
@@ -485,7 +500,7 @@ class TestCase(unittest.TestCase):
success = False
return success
- def _proceed(self, result, testfunc, args=(), kwargs=None):
+ def _proceed(self, result: Any, testfunc: Callable, args: Union[List[bool], List[int], Tuple[()]] = (), kwargs: Optional[Dict] = None) -> int:
"""proceed the actual test
returns 0 on success, 1 on failure, 2 on error
@@ -512,7 +527,7 @@ class TestCase(unittest.TestCase):
return 2
return 0
- def innerSkip(self, msg=None):
+ def innerSkip(self, msg: str = None) -> NoReturn:
"""mark a generative test as skipped for the <msg> reason"""
msg = msg or 'test was skipped'
raise InnerTestSkipped(msg)
@@ -549,7 +564,9 @@ class DocTestFinder(doctest.DocTestFinder):
globs, source_lines)
-class DocTest(TestCase, metaclass=class_deprecated):
+# mypy error: Invalid metaclass 'class_deprecated'
+# but it works?
+class DocTest(TestCase, metaclass=class_deprecated): # type: ignore
"""trigger module doctest
I don't know how to make unittest.main consider the DocTestSuite instance
without this hack
@@ -610,7 +627,9 @@ class MockConnection:
pass
-def mock_object(**params):
+# mypy error: Name 'Mock' is not defined
+# dynamic class created by this class
+def mock_object(**params: Any) -> 'Mock': # type: ignore
"""creates an object using params to set attributes
>>> option = mock_object(verbose=False, index=range(5))
>>> option.verbose
@@ -621,7 +640,7 @@ def mock_object(**params):
return type('Mock', (), params)()
-def create_files(paths, chroot):
+def create_files(paths: List[str], chroot: str) -> None:
"""Creates directories and files found in <path>.
:param paths: list of relative paths to files or directories
@@ -662,19 +681,21 @@ class AttrObject: # XXX cf mock_object
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
-def tag(*args, **kwargs):
+def tag(*args: str, **kwargs: Any) -> Callable:
"""descriptor adding tag to a function"""
- def desc(func):
+ def desc(func: Callable) -> Callable:
assert not hasattr(func, 'tags')
- func.tags = Tags(*args, **kwargs)
+ # mypy: "Callable[..., Any]" has no attribute "tags"
+ # dynamic magic attribute
+ func.tags = Tags(*args, **kwargs) # type: ignore
return func
return desc
-def require_version(version):
+def require_version(version: str) -> Callable:
""" Compare version of python interpreter to the given one. Skip the test
if older.
"""
- def check_require_version(f):
+ def check_require_version(f: Callable) -> Callable:
version_elements = version.split('.')
try:
compare = tuple([int(v) for v in version_elements])
@@ -690,10 +711,10 @@ def require_version(version):
return f
return check_require_version
-def require_module(module):
+def require_module(module: str) -> Callable:
""" Check if the given module is loaded. Skip the test if not.
"""
- def check_require_module(f):
+ def check_require_module(f: Callable) -> Callable:
try:
__import__(module)
return f
diff --git a/logilab/common/textutils.py b/logilab/common/textutils.py
index 1a5573d..4b6ea98 100644
--- a/logilab/common/textutils.py
+++ b/logilab/common/textutils.py
@@ -46,8 +46,10 @@ __docformat__ = "restructuredtext en"
import sys
import re
import os.path as osp
+from re import Pattern, Match
from warnings import warn
from unicodedata import normalize as _uninormalize
+from typing import Any, Optional, Tuple, List, Callable, Dict, Union
try:
from os import linesep
except ImportError:
@@ -74,7 +76,7 @@ MANUAL_UNICODE_MAP = {
u'\u2019': u"'", # SIMPLE QUOTE
}
-def unormalize(ustring, ignorenonascii=None, substitute=None):
+def unormalize(ustring: str, ignorenonascii: Optional[Any] = None, substitute: Optional[str] = None) -> str:
"""replace diacritical characters with their corresponding ascii characters
Convert the unicode string to its long normalized form (unicode character
@@ -107,7 +109,7 @@ def unormalize(ustring, ignorenonascii=None, substitute=None):
res.append(replacement)
return u''.join(res)
-def unquote(string):
+def unquote(string: str) -> str:
"""remove optional quotes (simple or double) from the string
:type string: str or unicode
@@ -128,7 +130,7 @@ def unquote(string):
_BLANKLINES_RGX = re.compile('\r?\n\r?\n')
_NORM_SPACES_RGX = re.compile('\s+')
-def normalize_text(text, line_len=80, indent='', rest=False):
+def normalize_text(text: str, line_len: int = 80, indent: str = '', rest: bool = False) -> str:
"""normalize a text to display it with a maximum line size and
optionally arbitrary indentation. Line jumps are normalized but blank
lines are kept. The indentation string may be used to insert a
@@ -159,7 +161,7 @@ def normalize_text(text, line_len=80, indent='', rest=False):
return ('%s%s%s' % (linesep, indent, linesep)).join(result)
-def normalize_paragraph(text, line_len=80, indent=''):
+def normalize_paragraph(text: str, line_len: int = 80, indent: str = '') -> str:
"""normalize a text to display it with a maximum line size and
optionally arbitrary indentation. Line jumps are normalized. The
indentation string may be used top insert a comment mark for
@@ -188,7 +190,7 @@ def normalize_paragraph(text, line_len=80, indent=''):
lines.append(indent + aline)
return linesep.join(lines)
-def normalize_rest_paragraph(text, line_len=80, indent=''):
+def normalize_rest_paragraph(text: str, line_len: int = 80, indent: str = '') -> str:
"""normalize a ReST text to display it with a maximum line size and
optionally arbitrary indentation. Line jumps are normalized. The
indentation string may be used top insert a comment mark for
@@ -229,7 +231,7 @@ def normalize_rest_paragraph(text, line_len=80, indent=''):
return linesep.join(lines)
-def splittext(text, line_len):
+def splittext(text: str, line_len: int) -> Tuple[str, str]:
"""split the given text on space according to the given max line size
return a 2-uple:
@@ -248,7 +250,7 @@ def splittext(text, line_len):
return text[:pos], text[pos+1:].strip()
-def splitstrip(string, sep=','):
+def splitstrip(string: str, sep: str = ',') -> List[str]:
"""return a list of stripped string by splitting the string given as
argument on `sep` (',' by default). Empty string are discarded.
@@ -337,8 +339,8 @@ TIME_UNITS = {
"d": 60 * 60 *24,
}
-def apply_units(string, units, inter=None, final=float, blank_reg=_BLANK_RE,
- value_reg=_VALUE_RE):
+def apply_units(string: str, units: Dict[str, int], inter: Union[Callable, None, type] = None, final: type = float, blank_reg: Pattern = _BLANK_RE,
+ value_reg: Pattern = _VALUE_RE) -> Union[float, int]:
"""Parse the string applying the units defined in units
(e.g.: "1.5m",{'m',60} -> 80).
@@ -379,7 +381,7 @@ def apply_units(string, units, inter=None, final=float, blank_reg=_BLANK_RE,
_LINE_RGX = re.compile('\r\n|\r+|\n')
-def pretty_match(match, string, underline_char='^'):
+def pretty_match(match: Match, string: str, underline_char: str = '^') -> str:
"""return a string with the match location underlined:
>>> import re
@@ -424,11 +426,14 @@ def pretty_match(match, string, underline_char='^'):
result.append(string)
result.append(underline)
else:
- end = string[end_line_pos + len(linesep):]
+ # mypy: Incompatible types in assignment (expression has type "str",
+ # mypy: variable has type "int")
+ # but it's a str :|
+ end = string[end_line_pos + len(linesep):] # type: ignore
string = string[start_line_pos:end_line_pos]
result.append(string)
result.append(underline)
- result.append(end)
+ result.append(end) # type: ignore # see previous comment
return linesep.join(result).rstrip()
@@ -458,7 +463,7 @@ ANSI_COLORS = {
'white': "37",
}
-def _get_ansi_code(color=None, style=None):
+def _get_ansi_code(color: Optional[str] = None, style: Optional[str] = None) -> str:
"""return ansi escape code corresponding to color and style
:type color: str or None
@@ -491,7 +496,7 @@ def _get_ansi_code(color=None, style=None):
return ANSI_PREFIX + ';'.join(ansi_code) + ANSI_END
return ''
-def colorize_ansi(msg, color=None, style=None):
+def colorize_ansi(msg: str, color: Optional[str] = None, style: Optional[str] = None) -> str:
"""colorize message by wrapping it with ansi escape codes
:type msg: str or unicode
diff --git a/logilab/common/tree.py b/logilab/common/tree.py
index 885eb0f..1fc5a21 100644
--- a/logilab/common/tree.py
+++ b/logilab/common/tree.py
@@ -25,29 +25,37 @@ __docformat__ = "restructuredtext en"
import sys
-from logilab.common import flatten
+# from logilab.common import flatten
from logilab.common.visitor import VisitedMixIn, FilteredIterator, no_filter
+from logilab.common.types import Paragraph, Title
+from typing import Optional, Any, Union, List, Callable, TypeVar
## Exceptions #################################################################
class NodeNotFound(Exception):
"""raised when a node has not been found"""
-EX_SIBLING_NOT_FOUND = "No such sibling as '%s'"
-EX_CHILD_NOT_FOUND = "No such child as '%s'"
-EX_NODE_NOT_FOUND = "No such node as '%s'"
+EX_SIBLING_NOT_FOUND: str = "No such sibling as '%s'"
+EX_CHILD_NOT_FOUND: str = "No such child as '%s'"
+EX_NODE_NOT_FOUND: str = "No such node as '%s'"
# Base node ###################################################################
+# describe node of current class
+NodeType = Any
+
+
class Node(object):
"""a basic tree node, characterized by an id"""
- def __init__(self, nid=None) :
+ def __init__(self, nid: Optional[str] = None) -> None :
self.id = nid
# navigation
- self.parent = None
- self.children = []
+ # should be something like Optional[type(self)] for subclasses but that's not possible?
+ self.parent: Optional[NodeType] = None
+ # should be something like List[type(self)] for subclasses but that's not possible?
+ self.children: List[NodeType] = []
def __iter__(self):
return iter(self.children)
@@ -65,31 +73,35 @@ class Node(object):
def is_leaf(self):
return not self.children
- def append(self, child):
+ def append(self, child: NodeType) -> None:
+ # should be child: type(self) but that's not possible
"""add a node to children"""
self.children.append(child)
child.parent = self
- def remove(self, child):
+ def remove(self, child: NodeType) -> None:
+ # should be child: type(self) but that's not possible
"""remove a child node"""
self.children.remove(child)
child.parent = None
- def insert(self, index, child):
+ def insert(self, index: int, child: NodeType) -> None:
+ # should be child: type(self) but that's not possible
"""insert a child node"""
self.children.insert(index, child)
child.parent = self
- def replace(self, old_child, new_child):
+ def replace(self, old_child: NodeType, new_child: NodeType) -> None:
"""replace a child node with another"""
i = self.children.index(old_child)
self.children.pop(i)
self.children.insert(i, new_child)
new_child.parent = self
- def get_sibling(self, nid):
+ def get_sibling(self, nid: str) -> NodeType:
"""return the sibling node that has given id"""
try:
+ assert self.parent is not None
return self.parent.get_child_by_id(nid)
except NodeNotFound :
raise NodeNotFound(EX_SIBLING_NOT_FOUND % nid)
@@ -121,7 +133,7 @@ class Node(object):
return parent.children[index-1]
return None
- def get_node_by_id(self, nid):
+ def get_node_by_id(self, nid: str) -> NodeType:
"""
return node in whole hierarchy that has given id
"""
@@ -131,7 +143,7 @@ class Node(object):
except NodeNotFound :
raise NodeNotFound(EX_NODE_NOT_FOUND % nid)
- def get_child_by_id(self, nid, recurse=None):
+ def get_child_by_id(self, nid: str, recurse: Optional[bool] = None) -> NodeType:
"""
return child of given id
"""
@@ -147,7 +159,7 @@ class Node(object):
return c
raise NodeNotFound(EX_CHILD_NOT_FOUND % nid)
- def get_child_by_path(self, path):
+ def get_child_by_path(self, path: List[str]) -> NodeType:
"""
return child of given path (path is a list of ids)
"""
@@ -162,7 +174,7 @@ class Node(object):
pass
raise NodeNotFound(EX_CHILD_NOT_FOUND % path)
- def depth(self):
+ def depth(self) -> int:
"""
return depth of this node in the tree
"""
@@ -171,7 +183,7 @@ class Node(object):
else :
return 0
- def depth_down(self):
+ def depth_down(self) -> int:
"""
return depth of the tree from this node
"""
@@ -179,13 +191,13 @@ class Node(object):
return 1 + max([c.depth_down() for c in self.children])
return 1
- def width(self):
+ def width(self) -> int:
"""
return the width of the tree from this node
"""
return len(self.leaves())
- def root(self):
+ def root(self) -> NodeType:
"""
return the root node of the tree
"""
@@ -193,7 +205,7 @@ class Node(object):
return self.parent.root()
return self
- def leaves(self):
+ def leaves(self) -> List[NodeType]:
"""
return a list with all the leaves nodes descendant from this node
"""
@@ -205,7 +217,7 @@ class Node(object):
else:
return [self]
- def flatten(self, _list=None):
+ def flatten(self, _list: Optional[List[NodeType]] = None) -> List[NodeType]:
"""
return a list with all the nodes descendant from this node
"""
@@ -216,7 +228,7 @@ class Node(object):
c.flatten(_list)
return _list
- def lineage(self):
+ def lineage(self) -> List[NodeType]:
"""
return list of parents up to root node
"""
@@ -226,6 +238,7 @@ class Node(object):
return lst
class VNode(Node, VisitedMixIn):
+ # we should probably merge this VisitedMixIn here because it's only used here
"""a visitable node
"""
pass
@@ -298,7 +311,7 @@ class ListNode(VNode, list_class):
# construct list from tree ####################################################
-def post_order_list(node, filter_func=no_filter):
+def post_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> List[Node]:
"""
create a list with tree nodes for which the <filter> function returned true
in a post order fashion
@@ -326,7 +339,7 @@ def post_order_list(node, filter_func=no_filter):
poped = 1
return l
-def pre_order_list(node, filter_func=no_filter):
+def pre_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> List[Node]:
"""
create a list with tree nodes for which the <filter> function returned true
in a pre order fashion
@@ -358,12 +371,12 @@ def pre_order_list(node, filter_func=no_filter):
class PostfixedDepthFirstIterator(FilteredIterator):
"""a postfixed depth first iterator, designed to be used with visitors
"""
- def __init__(self, node, filter_func=None):
+ def __init__(self, node: Node, filter_func: Optional[Any] = None) -> None:
FilteredIterator.__init__(self, node, post_order_list, filter_func)
class PrefixedDepthFirstIterator(FilteredIterator):
"""a prefixed depth first iterator, designed to be used with visitors
"""
- def __init__(self, node, filter_func=None):
+ def __init__(self, node: Node, filter_func: Optional[Any] = None) -> None:
FilteredIterator.__init__(self, node, pre_order_list, filter_func)
diff --git a/logilab/common/types.py b/logilab/common/types.py
new file mode 100644
index 0000000..b13f3c7
--- /dev/null
+++ b/logilab/common/types.py
@@ -0,0 +1,44 @@
+# copyright 2019 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of yams.
+#
+# yams is free software: you can redistribute it and/or modify it under the
+# terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option)
+# any later version.
+#
+# yams is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
+# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License
+# along with yams. If not, see <https://www.gnu.org/licenses/>.
+
+"""Types declarations for types annotations"""
+
+from typing import TYPE_CHECKING, TypeVar
+
+
+# to avoid circular imports
+if TYPE_CHECKING:
+ from logilab.common.tree import Node
+ from logilab.common.ureports.html_writer import HTMLWriter
+ from logilab.common.ureports.text_writer import TextWriter
+ from logilab.common.ureports.nodes import Paragraph
+ from logilab.common.ureports.nodes import Title
+ from logilab.common.table import Table
+ from logilab.common.optik_ext import OptionParser
+ from logilab.common.optik_ext import Option
+ from logilab.common import attrdict
+else:
+ Node = TypeVar("Node")
+ HTMLWriter = TypeVar("HTMLWriter")
+ TextWriter = TypeVar("TextWriter")
+ Table = TypeVar("Table")
+ OptionParser = TypeVar("OptionParser")
+ Option = TypeVar("Option")
+ attrdict = TypeVar("attrdict")
+ Paragraph = TypeVar("Paragraph")
+ Title = TypeVar("Title")
diff --git a/logilab/common/umessage.py b/logilab/common/umessage.py
index e5a7f4e..77a6272 100644
--- a/logilab/common/umessage.py
+++ b/logilab/common/umessage.py
@@ -24,8 +24,10 @@ from encodings import search_function
import sys
from email.utils import parseaddr, parsedate
from email.header import decode_header
+from email.message import Message
from datetime import datetime
+from typing import Any, Optional, List, Tuple, Union
try:
from mx.DateTime import DateTime
@@ -35,17 +37,20 @@ except ImportError:
import logilab.common as lgc
-def decode_QP(string):
- parts = []
- for decoded, charset in decode_header(string):
+def decode_QP(string: str) -> str:
+ parts: List[str] = []
+ for maybe_decoded, charset in decode_header(string):
if not charset :
charset = 'iso-8859-15'
# python 3 sometimes returns str and sometimes bytes.
# the 'official' fix is to use the new 'policy' APIs
# https://bugs.python.org/issue24797
# let's just handle this bug ourselves for now
- if isinstance(decoded, bytes):
- decoded = decoded.decode(charset, 'replace')
+ if isinstance(maybe_decoded, bytes):
+ decoded = maybe_decoded.decode(charset, 'replace')
+ else:
+ decoded = maybe_decoded
+
assert isinstance(decoded, str)
parts.append(decoded)
@@ -53,6 +58,7 @@ def decode_QP(string):
# decoding was non-RFC compliant wrt to whitespace handling
# see http://bugs.python.org/issue1079
return u' '.join(parts)
+
return u''.join(parts)
def message_from_file(fd):
@@ -61,7 +67,7 @@ def message_from_file(fd):
except email.errors.MessageParseError:
return ''
-def message_from_string(string):
+def message_from_string(string: str) -> Union['UMessage', str]:
try:
return UMessage(email.message_from_string(string))
except email.errors.MessageParseError:
@@ -71,12 +77,12 @@ class UMessage:
"""Encapsulates an email.Message instance and returns only unicode objects.
"""
- def __init__(self, message):
+ def __init__(self, message: Message) -> None:
self.message = message
# email.Message interface #################################################
- def get(self, header, default=None):
+ def get(self, header: str, default: Optional[Any] = None) -> Optional[str]:
value = self.message.get(header, default)
if value:
return decode_QP(value)
@@ -85,7 +91,7 @@ class UMessage:
def __getitem__(self, header):
return self.get(header)
- def get_all(self, header, default=()):
+ def get_all(self, header: str, default: Tuple[()] = ()) -> List[str]:
return [decode_QP(val) for val in self.message.get_all(header, default)
if val is not None]
@@ -99,23 +105,33 @@ class UMessage:
for part in self.message.walk():
yield UMessage(part)
- def get_payload(self, index=None, decode=False):
+ def get_payload(self, index: Optional[Any] = None, decode: bool = False) -> Union[str, 'UMessage', List['UMessage']]:
message = self.message
+
if index is None:
- payload = message.get_payload(index, decode)
+ # mypy: Argument 1 to "get_payload" of "Message" has incompatible type "None"; expected "int"
+ # email.message.Message.get_payload has type signature:
+ # Message.get_payload(self, i=None, decode=False)
+ # so None seems to be totally acceptable, I don't understand mypy here
+ payload = message.get_payload(index, decode) # type: ignore
+
if isinstance(payload, list):
return [UMessage(msg) for msg in payload]
+
if message.get_content_maintype() != 'text':
return payload
+
if isinstance(payload, str):
return payload
charset = message.get_content_charset() or 'iso-8859-1'
if search_function(charset) is None:
charset = 'iso-8859-1'
+
return str(payload or b'', charset, "replace")
else:
payload = UMessage(message.get_payload(index, decode))
+
return payload
def get_content_maintype(self):
diff --git a/logilab/common/ureports/__init__.py b/logilab/common/ureports/__init__.py
index 8ce68a0..9c0f1df 100644
--- a/logilab/common/ureports/__init__.py
+++ b/logilab/common/ureports/__init__.py
@@ -26,6 +26,9 @@ import sys
from logilab.common.compat import StringIO
from logilab.common.textutils import linesep
+from logilab.common.tree import VNode
+from logilab.common.ureports.nodes import Table, List as NodeList
+from typing import Any, Optional, Union, List, Generator, Tuple, Callable, TextIO
def get_nodes(node, klass):
@@ -76,7 +79,7 @@ def build_summary(layout, level=1):
class BaseWriter(object):
"""base class for ureport writers"""
- def format(self, layout, stream=None, encoding=None):
+ def format(self, layout: Any, stream: Optional[Union[StringIO, TextIO]] = None, encoding: Optional[Any] = None) -> None:
"""format and write the given layout into the stream object
unicode policy: unicode strings may be found in the layout;
@@ -88,87 +91,121 @@ class BaseWriter(object):
if not encoding:
encoding = getattr(stream, 'encoding', 'UTF-8')
self.encoding = encoding or 'UTF-8'
- self.__compute_funcs = []
+ self.__compute_funcs: List[Tuple[Callable[[str], Any], Callable[[str], Any]]] = []
self.out = stream
self.begin_format(layout)
layout.accept(self)
self.end_format(layout)
- def format_children(self, layout):
+ def format_children(self, layout: Union['Paragraph', 'Section', 'Title']) -> None:
"""recurse on the layout children and call their accept method
(see the Visitor pattern)
"""
for child in getattr(layout, 'children', ()):
child.accept(self)
- def writeln(self, string=u''):
+ def writeln(self, string: str = u'') -> None:
"""write a line in the output buffer"""
self.write(string + linesep)
- def write(self, string):
+ def write(self, string: str) -> None:
"""write a string in the output buffer"""
try:
self.out.write(string)
except UnicodeEncodeError:
- self.out.write(string.encode(self.encoding))
+ # mypy: Argument 1 to "write" of "IO" has incompatible type "bytes"; expected "str"
+ # probably a python3 port issue?
+ self.out.write(string.encode(self.encoding)) # type: ignore
- def begin_format(self, layout):
+ def begin_format(self, layout: Any) -> None:
"""begin to format a layout"""
self.section = 0
- def end_format(self, layout):
+ def end_format(self, layout: Any) -> None:
"""finished to format a layout"""
- def get_table_content(self, table):
+ def get_table_content(self, table: Table) -> List[List[str]]:
"""trick to get table content without actually writing it
return an aligned list of lists containing table cells values as string
"""
- result = [[]]
- cols = table.cols
+ result: List[List[str]] = [[]]
+ # mypy: "Table" has no attribute "cols"
+ # dynamic attribute
+ cols = table.cols # type: ignore
+
for cell in self.compute_content(table):
if cols == 0:
result.append([])
- cols = table.cols
+ # mypy: "Table" has no attribute "cols"
+ # dynamic attribute
+ cols = table.cols # type: ignore
+
cols -= 1
result[-1].append(cell)
+
# fill missing cells
while len(result[-1]) < cols:
result[-1].append(u'')
+
return result
- def compute_content(self, layout):
+ def compute_content(self, layout: VNode) -> Generator[str, Any, None]:
"""trick to compute the formatting of children layout before actually
writing it
return an iterator on strings (one for each child element)
"""
# use cells !
- def write(data):
+ def write(data: str) -> None:
try:
stream.write(data)
except UnicodeEncodeError:
- stream.write(data.encode(self.encoding))
- def writeln(data=u''):
+ # mypy: Argument 1 to "write" of "TextIOWrapper" has incompatible type "bytes";
+ # mypy: expected "str"
+ # error from porting to python3?
+ stream.write(data.encode(self.encoding)) # type: ignore
+
+ def writeln(data: str = u'') -> None:
try:
stream.write(data+linesep)
except UnicodeEncodeError:
- stream.write(data.encode(self.encoding)+linesep)
- self.write = write
- self.writeln = writeln
+ # mypy: Unsupported operand types for + ("bytes" and "str")
+ # error from porting to python3?
+ stream.write(data.encode(self.encoding)+linesep) # type: ignore
+
+ # mypy: Cannot assign to a method
+ # this really looks like black dirty magic since self.write is reused elsewhere in the code
+ # especially since self.write and self.writeln are conditionally
+ # deleted at the end of this function
+ self.write = write # type: ignore
+ self.writeln = writeln # type: ignore
+
self.__compute_funcs.append((write, writeln))
- for child in layout.children:
+
+ # mypy: Item "Table" of "Union[List[Any], Table, Title]" has no attribute "children"
+ # dynamic attribute?
+ for child in layout.children: # type: ignore
stream = StringIO()
+
child.accept(self)
+
yield stream.getvalue()
+
self.__compute_funcs.pop()
+
try:
- self.write, self.writeln = self.__compute_funcs[-1]
+ # mypy: Cannot assign to a method
+ # even more black dirty magic
+ self.write, self.writeln = self.__compute_funcs[-1] # type: ignore
except IndexError:
del self.write
del self.writeln
-
-from logilab.common.ureports.nodes import *
+# mypy error: Incompatible import of "Table" (imported name has type
+# mypy error: "Type[logilab.common.ureports.nodes.Table]", local name has type
+# mypy error: "Type[logilab.common.table.Table]")
+# this will be cleaned when the "*" will be removed
+from logilab.common.ureports.nodes import * # type: ignore
from logilab.common.ureports.text_writer import TextWriter
from logilab.common.ureports.html_writer import HTMLWriter
diff --git a/logilab/common/ureports/html_writer.py b/logilab/common/ureports/html_writer.py
index 989662f..0783075 100644
--- a/logilab/common/ureports/html_writer.py
+++ b/logilab/common/ureports/html_writer.py
@@ -20,16 +20,19 @@ __docformat__ = "restructuredtext en"
from logilab.common.ureports import BaseWriter
+from logilab.common.ureports.nodes import (Section, Title, Table, List,
+ Paragraph, Link, VerbatimText, Text)
+from typing import Any
class HTMLWriter(BaseWriter):
"""format layouts as HTML"""
- def __init__(self, snippet=None):
+ def __init__(self, snippet: int = None) -> None:
super(HTMLWriter, self).__init__()
self.snippet = snippet
- def handle_attrs(self, layout):
+ def handle_attrs(self, layout: Any) -> str:
"""get an attribute string from layout member attributes"""
attrs = u''
klass = getattr(layout, 'klass', None)
@@ -40,20 +43,20 @@ class HTMLWriter(BaseWriter):
attrs += u' id="%s"' % nid
return attrs
- def begin_format(self, layout):
+ def begin_format(self, layout: Any) -> None:
"""begin to format a layout"""
super(HTMLWriter, self).begin_format(layout)
if self.snippet is None:
self.writeln(u'<html>')
self.writeln(u'<body>')
- def end_format(self, layout):
+ def end_format(self, layout: Any) -> None:
"""finished to format a layout"""
if self.snippet is None:
self.writeln(u'</body>')
self.writeln(u'</html>')
- def visit_section(self, layout):
+ def visit_section(self, layout: Section) -> None:
"""display a section as html, using div + h[section level]"""
self.section += 1
self.writeln(u'<div%s>' % self.handle_attrs(layout))
@@ -61,13 +64,13 @@ class HTMLWriter(BaseWriter):
self.writeln(u'</div>')
self.section -= 1
- def visit_title(self, layout):
+ def visit_title(self, layout: Title) -> None:
"""display a title using <hX>"""
self.write(u'<h%s%s>' % (self.section, self.handle_attrs(layout)))
self.format_children(layout)
self.writeln(u'</h%s>' % self.section)
- def visit_table(self, layout):
+ def visit_table(self, layout: Table) -> None:
"""display a table as html"""
self.writeln(u'<table%s>' % self.handle_attrs(layout))
table_content = self.get_table_content(layout)
@@ -91,14 +94,14 @@ class HTMLWriter(BaseWriter):
self.writeln(u'</tr>')
self.writeln(u'</table>')
- def visit_list(self, layout):
+ def visit_list(self, layout: List) -> None:
"""display a list as html"""
self.writeln(u'<ul%s>' % self.handle_attrs(layout))
for row in list(self.compute_content(layout)):
self.writeln(u'<li>%s</li>' % row)
self.writeln(u'</ul>')
- def visit_paragraph(self, layout):
+ def visit_paragraph(self, layout: Paragraph) -> None:
"""display links (using <p>)"""
self.write(u'<p>')
self.format_children(layout)
@@ -110,19 +113,19 @@ class HTMLWriter(BaseWriter):
self.format_children(layout)
self.write(u'</span>')
- def visit_link(self, layout):
+ def visit_link(self, layout: Link) -> None:
"""display links (using <a>)"""
self.write(u' <a href="%s"%s>%s</a>' % (layout.url,
self.handle_attrs(layout),
layout.label))
- def visit_verbatimtext(self, layout):
+ def visit_verbatimtext(self, layout: VerbatimText) -> None:
"""display verbatim text (using <pre>)"""
self.write(u'<pre>')
self.write(layout.data.replace(u'&', u'&amp;').replace(u'<', u'&lt;'))
self.write(u'</pre>')
- def visit_text(self, layout):
+ def visit_text(self, layout: Text) -> None:
"""add some text"""
data = layout.data
if layout.escaped:
diff --git a/logilab/common/ureports/nodes.py b/logilab/common/ureports/nodes.py
index d1267f5..d086faf 100644
--- a/logilab/common/ureports/nodes.py
+++ b/logilab/common/ureports/nodes.py
@@ -22,6 +22,14 @@ A micro report is a tree of layout and content objects.
__docformat__ = "restructuredtext en"
from logilab.common.tree import VNode
+from typing import Optional
+# from logilab.common.ureports.nodes import List
+# from logilab.common.ureports.nodes import Paragraph
+# from logilab.common.ureports.nodes import Text
+from typing import Any
+from typing import List as TypingList
+from typing import Tuple
+from typing import Union
class BaseComponent(VNode):
@@ -31,7 +39,7 @@ class BaseComponent(VNode):
* id : the component's optional id
* klass : the component's optional klass
"""
- def __init__(self, id=None, klass=None):
+ def __init__(self, id: Optional[str] = None, klass: Optional[str] = None) -> None:
VNode.__init__(self, id)
self.klass = klass
@@ -43,27 +51,36 @@ class BaseLayout(BaseComponent):
* BaseComponent attributes
* children : components in this table (i.e. the table's cells)
"""
- def __init__(self, children=(), **kwargs):
+ def __init__(self,
+ children: Union[TypingList['Text'],
+ Tuple[Union['Paragraph', str],
+ Union[TypingList, str]], Tuple[str, ...]] = (),
+ **kwargs: Any) -> None:
+
super(BaseLayout, self).__init__(**kwargs)
+
for child in children:
if isinstance(child, BaseComponent):
self.append(child)
else:
- self.add_text(child)
+ # mypy: Argument 1 to "add_text" of "BaseLayout" has incompatible type
+ # mypy: "Union[str, List[Any]]"; expected "str"
+ # we check this situation in the if
+ self.add_text(child) # type: ignore
- def append(self, child):
+ def append(self, child: Any) -> None:
"""overridden to detect problems easily"""
assert child not in self.parents()
VNode.append(self, child)
- def parents(self):
+ def parents(self) -> TypingList:
"""return the ancestor nodes"""
assert self.parent is not self
if self.parent is None:
return []
return [self.parent] + self.parent.parents()
- def add_text(self, text):
+ def add_text(self, text: str) -> None:
"""shortcut to add text data"""
self.children.append(Text(text))
@@ -77,7 +94,7 @@ class Text(BaseComponent):
* BaseComponent attributes
* data : the text value as an encoded or unicode string
"""
- def __init__(self, data, escaped=True, **kwargs):
+ def __init__(self, data: str, escaped: bool = True, **kwargs: Any) -> None:
super(Text, self).__init__(**kwargs)
# if isinstance(data, unicode):
# data = data.encode('ascii')
@@ -103,7 +120,7 @@ class Link(BaseComponent):
* url : the link's target (REQUIRED)
* label : the link's label as a string (use the url by default)
"""
- def __init__(self, url, label=None, **kwargs):
+ def __init__(self, url: str, label: str = None, **kwargs: Any) -> None:
super(Link, self).__init__(**kwargs)
assert url
self.url = url
@@ -141,7 +158,7 @@ class Section(BaseLayout):
a description may also be given to the constructor, it'll be added
as a first paragraph
"""
- def __init__(self, title=None, description=None, **kwargs):
+ def __init__(self, title: str = None, description: str = None, **kwargs: Any) -> None:
super(Section, self).__init__(**kwargs)
if description:
self.insert(0, Paragraph([Text(description)]))
@@ -189,9 +206,9 @@ class Table(BaseLayout):
* cheaders : the first col's elements are table's header
* title : the table's optional title
"""
- def __init__(self, cols, title=None,
- rheaders=0, cheaders=0, rrheaders=0, rcheaders=0,
- **kwargs):
+ def __init__(self, cols: int, title: Optional[Any] = None,
+ rheaders: int = 0, cheaders: int = 0, rrheaders: int = 0, rcheaders: int = 0,
+ **kwargs: Any) -> None:
super(Table, self).__init__(**kwargs)
assert isinstance(cols, int)
self.cols = cols
diff --git a/logilab/common/ureports/text_writer.py b/logilab/common/ureports/text_writer.py
index ae92774..f75d7c9 100644
--- a/logilab/common/ureports/text_writer.py
+++ b/logilab/common/ureports/text_writer.py
@@ -18,11 +18,14 @@
"""Text formatting drivers for ureports"""
from __future__ import print_function
+from typing import Any, List, Tuple
__docformat__ = "restructuredtext en"
from logilab.common.textutils import linesep
from logilab.common.ureports import BaseWriter
+from logilab.common.ureports.nodes import (Section, Title, Table, List as NodeList,
+ Paragraph, Link, VerbatimText, Text)
TITLE_UNDERLINES = [u'', u'=', u'-', u'`', u'.', u'~', u'^']
@@ -33,12 +36,12 @@ class TextWriter(BaseWriter):
"""format layouts as text
(ReStructured inspiration but not totally handled yet)
"""
- def begin_format(self, layout):
+ def begin_format(self, layout: Any) -> None:
super(TextWriter, self).begin_format(layout)
self.list_level = 0
- self.pending_urls = []
+ self.pending_urls: List[Tuple[str, str]] = []
- def visit_section(self, layout):
+ def visit_section(self, layout: Section) -> None:
"""display a section as text
"""
self.section += 1
@@ -52,7 +55,7 @@ class TextWriter(BaseWriter):
self.section -= 1
self.writeln()
- def visit_title(self, layout):
+ def visit_title(self, layout: Title) -> None:
title = u''.join(list(self.compute_content(layout)))
self.writeln(title)
try:
@@ -60,7 +63,7 @@ class TextWriter(BaseWriter):
except IndexError:
print("FIXME TITLE TOO DEEP. TURNING TITLE INTO TEXT")
- def visit_paragraph(self, layout):
+ def visit_paragraph(self, layout: 'Paragraph') -> None:
"""enter a paragraph"""
self.format_children(layout)
self.writeln()
@@ -69,7 +72,7 @@ class TextWriter(BaseWriter):
"""enter a span"""
self.format_children(layout)
- def visit_table(self, layout):
+ def visit_table(self, layout: Table) -> None:
"""display a table as text"""
table_content = self.get_table_content(layout)
# get columns width
@@ -84,36 +87,40 @@ class TextWriter(BaseWriter):
self.default_table(layout, table_content, cols_width)
self.writeln()
- def default_table(self, layout, table_content, cols_width):
+ def default_table(self, layout: Table, table_content: List[List[str]], cols_width: List[int]) -> None:
"""format a table"""
cols_width = [size+1 for size in cols_width]
+
format_strings = u' '.join([u'%%-%ss'] * len(cols_width))
format_strings = format_strings % tuple(cols_width)
- format_strings = format_strings.split(' ')
+
+ format_strings_list = format_strings.split(' ')
+
table_linesep = (
u'\n+' + u'+'.join([u'-'*w for w in cols_width]) + u'+\n')
headsep = u'\n+' + u'+'.join([u'='*w for w in cols_width]) + u'+\n'
+
# FIXME: layout.cheaders
self.write(table_linesep)
for i in range(len(table_content)):
self.write(u'|')
line = table_content[i]
for j in range(len(line)):
- self.write(format_strings[j] % line[j])
+ self.write(format_strings_list[j] % line[j])
self.write(u'|')
if i == 0 and layout.rheaders:
self.write(headsep)
else:
self.write(table_linesep)
- def field_table(self, layout, table_content, cols_width):
+ def field_table(self, layout: Table, table_content: List[List[str]], cols_width: List[int]) -> None:
"""special case for field table"""
assert layout.cols == 2
format_string = u'%s%%-%ss: %%s' % (linesep, cols_width[0])
for field, value in table_content:
self.write(format_string % (field, value))
- def visit_list(self, layout):
+ def visit_list(self, layout: NodeList) -> None:
"""display a list layout as text"""
bullet = BULLETS[self.list_level % len(BULLETS)]
indent = ' ' * self.list_level
@@ -123,7 +130,7 @@ class TextWriter(BaseWriter):
child.accept(self)
self.list_level -= 1
- def visit_link(self, layout):
+ def visit_link(self, layout: Link) -> None:
"""add a hyperlink"""
if layout.label != layout.url:
self.write(u'`%s`_' % layout.label)
@@ -131,7 +138,7 @@ class TextWriter(BaseWriter):
else:
self.write(layout.url)
- def visit_verbatimtext(self, layout):
+ def visit_verbatimtext(self, layout: VerbatimText) -> None:
"""display a verbatim layout as text (so difficult ;)
"""
self.writeln(u'::\n')
@@ -139,6 +146,6 @@ class TextWriter(BaseWriter):
self.writeln(u' ' + line)
self.writeln()
- def visit_text(self, layout):
+ def visit_text(self, layout: Text) -> None:
"""add some text"""
self.write(u'%s' % layout.data)
diff --git a/logilab/common/visitor.py b/logilab/common/visitor.py
index ed2b70f..0698bae 100644
--- a/logilab/common/visitor.py
+++ b/logilab/common/visitor.py
@@ -21,21 +21,24 @@
"""
+from typing import Any, Callable, Optional, Union
+from logilab.common.types import Node, HTMLWriter, TextWriter
__docformat__ = "restructuredtext en"
-def no_filter(_):
+
+def no_filter(_: Node) -> int:
return 1
# Iterators ###################################################################
class FilteredIterator(object):
- def __init__(self, node, list_func, filter_func=None):
+ def __init__(self, node: Node, list_func: Callable, filter_func: Optional[Any] = None) -> None:
self._next = [(node, 0)]
if filter_func is None:
filter_func = no_filter
self._list = list_func(node, filter_func)
- def __next__(self):
+ def __next__(self) -> Optional[Node]:
try:
return self._list.pop(0)
except :
@@ -89,18 +92,20 @@ class VisitedMixIn(object):
"""
Visited interface allow node visitors to use the node
"""
- def get_visit_name(self):
+ def get_visit_name(self) -> str:
"""
return the visit name for the mixed class. When calling 'accept', the
method <'visit_' + name returned by this method> will be called on the
visitor
"""
try:
- return self.TYPE.replace('-', '_')
+ # mypy: "VisitedMixIn" has no attribute "TYPE"
+ # dynamic attribute
+ return self.TYPE.replace('-', '_') # type: ignore
except:
return self.__class__.__name__.lower()
- def accept(self, visitor, *args, **kwargs):
+ def accept(self, visitor: Union[HTMLWriter, TextWriter], *args: Any, **kwargs: Any) -> Optional[Any]:
func = getattr(visitor, 'visit_%s' % self.get_visit_name())
return func(self, *args, **kwargs)
diff --git a/logilab/common/xmlutils.py b/logilab/common/xmlutils.py
index d383b9d..7b12c45 100644
--- a/logilab/common/xmlutils.py
+++ b/logilab/common/xmlutils.py
@@ -29,11 +29,12 @@ instruction and return a Python dictionary.
__docformat__ = "restructuredtext en"
import re
+from typing import Dict, Optional, Union
RE_DOUBLE_QUOTE = re.compile('([\w\-\.]+)="([^"]+)"')
RE_SIMPLE_QUOTE = re.compile("([\w\-\.]+)='([^']+)'")
-def parse_pi_data(pi_data):
+def parse_pi_data(pi_data: str) -> Dict[str, Optional[str]]:
"""
Utility function that parses the data contained in an XML
processing instruction and returns a dictionary of keywords and their
@@ -51,10 +52,13 @@ def parse_pi_data(pi_data):
"""
results = {}
for elt in pi_data.split():
- if RE_DOUBLE_QUOTE.match(elt):
- kwd, val = RE_DOUBLE_QUOTE.match(elt).groups()
- elif RE_SIMPLE_QUOTE.match(elt):
- kwd, val = RE_SIMPLE_QUOTE.match(elt).groups()
+ val: Optional[str]
+ double_match = RE_DOUBLE_QUOTE.match(elt)
+ simple_match = RE_SIMPLE_QUOTE.match(elt)
+ if double_match:
+ kwd, val = double_match.groups()
+ elif simple_match:
+ kwd, val = simple_match.groups()
else:
kwd, val = elt, None
results[kwd] = val
diff --git a/tox.ini b/tox.ini
index 341b3ae..fb24e02 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist=py3
+envlist=py3,mypy
[testenv]
deps =
@@ -14,3 +14,8 @@ deps =
-r docs/requirements-doc.txt
commands=
{envpython} -m sphinx -b html {toxinidir}/docs {toxinidir}/docs/_build/html {posargs}
+
+[testenv:mypy]
+deps =
+ mypy >= 0.761
+commands = mypy --ignore-missing-imports logilab