diff options
author | Laurent Peuch <cortex@worlddomination.be> | 2020-04-01 00:11:10 +0200 |
---|---|---|
committer | Laurent Peuch <cortex@worlddomination.be> | 2020-04-01 00:11:10 +0200 |
commit | b8899451fa861b04568e2a0bb4e3fe4acc0daee3 (patch) | |
tree | 1446c809e19b5571b31b1999246aa0e50b19f5c8 /logilab/common | |
parent | 32cd73810056594f55eff0ffafebbdeb50c7a860 (diff) | |
download | logilab-common-b8899451fa861b04568e2a0bb4e3fe4acc0daee3.tar.gz |
Black the whole code base
Diffstat (limited to 'logilab/common')
39 files changed, 2586 insertions, 1817 deletions
diff --git a/logilab/common/__init__.py b/logilab/common/__init__.py index 0d7f183..34e6c1b 100644 --- a/logilab/common/__init__.py +++ b/logilab/common/__init__.py @@ -31,19 +31,19 @@ import types import pkg_resources from typing import List, Sequence -__version__ = pkg_resources.get_distribution('logilab-common').version +__version__ = pkg_resources.get_distribution("logilab-common").version # deprecated, but keep compatibility with pylint < 1.4.4 -__pkginfo__ = types.ModuleType('__pkginfo__') +__pkginfo__ = types.ModuleType("__pkginfo__") __pkginfo__.__package__ = __name__ # mypy output: Module has no attribute "version" # logilab's magic __pkginfo__.version = __version__ # type: ignore -sys.modules['logilab.common.__pkginfo__'] = __pkginfo__ +sys.modules["logilab.common.__pkginfo__"] = __pkginfo__ -STD_BLACKLIST = ('CVS', '.svn', '.hg', '.git', '.tox', 'debian', 'dist', 'build') +STD_BLACKLIST = ("CVS", ".svn", ".hg", ".git", ".tox", "debian", "dist", "build") -IGNORED_EXTENSIONS = ('.pyc', '.pyo', '.elc', '~', '.swp', '.orig') +IGNORED_EXTENSIONS = (".pyc", ".pyo", ".elc", "~", ".swp", ".orig") # set this to False if you've mx DateTime installed but you don't want your db # adapter to use it (should be set before you got a connection) @@ -52,12 +52,14 @@ USE_MX_DATETIME = True class attrdict(dict): """A dictionary for which keys are also accessible as attributes.""" + def __getattr__(self, attr: str) -> str: try: return self[attr] except KeyError: raise AttributeError(attr) + class dictattr(dict): def __init__(self, proxy): self.__proxy = proxy @@ -68,13 +70,17 @@ class dictattr(dict): except AttributeError: raise KeyError(attr) + class nullobject(object): def __repr__(self): - return '<nullobject>' + return "<nullobject>" + def __bool__(self): return False + __nonzero__ = __bool__ + class tempattr(object): def __init__(self, obj, attr, value): self.obj = obj @@ -90,7 +96,6 @@ class tempattr(object): setattr(self.obj, self.attr, self.oldvalue) - # flatten ----- # XXX move in a specific module and use yield instead # do not mix flatten and translate @@ -105,10 +110,10 @@ class tempattr(object): # except (TypeError, ValueError): return False # return True # -#def is_scalar(obj): +# def is_scalar(obj): # return is_string_like(obj) or not iterable(obj) # -#def flatten(seq): +# def flatten(seq): # for item in seq: # if is_scalar(item): # yield item @@ -116,6 +121,7 @@ class tempattr(object): # for subitem in flatten(item): # yield subitem + def flatten(iterable, tr_func=None, results=None): """Flatten a list of list with any level. @@ -141,6 +147,7 @@ def flatten(iterable, tr_func=None, results=None): # XXX is function below still used ? + def make_domains(lists): """ Given a list of lists, return a list of domain for each list to produce all @@ -157,7 +164,7 @@ def make_domains(lists): for iterable in lists: new_domain = iterable[:] for i in range(len(domains)): - domains[i] = domains[i]*len(iterable) + domains[i] = domains[i] * len(iterable) if domains: missing = (len(domains[0]) - len(iterable)) / len(iterable) i = 0 @@ -173,6 +180,7 @@ def make_domains(lists): # private stuff ################################################################ + def _handle_blacklist(blacklist: Sequence[str], dirnames: List[str], filenames: List[str]) -> None: """remove files/directories in the black list @@ -183,4 +191,3 @@ def _handle_blacklist(blacklist: Sequence[str], dirnames: List[str], filenames: dirnames.remove(norecurs) elif norecurs in filenames: filenames.remove(norecurs) - diff --git a/logilab/common/cache.py b/logilab/common/cache.py index c47f481..6b673a2 100644 --- a/logilab/common/cache.py +++ b/logilab/common/cache.py @@ -47,7 +47,7 @@ class Cache(dict): """ Warning : Cache.__init__() != dict.__init__(). Constructor does not take any arguments beside size. """ - assert size >= 0, 'cache size must be >= 0 (0 meaning no caching)' + assert size >= 0, "cache size must be >= 0 (0 meaning no caching)" self.size = size self._usage: List = [] self._lock = Lock() @@ -74,12 +74,13 @@ class Cache(dict): del self._usage[0] self._usage.append(key) else: - pass # key is already the most recently used key + pass # key is already the most recently used key def __getitem__(self, key: _KeyType): value = super(Cache, self).__getitem__(key) self._update_usage(key) return value + __getitem__ = locked(_acquire, _release)(__getitem__) def __setitem__(self, key: _KeyType, item): @@ -87,24 +88,28 @@ class Cache(dict): if self.size > 0: super(Cache, self).__setitem__(key, item) self._update_usage(key) + __setitem__ = locked(_acquire, _release)(__setitem__) def __delitem__(self, key: _KeyType): super(Cache, self).__delitem__(key) self._usage.remove(key) + __delitem__ = locked(_acquire, _release)(__delitem__) def clear(self): super(Cache, self).clear() self._usage = [] + clear = locked(_acquire, _release)(clear) def pop(self, key: _KeyType, default=_marker): if key in self: self._usage.remove(key) - #if default is _marker: + # if default is _marker: # return super(Cache, self).pop(key) return super(Cache, self).pop(key, default) + pop = locked(_acquire, _release)(pop) def popitem(self): @@ -115,5 +120,3 @@ class Cache(dict): def update(self, other): raise NotImplementedError() - - diff --git a/logilab/common/changelog.py b/logilab/common/changelog.py index c128eb7..cec1b5e 100644 --- a/logilab/common/changelog.py +++ b/logilab/common/changelog.py @@ -52,9 +52,9 @@ import codecs from typing import List, Any, Optional, Tuple from _io import StringIO -BULLET = '*' -SUBBULLET = '-' -INDENT = ' ' * 4 +BULLET = "*" +SUBBULLET = "-" +INDENT = " " * 4 class NoEntry(Exception): @@ -69,9 +69,10 @@ class Version(tuple): """simple class to handle soft version number has a tuple while correctly printing it as X.Y.Z """ + def __new__(cls, versionstr): if isinstance(versionstr, str): - versionstr = versionstr.strip(' :') # XXX (syt) duh? + versionstr = versionstr.strip(" :") # XXX (syt) duh? parsed = cls.parse(versionstr) else: parsed = versionstr @@ -79,26 +80,29 @@ class Version(tuple): @classmethod def parse(cls, versionstr: str) -> List[int]: - versionstr = versionstr.strip(' :') + versionstr = versionstr.strip(" :") try: - return [int(i) for i in versionstr.split('.')] + return [int(i) for i in versionstr.split(".")] except ValueError as ex: - raise ValueError("invalid literal for version '%s' (%s)" % - (versionstr, ex)) + raise ValueError("invalid literal for version '%s' (%s)" % (versionstr, ex)) def __str__(self) -> str: - return '.'.join([str(i) for i in self]) + return ".".join([str(i) for i in self]) # upstream change log ######################################################### + class ChangeLogEntry(object): """a change log entry, i.e. a set of messages associated to a version and its release date """ + version_class = Version - def __init__(self, date: Optional[str] = None, version: Optional[str] = None, **kwargs: Any) -> None: + def __init__( + self, date: Optional[str] = None, version: Optional[str] = None, **kwargs: Any + ) -> None: self.__dict__.update(kwargs) self.version: Optional[Version] if version: @@ -116,8 +120,7 @@ class ChangeLogEntry(object): """complete the latest added message """ if not self.messages: - raise ValueError('unable to complete last message as ' - 'there is no previous message)') + raise ValueError("unable to complete last message as " "there is no previous message)") if self.messages[-1][1]: # sub messages self.messages[-1][1][-1].append(msg_suite) else: # message @@ -125,29 +128,26 @@ class ChangeLogEntry(object): def add_sub_message(self, sub_msg: str, key: Optional[Any] = None) -> None: if not self.messages: - raise ValueError('unable to complete last message as ' - 'there is no previous message)') + raise ValueError("unable to complete last message as " "there is no previous message)") if key is None: self.messages[-1][1].append([sub_msg]) else: - raise NotImplementedError('sub message to specific key ' - 'are not implemented yet') + raise NotImplementedError("sub message to specific key " "are not implemented yet") def write(self, stream: StringIO = sys.stdout) -> None: """write the entry to file """ - stream.write(u'%s -- %s\n' % (self.date or '', self.version or '')) + stream.write("%s -- %s\n" % (self.date or "", self.version or "")) for msg, sub_msgs in self.messages: - stream.write(u'%s%s %s\n' % (INDENT, BULLET, msg[0])) - stream.write(u''.join(msg[1:])) + stream.write("%s%s %s\n" % (INDENT, BULLET, msg[0])) + stream.write("".join(msg[1:])) if sub_msgs: - stream.write(u'\n') + stream.write("\n") for sub_msg in sub_msgs: - stream.write(u'%s%s %s\n' % - (INDENT * 2, SUBBULLET, sub_msg[0])) - stream.write(u''.join(sub_msg[1:])) - stream.write(u'\n') + stream.write("%s%s %s\n" % (INDENT * 2, SUBBULLET, sub_msg[0])) + stream.write("".join(sub_msg[1:])) + stream.write("\n") - stream.write(u'\n\n') + stream.write("\n\n") class ChangeLog(object): @@ -155,23 +155,22 @@ class ChangeLog(object): entry_class = ChangeLogEntry - def __init__(self, changelog_file: str, title: str = u'') -> None: + def __init__(self, changelog_file: str, title: str = "") -> None: self.file = changelog_file - assert isinstance(title, type(u'')), 'title must be a unicode object' + assert isinstance(title, type("")), "title must be a unicode object" self.title = title - self.additional_content = u'' + self.additional_content = "" self.entries: List[ChangeLogEntry] = [] self.load() def __repr__(self): - return '<ChangeLog %s at %s (%s entries)>' % (self.file, id(self), - len(self.entries)) + return "<ChangeLog %s at %s (%s entries)>" % (self.file, id(self), len(self.entries)) def add_entry(self, entry: ChangeLogEntry) -> None: """add a new entry to the change log""" self.entries.append(entry) - def get_entry(self, version='', create=None): + def get_entry(self, version="", create=None): """ return a given changelog entry if version is omitted, return the current entry """ @@ -197,7 +196,7 @@ class ChangeLog(object): def load(self) -> None: """ read a logilab's ChangeLog from file """ try: - stream = codecs.open(self.file, encoding='utf-8') + stream = codecs.open(self.file, encoding="utf-8") except IOError: return @@ -209,20 +208,20 @@ class ChangeLog(object): words = sline.split() # if new entry - if len(words) == 1 and words[0] == '--': + if len(words) == 1 and words[0] == "--": expect_sub = False last = self.entry_class() self.add_entry(last) # if old entry - elif len(words) == 3 and words[1] == '--': + elif len(words) == 3 and words[1] == "--": expect_sub = False last = self.entry_class(words[0], words[2]) self.add_entry(last) # if title elif sline and last is None: - self.title = '%s%s' % (self.title, line) + self.title = "%s%s" % (self.title, line) # if new entry elif sline and sline[0] == BULLET: expect_sub = False @@ -243,14 +242,15 @@ class ChangeLog(object): stream.close() def format_title(self) -> str: - return u'%s\n\n' % self.title.strip() + return "%s\n\n" % self.title.strip() def save(self): """write back change log""" # filetutils isn't importable in appengine, so import locally from logilab.common.fileutils import ensure_fs_mode + ensure_fs_mode(self.file, S_IWRITE) - self.write(codecs.open(self.file, 'w', encoding='utf-8')) + self.write(codecs.open(self.file, "w", encoding="utf-8")) def write(self, stream: StringIO = sys.stdout) -> None: """write changelog to stream""" diff --git a/logilab/common/clcommands.py b/logilab/common/clcommands.py index 4778b99..f89a4b4 100644 --- a/logilab/common/clcommands.py +++ b/logilab/common/clcommands.py @@ -42,6 +42,7 @@ class BadCommandUsage(Exception): Trigger display of command usage. """ + class CommandError(Exception): """Raised when a command can't be processed and we want to display it and exit, without traceback nor usage displayed. @@ -50,6 +51,7 @@ class CommandError(Exception): # command line access point #################################################### + class CommandLine(dict): """Usage: @@ -77,9 +79,17 @@ class CommandLine(dict): * `logger`, logger to propagate to commands, default to `logging.getLogger(self.pgm))` """ - def __init__(self, pgm=None, doc=None, copyright=None, version=None, - rcfile=None, logthreshold=logging.ERROR, - check_duplicated_command=True): + + def __init__( + self, + pgm=None, + doc=None, + copyright=None, + version=None, + rcfile=None, + logthreshold=logging.ERROR, + check_duplicated_command=True, + ): if pgm is None: pgm = basename(sys.argv[0]) self.pgm = pgm @@ -93,8 +103,9 @@ class CommandLine(dict): def register(self, cls, force=False): """register the given :class:`Command` subclass""" - assert not self.check_duplicated_command or force or not cls.name in self, \ - 'a command %s is already defined' % cls.name + assert not self.check_duplicated_command or force or not cls.name in self, ( + "a command %s is already defined" % cls.name + ) self[cls.name] = cls return cls @@ -107,20 +118,22 @@ class CommandLine(dict): Terminate by :exc:`SystemExit` """ - init_log(debug=True, # so that we use StreamHandler - logthreshold=self.logthreshold, - logformat='%(levelname)s: %(message)s') + init_log( + debug=True, # so that we use StreamHandler + logthreshold=self.logthreshold, + logformat="%(levelname)s: %(message)s", + ) try: arg = args.pop(0) except IndexError: self.usage_and_exit(1) - if arg in ('-h', '--help'): + if arg in ("-h", "--help"): self.usage_and_exit(0) - if self.version is not None and arg in ('--version'): + if self.version is not None and arg in ("--version"): print(self.version) sys.exit(0) rcfile = self.rcfile - if rcfile is not None and arg in ('-C', '--rc-file'): + if rcfile is not None and arg in ("-C", "--rc-file"): try: rcfile = args.pop(0) arg = args.pop(0) @@ -129,19 +142,19 @@ class CommandLine(dict): try: command = self.get_command(arg) except KeyError: - print('ERROR: no %s command' % arg) + print("ERROR: no %s command" % arg) print() self.usage_and_exit(1) try: sys.exit(command.main_run(args, rcfile)) except KeyboardInterrupt as exc: - print('Interrupted', end=' ') + print("Interrupted", end=" ") if str(exc): - print(': %s' % exc, end=' ') + print(": %s" % exc, end=" ") print() sys.exit(4) except BadCommandUsage as err: - print('ERROR:', err) + print("ERROR:", err) print() print(command.help()) sys.exit(1) @@ -166,32 +179,44 @@ class CommandLine(dict): """display usage for the main program (i.e. when no command supplied) and exit """ - print('usage:', self.pgm, end=' ') + print("usage:", self.pgm, end=" ") if self.rcfile: - print('[--rc-file=<configuration file>]', end=' ') - print('<command> [options] <command argument>...') + print("[--rc-file=<configuration file>]", end=" ") + print("<command> [options] <command argument>...") if self.doc: - print('\n%s' % self.doc) - print(''' + print("\n%s" % self.doc) + print( + """ Type "%(pgm)s <command> --help" for more information about a specific -command. Available commands are :\n''' % self.__dict__) +command. Available commands are :\n""" + % self.__dict__ + ) max_len = max([len(cmd) for cmd in self]) - padding = ' ' * max_len + padding = " " * max_len for cmdname, cmd in sorted(self.items()): if not cmd.hidden: - print(' ', (cmdname + padding)[:max_len], cmd.short_description()) + print(" ", (cmdname + padding)[:max_len], cmd.short_description()) if self.rcfile: - print(''' + print( + """ Use --rc-file=<configuration file> / -C <configuration file> before the command to specify a configuration file. Default to %s. -''' % self.rcfile) - print('''%(pgm)s -h/--help - display this usage information and exit''' % self.__dict__) +""" + % self.rcfile + ) + print( + """%(pgm)s -h/--help + display this usage information and exit""" + % self.__dict__ + ) if self.version: - print('''%(pgm)s -v/--version - display version configuration and exit''' % self.__dict__) + print( + """%(pgm)s -v/--version + display version configuration and exit""" + % self.__dict__ + ) if self.copyright: - print('\n', self.copyright) + print("\n", self.copyright) def usage_and_exit(self, status): self.usage() @@ -200,6 +225,7 @@ to specify a configuration file. Default to %s. # base command classes ######################################################### + class Command(Configuration): """Base class for command line commands. @@ -219,8 +245,8 @@ class Command(Configuration): * `options`, options list, as allowed by :mod:configuration """ - arguments = '' - name = '' + arguments = "" + name = "" # hidden from help ? hidden = False # max/min args, None meaning unspecified @@ -229,24 +255,23 @@ class Command(Configuration): @classmethod def description(cls): - return cls.__doc__.replace(' ', '') + return cls.__doc__.replace(" ", "") @classmethod def short_description(cls): - return cls.description().split('.')[0] + return cls.description().split(".")[0] def __init__(self, logger): - usage = '%%prog %s %s\n\n%s' % (self.name, self.arguments, - self.description()) + usage = "%%prog %s %s\n\n%s" % (self.name, self.arguments, self.description()) Configuration.__init__(self, usage=usage) self.logger = logger def check_args(self, args): """check command's arguments are provided""" if self.min_args is not None and len(args) < self.min_args: - raise BadCommandUsage('missing argument') + raise BadCommandUsage("missing argument") if self.max_args is not None and len(args) > self.max_args: - raise BadCommandUsage('too many arguments') + raise BadCommandUsage("too many arguments") def main_run(self, args, rcfile=None): """Run the command and return status 0 if everything went fine. @@ -275,8 +300,9 @@ class Command(Configuration): class ListCommandsCommand(Command): """list available commands, useful for bash completion.""" - name = 'listcommands' - arguments = '[command]' + + name = "listcommands" + arguments = "[command]" hidden = True def run(self, args): @@ -285,8 +311,8 @@ class ListCommandsCommand(Command): command = args.pop() cmd = _COMMANDS[command] for optname, optdict in cmd.options: - print('--help') - print('--' + optname) + print("--help") + print("--" + optname) else: commands = sorted(_COMMANDS.keys()) for command in commands: @@ -299,17 +325,19 @@ class ListCommandsCommand(Command): _COMMANDS = CommandLine() -DEFAULT_COPYRIGHT = '''\ +DEFAULT_COPYRIGHT = """\ Copyright (c) 2004-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. -http://www.logilab.fr/ -- mailto:contact@logilab.fr''' +http://www.logilab.fr/ -- mailto:contact@logilab.fr""" + -@deprecated('use cls.register(cli)') +@deprecated("use cls.register(cli)") def register_commands(commands): """register existing commands""" for command_klass in commands: _COMMANDS.register(command_klass) -@deprecated('use args.pop(0)') + +@deprecated("use args.pop(0)") def main_run(args, doc=None, copyright=None, version=None): """command line tool: run command specified by argument list (without the program name). Raise SystemExit with status 0 if everything went fine. @@ -321,7 +349,8 @@ def main_run(args, doc=None, copyright=None, version=None): _COMMANDS.version = version _COMMANDS.run(args) -@deprecated('use args.pop(0)') + +@deprecated("use args.pop(0)") def pop_arg(args_list, expected_size_after=None, msg="Missing argument"): """helper function to get and check command line arguments""" try: @@ -329,6 +358,5 @@ def pop_arg(args_list, expected_size_after=None, msg="Missing argument"): except IndexError: raise BadCommandUsage(msg) if expected_size_after is not None and len(args_list) > expected_size_after: - raise BadCommandUsage('too many arguments') + raise BadCommandUsage("too many arguments") return value - diff --git a/logilab/common/compat.py b/logilab/common/compat.py index 4ca540b..e601b26 100644 --- a/logilab/common/compat.py +++ b/logilab/common/compat.py @@ -38,18 +38,25 @@ from typing import Union # not used here, but imported to preserve API import builtins + def str_to_bytes(string): return str.encode(string) + + # we have to ignore the encoding in py3k to be able to write a string into a # TextIOWrapper or like object (which expect an unicode string) def str_encode(string: Union[int, str], encoding: str) -> str: return str(string) + # See also http://bugs.python.org/issue11776 if sys.version_info[0] == 3: + def method_type(callable, instance, klass): # api change. klass is no more considered return types.MethodType(callable, instance) + + else: # alias types otherwise method_type = types.MethodType @@ -57,6 +64,7 @@ else: # Pythons 2 and 3 differ on where to get StringIO if sys.version_info < (3, 0): from cStringIO import StringIO + FileIO = file BytesIO = StringIO reload = reload diff --git a/logilab/common/configuration.py b/logilab/common/configuration.py index 61c2e97..4c83030 100644 --- a/logilab/common/configuration.py +++ b/logilab/common/configuration.py @@ -111,9 +111,13 @@ from __future__ import print_function __docformat__ = "restructuredtext en" -__all__ = ('OptionsManagerMixIn', 'OptionsProviderMixIn', - 'ConfigurationMixIn', 'Configuration', - 'OptionsManager2ConfigurationAdapter') +__all__ = ( + "OptionsManagerMixIn", + "OptionsProviderMixIn", + "ConfigurationMixIn", + "Configuration", + "OptionsManager2ConfigurationAdapter", +) import os import sys @@ -139,14 +143,16 @@ OptionError = optik_ext.OptionError REQUIRED: List = [] + class UnsupportedAction(Exception): """raised by set_option when it doesn't know what to do for an action""" def _get_encoding(encoding: Optional[str], stream: Union[StringIO, TextIOWrapper]) -> str: - encoding = encoding or getattr(stream, 'encoding', None) + encoding = encoding or getattr(stream, "encoding", None) if not encoding: import locale + encoding = locale.getpreferredencoding() return encoding @@ -158,19 +164,20 @@ _ValueType = Union[List[str], Tuple[str, ...], str] # validators will return the validated value or raise optparse.OptionValueError # XXX add to documentation + def choice_validator(optdict: Dict[str, Any], name: str, value: str) -> str: """validate and return a converted value for option of type 'choice' """ - if not value in optdict['choices']: + if not value in optdict["choices"]: msg = "option %s: invalid value: %r, should be in %s" - raise optik_ext.OptionValueError(msg % (name, value, optdict['choices'])) + raise optik_ext.OptionValueError(msg % (name, value, optdict["choices"])) return value def multiple_choice_validator(optdict: Dict[str, Any], name: str, value: _ValueType) -> _ValueType: """validate and return a converted value for option of type 'choice' """ - choices = optdict['choices'] + choices = optdict["choices"] values = optik_ext.check_csv(None, name, value) for value in values: if not value in choices: @@ -178,67 +185,81 @@ def multiple_choice_validator(optdict: Dict[str, Any], name: str, value: _ValueT raise optik_ext.OptionValueError(msg % (name, value, choices)) return values + def csv_validator(optdict: Dict[str, Any], name: str, value: _ValueType) -> _ValueType: """validate and return a converted value for option of type 'csv' """ return optik_ext.check_csv(None, name, value) + def yn_validator(optdict: Dict[str, Any], name: str, value: Union[bool, str]) -> bool: """validate and return a converted value for option of type 'yn' """ return optik_ext.check_yn(None, name, value) -def named_validator(optdict: Dict[str, Any], name: str, value: Union[Dict[str, str], str]) -> Dict[str, str]: + +def named_validator( + optdict: Dict[str, Any], name: str, value: Union[Dict[str, str], str] +) -> Dict[str, str]: """validate and return a converted value for option of type 'named' """ return optik_ext.check_named(None, name, value) + def file_validator(optdict, name, value): """validate and return a filepath for option of type 'file'""" return optik_ext.check_file(None, name, value) + def color_validator(optdict, name, value): """validate and return a valid color for option of type 'color'""" return optik_ext.check_color(None, name, value) + def password_validator(optdict, name, value): """validate and return a string for option of type 'password'""" return optik_ext.check_password(None, name, value) + def date_validator(optdict, name, value): """validate and return a mx DateTime object for option of type 'date'""" return optik_ext.check_date(None, name, value) + def time_validator(optdict, name, value): """validate and return a time object for option of type 'time'""" return optik_ext.check_time(None, name, value) + def bytes_validator(optdict: Dict[str, str], name: str, value: Union[int, str]) -> int: """validate and return an integer for option of type 'bytes'""" return optik_ext.check_bytes(None, name, value) VALIDATORS: Dict[str, Callable] = { - 'string': unquote, - 'int': int, - 'float': float, - 'file': file_validator, - 'font': unquote, - 'color': color_validator, - 'regexp': re.compile, - 'csv': csv_validator, - 'yn': yn_validator, - 'bool': yn_validator, - 'named': named_validator, - 'password': password_validator, - 'date': date_validator, - 'time': time_validator, - 'bytes': bytes_validator, - 'choice': choice_validator, - 'multiple_choice': multiple_choice_validator, + "string": unquote, + "int": int, + "float": float, + "file": file_validator, + "font": unquote, + "color": color_validator, + "regexp": re.compile, + "csv": csv_validator, + "yn": yn_validator, + "bool": yn_validator, + "named": named_validator, + "password": password_validator, + "date": date_validator, + "time": time_validator, + "bytes": bytes_validator, + "choice": choice_validator, + "multiple_choice": multiple_choice_validator, } -def _call_validator(opttype: str, optdict: Dict[str, Any], option: str, value: Union[List[str], int, str]) -> Union[List[str], int, str]: + +def _call_validator( + opttype: str, optdict: Dict[str, Any], option: str, value: Union[List[str], int, str] +) -> Union[List[str], int, str]: if opttype not in VALIDATORS: raise Exception('Unsupported type "%s"' % opttype) try: @@ -249,8 +270,10 @@ def _call_validator(opttype: str, optdict: Dict[str, Any], option: str, value: U except optik_ext.OptionValueError: raise except: - raise optik_ext.OptionValueError('%s value (%r) should be of type %s' % - (option, value, opttype)) + raise optik_ext.OptionValueError( + "%s value (%r) should be of type %s" % (option, value, opttype) + ) + # user input functions ######################################################## @@ -258,19 +281,23 @@ def _call_validator(opttype: str, optdict: Dict[str, Any], option: str, value: U # the result and return the validated value or raise optparse.OptionValueError # XXX add to documentation -def input_password(optdict, question='password:'): + +def input_password(optdict, question="password:"): from getpass import getpass + while True: value = getpass(question) - value2 = getpass('confirm: ') + value2 = getpass("confirm: ") if value == value2: return value - print('password mismatch, try again') + print("password mismatch, try again") + def input_string(optdict, question): value = input(question).strip() return value or None + def _make_input_function(opttype): def input_validator(optdict, question): while True: @@ -280,14 +307,15 @@ def _make_input_function(opttype): try: return _call_validator(opttype, optdict, None, value) except optik_ext.OptionValueError as ex: - msg = str(ex).split(':', 1)[-1].strip() - print('bad value: %s' % msg) + msg = str(ex).split(":", 1)[-1].strip() + print("bad value: %s" % msg) + return input_validator INPUT_FUNCTIONS: Dict[str, Callable] = { - 'string': input_string, - 'password': input_password, + "string": input_string, + "password": input_password, } for opttype in VALIDATORS.keys(): @@ -295,6 +323,7 @@ for opttype in VALIDATORS.keys(): # utility functions ############################################################ + def expand_default(self, option): """monkey patch OptionParser.expand_default since we have a particular way to handle defaults to avoid overriding values in the configuration @@ -317,125 +346,144 @@ def expand_default(self, option): return option.help.replace(self.default_tag, str(value)) -def _validate(value: Union[List[str], int, str], optdict: Dict[str, Any], name: str = '') -> Union[List[str], int, str]: +def _validate( + value: Union[List[str], int, str], optdict: Dict[str, Any], name: str = "" +) -> Union[List[str], int, str]: """return a validated value for an option according to its type optional argument name is only used for error message formatting """ try: - _type = optdict['type'] + _type = optdict["type"] except KeyError: # FIXME return value return _call_validator(_type, optdict, name, value) -convert = deprecated('[0.60] convert() was renamed _validate()')(_validate) + + +convert = deprecated("[0.60] convert() was renamed _validate()")(_validate) # format and output functions ################################################## + def comment(string): """return string as a comment""" lines = [line.strip() for line in string.splitlines()] - return '# ' + ('%s# ' % os.linesep).join(lines) + return "# " + ("%s# " % os.linesep).join(lines) + def format_time(value): if not value: - return '0' + return "0" if value != int(value): - return '%.2fs' % value + return "%.2fs" % value value = int(value) nbmin, nbsec = divmod(value, 60) if nbsec: - return '%ss' % value + return "%ss" % value nbhour, nbmin_ = divmod(nbmin, 60) if nbmin_: - return '%smin' % nbmin + return "%smin" % nbmin nbday, nbhour_ = divmod(nbhour, 24) if nbhour_: - return '%sh' % nbhour - return '%sd' % nbday + return "%sh" % nbhour + return "%sd" % nbday + def format_bytes(value: int) -> str: if not value: - return '0' + return "0" if value != int(value): - return '%.2fB' % value + return "%.2fB" % value value = int(value) - prevunit = 'B' - for unit in ('KB', 'MB', 'GB', 'TB'): + prevunit = "B" + for unit in ("KB", "MB", "GB", "TB"): next, remain = divmod(value, 1024) if remain: - return '%s%s' % (value, prevunit) + return "%s%s" % (value, prevunit) prevunit = unit value = next - return '%s%s' % (value, unit) + return "%s%s" % (value, unit) + def format_option_value(optdict: Dict[str, Any], value: Any) -> Union[None, int, str]: """return the user input's value from a 'compiled' value""" if isinstance(value, (list, tuple)): - value = ','.join(value) + value = ",".join(value) elif isinstance(value, dict): - value = ','.join(['%s:%s' % (k, v) for k, v in value.items()]) - elif hasattr(value, 'match'): # optdict.get('type') == 'regexp' + value = ",".join(["%s:%s" % (k, v) for k, v in value.items()]) + elif hasattr(value, "match"): # optdict.get('type') == 'regexp' # compiled regexp value = value.pattern - elif optdict.get('type') == 'yn': - value = value and 'yes' or 'no' + elif optdict.get("type") == "yn": + value = value and "yes" or "no" elif isinstance(value, str) and value.isspace(): value = "'%s'" % value - elif optdict.get('type') == 'time' and isinstance(value, (float, int)): + elif optdict.get("type") == "time" and isinstance(value, (float, int)): value = format_time(value) - elif optdict.get('type') == 'bytes' and hasattr(value, '__int__'): + elif optdict.get("type") == "bytes" and hasattr(value, "__int__"): value = format_bytes(value) return value -def ini_format_section(stream: Union[StringIO, TextIOWrapper], section: str, options: Any, encoding: str = None, doc: Optional[Any] = None) -> None: + +def ini_format_section( + stream: Union[StringIO, TextIOWrapper], + section: str, + options: Any, + encoding: str = None, + doc: Optional[Any] = None, +) -> None: """format an options section using the INI format""" encoding = _get_encoding(encoding, stream) if doc: print(_encode(comment(doc), encoding), file=stream) - print('[%s]' % section, file=stream) + print("[%s]" % section, file=stream) ini_format(stream, options, encoding) + def ini_format(stream: Union[StringIO, TextIOWrapper], options: Any, encoding: str) -> None: """format options using the INI format""" for optname, optdict, value in options: value = format_option_value(optdict, value) - help = optdict.get('help') + help = optdict.get("help") if help: - help = normalize_text(help, line_len=79, indent='# ') + help = normalize_text(help, line_len=79, indent="# ") print(file=stream) print(_encode(help, encoding), file=stream) else: print(file=stream) if value is None: - print('#%s=' % optname, file=stream) + print("#%s=" % optname, file=stream) else: value = _encode(value, encoding).strip() - if optdict.get('type') == 'string' and '\n' in value: - prefix = '\n ' - value = prefix + prefix.join(value.split('\n')) - print('%s=%s' % (optname, value), file=stream) + if optdict.get("type") == "string" and "\n" in value: + prefix = "\n " + value = prefix + prefix.join(value.split("\n")) + print("%s=%s" % (optname, value), file=stream) + format_section = ini_format_section + def rest_format_section(stream, section, options, encoding=None, doc=None): """format an options section using as ReST formatted output""" encoding = _get_encoding(encoding, stream) if section: - print('%s\n%s' % (section, "'"*len(section)), file=stream) + print("%s\n%s" % (section, "'" * len(section)), file=stream) if doc: - print(_encode(normalize_text(doc, line_len=79, indent=''), encoding), file=stream) + print(_encode(normalize_text(doc, line_len=79, indent=""), encoding), file=stream) print(file=stream) for optname, optdict, value in options: - help = optdict.get('help') - print(':%s:' % optname, file=stream) + help = optdict.get("help") + print(":%s:" % optname, file=stream) if help: - help = normalize_text(help, line_len=79, indent=' ') + help = normalize_text(help, line_len=79, indent=" ") print(_encode(help, encoding), file=stream) if value: value = _encode(format_option_value(optdict, value), encoding) print(file=stream) - print(' Default: ``%s``' % value.replace("`` ", "```` ``"), file=stream) + print(" Default: ``%s``" % value.replace("`` ", "```` ``"), file=stream) + # Options Manager ############################################################## @@ -445,7 +493,13 @@ class OptionsManagerMixIn(object): command line options """ - def __init__(self, usage: Optional[str], config_file: Optional[Any] = None, version: Optional[Any] = None, quiet: int = 0) -> None: + def __init__( + self, + usage: Optional[str], + config_file: Optional[Any] = None, + version: Optional[Any] = None, + quiet: int = 0, + ) -> None: self.config_file = config_file self.reset_parsers(usage, version=version) # list of registered options providers @@ -459,7 +513,7 @@ class OptionsManagerMixIn(object): self.quiet = quiet self._maxlevel = 0 - def reset_parsers(self, usage: Optional[str] = '', version: Optional[Any] = None) -> None: + def reset_parsers(self, usage: Optional[str] = "", version: Optional[Any] = None) -> None: # configuration file parser self.cfgfile_parser = cp.ConfigParser() # command line parser @@ -469,7 +523,9 @@ class OptionsManagerMixIn(object): self.cmdline_parser.options_manager = self # type: ignore self._optik_option_attrs = set(self.cmdline_parser.option_class.ATTRS) - def register_options_provider(self, provider: 'ConfigurationMixIn', own_group: bool = True) -> None: + def register_options_provider( + self, provider: "ConfigurationMixIn", own_group: bool = True + ) -> None: """register an options provider""" assert provider.priority <= 0, "provider's priority can't be >= 0" for i in range(len(self.options_providers)): @@ -481,13 +537,17 @@ class OptionsManagerMixIn(object): # mypy: Need type annotation for 'option' # you can't type variable of a list comprehension, right? - non_group_spec_options: List = [option for option in provider.options # type: ignore - if 'group' not in option[1]] # type: ignore + non_group_spec_options: List = [ + option + for option in provider.options # type: ignore + if "group" not in option[1] + ] # type: ignore - groups = getattr(provider, 'option_groups', ()) + groups = getattr(provider, "option_groups", ()) if own_group and non_group_spec_options: - self.add_option_group(provider.name.upper(), provider.__doc__, - non_group_spec_options, provider) + self.add_option_group( + provider.name.upper(), provider.__doc__, non_group_spec_options, provider + ) else: for opt, optdict in non_group_spec_options: self.add_optik_option(provider, self.cmdline_parser, opt, optdict) @@ -496,11 +556,20 @@ class OptionsManagerMixIn(object): # mypy: Need type annotation for 'option' # you can't type variable of a list comprehension, right? - goptions: List = [option for option in provider.options # type: ignore - if option[1].get('group', '').upper() == gname] # type: ignore + goptions: List = [ + option + for option in provider.options # type: ignore + if option[1].get("group", "").upper() == gname + ] # type: ignore self.add_option_group(gname, gdoc, goptions, provider) - def add_option_group(self, group_name: str, doc: Optional[str], options: Union[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, Dict[str, str]]]], provider: 'ConfigurationMixIn') -> None: + def add_option_group( + self, + group_name: str, + doc: Optional[str], + options: Union[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, Dict[str, str]]]], + provider: "ConfigurationMixIn", + ) -> None: """add an option group including the listed options """ assert options @@ -508,8 +577,7 @@ class OptionsManagerMixIn(object): if group_name in self._mygroups: group = self._mygroups[group_name] else: - group = optik_ext.OptionGroup(self.cmdline_parser, - title=group_name.capitalize()) + group = optik_ext.OptionGroup(self.cmdline_parser, title=group_name.capitalize()) self.cmdline_parser.add_option_group(group) # mypy: "OptionGroup" has no attribute "level" # dynamic attribute @@ -522,48 +590,63 @@ class OptionsManagerMixIn(object): for opt, optdict in options: self.add_optik_option(provider, group, opt, optdict) - def add_optik_option(self, provider: 'ConfigurationMixIn', optikcontainer: Union[OptionParser, OptionGroup], opt: str, optdict: Dict[str, Any]) -> None: - if 'inputlevel' in optdict: - warn('[0.50] "inputlevel" in option dictionary for %s is deprecated,' - ' use "level"' % opt, DeprecationWarning) - optdict['level'] = optdict.pop('inputlevel') + def add_optik_option( + self, + provider: "ConfigurationMixIn", + optikcontainer: Union[OptionParser, OptionGroup], + opt: str, + optdict: Dict[str, Any], + ) -> None: + if "inputlevel" in optdict: + warn( + '[0.50] "inputlevel" in option dictionary for %s is deprecated,' + ' use "level"' % opt, + DeprecationWarning, + ) + optdict["level"] = optdict.pop("inputlevel") args, optdict = self.optik_option(provider, opt, optdict) option = optikcontainer.add_option(*args, **optdict) self._all_options[opt] = provider self._maxlevel = max(self._maxlevel, option.level or 0) - def optik_option(self, provider: 'ConfigurationMixIn', opt: str, optdict: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any]]: + def optik_option( + self, provider: "ConfigurationMixIn", opt: str, optdict: Dict[str, Any] + ) -> Tuple[List[str], Dict[str, Any]]: """get our personal option definition and return a suitable form for use with optik/optparse """ optdict = copy(optdict) - if 'action' in optdict: + if "action" in optdict: self._nocallback_options[provider] = opt else: - optdict['action'] = 'callback' - optdict['callback'] = self.cb_set_provider_option + optdict["action"] = "callback" + optdict["callback"] = self.cb_set_provider_option # default is handled here and *must not* be given to optik if you # want the whole machinery to work - if 'default' in optdict: - if ('help' in optdict - and optdict.get('default') is not None - and not optdict['action'] in ('store_true', 'store_false')): - optdict['help'] += ' [current: %default]' - del optdict['default'] - args = ['--' + str(opt)] - if 'short' in optdict: - self._short_options[optdict['short']] = opt - args.append('-' + optdict['short']) - del optdict['short'] + if "default" in optdict: + if ( + "help" in optdict + and optdict.get("default") is not None + and not optdict["action"] in ("store_true", "store_false") + ): + optdict["help"] += " [current: %default]" + del optdict["default"] + args = ["--" + str(opt)] + if "short" in optdict: + self._short_options[optdict["short"]] = opt + args.append("-" + optdict["short"]) + del optdict["short"] # cleanup option definition dict before giving it to optik for key in list(optdict.keys()): if not key in self._optik_option_attrs: optdict.pop(key) return args, optdict - def cb_set_provider_option(self, option: 'Option', opt: str, value: Union[List[str], int, str], parser: 'OptionParser') -> None: + def cb_set_provider_option( + self, option: "Option", opt: str, value: Union[List[str], int, str], parser: "OptionParser" + ) -> None: """optik callback for option setting""" - if opt.startswith('--'): + if opt.startswith("--"): # remove -- on long option opt = opt[2:] else: @@ -578,7 +661,12 @@ class OptionsManagerMixIn(object): """set option on the correct option provider""" self._all_options[opt].set_option(opt, value) - def generate_config(self, stream: Union[StringIO, TextIOWrapper] = None, skipsections: Tuple[()] = (), encoding: Optional[Any] = None) -> None: + def generate_config( + self, + stream: Union[StringIO, TextIOWrapper] = None, + skipsections: Tuple[()] = (), + encoding: Optional[Any] = None, + ) -> None: """write a configuration file according to the current configuration into the given stream or stdout """ @@ -591,8 +679,7 @@ class OptionsManagerMixIn(object): section = provider.name if section in skipsections: continue - options = [(n, d, v) for (n, d, v) in options - if d.get('type') is not None] + options = [(n, d, v) for (n, d, v) in options if d.get("type") is not None] if not options: continue if not section in sections: @@ -604,20 +691,25 @@ class OptionsManagerMixIn(object): printed = False for section in sections: if printed: - print('\n', file=stream) - format_section(stream, section.upper(), options_by_section[section], - encoding) + print("\n", file=stream) + format_section(stream, section.upper(), options_by_section[section], encoding) printed = True - def generate_manpage(self, pkginfo: attrdict, section: int = 1, stream: StringIO = None) -> None: + def generate_manpage( + self, pkginfo: attrdict, section: int = 1, stream: StringIO = None + ) -> None: """write a man page for the current configuration into the given stream or stdout """ self._monkeypatch_expand_default() try: - optik_ext.generate_manpage(self.cmdline_parser, pkginfo, - section, stream=stream or sys.stdout, - level=self._maxlevel) + optik_ext.generate_manpage( + self.cmdline_parser, + pkginfo, + section, + stream=stream or sys.stdout, + level=self._maxlevel, + ) finally: self._unmonkeypatch_expand_default() @@ -639,18 +731,19 @@ class OptionsManagerMixIn(object): """ helplevel = 1 while helplevel <= self._maxlevel: - opt = '-'.join(['long'] * helplevel) + '-help' + opt = "-".join(["long"] * helplevel) + "-help" if opt in self._all_options: - break # already processed + break # already processed + def helpfunc(option, opt, val, p, level=helplevel): print(self.help(level)) sys.exit(0) - helpmsg = '%s verbose help.' % ' '.join(['more'] * helplevel) - optdict = {'action' : 'callback', 'callback' : helpfunc, - 'help' : helpmsg} + + helpmsg = "%s verbose help." % " ".join(["more"] * helplevel) + optdict = {"action": "callback", "callback": helpfunc, "help": helpmsg} provider = self.options_providers[0] self.add_optik_option(provider, self.cmdline_parser, opt, optdict) - provider.options += ( (opt, optdict), ) + provider.options += ((opt, optdict),) helplevel += 1 if config_file is None: config_file = self.config_file @@ -666,7 +759,7 @@ class OptionsManagerMixIn(object): if not sect.isupper() and values: parser._sections[sect.upper()] = values # type: ignore elif not self.quiet: - msg = 'No config file found, using default configuration' + msg = "No config file found, using default configuration" print(msg, file=sys.stderr) return @@ -680,7 +773,7 @@ class OptionsManagerMixIn(object): for section, option, optdict in provider.all_options(): if onlysection is not None and section != onlysection: continue - if not 'type' in optdict: + if not "type" in optdict: # ignore action without type (callback, store_true...) continue provider.input_option(option, optdict, inputlevel) @@ -694,18 +787,18 @@ class OptionsManagerMixIn(object): """ parser = self.cfgfile_parser for section in parser.sections(): - for option, value in parser.items(section): - try: - self.global_set_option(option, value) - except (KeyError, OptionError): - # TODO handle here undeclared options appearing in the config file - continue + for option, value in parser.items(section): + try: + self.global_set_option(option, value) + except (KeyError, OptionError): + # TODO handle here undeclared options appearing in the config file + continue def load_configuration(self, **kwargs: Any) -> None: """override configuration according to given parameters """ for opt, opt_value in kwargs.items(): - opt = opt.replace('_', '-') + opt = opt.replace("_", "-") provider = self._all_options[opt] provider.set_option(opt, opt_value) @@ -733,14 +826,13 @@ class OptionsManagerMixIn(object): finally: self._unmonkeypatch_expand_default() - # help methods ############################################################ def add_help_section(self, title: str, description: str, level: int = 0) -> None: """add a dummy option section for help purpose """ - group = optik_ext.OptionGroup(self.cmdline_parser, - title=title.capitalize(), - description=description) + group = optik_ext.OptionGroup( + self.cmdline_parser, title=title.capitalize(), description=description + ) # mypy: "OptionGroup" has no attribute "level" # it does, it is set in the optik_ext module group.level = level # type: ignore @@ -757,9 +849,10 @@ class OptionsManagerMixIn(object): except AttributeError: # python < 2.4: nothing to be done pass + def _unmonkeypatch_expand_default(self) -> None: # remove monkey patch - if hasattr(optik_ext.HelpFormatter, 'expand_default'): + if hasattr(optik_ext.HelpFormatter, "expand_default"): # mypy: Cannot assign to a method # it's dirty but you can @@ -782,27 +875,30 @@ class Method(object): """used to ease late binding of default method (so you can define options on the class using default methods on the configuration instance) """ + def __init__(self, methname): self.method = methname self._inst = None - def bind(self, instance: 'Configuration') -> None: + def bind(self, instance: "Configuration") -> None: """bind the method to its instance""" if self._inst is None: self._inst = instance def __call__(self, *args: Any, **kwargs: Any) -> Dict[str, str]: - assert self._inst, 'unbound method' + assert self._inst, "unbound method" return getattr(self._inst, self.method)(*args, **kwargs) + # Options Provider ############################################################# + class OptionsProviderMixIn(object): """Mixin to provide options to an OptionsManager""" # those attributes should be overridden priority = -1 - name = 'default' + name = "default" options: Tuple = () level = 0 @@ -812,18 +908,18 @@ class OptionsProviderMixIn(object): try: option, optdict = option_tuple except ValueError: - raise Exception('Bad option: %s' % str(option_tuple)) - if isinstance(optdict.get('default'), Method): - optdict['default'].bind(self) - elif isinstance(optdict.get('callback'), Method): - optdict['callback'].bind(self) + raise Exception("Bad option: %s" % str(option_tuple)) + if isinstance(optdict.get("default"), Method): + optdict["default"].bind(self) + elif isinstance(optdict.get("callback"), Method): + optdict["callback"].bind(self) self.load_defaults() def load_defaults(self) -> None: """initialize the provider using default values""" for opt, optdict in self.options: - action = optdict.get('action') - if action != 'callback': + action = optdict.get("action") + if action != "callback": # callback action have no default default = self.option_default(opt, optdict) if default is REQUIRED: @@ -834,7 +930,7 @@ class OptionsProviderMixIn(object): """return the default value for an option""" if optdict is None: optdict = self.get_option_def(opt) - default = optdict.get('default') + default = optdict.get("default") if callable(default): default = default() return default @@ -844,8 +940,11 @@ class OptionsProviderMixIn(object): """ if optdict is None: optdict = self.get_option_def(opt) - return optdict.get('dest', opt.replace('-', '_')) - option_name = deprecated('[0.60] OptionsProviderMixIn.option_name() was renamed to option_attrname()')(option_attrname) + return optdict.get("dest", opt.replace("-", "_")) + + option_name = deprecated( + "[0.60] OptionsProviderMixIn.option_name() was renamed to option_attrname()" + )(option_attrname) def option_value(self, opt): """get the current value for the given option""" @@ -859,20 +958,20 @@ class OptionsProviderMixIn(object): if value is not None: value = _validate(value, optdict, opt) if action is None: - action = optdict.get('action', 'store') - if optdict.get('type') == 'named': # XXX need specific handling + action = optdict.get("action", "store") + if optdict.get("type") == "named": # XXX need specific handling optname = self.option_attrname(opt, optdict) currentvalue = getattr(self.config, optname, None) if currentvalue: currentvalue.update(value) value = currentvalue - if action == 'store': + if action == "store": setattr(self.config, self.option_attrname(opt, optdict), value) - elif action in ('store_true', 'count'): + elif action in ("store_true", "count"): setattr(self.config, self.option_attrname(opt, optdict), 0) - elif action == 'store_false': + elif action == "store_false": setattr(self.config, self.option_attrname(opt, optdict), 1) - elif action == 'append': + elif action == "append": opt = self.option_attrname(opt, optdict) _list = getattr(self.config, opt, None) if _list is None: @@ -886,28 +985,28 @@ class OptionsProviderMixIn(object): setattr(self.config, opt, _list + (value,)) else: _list.append(value) - elif action == 'callback': - optdict['callback'](None, opt, value, None) + elif action == "callback": + optdict["callback"](None, opt, value, None) else: raise UnsupportedAction(action) def input_option(self, option, optdict, inputlevel=99): default = self.option_default(option, optdict) if default is REQUIRED: - defaultstr = '(required): ' - elif optdict.get('level', 0) > inputlevel: + defaultstr = "(required): " + elif optdict.get("level", 0) > inputlevel: return - elif optdict['type'] == 'password' or default is None: - defaultstr = ': ' + elif optdict["type"] == "password" or default is None: + defaultstr = ": " else: - defaultstr = '(default: %s): ' % format_option_value(optdict, default) - print(':%s:' % option) - print(optdict.get('help') or option) - inputfunc = INPUT_FUNCTIONS[optdict['type']] + defaultstr = "(default: %s): " % format_option_value(optdict, default) + print(":%s:" % option) + print(optdict.get("help") or option) + inputfunc = INPUT_FUNCTIONS[optdict["type"]] value = inputfunc(optdict, defaultstr) while default is REQUIRED and not value: - print('please specify a value') - value = inputfunc(optdict, '%s: ' % option) + print("please specify a value") + value = inputfunc(optdict, "%s: " % option) if value is None and default is not None: value = default self.set_option(option, value, optdict=optdict) @@ -920,9 +1019,7 @@ class OptionsProviderMixIn(object): return option[1] # mypy: Argument 2 to "OptionError" has incompatible type "str"; expected "Option" # seems to be working? - raise OptionError('no such option %s in section %r' - % (opt, self.name), opt) # type: ignore - + raise OptionError("no such option %s in section %r" % (opt, self.name), opt) # type: ignore def all_options(self): """return an iterator on available options for this provider @@ -944,8 +1041,9 @@ class OptionsProviderMixIn(object): """ sections: Dict[str, List[Tuple[str, Dict[str, Any], Any]]] = {} for optname, optdict in self.options: - sections.setdefault(optdict.get('group'), []).append( - (optname, optdict, self.option_value(optname))) + sections.setdefault(optdict.get("group"), []).append( + (optname, optdict, self.option_value(optname)) + ) if None in sections: # mypy: No overload variant of "pop" of "MutableMapping" matches argument type "None" # it actually works @@ -959,23 +1057,26 @@ class OptionsProviderMixIn(object): for optname, optdict in options: yield (optname, optdict, self.option_value(optname)) + # configuration ################################################################ + class ConfigurationMixIn(OptionsManagerMixIn, OptionsProviderMixIn): """basic mixin for simple configurations which don't need the manager / providers model """ + def __init__(self, *args: Any, **kwargs: Any) -> None: if not args: - kwargs.setdefault('usage', '') - kwargs.setdefault('quiet', 1) + kwargs.setdefault("usage", "") + kwargs.setdefault("quiet", 1) OptionsManagerMixIn.__init__(self, *args, **kwargs) OptionsProviderMixIn.__init__(self) - if not getattr(self, 'option_groups', None): + if not getattr(self, "option_groups", None): self.option_groups: List[Tuple[Any, str]] = [] for option, optdict in self.options: try: - gdef = (optdict['group'].upper(), '') + gdef = (optdict["group"].upper(), "") except KeyError: continue if not gdef in self.option_groups: @@ -986,7 +1087,9 @@ class ConfigurationMixIn(OptionsManagerMixIn, OptionsProviderMixIn): """add some options to the configuration""" options_by_group = {} for optname, optdict in options: - options_by_group.setdefault(optdict.get('group', self.name.upper()), []).append((optname, optdict)) + options_by_group.setdefault(optdict.get("group", self.name.upper()), []).append( + (optname, optdict) + ) for group, group_options in options_by_group.items(): self.add_option_group(group, None, group_options, self) self.options += tuple(options) @@ -1020,8 +1123,9 @@ class Configuration(ConfigurationMixIn): configuration values are accessible through a dict like interface """ - def __init__(self, config_file=None, options=None, name=None, - usage=None, doc=None, version=None): + def __init__( + self, config_file=None, options=None, name=None, usage=None, doc=None, version=None + ): if options is not None: self.options = options if name is not None: @@ -1035,6 +1139,7 @@ class OptionsManager2ConfigurationAdapter(object): """Adapt an option manager to behave like a `logilab.common.configuration.Configuration` instance """ + def __init__(self, provider): self.config = provider @@ -1057,8 +1162,10 @@ class OptionsManager2ConfigurationAdapter(object): except KeyError: return default + # other functions ############################################################## + def read_old_config(newconfig, changes, configfile): """initialize newconfig from a deprecated configuration file @@ -1070,38 +1177,38 @@ def read_old_config(newconfig, changes, configfile): # build an index of changes changesindex = {} for action in changes: - if action[0] == 'moved': + if action[0] == "moved": option, oldgroup, newgroup = action[1:] changesindex.setdefault(option, []).append((action[0], oldgroup, newgroup)) continue - if action[0] == 'renamed': + if action[0] == "renamed": oldname, newname = action[1:] changesindex.setdefault(newname, []).append((action[0], oldname)) continue - if action[0] == 'typechanged': + if action[0] == "typechanged": option, oldtype, newvalue = action[1:] changesindex.setdefault(option, []).append((action[0], oldtype, newvalue)) continue - if action[0] in ('added', 'removed'): - continue # nothing to do here - raise Exception('unknown change %s' % action[0]) + if action[0] in ("added", "removed"): + continue # nothing to do here + raise Exception("unknown change %s" % action[0]) # build a config object able to read the old config options = [] for optname, optdef in newconfig.options: for action in changesindex.pop(optname, ()): - if action[0] == 'moved': + if action[0] == "moved": oldgroup, newgroup = action[1:] optdef = optdef.copy() - optdef['group'] = oldgroup - elif action[0] == 'renamed': + optdef["group"] = oldgroup + elif action[0] == "renamed": optname = action[1] - elif action[0] == 'typechanged': + elif action[0] == "typechanged": oldtype = action[1] optdef = optdef.copy() - optdef['type'] = oldtype + optdef["type"] = oldtype options.append((optname, optdef)) if changesindex: - raise Exception('unapplied changes: %s' % changesindex) + raise Exception("unapplied changes: %s" % changesindex) oldconfig = Configuration(options=options, name=newconfig.name) # read the old config oldconfig.load_file_configuration(configfile) @@ -1109,16 +1216,16 @@ def read_old_config(newconfig, changes, configfile): changes.reverse() done = set() for action in changes: - if action[0] == 'renamed': + if action[0] == "renamed": oldname, newname = action[1:] newconfig[newname] = oldconfig[oldname] done.add(newname) - elif action[0] == 'typechanged': + elif action[0] == "typechanged": optname, oldtype, newvalue = action[1:] newconfig[optname] = newvalue done.add(optname) for optname, optdef in newconfig.options: - if optdef.get('type') and not optname in done: + if optdef.get("type") and not optname in done: newconfig.set_option(optname, oldconfig[optname], optdict=optdef) @@ -1131,7 +1238,7 @@ def merge_options(options, optgroup=None): """ alloptions = {} options = list(options) - for i in range(len(options)-1, -1, -1): + for i in range(len(options) - 1, -1, -1): optname, optdict = options[i] if optname in alloptions: options.pop(i) @@ -1141,5 +1248,5 @@ def merge_options(options, optgroup=None): options[i] = (optname, optdict) alloptions[optname] = optdict if optgroup is not None: - alloptions[optname]['group'] = optgroup + alloptions[optname]["group"] = optgroup return tuple(options) diff --git a/logilab/common/daemon.py b/logilab/common/daemon.py index 78e4743..c4c8d93 100644 --- a/logilab/common/daemon.py +++ b/logilab/common/daemon.py @@ -33,21 +33,24 @@ def setugid(user): Argument is a numeric user id or a user name""" try: from pwd import getpwuid + passwd = getpwuid(int(user)) except ValueError: from pwd import getpwnam + passwd = getpwnam(user) - if hasattr(os, 'initgroups'): # python >= 2.7 + if hasattr(os, "initgroups"): # python >= 2.7 os.initgroups(passwd.pw_name, passwd.pw_gid) else: import ctypes + if ctypes.CDLL(None).initgroups(passwd.pw_name, passwd.pw_gid) < 0: - err = ctypes.c_int.in_dll(ctypes.pythonapi,"errno").value - raise OSError(err, os.strerror(err), 'initgroups') + err = ctypes.c_int.in_dll(ctypes.pythonapi, "errno").value + raise OSError(err, os.strerror(err), "initgroups") os.setgid(passwd.pw_gid) os.setuid(passwd.pw_uid) - os.environ['HOME'] = passwd.pw_dir + os.environ["HOME"] = passwd.pw_dir def daemonize(pidfile=None, uid=None, umask=0o77): @@ -59,19 +62,19 @@ def daemonize(pidfile=None, uid=None, umask=0o77): # http://www.faqs.org/faqs/unix-faq/programmer/faq/ # # fork so the parent can exit - if os.fork(): # launch child and... + if os.fork(): # launch child and... return 1 # disconnect from tty and create a new session os.setsid() # fork again so the parent, (the session group leader), can exit. # as a non-session group leader, we can never regain a controlling # terminal. - if os.fork(): # launch child again. + if os.fork(): # launch child again. return 2 # move to the root to avoit mount pb - os.chdir('/') + os.chdir("/") # redirect standard descriptors - null = os.open('/dev/null', os.O_RDWR) + null = os.open("/dev/null", os.O_RDWR) for i in range(3): try: os.dup2(null, i) @@ -80,7 +83,7 @@ def daemonize(pidfile=None, uid=None, umask=0o77): raise os.close(null) # filter warnings - warnings.filterwarnings('ignore') + warnings.filterwarnings("ignore") # write pid in a file if pidfile: # ensure the directory where the pid-file should be set exists (for @@ -88,7 +91,7 @@ def daemonize(pidfile=None, uid=None, umask=0o77): piddir = os.path.dirname(pidfile) if not os.path.exists(piddir): os.makedirs(piddir) - f = file(pidfile, 'w') + f = file(pidfile, "w") f.write(str(os.getpid())) f.close() # set umask if specified diff --git a/logilab/common/date.py b/logilab/common/date.py index 2d2ed22..5f43d3e 100644 --- a/logilab/common/date.py +++ b/logilab/common/date.py @@ -42,63 +42,59 @@ else: # as we have in lgc.db ? FRENCH_FIXED_HOLIDAYS = { - 'jour_an': '%s-01-01', - 'fete_travail': '%s-05-01', - 'armistice1945': '%s-05-08', - 'fete_nat': '%s-07-14', - 'assomption': '%s-08-15', - 'toussaint': '%s-11-01', - 'armistice1918': '%s-11-11', - 'noel': '%s-12-25', - } + "jour_an": "%s-01-01", + "fete_travail": "%s-05-01", + "armistice1945": "%s-05-08", + "fete_nat": "%s-07-14", + "assomption": "%s-08-15", + "toussaint": "%s-11-01", + "armistice1918": "%s-11-11", + "noel": "%s-12-25", +} FRENCH_MOBILE_HOLIDAYS = { - 'paques2004': '2004-04-12', - 'ascension2004': '2004-05-20', - 'pentecote2004': '2004-05-31', - - 'paques2005': '2005-03-28', - 'ascension2005': '2005-05-05', - 'pentecote2005': '2005-05-16', - - 'paques2006': '2006-04-17', - 'ascension2006': '2006-05-25', - 'pentecote2006': '2006-06-05', - - 'paques2007': '2007-04-09', - 'ascension2007': '2007-05-17', - 'pentecote2007': '2007-05-28', - - 'paques2008': '2008-03-24', - 'ascension2008': '2008-05-01', - 'pentecote2008': '2008-05-12', - - 'paques2009': '2009-04-13', - 'ascension2009': '2009-05-21', - 'pentecote2009': '2009-06-01', - - 'paques2010': '2010-04-05', - 'ascension2010': '2010-05-13', - 'pentecote2010': '2010-05-24', - - 'paques2011': '2011-04-25', - 'ascension2011': '2011-06-02', - 'pentecote2011': '2011-06-13', - - 'paques2012': '2012-04-09', - 'ascension2012': '2012-05-17', - 'pentecote2012': '2012-05-28', - } + "paques2004": "2004-04-12", + "ascension2004": "2004-05-20", + "pentecote2004": "2004-05-31", + "paques2005": "2005-03-28", + "ascension2005": "2005-05-05", + "pentecote2005": "2005-05-16", + "paques2006": "2006-04-17", + "ascension2006": "2006-05-25", + "pentecote2006": "2006-06-05", + "paques2007": "2007-04-09", + "ascension2007": "2007-05-17", + "pentecote2007": "2007-05-28", + "paques2008": "2008-03-24", + "ascension2008": "2008-05-01", + "pentecote2008": "2008-05-12", + "paques2009": "2009-04-13", + "ascension2009": "2009-05-21", + "pentecote2009": "2009-06-01", + "paques2010": "2010-04-05", + "ascension2010": "2010-05-13", + "pentecote2010": "2010-05-24", + "paques2011": "2011-04-25", + "ascension2011": "2011-06-02", + "pentecote2011": "2011-06-13", + "paques2012": "2012-04-09", + "ascension2012": "2012-05-17", + "pentecote2012": "2012-05-28", +} # XXX this implementation cries for multimethod dispatching + def get_step(dateobj: Union[date, datetime], nbdays: int = 1) -> timedelta: # assume date is either a python datetime or a mx.DateTime object if isinstance(dateobj, date): return ONEDAY * nbdays - return nbdays # mx.DateTime is ok with integers + return nbdays # mx.DateTime is ok with integers -def datefactory(year: int, month: int, day: int, sampledate: Union[date, datetime]) -> Union[date, datetime]: + +def datefactory( + year: int, month: int, day: int, sampledate: Union[date, datetime] +) -> Union[date, datetime]: # assume date is either a python datetime or a mx.DateTime object if isinstance(sampledate, datetime): return datetime(year, month, day) @@ -106,17 +102,20 @@ def datefactory(year: int, month: int, day: int, sampledate: Union[date, datetim return date(year, month, day) return Date(year, month, day) + def weekday(dateobj: Union[date, datetime]) -> int: # assume date is either a python datetime or a mx.DateTime object if isinstance(dateobj, date): return dateobj.weekday() return dateobj.day_of_week + def str2date(datestr: str, sampledate: Union[date, datetime]) -> Union[date, datetime]: # NOTE: datetime.strptime is not an option until we drop py2.4 compat - year, month, day = [int(chunk) for chunk in datestr.split('-')] + year, month, day = [int(chunk) for chunk in datestr.split("-")] return datefactory(year, month, day, sampledate) + def days_between(start: Union[date, datetime], end: Union[date, datetime]) -> int: if isinstance(start, date): # mypy: No overload variant of "__sub__" of "datetime" matches argument type "date" @@ -130,32 +129,35 @@ def days_between(start: Union[date, datetime], end: Union[date, datetime]) -> in else: return int(math.ceil((end - start).days)) -def get_national_holidays(begin: Union[date, datetime], end: Union[date, datetime]) -> Union[List[date], List[datetime]]: + +def get_national_holidays( + begin: Union[date, datetime], end: Union[date, datetime] +) -> Union[List[date], List[datetime]]: """return french national days off between begin and end""" begin = datefactory(begin.year, begin.month, begin.day, begin) end = datefactory(end.year, end.month, end.day, end) - holidays = [str2date(datestr, begin) - for datestr in FRENCH_MOBILE_HOLIDAYS.values()] - for year in range(begin.year, end.year+1): + holidays = [str2date(datestr, begin) for datestr in FRENCH_MOBILE_HOLIDAYS.values()] + for year in range(begin.year, end.year + 1): for datestr in FRENCH_FIXED_HOLIDAYS.values(): date = str2date(datestr % year, begin) if date not in holidays: holidays.append(date) return [day for day in holidays if begin <= day < end] + def add_days_worked(start: date, days: int) -> date: """adds date but try to only take days worked into account""" step = get_step(start) weeks, plus = divmod(days, 5) end = start + ((weeks * 7) + plus) * step - if weekday(end) >= 5: # saturday or sunday - end += (2 * step) - end += len([x for x in get_national_holidays(start, end + step) - if weekday(x) < 5]) * step - if weekday(end) >= 5: # saturday or sunday - end += (2 * step) + if weekday(end) >= 5: # saturday or sunday + end += 2 * step + end += len([x for x in get_national_holidays(start, end + step) if weekday(x) < 5]) * step + if weekday(end) >= 5: # saturday or sunday + end += 2 * step return end + def nb_open_days(start: Union[date, datetime], end: Union[date, datetime]) -> int: assert start <= end step = get_step(start) @@ -166,15 +168,18 @@ def nb_open_days(start: Union[date, datetime], end: Union[date, datetime]) -> in elif weekday(end) == 6: plus -= 1 open_days = weeks * 5 + plus - nb_week_holidays = len([x for x in get_national_holidays(start, end+step) - if weekday(x) < 5 and x < end]) + nb_week_holidays = len( + [x for x in get_national_holidays(start, end + step) if weekday(x) < 5 and x < end] + ) open_days -= nb_week_holidays if open_days < 0: return 0 return open_days -def date_range(begin: date, end: date, incday: Optional[Any] = None, incmonth: Optional[bool] = None) -> Generator[date, Any, None]: +def date_range( + begin: date, end: date, incday: Optional[Any] = None, incmonth: Optional[bool] = None +) -> Generator[date, Any, None]: """yields each date between begin and end :param begin: the start date @@ -202,6 +207,7 @@ def date_range(begin: date, end: date, incday: Optional[Any] = None, incmonth: O yield begin begin += incr + # makes py datetime usable ##################################################### ONEDAY: timedelta = timedelta(days=1) @@ -209,14 +215,17 @@ ONEWEEK: timedelta = timedelta(days=7) try: strptime = datetime.strptime -except AttributeError: # py < 2.5 +except AttributeError: # py < 2.5 from time import strptime as time_strptime + def strptime(value, format): return datetime(*time_strptime(value, format)[:6]) -def strptime_time(value, format='%H:%M'): + +def strptime_time(value, format="%H:%M"): return time(*time_strptime(value, format)[3:6]) + def todate(somedate: date) -> date: """return a date from a date (leaving unchanged) or a datetime""" if isinstance(somedate, datetime): @@ -224,6 +233,7 @@ def todate(somedate: date) -> date: assert isinstance(somedate, (date, DateTimeType)), repr(somedate) return somedate + def totime(somedate): """return a time from a time (leaving unchanged), date or datetime""" # XXX mx compat @@ -232,6 +242,7 @@ def totime(somedate): assert isinstance(somedate, (time)), repr(somedate) return somedate + def todatetime(somedate): """return a date from a date (leaving unchanged) or a datetime""" # take care, datetime is a subclass of date @@ -240,8 +251,10 @@ def todatetime(somedate): assert isinstance(somedate, (date, DateTimeType)), repr(somedate) return datetime(somedate.year, somedate.month, somedate.day) + def datetime2ticks(somedate: Union[date, datetime]) -> int: - return timegm(somedate.timetuple()) * 1000 + int(getattr(somedate, 'microsecond', 0) / 1000) + return timegm(somedate.timetuple()) * 1000 + int(getattr(somedate, "microsecond", 0) / 1000) + def ticks2datetime(ticks: int) -> datetime: miliseconds, microseconds = divmod(ticks, 1000) @@ -256,9 +269,11 @@ def ticks2datetime(ticks: int) -> datetime: except (ValueError, OverflowError): raise + def days_in_month(somedate: date) -> int: return monthrange(somedate.year, somedate.month)[1] + def days_in_year(somedate): feb = date(somedate.year, 2, 1) if days_in_month(feb) == 29: @@ -266,25 +281,30 @@ def days_in_year(somedate): else: return 365 + def previous_month(somedate, nbmonth=1): while nbmonth: somedate = first_day(somedate) - ONEDAY nbmonth -= 1 return somedate + def next_month(somedate: date, nbmonth: int = 1) -> date: while nbmonth: somedate = last_day(somedate) + ONEDAY nbmonth -= 1 return somedate + def first_day(somedate): return date(somedate.year, somedate.month, 1) + def last_day(somedate: date) -> date: return date(somedate.year, somedate.month, days_in_month(somedate)) -def ustrftime(somedate: datetime, fmt: str = '%Y-%m-%d') -> str: + +def ustrftime(somedate: datetime, fmt: str = "%Y-%m-%d") -> str: """like strftime, but returns a unicode string instead of an encoded string which may be problematic with localized date. """ @@ -294,7 +314,7 @@ def ustrftime(somedate: datetime, fmt: str = '%Y-%m-%d') -> str: else: try: if sys.version_info < (3, 0): - encoding = getlocale(LC_TIME)[1] or 'ascii' + encoding = getlocale(LC_TIME)[1] or "ascii" return unicode(somedate.strftime(str(fmt)), encoding) else: return somedate.strftime(fmt) @@ -304,37 +324,41 @@ def ustrftime(somedate: datetime, fmt: str = '%Y-%m-%d') -> str: # datetime is not happy with dates before 1900 # we try to work around this, assuming a simple # format string - fields = {'Y': somedate.year, - 'm': somedate.month, - 'd': somedate.day, - } + fields = { + "Y": somedate.year, + "m": somedate.month, + "d": somedate.day, + } if isinstance(somedate, datetime): - fields.update({'H': somedate.hour, - 'M': somedate.minute, - 'S': somedate.second}) - fmt = re.sub('%([YmdHMS])', r'%(\1)02d', fmt) + fields.update({"H": somedate.hour, "M": somedate.minute, "S": somedate.second}) + fmt = re.sub("%([YmdHMS])", r"%(\1)02d", fmt) return unicode(fmt) % fields + def utcdatetime(dt: datetime) -> datetime: if dt.tzinfo is None: return dt # mypy: No overload variant of "__sub__" of "datetime" matches argument type "None" - return (dt.replace(tzinfo=None) - dt.utcoffset()) # type: ignore + return dt.replace(tzinfo=None) - dt.utcoffset() # type: ignore + def utctime(dt): if dt.tzinfo is None: return dt return (dt + dt.utcoffset() + dt.dst()).replace(tzinfo=None) + def datetime_to_seconds(date): """return the number of seconds since the begining of the day for that date """ - return date.second+60*date.minute + 3600*date.hour + return date.second + 60 * date.minute + 3600 * date.hour + def timedelta_to_days(delta): """return the time delta as a number of seconds""" - return delta.days + delta.seconds / (3600*24) + return delta.days + delta.seconds / (3600 * 24) + def timedelta_to_seconds(delta): """return the time delta as a fraction of days""" - return delta.days*(3600*24) + delta.seconds + return delta.days * (3600 * 24) + delta.seconds diff --git a/logilab/common/debugger.py b/logilab/common/debugger.py index 2df84ad..6553557 100644 --- a/logilab/common/debugger.py +++ b/logilab/common/debugger.py @@ -49,12 +49,17 @@ from logilab.common.compat import StringIO try: from IPython import PyColorize except ImportError: + def colorize(source, start_lineno, curlineno): """fallback colorize function""" return source + def colorize_source(source): return source + + else: + def colorize(source, start_lineno, curlineno): """colorize and annotate source with linenos (as in pdb's list command) @@ -66,10 +71,10 @@ else: for index, line in enumerate(output.getvalue().splitlines()): lineno = index + start_lineno if lineno == curlineno: - annotated.append('%4s\t->\t%s' % (lineno, line)) + annotated.append("%4s\t->\t%s" % (lineno, line)) else: - annotated.append('%4s\t\t%s' % (lineno, line)) - return '\n'.join(annotated) + annotated.append("%4s\t\t%s" % (lineno, line)) + return "\n".join(annotated) def colorize_source(source): """colorize given source""" @@ -86,7 +91,7 @@ def getsource(obj): or code object. The source code is returned as a single string. An IOError is raised if the source code cannot be retrieved.""" lines, lnum = inspect.getsourcelines(obj) - return ''.join(lines), lnum + return "".join(lines), lnum ################################################################ @@ -98,6 +103,7 @@ class Debugger(Pdb): - overrides list command to search for current block instead of using 5 lines of context """ + def __init__(self, tcbk=None): Pdb.__init__(self) self.reset() @@ -137,11 +143,10 @@ class Debugger(Pdb): """provide variable names completion for the ``p`` command""" namespace = dict(self.curframe.f_globals) namespace.update(self.curframe.f_locals) - if '.' in text: + if "." in text: return self.attr_matches(text, namespace) return [varname for varname in namespace if varname.startswith(text)] - def attr_matches(self, text, namespace): """implementation coming from rlcompleter.Completer.attr_matches Compute matches when text contains a dot. @@ -156,14 +161,15 @@ class Debugger(Pdb): """ import re + m = re.match(r"(\w+(\.\w+)*)\.(\w*)", text) if not m: return expr, attr = m.group(1, 3) object = eval(expr, namespace) words = dir(object) - if hasattr(object, '__class__'): - words.append('__class__') + if hasattr(object, "__class__"): + words.append("__class__") words = words + self.get_class_members(object.__class__) matches = [] n = len(attr) @@ -175,7 +181,7 @@ class Debugger(Pdb): def get_class_members(self, klass): """implementation coming from rlcompleter.get_class_members""" ret = dir(klass) - if hasattr(klass, '__bases__'): + if hasattr(klass, "__bases__"): for base in klass.__bases__: ret = ret + self.get_class_members(base) return ret @@ -185,33 +191,35 @@ class Debugger(Pdb): """overrides default list command to display the surrounding block instead of 5 lines of context """ - self.lastcmd = 'list' + self.lastcmd = "list" if not arg: try: source, start_lineno = getsource(self.curframe) - print(colorize(''.join(source), start_lineno, - self.curframe.f_lineno)) + print(colorize("".join(source), start_lineno, self.curframe.f_lineno)) except KeyboardInterrupt: pass except IOError: Pdb.do_list(self, arg) else: Pdb.do_list(self, arg) + do_l = do_list def do_open(self, arg): """opens source file corresponding to the current stack level""" filename = self.curframe.f_code.co_filename lineno = self.curframe.f_lineno - cmd = 'emacsclient --no-wait +%s %s' % (lineno, filename) + cmd = "emacsclient --no-wait +%s %s" % (lineno, filename) os.system(cmd) do_o = do_open + def pm(): """use our custom debugger""" dbg = Debugger(sys.last_traceback) dbg.start() + def set_trace(): Debugger().set_trace(sys._getframe().f_back) diff --git a/logilab/common/decorators.py b/logilab/common/decorators.py index 27ed7ee..a471353 100644 --- a/logilab/common/decorators.py +++ b/logilab/common/decorators.py @@ -34,13 +34,16 @@ from logilab.common.compat import method_type # XXX rewrite so we can use the decorator syntax when keyarg has to be specified + class cached_decorator(object): def __init__(self, cacheattr: Optional[str] = None, keyarg: Optional[int] = None) -> None: self.cacheattr = cacheattr self.keyarg = keyarg + def __call__(self, callableobj: Optional[Callable] = None) -> Callable: - assert not isgeneratorfunction(callableobj), \ - 'cannot cache generator function: %s' % callableobj + assert not isgeneratorfunction(callableobj), ( + "cannot cache generator function: %s" % callableobj + ) assert callableobj is not None if len(getfullargspec(callableobj).args) == 1 or self.keyarg == 0: cache = _SingleValueCache(callableobj, self.cacheattr) @@ -50,11 +53,12 @@ class cached_decorator(object): cache = _MultiValuesCache(callableobj, self.cacheattr) return cache.closure() + class _SingleValueCache(object): def __init__(self, callableobj: Callable, cacheattr: Optional[str] = None) -> None: self.callable = callableobj if cacheattr is None: - self.cacheattr = '_%s_cache_' % callableobj.__name__ + self.cacheattr = "_%s_cache_" % callableobj.__name__ else: assert cacheattr != callableobj.__name__ self.cacheattr = cacheattr @@ -70,6 +74,7 @@ class _SingleValueCache(object): def closure(self) -> Callable: def wrapped(*args, **kwargs): return self.__call__(*args, **kwargs) + # mypy: "Callable[[VarArg(Any), KwArg(Any)], Any]" has no attribute "cache_obj" # dynamic attribute for magic wrapped.cache_obj = self # type: ignore @@ -101,6 +106,7 @@ class _MultiValuesCache(_SingleValueCache): _cache[args] = __me.callable(self, *args) return _cache[args] + class _MultiValuesKeyArgCache(_MultiValuesCache): def __init__(self, callableobj: Callable, keyarg: int, cacheattr: Optional[str] = None) -> None: super(_MultiValuesKeyArgCache, self).__init__(callableobj, cacheattr) @@ -108,7 +114,7 @@ class _MultiValuesKeyArgCache(_MultiValuesCache): def __call__(__me, self, *args, **kwargs): _cache = __me._get_cache(self) - key = args[__me.keyarg-1] + key = args[__me.keyarg - 1] try: return _cache[key] except KeyError: @@ -116,9 +122,11 @@ class _MultiValuesKeyArgCache(_MultiValuesCache): return _cache[key] -def cached(callableobj: Optional[Callable] = None, keyarg: Optional[int] = None, **kwargs: Any) -> Union[Callable, cached_decorator]: +def cached( + callableobj: Optional[Callable] = None, keyarg: Optional[int] = None, **kwargs: Any +) -> Union[Callable, cached_decorator]: """Simple decorator to cache result of method call.""" - kwargs['keyarg'] = keyarg + kwargs["keyarg"] = keyarg decorator = cached_decorator(**kwargs) if callableobj is None: return decorator @@ -140,23 +148,22 @@ class cachedproperty(object): .. _pyramid: http://pypi.python.org/pypi/pyramid .. _mercurial: http://pypi.python.org/pypi/Mercurial """ - __slots__ = ('wrapped',) + + __slots__ = ("wrapped",) def __init__(self, wrapped): try: wrapped.__name__ except AttributeError: - raise TypeError('%s must have a __name__ attribute' % - wrapped) + raise TypeError("%s must have a __name__ attribute" % wrapped) self.wrapped = wrapped # mypy: Signature of "__doc__" incompatible with supertype "object" # but this works? @property def __doc__(self) -> str: # type: ignore - doc = getattr(self.wrapped, '__doc__', None) - return ('<wrapped by the cachedproperty decorator>%s' - % ('\n%s' % doc if doc else '')) + doc = getattr(self.wrapped, "__doc__", None) + return "<wrapped by the cachedproperty decorator>%s" % ("\n%s" % doc if doc else "") def __get__(self, inst, objtype=None): if inst is None: @@ -173,6 +180,7 @@ def get_cache_impl(obj, funcname): member = member.fget return member.cache_obj + def clear_cache(obj, funcname): """Clear a cache handled by the :func:`cached` decorator. If 'x' class has @cached on its method `foo`, type @@ -183,6 +191,7 @@ def clear_cache(obj, funcname): """ get_cache_impl(obj, funcname).clear(obj) + def copy_cache(obj, funcname, cacheobj): """Copy cache for <funcname> from cacheobj to obj.""" cacheattr = get_cache_impl(obj, funcname).cacheattr @@ -196,9 +205,10 @@ class wproperty(object): """Simple descriptor expecting to take a modifier function as first argument and looking for a _<function name> to retrieve the attribute. """ + def __init__(self, setfunc): self.setfunc = setfunc - self.attrname = '_%s' % setfunc.__name__ + self.attrname = "_%s" % setfunc.__name__ def __set__(self, obj, value): self.setfunc(obj, value) @@ -211,22 +221,27 @@ class wproperty(object): class classproperty(object): """this is a simple property-like class but for class attributes. """ + def __init__(self, get): self.get = get + def __get__(self, inst, cls): return self.get(cls) class iclassmethod(object): - '''Descriptor for method which should be available as class method if called + """Descriptor for method which should be available as class method if called on the class or instance method if called on an instance. - ''' + """ + def __init__(self, func): self.func = func + def __get__(self, instance, objtype): if instance is None: return method_type(self.func, objtype, objtype.__class__) return method_type(self.func, instance, objtype) + def __set__(self, instance, value): raise AttributeError("can't set attribute") @@ -236,9 +251,9 @@ def timed(f): t = time() c = process_time() res = f(*args, **kwargs) - print('%s clock: %.9f / time: %.9f' % (f.__name__, - process_time() - c, time() - t)) + print("%s clock: %.9f / time: %.9f" % (f.__name__, process_time() - c, time() - t)) return res + return wrap @@ -247,6 +262,7 @@ def locked(acquire, release): returning a decorator function which will call the inner method after having called acquire(self) et will call release(self) afterwards. """ + def decorator(f): def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: acquire(self) @@ -254,7 +270,9 @@ def locked(acquire, release): return f(self, *args, **kwargs) finally: release(self) + return wrapper + return decorator @@ -278,13 +296,16 @@ def monkeypatch(klass: type, methodname: Optional[str] = None) -> Callable: >>> a.foo() 12 """ + def decorator(func): try: name = methodname or func.__name__ except AttributeError: - raise AttributeError('%s has no __name__ attribute: ' - 'you should provide an explicit `methodname`' - % func) + raise AttributeError( + "%s has no __name__ attribute: " + "you should provide an explicit `methodname`" % func + ) setattr(klass, name, func) return func + return decorator diff --git a/logilab/common/deprecation.py b/logilab/common/deprecation.py index b147b43..15f8087 100644 --- a/logilab/common/deprecation.py +++ b/logilab/common/deprecation.py @@ -38,7 +38,7 @@ class DeprecationWrapper(object): return getattr(self._proxied, attr) def __setattr__(self, attr, value): - if attr in ('_proxied', '_msg'): + if attr in ("_proxied", "_msg"): self.__dict__[attr] = value else: send_warning(self._msg, stacklevel=3, version=self.version) diff --git a/logilab/common/fileutils.py b/logilab/common/fileutils.py index 102cd7c..1b1ed5b 100644 --- a/logilab/common/fileutils.py +++ b/logilab/common/fileutils.py @@ -44,6 +44,7 @@ from logilab.common.shellutils import find from logilab.common.deprecation import deprecated from logilab.common.compat import FileIO + def first_level_directory(path: str) -> str: """Return the first level directory of a path. @@ -69,6 +70,7 @@ def first_level_directory(path: str) -> str: # path was absolute, head is the fs root return head + def abspath_listdir(path): """Lists path's content using absolute paths.""" path = abspath(path) @@ -90,7 +92,7 @@ def is_binary(filename: str) -> int: try: # mypy: Item "None" of "Optional[str]" has no attribute "startswith" # it's handle by the exception - return not mimetypes.guess_type(filename)[0].startswith('text') # type: ignore + return not mimetypes.guess_type(filename)[0].startswith("text") # type: ignore except AttributeError: return 1 @@ -105,8 +107,8 @@ def write_open_mode(filename: str) -> str: :return: the mode that should be use to open the file ('w' or 'wb') """ if is_binary(filename): - return 'wb' - return 'w' + return "wb" + return "w" def ensure_fs_mode(filepath, desired_mode=S_IWRITE): @@ -147,10 +149,11 @@ class ProtectedFile(FileIO): - on close()/del(), write/append the StringIO content to the file and do the chmod only once """ + def __init__(self, filepath: str, mode: str) -> None: self.original_mode = stat(filepath)[ST_MODE] self.mode_changed = False - if mode in ('w', 'a', 'wb', 'ab'): + if mode in ("w", "a", "wb", "ab"): if not self.original_mode & S_IWRITE: chmod(filepath, self.original_mode | S_IWRITE) self.mode_changed = True @@ -178,6 +181,7 @@ class UnresolvableError(Exception): path between two paths. """ + def relative_path(from_file, to_file): """Try to get a relative path from `from_file` to `to_file` (path will be absolute if to_file is an absolute file). This function @@ -224,7 +228,7 @@ def relative_path(from_file, to_file): from_file = normpath(from_file) to_file = normpath(to_file) if from_file == to_file: - return '' + return "" if isabs(to_file): if not isabs(from_file): return to_file @@ -240,7 +244,7 @@ def relative_path(from_file, to_file): to_parts.pop(0) else: idem = 0 - result.append('..') + result.append("..") result += to_parts return sep.join(result) @@ -254,9 +258,12 @@ def norm_read(path): :rtype: str :return: the content of the file with normalized line feeds """ - return open(path, 'U').read() + return open(path, "U").read() + + norm_read = deprecated("use \"open(path, 'U').read()\"")(norm_read) + def norm_open(path): """Return a stream for a file with content with normalized line feeds. @@ -266,9 +273,12 @@ def norm_open(path): :rtype: file or StringIO :return: the opened file with normalized line feeds """ - return open(path, 'U') + return open(path, "U") + + norm_open = deprecated("use \"open(path, 'U')\"")(norm_open) + def lines(path: str, comments: Optional[str] = None) -> List[str]: """Return a list of non empty lines in the file located at `path`. @@ -321,9 +331,13 @@ def stream_lines(stream: TextIOWrapper, comments: Optional[str] = None) -> List[ return result -def export(from_dir: str, to_dir: str, - blacklist: Tuple[str, str, str, str, str, str, str, str] = BASE_BLACKLIST, ignore_ext: Tuple[str, str, str, str, str, str] = IGNORED_EXTENSIONS, - verbose: int = 0) -> None: +def export( + from_dir: str, + to_dir: str, + blacklist: Tuple[str, str, str, str, str, str, str, str] = BASE_BLACKLIST, + ignore_ext: Tuple[str, str, str, str, str, str] = IGNORED_EXTENSIONS, + verbose: int = 0, +) -> None: """Make a mirror of `from_dir` in `to_dir`, omitting directories and files listed in the black list or ending with one of the given extensions. @@ -352,8 +366,8 @@ def export(from_dir: str, to_dir: str, try: mkdir(to_dir) except OSError: - pass # FIXME we should use "exists" if the point is about existing dir - # else (permission problems?) shouldn't return / raise ? + pass # FIXME we should use "exists" if the point is about existing dir + # else (permission problems?) shouldn't return / raise ? for directory, dirnames, filenames in walk(from_dir): for norecurs in blacklist: try: @@ -362,7 +376,7 @@ def export(from_dir: str, to_dir: str, continue for dirname in dirnames: src = join(directory, dirname) - dest = to_dir + src[len(from_dir):] + dest = to_dir + src[len(from_dir) :] if isdir(src): if not exists(dest): mkdir(dest) @@ -372,9 +386,9 @@ def export(from_dir: str, to_dir: str, if any([filename.endswith(ext) for ext in ignore_ext]): continue src = join(directory, filename) - dest = to_dir + src[len(from_dir):] + dest = to_dir + src[len(from_dir) :] if verbose: - print(src, '->', dest, file=sys.stderr) + print(src, "->", dest, file=sys.stderr) if exists(dest): remove(dest) shutil.copy2(src, dest) @@ -396,6 +410,5 @@ def remove_dead_links(directory, verbose=0): src = join(dirpath, filename) if islink(src) and not exists(src): if verbose: - print('remove dead link', src) + print("remove dead link", src) remove(src) - diff --git a/logilab/common/graph.py b/logilab/common/graph.py index fffa172..82c3a32 100644 --- a/logilab/common/graph.py +++ b/logilab/common/graph.py @@ -32,47 +32,59 @@ import codecs import errno from typing import Dict, List, Tuple, Union, Any, Optional, Set, TypeVar, Iterable + def escape(value): """Make <value> usable in a dot file.""" - lines = [line.replace('"', '\\"') for line in value.split('\n')] - data = '\\l'.join(lines) - return '\\n' + data + lines = [line.replace('"', '\\"') for line in value.split("\n")] + data = "\\l".join(lines) + return "\\n" + data + def target_info_from_filename(filename): """Transforms /some/path/foo.png into ('/some/path', 'foo.png', 'png').""" basename = osp.basename(filename) storedir = osp.dirname(osp.abspath(filename)) - target = filename.split('.')[-1] + target = filename.split(".")[-1] return storedir, basename, target class DotBackend: """Dot File backend.""" - def __init__(self, graphname, rankdir=None, size=None, ratio=None, - charset='utf-8', renderer='dot', additionnal_param={}): + + def __init__( + self, + graphname, + rankdir=None, + size=None, + ratio=None, + charset="utf-8", + renderer="dot", + additionnal_param={}, + ): self.graphname = graphname self.renderer = renderer self.lines = [] self._source = None self.emit("digraph %s {" % normalize_node_id(graphname)) if rankdir: - self.emit('rankdir=%s' % rankdir) + self.emit("rankdir=%s" % rankdir) if ratio: - self.emit('ratio=%s' % ratio) + self.emit("ratio=%s" % ratio) if size: self.emit('size="%s"' % size) if charset: - assert charset.lower() in ('utf-8', 'iso-8859-1', 'latin1'), \ - 'unsupported charset %s' % charset + assert charset.lower() in ("utf-8", "iso-8859-1", "latin1"), ( + "unsupported charset %s" % charset + ) self.emit('charset="%s"' % charset) for param in sorted(additionnal_param.items()): - self.emit('='.join(param)) + self.emit("=".join(param)) def get_source(self): """returns self._source""" if self._source is None: self.emit("}\n") - self._source = '\n'.join(self.lines) + self._source = "\n".join(self.lines) del self.lines return self._source @@ -87,14 +99,15 @@ class DotBackend: :rtype: str :return: a path to the generated file """ - import subprocess # introduced in py 2.4 + import subprocess # introduced in py 2.4 + name = self.graphname if not dotfile: # if 'outputfile' is a dot file use it as 'dotfile' if outputfile and outputfile.endswith(".dot"): dotfile = outputfile else: - dotfile = '%s.dot' % name + dotfile = "%s.dot" % name if outputfile is not None: storedir, basename, target = target_info_from_filename(outputfile) if target != "dot": @@ -103,30 +116,43 @@ class DotBackend: else: dot_sourcepath = osp.join(storedir, dotfile) else: - target = 'png' + target = "png" pdot, dot_sourcepath = tempfile.mkstemp(".dot", name) ppng, outputfile = tempfile.mkstemp(".png", name) os.close(pdot) os.close(ppng) - pdot = codecs.open(dot_sourcepath, 'w', encoding='utf8') + pdot = codecs.open(dot_sourcepath, "w", encoding="utf8") pdot.write(self.source) pdot.close() - if target != 'dot': - if sys.platform == 'win32': + if target != "dot": + if sys.platform == "win32": use_shell = True else: use_shell = False try: if mapfile: - subprocess.call([self.renderer, '-Tcmapx', '-o', mapfile, '-T', target, dot_sourcepath, '-o', outputfile], - shell=use_shell) + subprocess.call( + [ + self.renderer, + "-Tcmapx", + "-o", + mapfile, + "-T", + target, + dot_sourcepath, + "-o", + outputfile, + ], + shell=use_shell, + ) else: - subprocess.call([self.renderer, '-T', target, - dot_sourcepath, '-o', outputfile], - shell=use_shell) + subprocess.call( + [self.renderer, "-T", target, dot_sourcepath, "-o", outputfile], + shell=use_shell, + ) except OSError as e: if e.errno == errno.ENOENT: - e.strerror = 'File not found: {0}'.format(self.renderer) + e.strerror = "File not found: {0}".format(self.renderer) raise os.unlink(dot_sourcepath) return outputfile @@ -141,19 +167,21 @@ class DotBackend: """ attrs = ['%s="%s"' % (prop, value) for prop, value in props.items()] n_from, n_to = normalize_node_id(name1), normalize_node_id(name2) - self.emit('%s -> %s [%s];' % (n_from, n_to, ', '.join(sorted(attrs))) ) + self.emit("%s -> %s [%s];" % (n_from, n_to, ", ".join(sorted(attrs)))) def emit_node(self, name, **props): """emit a node with given properties. node properties: see http://www.graphviz.org/doc/info/attrs.html """ attrs = ['%s="%s"' % (prop, value) for prop, value in props.items()] - self.emit('%s [%s];' % (normalize_node_id(name), ', '.join(sorted(attrs)))) + self.emit("%s [%s];" % (normalize_node_id(name), ", ".join(sorted(attrs)))) + def normalize_node_id(nid): """Returns a suitable DOT node id for `nid`.""" return '"%s"' % nid + class GraphGenerator: def __init__(self, backend): # the backend is responsible to output the graph in a particular format @@ -194,8 +222,8 @@ def ordered_nodes(graph: _Graph) -> Tuple[V, ...]: cycles: List[List[V]] = get_cycles(graph) if cycles: - bad_cycles = '\n'.join([' -> '.join(map(str, cycle)) for cycle in cycles]) - raise UnorderableGraph('cycles in graph: %s' % bad_cycles) + bad_cycles = "\n".join([" -> ".join(map(str, cycle)) for cycle in cycles]) + raise UnorderableGraph("cycles in graph: %s" % bad_cycles) vertices = set(graph) to_vertices = set() @@ -205,7 +233,7 @@ def ordered_nodes(graph: _Graph) -> Tuple[V, ...]: missing_vertices = to_vertices - vertices if missing_vertices: - raise UnorderableGraph('missing vertices: %s' % ', '.join(missing_vertices)) + raise UnorderableGraph("missing vertices: %s" % ", ".join(missing_vertices)) # order vertices order = [] @@ -214,7 +242,7 @@ def ordered_nodes(graph: _Graph) -> Tuple[V, ...]: while graph: if old_len == len(graph): - raise UnorderableGraph('unknown problem with %s' % graph) + raise UnorderableGraph("unknown problem with %s" % graph) old_len = len(graph) deps_ok = [] @@ -240,12 +268,11 @@ def ordered_nodes(graph: _Graph) -> Tuple[V, ...]: return tuple(result) -def get_cycles(graph_dict: _Graph, - vertices: Optional[Iterable] = None) -> List[List]: - '''given a dictionary representing an ordered graph (i.e. key are vertices +def get_cycles(graph_dict: _Graph, vertices: Optional[Iterable] = None) -> List[List]: + """given a dictionary representing an ordered graph (i.e. key are vertices and values is a list of destination vertices representing edges), return a list of detected cycles - ''' + """ if not graph_dict: return [] @@ -259,11 +286,9 @@ def get_cycles(graph_dict: _Graph, return result -def _get_cycles(graph_dict: _Graph, - path: List, - visited: Set, - result: List[List], - vertice: V) -> None: +def _get_cycles( + graph_dict: _Graph, path: List, visited: Set, result: List[List], vertice: V +) -> None: """recursive function doing the real work for get_cycles""" if vertice in path: cycle = [vertice] @@ -299,7 +324,9 @@ def _get_cycles(graph_dict: _Graph, path.pop() -def has_path(graph_dict: Dict[str, List[str]], fromnode: str, tonode: str, path: Optional[List[str]] = None) -> Optional[List[str]]: +def has_path( + graph_dict: Dict[str, List[str]], fromnode: str, tonode: str, path: Optional[List[str]] = None +) -> Optional[List[str]]: """generic function taking a simple graph definition as a dictionary, with node has key associated to a list of nodes directly reachable from it. @@ -316,4 +343,3 @@ def has_path(graph_dict: Dict[str, List[str]], fromnode: str, tonode: str, path: return path[1:] + [tonode] path.pop() return None - diff --git a/logilab/common/interface.py b/logilab/common/interface.py index 8248a27..4d4b92d 100644 --- a/logilab/common/interface.py +++ b/logilab/common/interface.py @@ -28,6 +28,7 @@ __docformat__ = "restructuredtext en" class Interface(object): """Base class for interfaces.""" + @classmethod def is_implemented_by(cls, instance: type) -> bool: return implements(instance, cls) @@ -37,7 +38,7 @@ def implements(obj: type, interface: type) -> bool: """Return true if the give object (maybe an instance or class) implements the interface. """ - kimplements = getattr(obj, '__implements__', ()) + kimplements = getattr(obj, "__implements__", ()) if not isinstance(kimplements, (list, tuple)): kimplements = (kimplements,) for implementedinterface in kimplements: @@ -62,7 +63,7 @@ def extend(klass: type, interface: type, _recurs: bool = False) -> None: kimplementsklass = tuple kimplements = [] kimplements.append(interface) - klass.__implements__ = kimplementsklass(kimplements) #type: ignore + klass.__implements__ = kimplementsklass(kimplements) # type: ignore for subklass in klass.__subclasses__(): extend(subklass, interface, _recurs=True) elif _recurs: diff --git a/logilab/common/logging_ext.py b/logilab/common/logging_ext.py index 9657581..e1df45d 100644 --- a/logilab/common/logging_ext.py +++ b/logilab/common/logging_ext.py @@ -30,13 +30,14 @@ from logilab.common.textutils import colorize_ansi def set_log_methods(cls, logger): """bind standard logger's methods as methods on the class""" cls.__logger = logger - for attr in ('debug', 'info', 'warning', 'error', 'critical', 'exception'): + for attr in ("debug", "info", "warning", "error", "critical", "exception"): setattr(cls, attr, getattr(logger, attr)) def xxx_cyan(record): - if 'XXX' in record.message: - return 'cyan' + if "XXX" in record.message: + return "cyan" + class ColorFormatter(logging.Formatter): """ @@ -54,12 +55,13 @@ class ColorFormatter(logging.Formatter): def __init__(self, fmt=None, datefmt=None, colors=None): logging.Formatter.__init__(self, fmt, datefmt) self.colorfilters = [] - self.colors = {'CRITICAL': 'red', - 'ERROR': 'red', - 'WARNING': 'magenta', - 'INFO': 'green', - 'DEBUG': 'yellow', - } + self.colors = { + "CRITICAL": "red", + "ERROR": "red", + "WARNING": "magenta", + "INFO": "green", + "DEBUG": "yellow", + } if colors is not None: assert isinstance(colors, dict) self.colors.update(colors) @@ -76,6 +78,7 @@ class ColorFormatter(logging.Formatter): return colorize_ansi(msg, color) return msg + def set_color_formatter(logger=None, **kw): """ Install a color formatter on the 'logger'. If not given, it will @@ -94,37 +97,41 @@ def set_color_formatter(logger=None, **kw): logger.handlers[0].setFormatter(fmt) -LOG_FORMAT = '%(asctime)s - (%(name)s) %(levelname)s: %(message)s' -LOG_DATE_FORMAT = '%Y-%m-%d %H:%M:%S' +LOG_FORMAT = "%(asctime)s - (%(name)s) %(levelname)s: %(message)s" +LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + def get_handler(debug=False, syslog=False, logfile=None, rotation_parameters=None): """get an apropriate handler according to given parameters""" - if os.environ.get('APYCOT_ROOT'): + if os.environ.get("APYCOT_ROOT"): handler = logging.StreamHandler(sys.stdout) if debug: handler = logging.StreamHandler() elif logfile is None: if syslog: from logging import handlers + handler = handlers.SysLogHandler() else: handler = logging.StreamHandler() else: try: if rotation_parameters is None: - if os.name == 'posix' and sys.version_info >= (2, 6): + if os.name == "posix" and sys.version_info >= (2, 6): from logging.handlers import WatchedFileHandler + handler = WatchedFileHandler(logfile) else: handler = logging.FileHandler(logfile) else: from logging.handlers import TimedRotatingFileHandler - handler = TimedRotatingFileHandler( - logfile, **rotation_parameters) + + handler = TimedRotatingFileHandler(logfile, **rotation_parameters) except IOError: handler = logging.StreamHandler() return handler + def get_threshold(debug=False, logthreshold=None): if logthreshold is None: if debug: @@ -132,15 +139,15 @@ def get_threshold(debug=False, logthreshold=None): else: logthreshold = logging.ERROR elif isinstance(logthreshold, str): - logthreshold = getattr(logging, THRESHOLD_MAP.get(logthreshold, - logthreshold)) + logthreshold = getattr(logging, THRESHOLD_MAP.get(logthreshold, logthreshold)) return logthreshold + def _colorable_terminal(): - isatty = hasattr(sys.__stdout__, 'isatty') and sys.__stdout__.isatty() + isatty = hasattr(sys.__stdout__, "isatty") and sys.__stdout__.isatty() if not isatty: return False - if os.name == 'nt': + if os.name == "nt": try: from colorama import init as init_win32_colors except ImportError: @@ -148,22 +155,34 @@ def _colorable_terminal(): init_win32_colors() return True + def get_formatter(logformat=LOG_FORMAT, logdateformat=LOG_DATE_FORMAT): if _colorable_terminal(): fmt = ColorFormatter(logformat, logdateformat) + def col_fact(record): - if 'XXX' in record.message: - return 'cyan' - if 'kick' in record.message: - return 'red' + if "XXX" in record.message: + return "cyan" + if "kick" in record.message: + return "red" + fmt.colorfilters.append(col_fact) else: fmt = logging.Formatter(logformat, logdateformat) return fmt -def init_log(debug=False, syslog=False, logthreshold=None, logfile=None, - logformat=LOG_FORMAT, logdateformat=LOG_DATE_FORMAT, fmt=None, - rotation_parameters=None, handler=None): + +def init_log( + debug=False, + syslog=False, + logthreshold=None, + logfile=None, + logformat=LOG_FORMAT, + logdateformat=LOG_DATE_FORMAT, + fmt=None, + rotation_parameters=None, + handler=None, +): """init the log service""" logger = logging.getLogger() if handler is None: @@ -181,13 +200,15 @@ def init_log(debug=False, syslog=False, logthreshold=None, logfile=None, handler.setFormatter(fmt) return handler + # map logilab.common.logger thresholds to logging thresholds -THRESHOLD_MAP = {'LOG_DEBUG': 'DEBUG', - 'LOG_INFO': 'INFO', - 'LOG_NOTICE': 'INFO', - 'LOG_WARN': 'WARNING', - 'LOG_WARNING': 'WARNING', - 'LOG_ERR': 'ERROR', - 'LOG_ERROR': 'ERROR', - 'LOG_CRIT': 'CRITICAL', - } +THRESHOLD_MAP = { + "LOG_DEBUG": "DEBUG", + "LOG_INFO": "INFO", + "LOG_NOTICE": "INFO", + "LOG_WARN": "WARNING", + "LOG_WARNING": "WARNING", + "LOG_ERR": "ERROR", + "LOG_ERROR": "ERROR", + "LOG_CRIT": "CRITICAL", +} diff --git a/logilab/common/modutils.py b/logilab/common/modutils.py index 76c4ac4..9ca4c81 100644 --- a/logilab/common/modutils.py +++ b/logilab/common/modutils.py @@ -32,8 +32,18 @@ __docformat__ = "restructuredtext en" import sys import os -from os.path import (splitext, join, abspath, isdir, dirname, exists, - basename, expanduser, normcase, realpath) +from os.path import ( + splitext, + join, + abspath, + isdir, + dirname, + exists, + basename, + expanduser, + normcase, + realpath, +) from imp import find_module, load_module, C_BUILTIN, PY_COMPILED, PKG_DIRECTORY from distutils.sysconfig import get_config_var, get_python_lib from distutils.errors import DistutilsPlatformError @@ -59,19 +69,19 @@ from logilab.common.deprecation import deprecated # # :see: `Problems with /usr/lib64 builds <http://bugs.python.org/issue1294959>`_ # :see: `FHS <http://www.pathname.com/fhs/pub/fhs-2.3.html#LIBLTQUALGTALTERNATEFORMATESSENTIAL>`_ -if sys.platform.startswith('win'): - PY_SOURCE_EXTS = ('py', 'pyw') - PY_COMPILED_EXTS = ('dll', 'pyd') +if sys.platform.startswith("win"): + PY_SOURCE_EXTS = ("py", "pyw") + PY_COMPILED_EXTS = ("dll", "pyd") else: - PY_SOURCE_EXTS = ('py',) - PY_COMPILED_EXTS = ('so',) + PY_SOURCE_EXTS = ("py",) + PY_COMPILED_EXTS = ("so",) try: STD_LIB_DIR = get_python_lib(standard_lib=True) # get_python_lib(standard_lib=1) is not available on pypy, set STD_LIB_DIR to # non-valid path, see https://bugs.pypy.org/issue1164 except DistutilsPlatformError: - STD_LIB_DIR = '//' + STD_LIB_DIR = "//" EXT_LIB_DIR = get_python_lib() @@ -83,6 +93,7 @@ class NoSourceFile(Exception): source file for a precompiled file """ + class LazyObject(object): def __init__(self, module, obj): self.module = module @@ -91,8 +102,7 @@ class LazyObject(object): def _getobj(self): if self._imported is None: - self._imported = getattr(load_module_from_name(self.module), - self.obj) + self._imported = getattr(load_module_from_name(self.module), self.obj) return self._imported def __getattribute__(self, attr): @@ -105,7 +115,9 @@ class LazyObject(object): return self._getobj()(*args, **kwargs) -def load_module_from_name(dotted_name: str, path: Optional[Any] = None, use_sys: int = True) -> ModuleType: +def load_module_from_name( + dotted_name: str, path: Optional[Any] = None, use_sys: int = True +) -> ModuleType: """Load a Python module from its name. :type dotted_name: str @@ -127,13 +139,15 @@ def load_module_from_name(dotted_name: str, path: Optional[Any] = None, use_sys: :rtype: module :return: the loaded module """ - module = load_module_from_modpath(dotted_name.split('.'), path, use_sys) + module = load_module_from_modpath(dotted_name.split("."), path, use_sys) if module is None: raise ImportError("module %s doesn't exist" % dotted_name) return module -def load_module_from_modpath(parts: List[str], path: Optional[Any] = None, use_sys: int = True) -> Optional[ModuleType]: +def load_module_from_modpath( + parts: List[str], path: Optional[Any] = None, use_sys: int = True +) -> Optional[ModuleType]: """Load a python module from its splitted name. :type parts: list(str) or tuple(str) @@ -156,14 +170,14 @@ def load_module_from_modpath(parts: List[str], path: Optional[Any] = None, use_s """ if use_sys: try: - return sys.modules['.'.join(parts)] + return sys.modules[".".join(parts)] except KeyError: pass modpath = [] prevmodule = None for part in parts: modpath.append(part) - curname = '.'.join(modpath) + curname = ".".join(modpath) module = None if len(modpath) != len(parts): # even with use_sys=False, should try to get outer packages from sys.modules @@ -180,13 +194,13 @@ def load_module_from_modpath(parts: List[str], path: Optional[Any] = None, use_s mp_file.close() if prevmodule: setattr(prevmodule, part, module) - _file = getattr(module, '__file__', '') + _file = getattr(module, "__file__", "") prevmodule = module if not _file and _is_namespace(curname): continue if not _file and len(modpath) != len(parts): - raise ImportError('no module in %s' % '.'.join(parts[len(modpath):]) ) - path = [dirname( _file )] + raise ImportError("no module in %s" % ".".join(parts[len(modpath) :])) + path = [dirname(_file)] return module @@ -222,7 +236,7 @@ def _check_init(path: str, mod_path: List[str]) -> bool: for part in mod_path: modpath.append(part) path = join(path, part) - if not _is_namespace('.'.join(modpath)) and not _has_init(path): + if not _is_namespace(".".join(modpath)) and not _has_init(path): return False return True @@ -231,8 +245,7 @@ def _canonicalize_path(path: str) -> str: return realpath(expanduser(path)) - -@deprecated('you should avoid using modpath_from_file()') +@deprecated("you should avoid using modpath_from_file()") def modpath_from_file(filename: str, extrapath: Optional[Dict[str, str]] = None) -> List[str]: """DEPRECATED: doens't play well with symlinks and sys.meta_path @@ -261,23 +274,23 @@ def modpath_from_file(filename: str, extrapath: Optional[Dict[str, str]] = None) if extrapath is not None: for path_ in map(_canonicalize_path, extrapath): path = abspath(path_) - if path and normcase(base[:len(path)]) == normcase(path): - submodpath = [pkg for pkg in base[len(path):].split(os.sep) - if pkg] + if path and normcase(base[: len(path)]) == normcase(path): + submodpath = [pkg for pkg in base[len(path) :].split(os.sep) if pkg] if _check_init(path, submodpath[:-1]): - return extrapath[path_].split('.') + submodpath + return extrapath[path_].split(".") + submodpath for path in map(_canonicalize_path, sys.path): if path and normcase(base).startswith(path): - modpath = [pkg for pkg in base[len(path):].split(os.sep) if pkg] + modpath = [pkg for pkg in base[len(path) :].split(os.sep) if pkg] if _check_init(path, modpath[:-1]): return modpath - raise ImportError('Unable to find module for %s in %s' % ( - filename, ', \n'.join(sys.path))) + raise ImportError("Unable to find module for %s in %s" % (filename, ", \n".join(sys.path))) -def file_from_modpath(modpath: List[str], path: Optional[Any] = None, context_file: Optional[str] = None) -> Optional[str]: +def file_from_modpath( + modpath: List[str], path: Optional[Any] = None, context_file: Optional[str] = None +) -> Optional[str]: """given a mod path (i.e. splitted module / package name), return the corresponding file, giving priority to source file over precompiled file if it exists @@ -312,19 +325,18 @@ def file_from_modpath(modpath: List[str], path: Optional[Any] = None, context_fi context = dirname(context_file) else: context = context_file - if modpath[0] == 'xml': + if modpath[0] == "xml": # handle _xmlplus try: - return _file_from_modpath(['_xmlplus'] + modpath[1:], path, context) + return _file_from_modpath(["_xmlplus"] + modpath[1:], path, context) except ImportError: return _file_from_modpath(modpath, path, context) - elif modpath == ['os', 'path']: + elif modpath == ["os", "path"]: # FIXME: currently ignoring search_path... return os.path.__file__ return _file_from_modpath(modpath, path, context) - def get_module_part(dotted_name: str, context_file: Optional[str] = None) -> str: """given a dotted name return the module part of the name : @@ -352,9 +364,9 @@ def get_module_part(dotted_name: str, context_file: Optional[str] = None) -> str (see #10066) """ # os.path trick - if dotted_name.startswith('os.path'): - return 'os.path' - parts = dotted_name.split('.') + if dotted_name.startswith("os.path"): + return "os.path" + parts = dotted_name.split(".") if context_file is not None: # first check for builtin module which won't be considered latter # in that case (path != None) @@ -365,27 +377,27 @@ def get_module_part(dotted_name: str, context_file: Optional[str] = None) -> str # don't use += or insert, we want a new list to be created ! path: Optional[List] = None starti = 0 - if parts[0] == '': - assert context_file is not None, \ - 'explicit relative import, but no context_file?' - path = [] # prevent resolving the import non-relatively + if parts[0] == "": + assert context_file is not None, "explicit relative import, but no context_file?" + path = [] # prevent resolving the import non-relatively starti = 1 - while parts[starti] == '': # for all further dots: change context + while parts[starti] == "": # for all further dots: change context starti += 1 assert context_file is not None context_file = dirname(context_file) for i in range(starti, len(parts)): try: - file_from_modpath(parts[starti:i+1], - path=path, context_file=context_file) + file_from_modpath(parts[starti : i + 1], path=path, context_file=context_file) except ImportError: if not i >= max(1, len(parts) - 2): raise - return '.'.join(parts[:i]) + return ".".join(parts[:i]) return dotted_name -def get_modules(package: str, src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST) -> List[str]: +def get_modules( + package: str, src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST +) -> List[str]: """given a package directory return a list of all available python modules in the package and its subpackages @@ -410,21 +422,20 @@ def get_modules(package: str, src_directory: str, blacklist: Sequence[str] = STD for directory, dirnames, filenames in os.walk(src_directory): _handle_blacklist(blacklist, dirnames, filenames) # check for __init__.py - if not '__init__.py' in filenames: + if not "__init__.py" in filenames: dirnames[:] = () continue if directory != src_directory: - dir_package = directory[len(src_directory):].replace(os.sep, '.') + dir_package = directory[len(src_directory) :].replace(os.sep, ".") modules.append(package + dir_package) for filename in filenames: - if _is_python_file(filename) and filename != '__init__.py': + if _is_python_file(filename) and filename != "__init__.py": src = join(directory, filename) - module = package + src[len(src_directory):-3] - modules.append(module.replace(os.sep, '.')) + module = package + src[len(src_directory) : -3] + modules.append(module.replace(os.sep, ".")) return modules - def get_module_files(src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST) -> List[str]: """given a package directory return a list of all available python module's files in the package and its subpackages @@ -447,7 +458,7 @@ def get_module_files(src_directory: str, blacklist: Sequence[str] = STD_BLACKLIS for directory, dirnames, filenames in os.walk(src_directory): _handle_blacklist(blacklist, dirnames, filenames) # check for __init__.py - if not '__init__.py' in filenames: + if not "__init__.py" in filenames: dirnames[:] = () continue for filename in filenames: @@ -473,7 +484,7 @@ def get_source_file(filename: str, include_no_ext: bool = False) -> str: """ base, orig_ext = splitext(abspath(filename)) for ext in PY_SOURCE_EXTS: - source_path = '%s.%s' % (base, ext) + source_path = "%s.%s" % (base, ext) if exists(source_path): return source_path if include_no_ext and not orig_ext and exists(base): @@ -485,7 +496,7 @@ def cleanup_sys_modules(directories): """remove submodules of `directories` from `sys.modules`""" cleaned = [] for modname, module in list(sys.modules.items()): - modfile = getattr(module, '__file__', None) + modfile = getattr(module, "__file__", None) if modfile: for directory in directories: if modfile.startswith(directory): @@ -515,7 +526,9 @@ def is_python_source(filename): return splitext(filename)[1][1:] in PY_SOURCE_EXTS -def is_standard_module(modname: str, std_path: Union[List[str], Tuple[str]] = (STD_LIB_DIR,)) -> bool: +def is_standard_module( + modname: str, std_path: Union[List[str], Tuple[str]] = (STD_LIB_DIR,) +) -> bool: """try to guess if a module is a standard python module (by default, see `std_path` parameter's description) @@ -535,7 +548,7 @@ def is_standard_module(modname: str, std_path: Union[List[str], Tuple[str]] = (S Note: this function is known to return wrong values when inside virtualenv. See https://www.logilab.org/ticket/294756. """ - modname = modname.split('.')[0] + modname = modname.split(".")[0] try: filename = file_from_modpath([modname]) except ImportError as ex: @@ -556,7 +569,6 @@ def is_standard_module(modname: str, std_path: Union[List[str], Tuple[str]] = (S return False - def is_relative(modname: str, from_file: str) -> bool: """return true if the given module name is relative to the given file name @@ -577,7 +589,7 @@ def is_relative(modname: str, from_file: str) -> bool: if from_file in sys.path: return False try: - find_module(modname.split('.')[0], [from_file]) + find_module(modname.split(".")[0], [from_file]) return True except ImportError: return False @@ -585,7 +597,10 @@ def is_relative(modname: str, from_file: str) -> bool: # internal only functions ##################################################### -def _file_from_modpath(modpath: List[str], path: Optional[Any] = None, context: Optional[str] = None) -> Optional[str]: + +def _file_from_modpath( + modpath: List[str], path: Optional[Any] = None, context: Optional[str] = None +) -> Optional[str]: """given a mod path (i.e. splitted module / package name), return the corresponding file @@ -614,15 +629,20 @@ def _file_from_modpath(modpath: List[str], path: Optional[Any] = None, context: mp_filename = _has_init(mp_filename) return mp_filename -def _search_zip(modpath: List[str], pic: Dict[str, Optional[FileFinder]]) -> Tuple[object, str, str]: + +def _search_zip( + modpath: List[str], pic: Dict[str, Optional[FileFinder]] +) -> Tuple[object, str, str]: for filepath, importer in pic.items(): if importer is not None: if importer.find_module(modpath[0]): - if not importer.find_module('/'.join(modpath)): - raise ImportError('No module named %s in %s/%s' % ( - '.'.join(modpath[1:]), filepath, modpath)) - return ZIPFILE, abspath(filepath) + '/' + '/'.join(modpath), filepath - raise ImportError('No module named %s' % '.'.join(modpath)) + if not importer.find_module("/".join(modpath)): + raise ImportError( + "No module named %s in %s/%s" % (".".join(modpath[1:]), filepath, modpath) + ) + return ZIPFILE, abspath(filepath) + "/" + "/".join(modpath), filepath + raise ImportError("No module named %s" % ".".join(modpath)) + try: import pkg_resources @@ -635,11 +655,14 @@ except ImportError: def _is_namespace(modname: str) -> bool: # mypy: Module has no attribute "_namespace_packages"; maybe "fixup_namespace_packages"?" # but is still has? or is it a failure from python3 port? - return (pkg_resources is not None - and modname in pkg_resources._namespace_packages) # type: ignore + return ( + pkg_resources is not None and modname in pkg_resources._namespace_packages + ) # type: ignore -def _module_file(modpath: List[str], path: Optional[List[str]] = None) -> Tuple[Union[int, object], Optional[str]]: +def _module_file( + modpath: List[str], path: Optional[List[str]] = None +) -> Tuple[Union[int, object], Optional[str]]: """get a module type / file path :type modpath: list or tuple @@ -670,7 +693,7 @@ def _module_file(modpath: List[str], path: Optional[List[str]] = None) -> Tuple[ except AttributeError: checkeggs = False # pkg_resources support (aka setuptools namespace packages) - if (_is_namespace(modpath[0]) and modpath[0] in sys.modules): + if _is_namespace(modpath[0]) and modpath[0] in sys.modules: # setuptools has added into sys.modules a module object with proper # __path__, get back information from there module = sys.modules[modpath.pop(0)] @@ -720,31 +743,30 @@ def _module_file(modpath: List[str], path: Optional[List[str]] = None) -> Tuple[ mtype = mp_desc[2] if modpath: if mtype != PKG_DIRECTORY: - raise ImportError('No module %s in %s' % ('.'.join(modpath), - '.'.join(imported))) + raise ImportError("No module %s in %s" % (".".join(modpath), ".".join(imported))) # XXX guess if package is using pkgutil.extend_path by looking for # those keywords in the first four Kbytes try: - with open(join(mp_filename, '__init__.py')) as stream: + with open(join(mp_filename, "__init__.py")) as stream: data = stream.read(4096) except IOError: path = [mp_filename] else: - if 'pkgutil' in data and 'extend_path' in data: + if "pkgutil" in data and "extend_path" in data: # extend_path is called, search sys.path for module/packages # of this name see pkgutil.extend_path documentation - path = [join(p, *imported) for p in sys.path - if isdir(join(p, *imported))] + path = [join(p, *imported) for p in sys.path if isdir(join(p, *imported))] else: path = [mp_filename] return mtype, mp_filename + def _is_python_file(filename: str) -> bool: """return true if the given filename should be considered as a python file .pyc and .pyo are ignored """ - for ext in ('.py', '.so', '.pyd', '.pyw'): + for ext in (".py", ".so", ".pyd", ".pyw"): if filename.endswith(ext): return True return False @@ -754,10 +776,10 @@ def _has_init(directory: str) -> Optional[str]: """if the given directory has a valid __init__ file, return its path, else return None """ - mod_or_pack = join(directory, '__init__') + mod_or_pack = join(directory, "__init__") - for ext in PY_SOURCE_EXTS + ('pyc', 'pyo'): - if exists(mod_or_pack + '.' + ext): - return mod_or_pack + '.' + ext + for ext in PY_SOURCE_EXTS + ("pyc", "pyo"): + if exists(mod_or_pack + "." + ext): + return mod_or_pack + "." + ext return None diff --git a/logilab/common/optik_ext.py b/logilab/common/optik_ext.py index 11e2155..3f321b5 100644 --- a/logilab/common/optik_ext.py +++ b/logilab/common/optik_ext.py @@ -62,33 +62,44 @@ from optparse import Values, IndentedHelpFormatter, OptionGroup from _io import StringIO # python >= 2.3 -from optparse import OptionParser as BaseParser, Option as BaseOption, \ - OptionGroup, OptionContainer, OptionValueError, OptionError, \ - Values, HelpFormatter, NO_DEFAULT, SUPPRESS_HELP +from optparse import ( + OptionParser as BaseParser, + Option as BaseOption, + OptionGroup, + OptionContainer, + OptionValueError, + OptionError, + Values, + HelpFormatter, + NO_DEFAULT, + SUPPRESS_HELP, +) try: from mx import DateTime + HAS_MX_DATETIME = True except ImportError: HAS_MX_DATETIME = False -from logilab.common.textutils import splitstrip, TIME_UNITS, BYTE_UNITS, \ - apply_units +from logilab.common.textutils import splitstrip, TIME_UNITS, BYTE_UNITS, apply_units def check_regexp(option, opt, value): """check a regexp value by trying to compile it return the compiled regexp """ - if hasattr(value, 'pattern'): + if hasattr(value, "pattern"): return value try: return re.compile(value) except ValueError: - raise OptionValueError( - "option %s: invalid regexp value: %r" % (opt, value)) + raise OptionValueError("option %s: invalid regexp value: %r" % (opt, value)) + -def check_csv(option: Optional['Option'], opt: str, value: Union[List[str], Tuple[str, ...], str]) -> Union[List[str], Tuple[str, ...]]: +def check_csv( + option: Optional["Option"], opt: str, value: Union[List[str], Tuple[str, ...], str] +) -> Union[List[str], Tuple[str, ...]]: """check a csv value by trying to split it return the list of separated values """ @@ -97,23 +108,26 @@ def check_csv(option: Optional['Option'], opt: str, value: Union[List[str], Tupl try: return splitstrip(value) except ValueError: - raise OptionValueError( - "option %s: invalid csv value: %r" % (opt, value)) + raise OptionValueError("option %s: invalid csv value: %r" % (opt, value)) + -def check_yn(option: Optional['Option'], opt: str, value: Union[bool, str]) -> bool: +def check_yn(option: Optional["Option"], opt: str, value: Union[bool, str]) -> bool: """check a yn value return true for yes and false for no """ if isinstance(value, int): return bool(value) - if value in ('y', 'yes'): + if value in ("y", "yes"): return True - if value in ('n', 'no'): + if value in ("n", "no"): return False msg = "option %s: invalid yn value %r, should be in (y, yes, n, no)" raise OptionValueError(msg % (opt, value)) -def check_named(option: Optional[Any], opt: str, value: Union[Dict[str, str], str]) -> Dict[str, str]: + +def check_named( + option: Optional[Any], opt: str, value: Union[Dict[str, str], str] +) -> Dict[str, str]: """check a named value return a dictionary containing (name, value) associations """ @@ -124,22 +138,24 @@ def check_named(option: Optional[Any], opt: str, value: Union[Dict[str, str], st # mypy: Argument 1 to "append" of "list" has incompatible type "List[str]"; # mypy: expected "Tuple[str, str]" # we know that the split will give a 2 items list - if value.find('=') != -1: - values.append(value.split('=', 1)) # type: ignore - elif value.find(':') != -1: - values.append(value.split(':', 1)) # type: ignore + if value.find("=") != -1: + values.append(value.split("=", 1)) # type: ignore + elif value.find(":") != -1: + values.append(value.split(":", 1)) # type: ignore if values: return dict(values) msg = "option %s: invalid named value %r, should be <NAME>=<VALUE> or \ <NAME>:<VALUE>" raise OptionValueError(msg % (opt, value)) + def check_password(option, opt, value): """check a password value (can't be empty) """ # no actual checking, monkey patch if you want more return value + def check_file(option, opt, value): """check a file value return the filepath @@ -149,6 +165,7 @@ def check_file(option, opt, value): msg = "option %s: file %r does not exist" raise OptionValueError(msg % (opt, value)) + # XXX use python datetime def check_date(option, opt, value): """check a file value @@ -156,9 +173,9 @@ def check_date(option, opt, value): """ try: return DateTime.strptime(value, "%Y/%m/%d") - except DateTime.Error : - raise OptionValueError( - "expected format of %s is yyyy/mm/dd" % opt) + except DateTime.Error: + raise OptionValueError("expected format of %s is yyyy/mm/dd" % opt) + def check_color(option, opt, value): """check a color value and returns it @@ -166,23 +183,25 @@ def check_color(option, opt, value): checks hexadecimal forms """ # Case (1) : color label, we trust the end-user - if re.match('[a-z0-9 ]+$', value, re.I): + if re.match("[a-z0-9 ]+$", value, re.I): return value # Case (2) : only accepts hexadecimal forms - if re.match('#[a-f0-9]{6}', value, re.I): + if re.match("#[a-f0-9]{6}", value, re.I): return value # Else : not a color label neither a valid hexadecimal form => error msg = "option %s: invalid color : %r, should be either hexadecimal \ value or predefined color" raise OptionValueError(msg % (opt, value)) + def check_time(option, opt, value): if isinstance(value, (int, float)): return value return apply_units(value, TIME_UNITS) -def check_bytes(option: Optional['Option'], opt: str, value: Any) -> int: - if hasattr(value, '__int__'): + +def check_bytes(option: Optional["Option"], opt: str, value: Any) -> int: + if hasattr(value, "__int__"): return value # mypy: Incompatible return value type (got "Union[float, int]", expected "int") # we force "int" using "final=int" @@ -192,24 +211,34 @@ def check_bytes(option: Optional['Option'], opt: str, value: Any) -> int: class Option(BaseOption): """override optik.Option to add some new option types """ - TYPES = BaseOption.TYPES + ('regexp', 'csv', 'yn', 'named', 'password', - 'multiple_choice', 'file', 'color', - 'time', 'bytes') - ATTRS = BaseOption.ATTRS + ['hide', 'level'] + + TYPES = BaseOption.TYPES + ( + "regexp", + "csv", + "yn", + "named", + "password", + "multiple_choice", + "file", + "color", + "time", + "bytes", + ) + ATTRS = BaseOption.ATTRS + ["hide", "level"] TYPE_CHECKER = copy(BaseOption.TYPE_CHECKER) - TYPE_CHECKER['regexp'] = check_regexp - TYPE_CHECKER['csv'] = check_csv - TYPE_CHECKER['yn'] = check_yn - TYPE_CHECKER['named'] = check_named - TYPE_CHECKER['multiple_choice'] = check_csv - TYPE_CHECKER['file'] = check_file - TYPE_CHECKER['color'] = check_color - TYPE_CHECKER['password'] = check_password - TYPE_CHECKER['time'] = check_time - TYPE_CHECKER['bytes'] = check_bytes + TYPE_CHECKER["regexp"] = check_regexp + TYPE_CHECKER["csv"] = check_csv + TYPE_CHECKER["yn"] = check_yn + TYPE_CHECKER["named"] = check_named + TYPE_CHECKER["multiple_choice"] = check_csv + TYPE_CHECKER["file"] = check_file + TYPE_CHECKER["color"] = check_color + TYPE_CHECKER["password"] = check_password + TYPE_CHECKER["time"] = check_time + TYPE_CHECKER["bytes"] = check_bytes if HAS_MX_DATETIME: - TYPES += ('date',) - TYPE_CHECKER['date'] = check_date + TYPES += ("date",) + TYPE_CHECKER["date"] = check_date def __init__(self, *opts: str, **attrs: Any) -> None: BaseOption.__init__(self, *opts, **attrs) @@ -224,15 +253,16 @@ class Option(BaseOption): # mypy: "Option" has no attribute "choices" # we know that option of this type has this attribute if self.choices is None: # type: ignore - raise OptionError( - "must supply a list of choices for type 'choice'", self) + raise OptionError("must supply a list of choices for type 'choice'", self) elif not isinstance(self.choices, (tuple, list)): # type: ignore raise OptionError( "choices must be a list of strings ('%s' supplied)" - % str(type(self.choices)).split("'")[1], self) # type: ignore + % str(type(self.choices)).split("'")[1], + self, + ) # type: ignore elif self.choices is not None: # type: ignore - raise OptionError( - "must not supply choices for type %r" % self.type, self) + raise OptionError("must not supply choices for type %r" % self.type, self) + # mypy: Unsupported target for indexed assignment # black magic? BaseOption.CHECK_METHODS[2] = _check_choice # type: ignore @@ -241,7 +271,7 @@ class Option(BaseOption): # First, convert the value(s) to the right type. Howl if any # value(s) are bogus. value = self.convert_value(opt, value) - if self.type == 'named': + if self.type == "named": assert self.dest is not None existant = getattr(values, self.dest) if existant: @@ -253,13 +283,13 @@ class Option(BaseOption): # mypy: Argument 2 to "take_action" of "Option" has incompatible type "Optional[str]"; # mypy: expected "str" # is it ok? - return self.take_action( - self.action, self.dest, opt, value, values, parser) # type: ignore + return self.take_action(self.action, self.dest, opt, value, values, parser) # type: ignore class OptionParser(BaseParser): """override optik.OptionParser to use our Option class """ + def __init__(self, option_class: type = Option, *args: Any, **kwargs: Any) -> None: # mypy: Argument "option_class" to "__init__" of "OptionParser" has incompatible type # mypy: "type"; expected "Option" @@ -269,7 +299,7 @@ class OptionParser(BaseParser): def format_option_help(self, formatter: Optional[HelpFormatter] = None) -> str: if formatter is None: formatter = self.formatter - outputlevel = getattr(formatter, 'output_level', 0) + outputlevel = getattr(formatter, "output_level", 0) formatter.store_option_strings(self) result = [] result.append(formatter.format_heading("Options")) @@ -281,7 +311,8 @@ class OptionParser(BaseParser): # mypy: "OptionParser" has no attribute "level" # but it has one no? if group.level <= outputlevel and ( # type: ignore - group.description or level_options(group, outputlevel)): + group.description or level_options(group, outputlevel) + ): result.append(group.format_help(formatter)) result.append("\n") formatter.dedent() @@ -293,19 +324,25 @@ class OptionParser(BaseParser): # monkeypatching OptionGroup.level = 0 # type: ignore + def level_options(group: BaseParser, outputlevel: int) -> List[BaseOption]: # mypy: "Option" has no attribute "help" # but it does - return [option for option in group.option_list - if (getattr(option, 'level', 0) or 0) <= outputlevel - and not option.help is SUPPRESS_HELP] # type: ignore + return [ + option + for option in group.option_list + if (getattr(option, "level", 0) or 0) <= outputlevel and not option.help is SUPPRESS_HELP + ] # type: ignore + def format_option_help(self, formatter): result = [] - outputlevel = getattr(formatter, 'output_level', 0) or 0 + outputlevel = getattr(formatter, "output_level", 0) or 0 for option in level_options(self, outputlevel): result.append(formatter.format_option(option)) return "".join(result) + + # mypy error: Cannot assign to a method # but we still do it because magic OptionContainer.format_option_help = format_option_help # type: ignore @@ -314,16 +351,17 @@ OptionContainer.format_option_help = format_option_help # type: ignore class ManHelpFormatter(HelpFormatter): """Format help using man pages ROFF format""" - def __init__ (self, - indent_increment: int = 0, - max_help_position: int = 24, - width: int = 79, - short_first: int = 0) -> None: - HelpFormatter.__init__ ( - self, indent_increment, max_help_position, width, short_first) + def __init__( + self, + indent_increment: int = 0, + max_help_position: int = 24, + width: int = 79, + short_first: int = 0, + ) -> None: + HelpFormatter.__init__(self, indent_increment, max_help_position, width, short_first) def format_heading(self, heading: str) -> str: - return '.SH %s\n' % heading.upper() + return ".SH %s\n" % heading.upper() def format_description(self, description): return description @@ -342,12 +380,15 @@ class ManHelpFormatter(HelpFormatter): # mypy: "OptionParser"; expected "Option" # it still works? help_text = self.expand_default(option) # type: ignore - help = ' '.join([l.strip() for l in help_text.splitlines()]) + help = " ".join([l.strip() for l in help_text.splitlines()]) else: - help = '' - return '''.IP "%s" + help = "" + return """.IP "%s" %s -''' % (optstring, help) +""" % ( + optstring, + help, + ) def format_head(self, optparser: OptionParser, pkginfo: attrdict, section: int = 1) -> str: long_desc = "" @@ -355,43 +396,54 @@ class ManHelpFormatter(HelpFormatter): short_desc = self.format_short_description(pgm, pkginfo.description) if hasattr(pkginfo, "long_desc"): long_desc = self.format_long_description(pgm, pkginfo.long_desc) - return '%s\n%s\n%s\n%s' % (self.format_title(pgm, section), - short_desc, self.format_synopsis(pgm), - long_desc) + return "%s\n%s\n%s\n%s" % ( + self.format_title(pgm, section), + short_desc, + self.format_synopsis(pgm), + long_desc, + ) def format_title(self, pgm: str, section: int) -> str: - date = '-'.join([str(num) for num in time.localtime()[:3]]) + date = "-".join([str(num) for num in time.localtime()[:3]]) return '.TH %s %s "%s" %s' % (pgm, section, date, pgm) def format_short_description(self, pgm: str, short_desc: str) -> str: - return '''.SH NAME + return """.SH NAME .B %s \- %s -''' % (pgm, short_desc.strip()) +""" % ( + pgm, + short_desc.strip(), + ) def format_synopsis(self, pgm: str) -> str: - return '''.SH SYNOPSIS + return ( + """.SH SYNOPSIS .B %s [ .I OPTIONS ] [ .I <arguments> ] -''' % pgm +""" + % pgm + ) def format_long_description(self, pgm, long_desc): - long_desc = '\n'.join([line.lstrip() - for line in long_desc.splitlines()]) - long_desc = long_desc.replace('\n.\n', '\n\n') + long_desc = "\n".join([line.lstrip() for line in long_desc.splitlines()]) + long_desc = long_desc.replace("\n.\n", "\n\n") if long_desc.lower().startswith(pgm): - long_desc = long_desc[len(pgm):] - return '''.SH DESCRIPTION + long_desc = long_desc[len(pgm) :] + return """.SH DESCRIPTION .B %s %s -''' % (pgm, long_desc.strip()) +""" % ( + pgm, + long_desc.strip(), + ) def format_tail(self, pkginfo: attrdict) -> str: - tail = '''.SH SEE ALSO + tail = """.SH SEE ALSO /usr/share/doc/pythonX.Y-%s/ .SH BUGS @@ -400,18 +452,32 @@ Please report bugs on the project\'s mailing list: .SH AUTHOR %s <%s> -''' % (getattr(pkginfo, 'debian_name', pkginfo.modname), - pkginfo.mailinglist, pkginfo.author, pkginfo.author_email) +""" % ( + getattr(pkginfo, "debian_name", pkginfo.modname), + pkginfo.mailinglist, + pkginfo.author, + pkginfo.author_email, + ) if hasattr(pkginfo, "copyright"): - tail += ''' + tail += ( + """ .SH COPYRIGHT %s -''' % pkginfo.copyright +""" + % pkginfo.copyright + ) return tail -def generate_manpage(optparser: OptionParser, pkginfo: attrdict, section: int = 1, stream: StringIO = sys.stdout, level: int = 0) -> None: + +def generate_manpage( + optparser: OptionParser, + pkginfo: attrdict, + section: int = 1, + stream: StringIO = sys.stdout, + level: int = 0, +) -> None: """generate a man page from an optik parser""" formatter = ManHelpFormatter() # mypy: "ManHelpFormatter" has no attribute "output_level" @@ -423,5 +489,4 @@ def generate_manpage(optparser: OptionParser, pkginfo: attrdict, section: int = print(formatter.format_tail(pkginfo), file=stream) -__all__ = ('OptionParser', 'Option', 'OptionGroup', 'OptionValueError', - 'Values') +__all__ = ("OptionParser", "Option", "OptionGroup", "OptionValueError", "Values") diff --git a/logilab/common/optparser.py b/logilab/common/optparser.py index aa17750..8dd6b36 100644 --- a/logilab/common/optparser.py +++ b/logilab/common/optparser.py @@ -34,32 +34,37 @@ from __future__ import print_function __docformat__ = "restructuredtext en" from warnings import warn -warn('lgc.optparser module is deprecated, use lgc.clcommands instead', DeprecationWarning, - stacklevel=2) + +warn( + "lgc.optparser module is deprecated, use lgc.clcommands instead", + DeprecationWarning, + stacklevel=2, +) import sys import optparse -class OptionParser(optparse.OptionParser): +class OptionParser(optparse.OptionParser): def __init__(self, *args, **kwargs): optparse.OptionParser.__init__(self, *args, **kwargs) self._commands = {} self.min_args, self.max_args = 0, 1 - def add_command(self, name, mod_or_funcs, help=''): + def add_command(self, name, mod_or_funcs, help=""): """name of the command, name of module or tuple of functions (run, add_options) """ - assert isinstance(mod_or_funcs, str) or isinstance(mod_or_funcs, tuple), \ - "mod_or_funcs has to be a module name or a tuple of functions" + assert isinstance(mod_or_funcs, str) or isinstance( + mod_or_funcs, tuple + ), "mod_or_funcs has to be a module name or a tuple of functions" self._commands[name] = (mod_or_funcs, help) def print_main_help(self): optparse.OptionParser.print_help(self) - print('\ncommands:') + print("\ncommands:") for cmdname, (_, help) in self._commands.items(): - print('% 10s - %s' % (cmdname, help)) + print("% 10s - %s" % (cmdname, help)) def parse_command(self, args): if len(args) == 0: @@ -68,25 +73,23 @@ class OptionParser(optparse.OptionParser): cmd = args[0] args = args[1:] if cmd not in self._commands: - if cmd in ('-h', '--help'): + if cmd in ("-h", "--help"): self.print_main_help() sys.exit(0) elif self.version is not None and cmd == "--version": self.print_version() sys.exit(0) - self.error('unknown command') - self.prog = '%s %s' % (self.prog, cmd) + self.error("unknown command") + self.prog = "%s %s" % (self.prog, cmd) mod_or_f, help = self._commands[cmd] # optparse inserts self.description between usage and options help self.description = help if isinstance(mod_or_f, str): - exec('from %s import run, add_options' % mod_or_f) + exec("from %s import run, add_options" % mod_or_f) else: run, add_options = mod_or_f add_options(self) (options, args) = self.parse_args(args) if not (self.min_args <= len(args) <= self.max_args): - self.error('incorrect number of arguments') + self.error("incorrect number of arguments") return run, options, args - - diff --git a/logilab/common/proc.py b/logilab/common/proc.py index 30e9494..2d2e78c 100644 --- a/logilab/common/proc.py +++ b/logilab/common/proc.py @@ -37,15 +37,19 @@ from time import time from logilab.common.tree import Node -class NoSuchProcess(Exception): pass + +class NoSuchProcess(Exception): + pass + def proc_exists(pid): """check the a pid is registered in /proc raise NoSuchProcess exception if not """ - if not os.path.exists('/proc/%s' % pid): + if not os.path.exists("/proc/%s" % pid): raise NoSuchProcess() + PPID = 3 UTIME = 13 STIME = 14 @@ -53,6 +57,7 @@ CUTIME = 15 CSTIME = 16 VSIZE = 22 + class ProcInfo(Node): """provide access to process information found in /proc""" @@ -60,19 +65,18 @@ class ProcInfo(Node): self.pid = int(pid) Node.__init__(self, self.pid) proc_exists(self.pid) - self.file = '/proc/%s/stat' % self.pid + self.file = "/proc/%s/stat" % self.pid self.ppid = int(self.status()[PPID]) def memory_usage(self): """return the memory usage of the process in Ko""" - try : + try: return int(self.status()[VSIZE]) except IOError: return 0 def lineage_memory_usage(self): - return self.memory_usage() + sum([child.lineage_memory_usage() - for child in self.children]) + return self.memory_usage() + sum([child.lineage_memory_usage() for child in self.children]) def time(self, children=0): """return the number of jiffies that this process has been scheduled @@ -90,13 +94,14 @@ class ProcInfo(Node): def name(self): """return the process name found in /proc/<pid>/stat """ - return self.status()[1].strip('()') + return self.status()[1].strip("()") def age(self): """return the age of the process """ return os.stat(self.file)[stat.ST_MTIME] + class ProcInfoLoader: """manage process information""" @@ -105,7 +110,7 @@ class ProcInfoLoader: def list_pids(self): """return a list of existent process ids""" - for subdir in os.listdir('/proc'): + for subdir in os.listdir("/proc"): if subdir.isdigit(): yield int(subdir) @@ -120,7 +125,6 @@ class ProcInfoLoader: self._loaded[pid] = procinfo return procinfo - def load_all(self): """load all processes information""" for pid in self.list_pids(): @@ -135,22 +139,29 @@ class ProcInfoLoader: class ResourceError(Exception): """Error raise when resource limit is reached""" + limit = "Unknown Resource Limit" class XCPUError(ResourceError): """Error raised when CPU Time limit is reached""" + limit = "CPU Time" + class LineageMemoryError(ResourceError): """Error raised when the total amount of memory used by a process and it's child is reached""" + limit = "Lineage total Memory" + class TimeoutError(ResourceError): """Error raised when the process is running for to much time""" + limit = "Real Time" + # Can't use subclass because the StandardError MemoryError raised RESOURCE_LIMIT_EXCEPTION = (ResourceError, MemoryError) @@ -159,6 +170,7 @@ class MemorySentinel(Thread): """A class checking a process don't use too much memory in a separated daemonic thread """ + def __init__(self, interval, memory_limit, gpid=os.getpid()): Thread.__init__(self, target=self._run, name="Test.Sentinel") self.memory_limit = memory_limit @@ -180,9 +192,7 @@ class MemorySentinel(Thread): class ResourceController: - - def __init__(self, max_cpu_time=None, max_time=None, max_memory=None, - max_reprieve=60): + def __init__(self, max_cpu_time=None, max_time=None, max_memory=None, max_reprieve=60): if SIGXCPU == -1: raise RuntimeError("Unsupported platform") self.max_time = max_time @@ -230,13 +240,12 @@ class ResourceController: def setup_limit(self): """set up the process limit""" - assert currentThread().getName() == 'MainThread' + assert currentThread().getName() == "MainThread" os.setpgrp() if self._limit_set <= 0: if self.max_time is not None: self._old_usr2_hdlr = signal(SIGUSR2, self._hangle_sig_timeout) - self._timer = Timer(max(1, int(self.max_time) - self._elapse_time), - self._time_out) + self._timer = Timer(max(1, int(self.max_time) - self._elapse_time), self._time_out) self._start_time = int(time()) self._timer.start() if self.max_cpu_time is not None: @@ -245,7 +254,7 @@ class ResourceController: self._old_sigxcpu_hdlr = signal(SIGXCPU, self._handle_sigxcpu) setrlimit(RLIMIT_CPU, cpu_limit) if self.max_memory is not None: - self._msentinel = MemorySentinel(1, int(self.max_memory) ) + self._msentinel = MemorySentinel(1, int(self.max_memory)) self._old_max_memory = getrlimit(RLIMIT_AS) self._old_usr1_hdlr = signal(SIGUSR1, self._hangle_sig_memory) as_limit = (int(self.max_memory), self._old_max_memory[1]) @@ -258,7 +267,7 @@ class ResourceController: if self._limit_set > 0: if self.max_time is not None: self._timer.cancel() - self._elapse_time += int(time())-self._start_time + self._elapse_time += int(time()) - self._start_time self._timer = None signal(SIGUSR2, self._old_usr2_hdlr) if self.max_cpu_time is not None: diff --git a/logilab/common/pytest.py b/logilab/common/pytest.py index 6819c01..0f89ddf 100644 --- a/logilab/common/pytest.py +++ b/logilab/common/pytest.py @@ -124,6 +124,7 @@ import traceback from inspect import isgeneratorfunction, isclass, FrameInfo from random import shuffle from itertools import dropwhile + # mypy error: Module 'unittest.runner' has no attribute '_WritelnDecorator' # but it does from unittest.runner import _WritelnDecorator # type: ignore @@ -135,6 +136,7 @@ from logilab.common.deprecation import deprecated from logilab.common.fileutils import abspath_listdir from logilab.common import textutils from logilab.common import testlib, STD_BLACKLIST + # use the same unittest module as testlib from logilab.common.testlib import unittest, start_interactive_mode from logilab.common.testlib import nocoverage, pause_trace, replace_trace # bwcompat @@ -142,6 +144,7 @@ from logilab.common.debugger import Debugger, colorize_source import doctest import unittest as unittest_legacy + if not getattr(unittest_legacy, "__package__", None): try: import unittest2.suite as unittest_suite @@ -154,18 +157,24 @@ else: try: import django from logilab.common.modutils import modpath_from_file, load_module_from_modpath + DJANGO_FOUND = True except ImportError: DJANGO_FOUND = False -CONF_FILE = 'pytestconf.py' +CONF_FILE = "pytestconf.py" TESTFILE_RE = re.compile("^((unit)?test.*|smoketest)\.py$") + + def this_is_a_testfile(filename: str) -> Optional[Match]: """returns True if `filename` seems to be a test file""" return TESTFILE_RE.match(osp.basename(filename)) + TESTDIR_RE = re.compile("^(unit)?tests?$") + + def this_is_a_testdir(dirpath: str) -> Optional[Match]: """returns True if `filename` seems to be a test directory""" return TESTDIR_RE.match(osp.basename(dirpath)) @@ -176,10 +185,10 @@ def load_pytest_conf(path, parser): and / or tester. """ namespace = {} - exec(open(path, 'rb').read(), namespace) - if 'update_parser' in namespace: - namespace['update_parser'](parser) - return namespace.get('CustomPyTester', PyTester) + exec(open(path, "rb").read(), namespace) + if "update_parser" in namespace: + namespace["update_parser"](parser) + return namespace.get("CustomPyTester", PyTester) def project_root(parser, projdir=os.getcwd()): @@ -189,8 +198,7 @@ def project_root(parser, projdir=os.getcwd()): conf_file_path = osp.join(curdir, CONF_FILE) if osp.isfile(conf_file_path): testercls = load_pytest_conf(conf_file_path, parser) - while this_is_a_testdir(curdir) or \ - osp.isfile(osp.join(curdir, '__init__.py')): + while this_is_a_testdir(curdir) or osp.isfile(osp.join(curdir, "__init__.py")): newdir = osp.normpath(osp.join(curdir, os.pardir)) if newdir == curdir: break @@ -204,6 +212,7 @@ def project_root(parser, projdir=os.getcwd()): class GlobalTestReport(object): """this class holds global test statistics""" + def __init__(self): self.ran = 0 self.skipped = 0 @@ -218,7 +227,7 @@ class GlobalTestReport(object): """integrates new test information into internal statistics""" ran = testresult.testsRun self.ran += ran - self.skipped += len(getattr(testresult, 'skipped', ())) + self.skipped += len(getattr(testresult, "skipped", ())) self.failures += len(testresult.failures) self.errors += len(testresult.errors) self.ttime += ttime @@ -243,27 +252,24 @@ class GlobalTestReport(object): def __str__(self): """this is just presentation stuff""" - line1 = ['Ran %s test cases in %.2fs (%.2fs CPU)' - % (self.ran, self.ttime, self.ctime)] + line1 = ["Ran %s test cases in %.2fs (%.2fs CPU)" % (self.ran, self.ttime, self.ctime)] if self.errors: - line1.append('%s errors' % self.errors) + line1.append("%s errors" % self.errors) if self.failures: - line1.append('%s failures' % self.failures) + line1.append("%s failures" % self.failures) if self.skipped: - line1.append('%s skipped' % self.skipped) + line1.append("%s skipped" % self.skipped) modulesok = self.modulescount - len(self.errmodules) if self.errors or self.failures: - line2 = '%s modules OK (%s failed)' % (modulesok, - len(self.errmodules)) - descr = ', '.join(['%s [%s/%s]' % info for info in self.errmodules]) - line3 = '\nfailures: %s' % descr + line2 = "%s modules OK (%s failed)" % (modulesok, len(self.errmodules)) + descr = ", ".join(["%s [%s/%s]" % info for info in self.errmodules]) + line3 = "\nfailures: %s" % descr elif modulesok: - line2 = 'All %s modules OK' % modulesok - line3 = '' + line2 = "All %s modules OK" % modulesok + line3 = "" else: - return '' - return '%s\n%s%s' % (', '.join(line1), line2, line3) - + return "" + return "%s\n%s%s" % (", ".join(line1), line2, line3) def remove_local_modules_from_sys(testdir): @@ -282,7 +288,7 @@ def remove_local_modules_from_sys(testdir): for modname, mod in list(sys.modules.items()): if mod is None: continue - if not hasattr(mod, '__file__'): + if not hasattr(mod, "__file__"): # this is the case of some built-in modules like sys, imp, marshal continue modfile = mod.__file__ @@ -292,7 +298,6 @@ def remove_local_modules_from_sys(testdir): del sys.modules[modname] - class PyTester(object): """encapsulates testrun logic""" @@ -317,6 +322,7 @@ class PyTester(object): def set_errcode(self, errcode): self._errcode = errcode + errcode = property(get_errcode, set_errcode) def testall(self, exitfirst=False): @@ -358,9 +364,11 @@ class PyTester(object): restartfile = open(FILE_RESTART, "w") restartfile.close() except Exception: - print("Error while overwriting succeeded test file :", - osp.join(os.getcwd(), FILE_RESTART), - file=sys.__stderr__) + print( + "Error while overwriting succeeded test file :", + osp.join(os.getcwd(), FILE_RESTART), + file=sys.__stderr__, + ) raise # run test and collect information prog = self.testfile(filename, batchmode=True) @@ -386,17 +394,24 @@ class PyTester(object): restartfile = open(FILE_RESTART, "w") restartfile.close() except Exception: - print("Error while overwriting succeeded test file :", - osp.join(os.getcwd(), FILE_RESTART), file=sys.__stderr__) + print( + "Error while overwriting succeeded test file :", + osp.join(os.getcwd(), FILE_RESTART), + file=sys.__stderr__, + ) raise modname = osp.basename(filename)[:-3] - print((' %s ' % osp.basename(filename)).center(70, '='), - file=sys.__stderr__) + print((" %s " % osp.basename(filename)).center(70, "="), file=sys.__stderr__) try: tstart, cstart = time(), process_time() try: - testprog = SkipAwareTestProgram(modname, batchmode=batchmode, cvg=self.cvg, - options=self.options, outstream=sys.stderr) + testprog = SkipAwareTestProgram( + modname, + batchmode=batchmode, + cvg=self.cvg, + options=self.options, + outstream=sys.stderr, + ) except KeyboardInterrupt: raise except SystemExit as exc: @@ -408,9 +423,9 @@ class PyTester(object): return None except Exception: self.report.failed_to_test_module(filename) - print('unhandled exception occurred while testing', modname, - file=sys.stderr) + print("unhandled exception occurred while testing", modname, file=sys.stderr) import traceback + traceback.print_exc(file=sys.stderr) return None @@ -423,23 +438,23 @@ class PyTester(object): os.chdir(here) - class DjangoTester(PyTester): - def load_django_settings(self, dirname): """try to find project's setting and load it""" curdir = osp.abspath(dirname) previousdir = curdir - while not osp.isfile(osp.join(curdir, 'settings.py')) and \ - osp.isfile(osp.join(curdir, '__init__.py')): + while not osp.isfile(osp.join(curdir, "settings.py")) and osp.isfile( + osp.join(curdir, "__init__.py") + ): newdir = osp.normpath(osp.join(curdir, os.pardir)) if newdir == curdir: - raise AssertionError('could not find settings.py') + raise AssertionError("could not find settings.py") previousdir = curdir curdir = newdir # late django initialization - settings = load_module_from_modpath(modpath_from_file(osp.join(curdir, 'settings.py'))) + settings = load_module_from_modpath(modpath_from_file(osp.join(curdir, "settings.py"))) from django.core.management import setup_environ + setup_environ(settings) settings.DEBUG = False self.settings = settings @@ -451,6 +466,7 @@ class DjangoTester(PyTester): # Those imports must be done **after** setup_environ was called from django.test.utils import setup_test_environment from django.test.utils import create_test_db + setup_test_environment() create_test_db(verbosity=0) self.dbname = self.settings.TEST_DATABASE_NAME @@ -459,8 +475,9 @@ class DjangoTester(PyTester): # Those imports must be done **after** setup_environ was called from django.test.utils import teardown_test_environment from django.test.utils import destroy_test_db + teardown_test_environment() - print('destroying', self.dbname) + print("destroying", self.dbname) destroy_test_db(self.dbname, verbosity=0) def testall(self, exitfirst=False): @@ -468,16 +485,16 @@ class DjangoTester(PyTester): which can be considered as a testdir and runs every test there """ for dirname, dirs, files in os.walk(os.getcwd()): - for skipped in ('CVS', '.svn', '.hg'): + for skipped in ("CVS", ".svn", ".hg"): if skipped in dirs: dirs.remove(skipped) - if 'tests.py' in files: + if "tests.py" in files: if not self.testonedir(dirname, exitfirst): break dirs[:] = [] else: basename = osp.basename(dirname) - if basename in ('test', 'tests'): + if basename in ("test", "tests"): print("going into", dirname) # we found a testdir, let's explore it ! if not self.testonedir(dirname, exitfirst): @@ -492,11 +509,10 @@ class DjangoTester(PyTester): """ # special django behaviour : if tests are splitted in several files, # remove the main tests.py file and tests each test file separately - testfiles = [fpath for fpath in abspath_listdir(testdir) - if this_is_a_testfile(fpath)] + testfiles = [fpath for fpath in abspath_listdir(testdir) if this_is_a_testfile(fpath)] if len(testfiles) > 1: try: - testfiles.remove(osp.join(testdir, 'tests.py')) + testfiles.remove(osp.join(testdir, "tests.py")) except ValueError: pass for filename in testfiles: @@ -519,8 +535,7 @@ class DjangoTester(PyTester): os.chdir(dirname) self.load_django_settings(dirname) modname = osp.basename(filename)[:-3] - print((' %s ' % osp.basename(filename)).center(70, '='), - file=sys.stderr) + print((" %s " % osp.basename(filename)).center(70, "="), file=sys.stderr) try: try: tstart, cstart = time(), process_time() @@ -534,10 +549,11 @@ class DjangoTester(PyTester): raise except Exception as exc: import traceback + traceback.print_exc() self.report.failed_to_test_module(filename) - print('unhandled exception occurred while testing', modname) - print('error: %s' % exc) + print("unhandled exception occurred while testing", modname) + print("error: %s" % exc) return None finally: self.after_testfile() @@ -549,9 +565,11 @@ def make_parser(): """creates the OptionParser instance """ from optparse import OptionParser + parser = OptionParser(usage=PYTEST_DOC) parser.newargs = [] + def rebuild_cmdline(option, opt, value, parser): """carry the option to unittest_main""" parser.newargs.append(opt) @@ -564,50 +582,89 @@ def make_parser(): setattr(parser.values, option.dest, True) def capture_and_rebuild(option, opt, value, parser): - warnings.simplefilter('ignore', DeprecationWarning) + warnings.simplefilter("ignore", DeprecationWarning) rebuild_cmdline(option, opt, value, parser) # logilab-pytest options - parser.add_option('-t', dest='testdir', default=None, - help="directory where the tests will be found") - parser.add_option('-d', dest='dbc', default=False, - action="store_true", help="enable design-by-contract") + parser.add_option( + "-t", dest="testdir", default=None, help="directory where the tests will be found" + ) + parser.add_option( + "-d", dest="dbc", default=False, action="store_true", help="enable design-by-contract" + ) # unittest_main options provided and passed through logilab-pytest - parser.add_option('-v', '--verbose', callback=rebuild_cmdline, - action="callback", help="Verbose output") - parser.add_option('-i', '--pdb', callback=rebuild_and_store, - dest="pdb", action="callback", - help="Enable test failure inspection") - parser.add_option('-x', '--exitfirst', callback=rebuild_and_store, - dest="exitfirst", default=False, - action="callback", help="Exit on first failure " - "(only make sense when logilab-pytest run one test file)") - parser.add_option('-R', '--restart', callback=rebuild_and_store, - dest="restart", default=False, - action="callback", - help="Restart tests from where it failed (implies exitfirst) " - "(only make sense if tests previously ran with exitfirst only)") - parser.add_option('--color', callback=rebuild_cmdline, - action="callback", - help="colorize tracebacks") - parser.add_option('-s', '--skip', - # XXX: I wish I could use the callback action but it - # doesn't seem to be able to get the value - # associated to the option - action="store", dest="skipped", default=None, - help="test names matching this name will be skipped " - "to skip several patterns, use commas") - parser.add_option('-q', '--quiet', callback=rebuild_cmdline, - action="callback", help="Minimal output") - parser.add_option('-P', '--profile', default=None, dest='profile', - help="Profile execution and store data in the given file") - parser.add_option('-m', '--match', default=None, dest='tags_pattern', - help="only execute test whose tag match the current pattern") + parser.add_option( + "-v", "--verbose", callback=rebuild_cmdline, action="callback", help="Verbose output" + ) + parser.add_option( + "-i", + "--pdb", + callback=rebuild_and_store, + dest="pdb", + action="callback", + help="Enable test failure inspection", + ) + parser.add_option( + "-x", + "--exitfirst", + callback=rebuild_and_store, + dest="exitfirst", + default=False, + action="callback", + help="Exit on first failure " "(only make sense when logilab-pytest run one test file)", + ) + parser.add_option( + "-R", + "--restart", + callback=rebuild_and_store, + dest="restart", + default=False, + action="callback", + help="Restart tests from where it failed (implies exitfirst) " + "(only make sense if tests previously ran with exitfirst only)", + ) + parser.add_option( + "--color", callback=rebuild_cmdline, action="callback", help="colorize tracebacks" + ) + parser.add_option( + "-s", + "--skip", + # XXX: I wish I could use the callback action but it + # doesn't seem to be able to get the value + # associated to the option + action="store", + dest="skipped", + default=None, + help="test names matching this name will be skipped " + "to skip several patterns, use commas", + ) + parser.add_option( + "-q", "--quiet", callback=rebuild_cmdline, action="callback", help="Minimal output" + ) + parser.add_option( + "-P", + "--profile", + default=None, + dest="profile", + help="Profile execution and store data in the given file", + ) + parser.add_option( + "-m", + "--match", + default=None, + dest="tags_pattern", + help="only execute test whose tag match the current pattern", + ) if DJANGO_FOUND: - parser.add_option('-J', '--django', dest='django', default=False, - action="store_true", - help='use logilab-pytest for django test cases') + parser.add_option( + "-J", + "--django", + dest="django", + default=False, + action="store_true", + help="use logilab-pytest for django test cases", + ) return parser @@ -617,7 +674,7 @@ def parseargs(parser): """ # parse the command line options, args = parser.parse_args() - filenames = [arg for arg in args if arg.endswith('.py')] + filenames = [arg for arg in args if arg.endswith(".py")] if filenames: if len(filenames) > 1: parser.error("only one filename is acceptable") @@ -629,7 +686,7 @@ def parseargs(parser): testlib.ENABLE_DBC = options.dbc newargs = parser.newargs if options.skipped: - newargs.extend(['--skip', options.skipped]) + newargs.extend(["--skip", options.skipped]) # restart implies exitfirst if options.restart: options.exitfirst = True @@ -639,8 +696,7 @@ def parseargs(parser): return options, explicitfile - -@deprecated('[logilab-common 1.3] logilab-pytest is deprecated, use another test runner') +@deprecated("[logilab-common 1.3] logilab-pytest is deprecated, use another test runner") def run(): parser = make_parser() rootdir, testercls = project_root(parser) @@ -648,8 +704,8 @@ def run(): # mock a new command line sys.argv[1:] = parser.newargs cvg = None - if not '' in sys.path: - sys.path.insert(0, '') + if not "" in sys.path: + sys.path.insert(0, "") if DJANGO_FOUND and options.django: tester = DjangoTester(cvg, options) else: @@ -664,21 +720,24 @@ def run(): try: if options.profile: import hotshot + prof = hotshot.Profile(options.profile) prof.runcall(cmd, *args) prof.close() - print('profile data saved in', options.profile) + print("profile data saved in", options.profile) else: cmd(*args) except SystemExit: raise except: import traceback + traceback.print_exc() finally: tester.show_report() sys.exit(tester.errcode) + class SkipAwareTestProgram(unittest.TestProgram): # XXX: don't try to stay close to unittest.py, use optparse USAGE = """\ @@ -705,15 +764,23 @@ Examples: %(progName)s MyTestCase - run all 'test*' test methods in MyTestCase """ - def __init__(self, module='__main__', defaultTest=None, batchmode=False, - cvg=None, options=None, outstream=sys.stderr): + + def __init__( + self, + module="__main__", + defaultTest=None, + batchmode=False, + cvg=None, + options=None, + outstream=sys.stderr, + ): self.batchmode = batchmode self.cvg = cvg self.options = options self.outstream = outstream super(SkipAwareTestProgram, self).__init__( - module=module, defaultTest=defaultTest, - testLoader=NonStrictTestLoader()) + module=module, defaultTest=defaultTest, testLoader=NonStrictTestLoader() + ) def parseArgs(self, argv): self.pdbmode = False @@ -724,40 +791,51 @@ Examples: self.colorize = False self.profile_name = None import getopt + try: - options, args = getopt.getopt(argv[1:], 'hHvixrqcp:s:m:P:', - ['help', 'verbose', 'quiet', 'pdb', - 'exitfirst', 'restart', - 'skip=', 'color', 'match=', 'profile=']) + options, args = getopt.getopt( + argv[1:], + "hHvixrqcp:s:m:P:", + [ + "help", + "verbose", + "quiet", + "pdb", + "exitfirst", + "restart", + "skip=", + "color", + "match=", + "profile=", + ], + ) for opt, value in options: - if opt in ('-h', '-H', '--help'): + if opt in ("-h", "-H", "--help"): self.usageExit() - if opt in ('-i', '--pdb'): + if opt in ("-i", "--pdb"): self.pdbmode = True - if opt in ('-x', '--exitfirst'): + if opt in ("-x", "--exitfirst"): self.exitfirst = True - if opt in ('-r', '--restart'): + if opt in ("-r", "--restart"): self.restart = True self.exitfirst = True - if opt in ('-q', '--quiet'): + if opt in ("-q", "--quiet"): self.verbosity = 0 - if opt in ('-v', '--verbose'): + if opt in ("-v", "--verbose"): self.verbosity = 2 - if opt in ('-s', '--skip'): - self.skipped_patterns = [pat.strip() for pat in - value.split(', ')] - if opt == '--color': + if opt in ("-s", "--skip"): + self.skipped_patterns = [pat.strip() for pat in value.split(", ")] + if opt == "--color": self.colorize = True - if opt in ('-m', '--match'): - #self.tags_pattern = value + if opt in ("-m", "--match"): + # self.tags_pattern = value self.options["tag_pattern"] = value - if opt in ('-P', '--profile'): + if opt in ("-P", "--profile"): self.profile_name = value self.testLoader.skipped_patterns = self.skipped_patterns if len(args) == 0 and self.defaultTest is None: - suitefunc = getattr(self.module, 'suite', None) - if isinstance(suitefunc, (types.FunctionType, - types.MethodType)): + suitefunc = getattr(self.module, "suite", None) + if isinstance(suitefunc, (types.FunctionType, types.MethodType)): self.test = self.module.suite() else: self.test = self.testLoader.loadTestsFromModule(self.module) @@ -766,7 +844,7 @@ Examples: self.test_pattern = args[0] self.testNames = args else: - self.testNames = (self.defaultTest, ) + self.testNames = (self.defaultTest,) self.createTests() except getopt.error as msg: self.usageExit(msg) @@ -774,21 +852,24 @@ Examples: def runTests(self): if self.profile_name: import cProfile - cProfile.runctx('self._runTests()', globals(), locals(), self.profile_name ) + + cProfile.runctx("self._runTests()", globals(), locals(), self.profile_name) else: return self._runTests() def _runTests(self): - self.testRunner = SkipAwareTextTestRunner(verbosity=self.verbosity, - stream=self.outstream, - exitfirst=self.exitfirst, - pdbmode=self.pdbmode, - cvg=self.cvg, - test_pattern=self.test_pattern, - skipped_patterns=self.skipped_patterns, - colorize=self.colorize, - batchmode=self.batchmode, - options=self.options) + self.testRunner = SkipAwareTextTestRunner( + verbosity=self.verbosity, + stream=self.outstream, + exitfirst=self.exitfirst, + pdbmode=self.pdbmode, + cvg=self.cvg, + test_pattern=self.test_pattern, + skipped_patterns=self.skipped_patterns, + colorize=self.colorize, + batchmode=self.batchmode, + options=self.options, + ) def removeSucceededTests(obj, succTests): """ Recursive function that removes succTests from @@ -801,32 +882,33 @@ Examples: if isinstance(el, unittest.TestSuite): removeSucceededTests(el, succTests) elif isinstance(el, unittest.TestCase): - descr = '.'.join((el.__class__.__module__, - el.__class__.__name__, - el._testMethodName)) + descr = ".".join( + (el.__class__.__module__, el.__class__.__name__, el._testMethodName) + ) if descr in succTests: obj.remove(el) + # take care, self.options may be None - if getattr(self.options, 'restart', False): + if getattr(self.options, "restart", False): # retrieve succeeded tests from FILE_RESTART try: - restartfile = open(FILE_RESTART, 'r') + restartfile = open(FILE_RESTART, "r") try: - succeededtests = list(elem.rstrip('\n\r') for elem in - restartfile.readlines()) + succeededtests = list(elem.rstrip("\n\r") for elem in restartfile.readlines()) removeSucceededTests(self.test, succeededtests) finally: restartfile.close() except Exception as ex: - raise Exception("Error while reading succeeded tests into %s: %s" - % (osp.join(os.getcwd(), FILE_RESTART), ex)) + raise Exception( + "Error while reading succeeded tests into %s: %s" + % (osp.join(os.getcwd(), FILE_RESTART), ex) + ) result = self.testRunner.run(self.test) # help garbage collection: we want TestSuite, which hold refs to every # executed TestCase, to be gc'ed del self.test - if getattr(result, "debuggers", None) and \ - getattr(self, "pdbmode", None): + if getattr(result, "debuggers", None) and getattr(self, "pdbmode", None): start_interactive_mode(result) if not getattr(self, "batchmode", None): sys.exit(not result.wasSuccessful()) @@ -834,13 +916,20 @@ Examples: class SkipAwareTextTestRunner(unittest.TextTestRunner): - - def __init__(self, stream=sys.stderr, verbosity=1, - exitfirst=False, pdbmode=False, cvg=None, test_pattern=None, - skipped_patterns=(), colorize=False, batchmode=False, - options=None): - super(SkipAwareTextTestRunner, self).__init__(stream=stream, - verbosity=verbosity) + def __init__( + self, + stream=sys.stderr, + verbosity=1, + exitfirst=False, + pdbmode=False, + cvg=None, + test_pattern=None, + skipped_patterns=(), + colorize=False, + batchmode=False, + options=None, + ): + super(SkipAwareTextTestRunner, self).__init__(stream=stream, verbosity=verbosity) self.exitfirst = exitfirst self.pdbmode = pdbmode self.cvg = cvg @@ -859,23 +948,23 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner): else: if isinstance(test, testlib.TestCase): meth = test._get_test_method() - testname = '%s.%s' % (test.__name__, meth.__name__) + testname = "%s.%s" % (test.__name__, meth.__name__) elif isinstance(test, types.FunctionType): func = test testname = func.__name__ elif isinstance(test, types.MethodType): cls = test.__self__.__class__ - testname = '%s.%s' % (cls.__name__, test.__name__) + testname = "%s.%s" % (cls.__name__, test.__name__) else: - return True # Not sure when this happens + return True # Not sure when this happens if isgeneratorfunction(test) and skipgenerator: - return self.does_match_tags(test) # Let inner tests decide at run time + return self.does_match_tags(test) # Let inner tests decide at run time if self._this_is_skipped(testname): - return False # this was explicitly skipped + return False # this was explicitly skipped if self.test_pattern is not None: try: - classpattern, testpattern = self.test_pattern.split('.') - klass, name = testname.split('.') + classpattern, testpattern = self.test_pattern.split(".") + klass, name = testname.split(".") if classpattern not in klass or testpattern not in name: return False except ValueError: @@ -886,18 +975,24 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner): def does_match_tags(self, test: Callable) -> bool: if self.options is not None: - tags_pattern = getattr(self.options, 'tags_pattern', None) + tags_pattern = getattr(self.options, "tags_pattern", None) if tags_pattern is not None: - tags = getattr(test, 'tags', testlib.Tags()) + tags = getattr(test, "tags", testlib.Tags()) if tags.inherit and isinstance(test, types.MethodType): - tags = tags | getattr(test.__self__.__class__, 'tags', testlib.Tags()) + tags = tags | getattr(test.__self__.__class__, "tags", testlib.Tags()) return tags.match(tags_pattern) - return True # no pattern - - def _makeResult(self) -> 'SkipAwareTestResult': - return SkipAwareTestResult(self.stream, self.descriptions, - self.verbosity, self.exitfirst, - self.pdbmode, self.cvg, self.colorize) + return True # no pattern + + def _makeResult(self) -> "SkipAwareTestResult": + return SkipAwareTestResult( + self.stream, + self.descriptions, + self.verbosity, + self.exitfirst, + self.pdbmode, + self.cvg, + self.colorize, + ) def run(self, test): "Run the given test case or test suite." @@ -910,43 +1005,48 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner): if not self.batchmode: self.stream.writeln(result.separator2) run = result.testsRun - self.stream.writeln("Ran %d test%s in %.3fs" % - (run, run != 1 and "s" or "", timeTaken)) + self.stream.writeln("Ran %d test%s in %.3fs" % (run, run != 1 and "s" or "", timeTaken)) self.stream.writeln() if not result.wasSuccessful(): if self.colorize: - self.stream.write(textutils.colorize_ansi("FAILED", color='red')) + self.stream.write(textutils.colorize_ansi("FAILED", color="red")) else: self.stream.write("FAILED") else: if self.colorize: - self.stream.write(textutils.colorize_ansi("OK", color='green')) + self.stream.write(textutils.colorize_ansi("OK", color="green")) else: self.stream.write("OK") - failed, errored, skipped = map(len, (result.failures, - result.errors, - result.skipped)) + failed, errored, skipped = map(len, (result.failures, result.errors, result.skipped)) det_results = [] - for name, value in (("failures", result.failures), - ("errors",result.errors), - ("skipped", result.skipped)): + for name, value in ( + ("failures", result.failures), + ("errors", result.errors), + ("skipped", result.skipped), + ): if value: det_results.append("%s=%i" % (name, len(value))) if det_results: self.stream.write(" (") - self.stream.write(', '.join(det_results)) + self.stream.write(", ".join(det_results)) self.stream.write(")") self.stream.writeln("") return result class SkipAwareTestResult(unittest._TextTestResult): - - def __init__(self, stream: _WritelnDecorator, descriptions: bool, verbosity: int, - exitfirst: bool = False, pdbmode: bool = False, cvg: Optional[Any] = None, colorize: bool = False) -> None: - super(SkipAwareTestResult, self).__init__(stream, - descriptions, verbosity) + def __init__( + self, + stream: _WritelnDecorator, + descriptions: bool, + verbosity: int, + exitfirst: bool = False, + pdbmode: bool = False, + cvg: Optional[Any] = None, + colorize: bool = False, + ) -> None: + super(SkipAwareTestResult, self).__init__(stream, descriptions, verbosity) self.skipped: List[Tuple[Any, Any]] = [] self.debuggers: List = [] self.fail_descrs: List = [] @@ -959,10 +1059,10 @@ class SkipAwareTestResult(unittest._TextTestResult): self.verbose = verbosity > 1 def descrs_for(self, flavour: str) -> List[Tuple[int, str]]: - return getattr(self, '%s_descrs' % flavour.lower()) + return getattr(self, "%s_descrs" % flavour.lower()) def _create_pdb(self, test_descr: str, flavour: str) -> None: - self.descrs_for(flavour).append( (len(self.debuggers), test_descr) ) + self.descrs_for(flavour).append((len(self.debuggers), test_descr)) if self.pdbmode: self.debuggers.append(self.pdbclass(sys.exc_info()[2])) @@ -982,34 +1082,34 @@ class SkipAwareTestResult(unittest._TextTestResult): --verbose is passed """ exctype, exc, tb = err - output = ['Traceback (most recent call last)'] + output = ["Traceback (most recent call last)"] frames = inspect.getinnerframes(tb) colorize = self.colorize frames = enumerate(self._iter_valid_frames(frames)) for index, (frame, filename, lineno, funcname, ctx, ctxindex) in frames: filename = osp.abspath(filename) - if ctx is None: # pyc files or C extensions for instance - source = '<no source available>' + if ctx is None: # pyc files or C extensions for instance + source = "<no source available>" else: - source = ''.join(ctx) + source = "".join(ctx) if colorize: - filename = textutils.colorize_ansi(filename, 'magenta') + filename = textutils.colorize_ansi(filename, "magenta") source = colorize_source(source) output.append(' File "%s", line %s, in %s' % (filename, lineno, funcname)) - output.append(' %s' % source.strip()) + output.append(" %s" % source.strip()) if self.verbose: - output.append('%r == %r' % (dir(frame), test.__module__)) - output.append('') - output.append(' ' + ' local variables '.center(66, '-')) + output.append("%r == %r" % (dir(frame), test.__module__)) + output.append("") + output.append(" " + " local variables ".center(66, "-")) for varname, value in sorted(frame.f_locals.items()): - output.append(' %s: %r' % (varname, value)) - if varname == 'self': # special handy processing for self + output.append(" %s: %r" % (varname, value)) + if varname == "self": # special handy processing for self for varname, value in sorted(vars(value).items()): - output.append(' self.%s: %r' % (varname, value)) - output.append(' ' + '-' * 66) - output.append('') - output.append(''.join(traceback.format_exception_only(exctype, exc))) - return '\n'.join(output) + output.append(" self.%s: %r" % (varname, value)) + output.append(" " + "-" * 66) + output.append("") + output.append("".join(traceback.format_exception_only(exctype, exc))) + return "\n".join(output) def addError(self, test, err): """err -> (exc_type, exc, tcbk)""" @@ -1022,21 +1122,21 @@ class SkipAwareTestResult(unittest._TextTestResult): self.shouldStop = True descr = self.getDescription(test) super(SkipAwareTestResult, self).addError(test, err) - self._create_pdb(descr, 'error') + self._create_pdb(descr, "error") def addFailure(self, test, err): if self.exitfirst: self.shouldStop = True descr = self.getDescription(test) super(SkipAwareTestResult, self).addFailure(test, err) - self._create_pdb(descr, 'fail') + self._create_pdb(descr, "fail") def addSkip(self, test, reason): self.skipped.append((test, reason)) if self.showAll: self.stream.writeln("SKIPPED") elif self.dots: - self.stream.write('S') + self.stream.write("S") def printErrors(self) -> None: super(SkipAwareTestResult, self).printErrors() @@ -1047,7 +1147,7 @@ class SkipAwareTestResult(unittest._TextTestResult): for test, err in self.skipped: descr = self.getDescription(test) self.stream.writeln(self.separator1) - self.stream.writeln("%s: %s" % ('SKIPPED', descr)) + self.stream.writeln("%s: %s" % ("SKIPPED", descr)) self.stream.writeln("\t%s" % err) def printErrorList(self, flavour, errors): @@ -1056,32 +1156,42 @@ class SkipAwareTestResult(unittest._TextTestResult): self.stream.writeln("%s: %s" % (flavour, descr)) self.stream.writeln(self.separator2) self.stream.writeln(err) - self.stream.writeln('no stdout'.center(len(self.separator2))) - self.stream.writeln('no stderr'.center(len(self.separator2))) + self.stream.writeln("no stdout".center(len(self.separator2))) + self.stream.writeln("no stderr".center(len(self.separator2))) from .decorators import monkeypatch + orig_call = testlib.TestCase.__call__ -@monkeypatch(testlib.TestCase, '__call__') -def call(self: Any, result: SkipAwareTestResult = None, runcondition: Optional[Callable] = None, options: Optional[Any] = None) -> None: + + +@monkeypatch(testlib.TestCase, "__call__") +def call( + self: Any, + result: SkipAwareTestResult = None, + runcondition: Optional[Callable] = None, + options: Optional[Any] = None, +) -> None: orig_call(self, result=result, runcondition=runcondition, options=options) # mypy: Item "None" of "Optional[Any]" has no attribute "exitfirst" # we check it first in the if if hasattr(options, "exitfirst") and options.exitfirst: # type: ignore # add this test to restart file try: - restartfile = open(FILE_RESTART, 'a') + restartfile = open(FILE_RESTART, "a") try: - descr = '.'.join((self.__class__.__module__, - self.__class__.__name__, - self._testMethodName)) - restartfile.write(descr+os.linesep) + descr = ".".join( + (self.__class__.__module__, self.__class__.__name__, self._testMethodName) + ) + restartfile.write(descr + os.linesep) finally: restartfile.close() except Exception: - print("Error while saving succeeded test into", - osp.join(os.getcwd(), FILE_RESTART), - file=sys.__stderr__) + print( + "Error while saving succeeded test into", + osp.join(os.getcwd(), FILE_RESTART), + file=sys.__stderr__, + ) raise @@ -1129,7 +1239,7 @@ class NonStrictTestLoader(unittest.TestLoader): for obj in vars(module).values(): if isclass(obj) and issubclass(obj, unittest.TestCase): classname = obj.__name__ - if classname[0] == '_' or self._this_is_skipped(classname): + if classname[0] == "_" or self._this_is_skipped(classname): continue methodnames = [] # obj is a TestCase class @@ -1147,14 +1257,16 @@ class NonStrictTestLoader(unittest.TestLoader): suite = getattr(module, suitename)() except AttributeError: return [] - assert hasattr(suite, '_tests'), \ - "%s.%s is not a valid TestSuite" % (module.__name__, suitename) + assert hasattr(suite, "_tests"), "%s.%s is not a valid TestSuite" % ( + module.__name__, + suitename, + ) # python2.3 does not implement __iter__ on suites, we need to return # _tests explicitly return suite._tests def loadTestsFromName(self, name, module=None): - parts = name.split('.') + parts = name.split(".") if module is None or len(parts) > 2: # let the base class do its job here return [super(NonStrictTestLoader, self).loadTestsFromName(name)] @@ -1162,34 +1274,35 @@ class NonStrictTestLoader(unittest.TestLoader): collected = [] if len(parts) == 1: pattern = parts[0] - if callable(getattr(module, pattern, None) - ) and pattern not in tests: + if callable(getattr(module, pattern, None)) and pattern not in tests: # consider it as a suite return self.loadTestsFromSuite(module, pattern) if pattern in tests: # case python unittest_foo.py MyTestTC klass, methodnames = tests[pattern] for methodname in methodnames: - collected = [klass(methodname) - for methodname in methodnames] + collected = [klass(methodname) for methodname in methodnames] else: # case python unittest_foo.py something for klass, methodnames in tests.values(): # skip methodname if matched by skipped_patterns for skip_pattern in self.skipped_patterns: - methodnames = [methodname - for methodname in methodnames - if skip_pattern not in methodname] - collected += [klass(methodname) - for methodname in methodnames - if pattern in methodname] + methodnames = [ + methodname + for methodname in methodnames + if skip_pattern not in methodname + ] + collected += [ + klass(methodname) for methodname in methodnames if pattern in methodname + ] elif len(parts) == 2: # case "MyClass.test_1" classname, pattern = parts klass, methodnames = tests.get(classname, (None, [])) for methodname in methodnames: - collected = [klass(methodname) for methodname in methodnames - if pattern in methodname] + collected = [ + klass(methodname) for methodname in methodnames if pattern in methodname + ] return collected def _this_is_skipped(self, testedname: str) -> bool: @@ -1202,10 +1315,9 @@ class NonStrictTestLoader(unittest.TestLoader): """ is_skipped = self._this_is_skipped classname = testCaseClass.__name__ - if classname[0] == '_' or is_skipped(classname): + if classname[0] == "_" or is_skipped(classname): return [] - testnames = super(NonStrictTestLoader, self).getTestCaseNames( - testCaseClass) + testnames = super(NonStrictTestLoader, self).getTestCaseNames(testCaseClass) return [testname for testname in testnames if not is_skipped(testname)] @@ -1214,13 +1326,27 @@ class NonStrictTestLoader(unittest.TestLoader): # It is used to monkeypatch the original implementation to support # extra runcondition and options arguments (see in testlib.py) -def _ts_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult: + +def _ts_run( + self: Any, + result: SkipAwareTestResult, + debug: bool = False, + runcondition: Callable = None, + options: Optional[Any] = None, +) -> SkipAwareTestResult: self._wrapped_run(result, runcondition=runcondition, options=options) self._tearDownPreviousClass(None, result) self._handleModuleTearDown(result) return result -def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult: + +def _ts_wrapped_run( + self: Any, + result: SkipAwareTestResult, + debug: bool = False, + runcondition: Callable = None, + options: Optional[Any] = None, +) -> SkipAwareTestResult: for test in self: if result.shouldStop: break @@ -1229,8 +1355,9 @@ def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False, self._handleModuleFixture(test, result) self._handleClassSetUp(test, result) result._previousTestClass = test.__class__ - if (getattr(test.__class__, '_classSetupFailed', False) or - getattr(result, '_moduleSetUpFailed', False)): + if getattr(test.__class__, "_classSetupFailed", False) or getattr( + result, "_moduleSetUpFailed", False + ): continue # --- modifications to deal with _wrapped_run --- @@ -1240,7 +1367,7 @@ def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False, # test(result) # else: # test.debug() - if hasattr(test, '_wrapped_run'): + if hasattr(test, "_wrapped_run"): try: test._wrapped_run(result, debug, runcondition=runcondition, options=options) except TypeError: @@ -1255,13 +1382,20 @@ def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False, # --- end of modifications to deal with _wrapped_run --- return result + if sys.version_info >= (2, 7): # The function below implements a modified version of the # TestSuite.run method that is provided with python 2.7, in # unittest/suite.py - def _ts_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult: + def _ts_run( + self: Any, + result: SkipAwareTestResult, + debug: bool = False, + runcondition: Callable = None, + options: Optional[Any] = None, + ) -> SkipAwareTestResult: topLevel = False - if getattr(result, '_testRunEntered', False) is False: + if getattr(result, "_testRunEntered", False) is False: result._testRunEntered = topLevel = True self._wrapped_run(result, debug, runcondition, options) @@ -1287,8 +1421,7 @@ def enable_dbc(*args): from logilab.aspects.weaver import weaver from logilab.aspects.lib.contracts import ContractAspect except ImportError: - sys.stderr.write( - 'Warning: logilab.aspects is not available. Contracts disabled.') + sys.stderr.write("Warning: logilab.aspects is not available. Contracts disabled.") return False for arg in args: weaver.weave_module(arg, ContractAspect) @@ -1304,13 +1437,12 @@ unittest.TestProgram = SkipAwareTestProgram if sys.version_info >= (2, 4): doctest.DocTestCase.__bases__ = (testlib.TestCase,) # XXX check python2.6 compatibility - #doctest.DocTestCase._cleanups = [] - #doctest.DocTestCase._out = [] + # doctest.DocTestCase._cleanups = [] + # doctest.DocTestCase._out = [] else: unittest.FunctionTestCase.__bases__ = (testlib.TestCase,) unittest.TestSuite.run = _ts_run unittest.TestSuite._wrapped_run = _ts_wrapped_run -if __name__ == '__main__': +if __name__ == "__main__": run() - diff --git a/logilab/common/registry.py b/logilab/common/registry.py index d9ae11b..83f4703 100644 --- a/logilab/common/registry.py +++ b/logilab/common/registry.py @@ -105,6 +105,7 @@ from logilab.common.deprecation import deprecated # selector base classes and operations ######################################## + def objectify_predicate(selector_func: Callable) -> Any: """Most of the time, a simple score function is enough to build a selector. The :func:`objectify_predicate` decorator turn it into a proper selector @@ -118,22 +119,29 @@ def objectify_predicate(selector_func: Callable) -> Any: __select__ = View.__select__ & one() """ - return type(selector_func.__name__, (Predicate,), - {'__doc__': selector_func.__doc__, - '__call__': lambda self, *a, **kw: selector_func(*a, **kw)}) + return type( + selector_func.__name__, + (Predicate,), + { + "__doc__": selector_func.__doc__, + "__call__": lambda self, *a, **kw: selector_func(*a, **kw), + }, + ) _PREDICATES: Dict[int, Type] = {} + def wrap_predicates(decorator: Callable) -> None: for predicate in _PREDICATES.values(): - if not '_decorators' in predicate.__dict__: + if not "_decorators" in predicate.__dict__: predicate._decorators = set() if decorator in predicate._decorators: continue predicate._decorators.add(decorator) predicate.__call__ = decorator(predicate.__call__) + class PredicateMetaClass(type): def __new__(mcs, *args, **kwargs): # use __new__ so subclasses doesn't have to call Predicate.__init__ @@ -164,36 +172,37 @@ class Predicate(object, metaclass=PredicateMetaClass): # backward compatibility return self.__class__.__name__ - def search_selector(self, selector: 'Predicate') -> Optional['Predicate']: + def search_selector(self, selector: "Predicate") -> Optional["Predicate"]: """search for the given selector, selector instance or tuple of selectors in the selectors tree. Return None if not found. """ if self is selector: return self - if (isinstance(selector, type) or isinstance(selector, tuple)) and \ - isinstance(self, selector): + if (isinstance(selector, type) or isinstance(selector, tuple)) and isinstance( + self, selector + ): return self return None def __str__(self): return self.__class__.__name__ - def __and__(self, other: 'Predicate') -> 'AndPredicate': + def __and__(self, other: "Predicate") -> "AndPredicate": return AndPredicate(self, other) - def __rand__(self, other: 'Predicate') -> 'AndPredicate': + def __rand__(self, other: "Predicate") -> "AndPredicate": return AndPredicate(other, self) - def __iand__(self, other: 'Predicate') -> 'AndPredicate': + def __iand__(self, other: "Predicate") -> "AndPredicate": return AndPredicate(self, other) - def __or__(self, other: 'Predicate') -> 'OrPredicate': + def __or__(self, other: "Predicate") -> "OrPredicate": return OrPredicate(self, other) - def __ror__(self, other: 'Predicate'): + def __ror__(self, other: "Predicate"): return OrPredicate(other, self) - def __ior__(self, other: 'Predicate') -> 'OrPredicate': + def __ior__(self, other: "Predicate") -> "OrPredicate": return OrPredicate(self, other) def __invert__(self): @@ -202,11 +211,12 @@ class Predicate(object, metaclass=PredicateMetaClass): # XXX (function | function) or (function & function) not managed yet def __call__(self, cls, *args, **kwargs): - return NotImplementedError("selector %s must implement its logic " - "in its __call__ method" % self.__class__) + return NotImplementedError( + "selector %s must implement its logic " "in its __call__ method" % self.__class__ + ) def __repr__(self): - return u'<Predicate %s at %x>' % (self.__class__.__name__, id(self)) + return "<Predicate %s at %x>" % (self.__class__.__name__, id(self)) class MultiPredicate(Predicate): @@ -216,8 +226,7 @@ class MultiPredicate(Predicate): self.selectors = self.merge_selectors(selectors) def __str__(self): - return '%s(%s)' % (self.__class__.__name__, - ','.join(str(s) for s in self.selectors)) + return "%s(%s)" % (self.__class__.__name__, ",".join(str(s) for s in self.selectors)) @classmethod def merge_selectors(cls, selectors: Sequence[Predicate]) -> List[Predicate]: @@ -258,6 +267,7 @@ class MultiPredicate(Predicate): class AndPredicate(MultiPredicate): """and-chained selectors""" + def __call__(self, cls: Optional[Any], *args: Any, **kwargs: Any) -> int: score = 0 for selector in self.selectors: @@ -270,6 +280,7 @@ class AndPredicate(MultiPredicate): class OrPredicate(MultiPredicate): """or-chained selectors""" + def __call__(self, cls: Optional[Any], *args: Any, **kwargs: Any) -> int: for selector in self.selectors: partscore = selector(cls, *args, **kwargs) @@ -277,8 +288,10 @@ class OrPredicate(MultiPredicate): return partscore return 0 + class NotPredicate(Predicate): """negation selector""" + def __init__(self, selector): self.selector = selector @@ -287,10 +300,10 @@ class NotPredicate(Predicate): return int(not score) def __str__(self): - return 'NOT(%s)' % self.selector + return "NOT(%s)" % self.selector -class yes(Predicate): # pylint: disable=C0103 +class yes(Predicate): # pylint: disable=C0103 """Return the score given as parameter, with a default score of 0.5 so any other selector take precedence. @@ -299,6 +312,7 @@ class yes(Predicate): # pylint: disable=C0103 Take care, `yes(0)` could be named 'no'... """ + def __init__(self, score: float = 0.5) -> None: self.score = score @@ -308,39 +322,50 @@ class yes(Predicate): # pylint: disable=C0103 # deprecated stuff ############################################################# -@deprecated('[lgc 0.59] use Registry.objid class method instead') + +@deprecated("[lgc 0.59] use Registry.objid class method instead") def classid(cls): - return '%s.%s' % (cls.__module__, cls.__name__) + return "%s.%s" % (cls.__module__, cls.__name__) -@deprecated('[lgc 0.59] use obj_registries function instead') + +@deprecated("[lgc 0.59] use obj_registries function instead") def class_registries(cls, registryname): return obj_registries(cls, registryname) + class RegistryException(Exception): """Base class for registry exception.""" + class RegistryNotFound(RegistryException): """Raised when an unknown registry is requested. This is usually a programming/typo error. """ + class ObjectNotFound(RegistryException): """Raised when an unregistered object is requested. This may be a programming/typo or a misconfiguration error. """ + class NoSelectableObject(RegistryException): """Raised when no object is selectable for a given context.""" + def __init__(self, args, kwargs, objects): self.args = args self.kwargs = kwargs self.objects = objects def __str__(self): - return ('args: %s, kwargs: %s\ncandidates: %s' - % (self.args, self.kwargs.keys(), self.objects)) + return "args: %s, kwargs: %s\ncandidates: %s" % ( + self.args, + self.kwargs.keys(), + self.objects, + ) + class SelectAmbiguity(RegistryException): """Raised when several objects compete at selection time with an equal @@ -362,12 +387,14 @@ def _modname_from_path(path: str, extrapath: Optional[Any] = None) -> str: # from package.__init__ import something # # which seems quite correct. - if modpath[-1] == '__init__': + if modpath[-1] == "__init__": modpath.pop() - return '.'.join(modpath) + return ".".join(modpath) -def _toload_info(path: List[str], extrapath: Optional[Any], _toload: Optional[Tuple[Dict[str, str], List]] = None) -> Tuple[Dict[str, str], List[Tuple[str, str]]]: +def _toload_info( + path: List[str], extrapath: Optional[Any], _toload: Optional[Tuple[Dict[str, str], List]] = None +) -> Tuple[Dict[str, str], List[Tuple[str, str]]]: """Return a dictionary of <modname>: <modpath> and an ordered list of (file, module name) to load """ @@ -376,12 +403,12 @@ def _toload_info(path: List[str], extrapath: Optional[Any], _toload: Optional[Tu _toload = {}, [] for fileordir in path: - if isdir(fileordir) and exists(join(fileordir, '__init__.py')): + if isdir(fileordir) and exists(join(fileordir, "__init__.py")): subfiles = [join(fileordir, fname) for fname in listdir(fileordir)] _toload_info(subfiles, extrapath, _toload) - elif fileordir[-3:] == '.py': + elif fileordir[-3:] == ".py": modname = _modname_from_path(fileordir, extrapath) _toload[0][modname] = fileordir @@ -417,7 +444,7 @@ class RegistrableObject(object): __registry__: Optional[str] = None __regid__: Optional[str] = None __select__: Union[None, str, Predicate] = None - __abstract__ = True # see doc snipppets below (in Registry class) + __abstract__ = True # see doc snipppets below (in Registry class) @classproperty def __registries__(cls) -> Union[Tuple[str], Tuple]: @@ -435,12 +462,13 @@ class RegistrableInstance(RegistrableObject): """Add a __module__ attribute telling the module where the instance was created, for automatic registration. """ - module = kwargs.pop('__module__', None) + module = kwargs.pop("__module__", None) obj = super(RegistrableInstance, cls).__new__(cls) if module is None: - warn('instantiate {0} with ' - '__module__=__name__'.format(cls.__name__), - DeprecationWarning) + warn( + "instantiate {0} with " "__module__=__name__".format(cls.__name__), + DeprecationWarning, + ) # XXX subclass must no override __new__ filepath = tb.extract_stack(limit=2)[0][0] obj.__module__ = _modname_from_path(filepath) @@ -452,11 +480,19 @@ class RegistrableInstance(RegistrableObject): super(RegistrableInstance, self).__init__() -SelectBestReport = TypedDict("SelectBestReport", {"all_objects": List, "end_score": int, - "winners": List, - "winner": Optional[Any], "self": 'Registry', - "args": List, "kwargs": Dict, - "registry": 'Registry'}) +SelectBestReport = TypedDict( + "SelectBestReport", + { + "all_objects": List, + "end_score": int, + "winners": List, + "winner": Optional[Any], + "self": "Registry", + "args": List, + "kwargs": Dict, + "registry": "Registry", + }, +) class Registry(dict): @@ -492,6 +528,7 @@ class Registry(dict): .. automethod:: possible_objects .. automethod:: object_by_id """ + def __init__(self, debugmode: bool) -> None: super(Registry, self).__init__() self.debugmode = debugmode @@ -511,19 +548,19 @@ class Registry(dict): @classmethod def objid(cls, obj: Any) -> str: """returns a unique identifier for an object stored in the registry""" - return '%s.%s' % (obj.__module__, cls.objname(obj)) + return "%s.%s" % (obj.__module__, cls.objname(obj)) @classmethod def objname(cls, obj: Any) -> str: """returns a readable name for an object stored in the registry""" - return getattr(obj, '__name__', id(obj)) + return getattr(obj, "__name__", id(obj)) def initialization_completed(self) -> None: """call method __registered__() on registered objects when the callback is defined""" for objects in self.values(): for objectcls in objects: - registered = getattr(objectcls, '__registered__', None) + registered = getattr(objectcls, "__registered__", None) if registered: registered(self) if self.debugmode: @@ -531,16 +568,17 @@ class Registry(dict): def register(self, obj: Any, oid: Optional[Any] = None, clear: bool = False) -> None: """base method to add an object in the registry""" - assert not '__abstract__' in obj.__dict__, obj + assert not "__abstract__" in obj.__dict__, obj assert obj.__select__, obj oid = oid or obj.__regid__ - assert oid, ('no explicit name supplied to register object %s, ' - 'which has no __regid__ set' % obj) + assert oid, ( + "no explicit name supplied to register object %s, " "which has no __regid__ set" % obj + ) if clear: - objects = self[oid] = [] + objects = self[oid] = [] else: objects = self.setdefault(oid, []) - assert not obj in objects, 'object %s is already registered' % obj + assert not obj in objects, "object %s is already registered" % obj objects.append(obj) def register_and_replace(self, obj, replaced): @@ -551,15 +589,14 @@ class Registry(dict): if not isinstance(replaced, str): replaced = self.objid(replaced) # prevent from misspelling - assert obj is not replaced, 'replacing an object by itself: %s' % obj + assert obj is not replaced, "replacing an object by itself: %s" % obj registered_objs = self.get(obj.__regid__, ()) for index, registered in enumerate(registered_objs): if self.objid(registered) == replaced: del registered_objs[index] break else: - self.warning('trying to replace %s that is not registered with %s', - replaced, obj) + self.warning("trying to replace %s that is not registered with %s", replaced, obj) self.register(obj) def unregister(self, obj): @@ -573,8 +610,7 @@ class Registry(dict): self[oid].remove(registered) break else: - self.warning('can\'t remove %s, no id %s in the registry', - objid, oid) + self.warning("can't remove %s, no id %s in the registry", objid, oid) def all_objects(self): """return a list containing all objects in this registry. @@ -608,9 +644,9 @@ class Registry(dict): raise :exc:`NoSelectableObject` if no object can be selected """ - obj = self._select_best(self[__oid], *args, **kwargs) + obj = self._select_best(self[__oid], *args, **kwargs) if obj is None: - raise NoSelectableObject(args, kwargs, self[__oid] ) + raise NoSelectableObject(args, kwargs, self[__oid]) return obj def select_or_none(self, __oid, *args, **kwargs): @@ -627,7 +663,7 @@ class Registry(dict): context """ for objects in self.values(): - obj = self._select_best(objects, *args, **kwargs) + obj = self._select_best(objects, *args, **kwargs) if obj is None: continue yield obj @@ -695,7 +731,7 @@ class Registry(dict): if len(winners) > 1: # log in production environement / test, error while debugging - msg = 'select ambiguity: %s\n(args: %s, kwargs: %s)' + msg = "select ambiguity: %s\n(args: %s, kwargs: %s)" if self.debugmode: # raise bare exception in debug mode @@ -903,8 +939,9 @@ class RegistryStore(dict): :meth:`~logilab.common.registry.RegistryStore.register_and_replace` for instance). """ - assert isinstance(modname, str), \ - 'modname expected to be a module name (ie string), got %r' % modname + assert isinstance(modname, str), ( + "modname expected to be a module name (ie string), got %r" % modname + ) for obj in objects: if self.is_registrable(obj) and obj.__module__ == modname and not obj in butclasses: if isinstance(obj, type): @@ -912,7 +949,13 @@ class RegistryStore(dict): else: self.register(obj) - def register(self, obj: Any, registryname: Optional[Any] = None, oid: Optional[Any] = None, clear: bool = False) -> None: + def register( + self, + obj: Any, + registryname: Optional[Any] = None, + oid: Optional[Any] = None, + clear: bool = False, + ) -> None: """register `obj` implementation into `registryname` or `obj.__registries__` if not specified, with identifier `oid` or `obj.__regid__` if not specified. @@ -920,12 +963,13 @@ class RegistryStore(dict): If `clear` is true, all objects with the same identifier will be previously unregistered. """ - assert not obj.__dict__.get('__abstract__'), obj + assert not obj.__dict__.get("__abstract__"), obj for registryname in obj_registries(obj, registryname): registry = self.setdefault(registryname) registry.register(obj, oid=oid, clear=clear) - self.debug("register %s in %s['%s']", - registry.objname(obj), registryname, oid or obj.__regid__) + self.debug( + "register %s in %s['%s']", registry.objname(obj), registryname, oid or obj.__regid__ + ) self._loadedmods.setdefault(obj.__module__, {})[registry.objid(obj)] = obj def unregister(self, obj, registryname=None): @@ -935,8 +979,9 @@ class RegistryStore(dict): for registryname in obj_registries(obj, registryname): registry = self[registryname] registry.unregister(obj) - self.debug("unregister %s from %s['%s']", - registry.objname(obj), registryname, obj.__regid__) + self.debug( + "unregister %s from %s['%s']", registry.objname(obj), registryname, obj.__regid__ + ) def register_and_replace(self, obj, replaced, registryname=None): """register `obj` object into `registryname` or @@ -947,13 +992,19 @@ class RegistryStore(dict): for registryname in obj_registries(obj, registryname): registry = self[registryname] registry.register_and_replace(obj, replaced) - self.debug("register %s in %s['%s'] instead of %s", - registry.objname(obj), registryname, obj.__regid__, - registry.objname(replaced)) + self.debug( + "register %s in %s['%s'] instead of %s", + registry.objname(obj), + registryname, + obj.__regid__, + registry.objname(replaced), + ) # initialization methods ################################################### - def init_registration(self, path: List[str], extrapath: Optional[Any] = None) -> List[Tuple[str, str]]: + def init_registration( + self, path: List[str], extrapath: Optional[Any] = None + ) -> List[Tuple[str, str]]: """reset registry and walk down path to return list of (path, name) file modules to be loaded""" # XXX make this private by renaming it to _init_registration ? @@ -966,7 +1017,7 @@ class RegistryStore(dict): self._loadedmods: Dict[str, Dict[str, type]] = {} return filemods - @deprecated('use register_modnames() instead') + @deprecated("use register_modnames() instead") def register_objects(self, path: List[str], extrapath: Optional[Any] = None) -> None: """register all objects found walking down <path>""" # load views from each directory in the instance's path @@ -988,7 +1039,7 @@ class RegistryStore(dict): # mypy: "Loader" has no attribute "get_filename" # the selected class has one filepath = loader.get_filename() # type: ignore - if filepath[-4:] in ('.pyc', '.pyo'): + if filepath[-4:] in (".pyc", ".pyo"): # The source file *must* exists filepath = filepath[:-1] self._toloadmods[modname] = filepath @@ -1008,8 +1059,7 @@ class RegistryStore(dict): return stat(filepath)[-2] except OSError: # this typically happens on emacs backup files (.#foo.py) - self.warning('Unable to load %s. It is likely to be a backup file', - filepath) + self.warning("Unable to load %s. It is likely to be a backup file", filepath) return None def is_reload_needed(self, path): @@ -1018,19 +1068,18 @@ class RegistryStore(dict): """ lastmodifs = self._lastmodifs for fileordir in path: - if isdir(fileordir) and exists(join(fileordir, '__init__.py')): - if self.is_reload_needed([join(fileordir, fname) - for fname in listdir(fileordir)]): + if isdir(fileordir) and exists(join(fileordir, "__init__.py")): + if self.is_reload_needed([join(fileordir, fname) for fname in listdir(fileordir)]): return True - elif fileordir[-3:] == '.py': + elif fileordir[-3:] == ".py": mdate = self._mdate(fileordir) if mdate is None: - continue # backup file, see _mdate implementation + continue # backup file, see _mdate implementation elif "flymake" in fileordir: # flymake + pylint in use, don't consider these they will corrupt the registry continue if fileordir not in lastmodifs or lastmodifs[fileordir] < mdate: - self.info('File %s changed since last visit', fileordir) + self.info("File %s changed since last visit", fileordir) return True return False @@ -1041,7 +1090,7 @@ class RegistryStore(dict): self._loadedmods[modname] = {} mdate = self._mdate(filepath) if mdate is None: - return # backup file, see _mdate implementation + return # backup file, see _mdate implementation elif "flymake" in filepath: # flymake + pylint in use, don't consider these they will corrupt the registry return @@ -1052,7 +1101,7 @@ class RegistryStore(dict): # load the module if sys.version_info < (3,) and not isinstance(modname, str): modname = str(modname) - module = __import__(modname, fromlist=modname.split('.')[:-1]) + module = __import__(modname, fromlist=modname.split(".")[:-1]) self.load_module(module) def load_module(self, module: ModuleType) -> None: @@ -1074,15 +1123,17 @@ class RegistryStore(dict): - object class needs to have registries and identifier properly set to a non empty string to be registered. """ - self.info('loading %s from %s', module.__name__, module.__file__) - if hasattr(module, 'registration_callback'): + self.info("loading %s from %s", module.__name__, module.__file__) + if hasattr(module, "registration_callback"): # mypy: Module has no attribute "registration_callback" # we check that before module.registration_callback(self) # type: ignore else: self.register_all(vars(module).values(), module.__name__) - def _load_ancestors_then_object(self, modname: str, objectcls: type, butclasses: Sequence[Any] = ()) -> None: + def _load_ancestors_then_object( + self, modname: str, objectcls: type, butclasses: Sequence[Any] = () + ) -> None: """handle class registration according to rules defined in :meth:`load_module` """ @@ -1103,7 +1154,7 @@ class RegistryStore(dict): self.load_file(self._toloadmods[objmodname], objmodname) return # ensure object hasn't been already processed - clsid = '%s.%s' % (modname, objectcls.__name__) + clsid = "%s.%s" % (modname, objectcls.__name__) if clsid in self._loadedmods[modname]: return self._loadedmods[modname][clsid] = objectcls @@ -1115,10 +1166,13 @@ class RegistryStore(dict): return # backward compat reg = self.setdefault(obj_registries(objectcls)[0]) - if reg.objname(objectcls)[0] == '_': - warn("[lgc 0.59] object whose name start with '_' won't be " - "skipped anymore at some point, use __abstract__ = True " - "instead (%s)" % objectcls, DeprecationWarning) + if reg.objname(objectcls)[0] == "_": + warn( + "[lgc 0.59] object whose name start with '_' won't be " + "skipped anymore at some point, use __abstract__ = True " + "instead (%s)" % objectcls, + DeprecationWarning, + ) return # register, finally self.register(objectcls) @@ -1133,9 +1187,11 @@ class RegistryStore(dict): if isinstance(obj, type): if not issubclass(obj, RegistrableObject): # ducktyping backward compat - if not (getattr(obj, '__registries__', None) - and getattr(obj, '__regid__', None) - and getattr(obj, '__select__', None)): + if not ( + getattr(obj, "__registries__", None) + and getattr(obj, "__regid__", None) + and getattr(obj, "__select__", None) + ): return False elif issubclass(obj, RegistrableInstance): return False @@ -1144,26 +1200,26 @@ class RegistryStore(dict): return False if not obj.__regid__: - return False # no regid + return False # no regid registries = obj.__registries__ if not registries: - return False # no registries + return False # no registries selector = obj.__select__ if not selector: - return False # no selector + return False # no selector - if obj.__dict__.get('__abstract__', False): + if obj.__dict__.get("__abstract__", False): return False # then detect potential problems that should be warned if not isinstance(registries, (tuple, list)): - cls.warning('%s has __registries__ which is not a list or tuple', obj) + cls.warning("%s has __registries__ which is not a list or tuple", obj) return False if not callable(selector): - cls.warning('%s has not callable __select__', obj) + cls.warning("%s has not callable __select__", obj) return False return True @@ -1174,32 +1230,37 @@ class RegistryStore(dict): # init logging -set_log_methods(RegistryStore, getLogger('registry.store')) -set_log_methods(Registry, getLogger('registry')) +set_log_methods(RegistryStore, getLogger("registry.store")) +set_log_methods(Registry, getLogger("registry")) # helpers for debugging selectors TRACED_OIDS = None + def _trace_selector(cls, selector, args, ret): vobj = args[0] - if TRACED_OIDS == 'all' or vobj.__regid__ in TRACED_OIDS: - print('%s -> %s for %s(%s)' % (cls, ret, vobj, vobj.__regid__)) + if TRACED_OIDS == "all" or vobj.__regid__ in TRACED_OIDS: + print("%s -> %s for %s(%s)" % (cls, ret, vobj, vobj.__regid__)) + def _lltrace(selector): """use this decorator on your predicates so they become traceable with :class:`traced_selection` """ + def traced(cls, *args, **kwargs): ret = selector(cls, *args, **kwargs) if TRACED_OIDS is not None: _trace_selector(cls, selector, args, ret) return ret + traced.__name__ = selector.__name__ traced.__doc__ = selector.__doc__ return traced -class traced_selection(object): # pylint: disable=C0103 + +class traced_selection(object): # pylint: disable=C0103 """ Typical usage is : @@ -1227,7 +1288,7 @@ class traced_selection(object): # pylint: disable=C0103 the `logilab.common.registry.Registry.select` method body. """ - def __init__(self, traced='all'): + def __init__(self, traced="all"): self.traced = traced def __enter__(self): diff --git a/logilab/common/shellutils.py b/logilab/common/shellutils.py index 2764723..557e45d 100644 --- a/logilab/common/shellutils.py +++ b/logilab/common/shellutils.py @@ -46,7 +46,6 @@ from logilab.common.deprecation import deprecated class tempdir(object): - def __enter__(self): self.path = tempfile.mkdtemp() return self.path @@ -82,7 +81,8 @@ def chown(path, login=None, group=None): try: uid = int(login) except ValueError: - import pwd # Platforms: Unix + import pwd # Platforms: Unix + uid = pwd.getpwnam(login).pw_uid if group is None: gid = -1 @@ -91,9 +91,11 @@ def chown(path, login=None, group=None): gid = int(group) except ValueError: import grp + gid = grp.getgrnam(group).gr_gid os.chown(path, uid, gid) + def mv(source, destination, _action=shutil.move): """A shell-like mv, supporting wildcards. """ @@ -106,14 +108,14 @@ def mv(source, destination, _action=shutil.move): try: source = sources[0] except IndexError: - raise OSError('No file matching %s' % source) + raise OSError("No file matching %s" % source) if isdir(destination) and exists(destination): destination = join(destination, basename(source)) try: _action(source, destination) except OSError as ex: - raise OSError('Unable to move %r to %r (%s)' % ( - source, destination, ex)) + raise OSError("Unable to move %r to %r (%s)" % (source, destination, ex)) + def rm(*files): """A shell-like rm, supporting wildcards. @@ -127,12 +129,19 @@ def rm(*files): else: os.remove(filename) + def cp(source, destination): """A shell-like cp, supporting wildcards. """ mv(source, destination, _action=shutil.copy) -def find(directory: str, exts: Union[Tuple[str, ...], str], exclude: bool = False, blacklist: Tuple[str, ...] = STD_BLACKLIST) -> List[str]: + +def find( + directory: str, + exts: Union[Tuple[str, ...], str], + exclude: bool = False, + blacklist: Tuple[str, ...] = STD_BLACKLIST, +) -> List[str]: """Recursively find files ending with the given extensions from the directory. :type directory: str @@ -160,17 +169,21 @@ def find(directory: str, exts: Union[Tuple[str, ...], str], exclude: bool = Fals if isinstance(exts, str): exts = (exts,) if exclude: + def match(filename: str, exts: Tuple[str, ...]) -> bool: for ext in exts: if filename.endswith(ext): return False return True + else: + def match(filename: str, exts: Tuple[str, ...]) -> bool: for ext in exts: if filename.endswith(ext): return True return False + files = [] for dirpath, dirnames, filenames in os.walk(directory): _handle_blacklist(blacklist, dirnames, filenames) @@ -182,7 +195,11 @@ def find(directory: str, exts: Union[Tuple[str, ...], str], exclude: bool = Fals return files -def globfind(directory: str, pattern: str, blacklist: Tuple[str, str, str, str, str, str, str, str] = STD_BLACKLIST) -> Iterator[str]: +def globfind( + directory: str, + pattern: str, + blacklist: Tuple[str, str, str, str, str, str, str, str] = STD_BLACKLIST, +) -> Iterator[str]: """Recursively finds files matching glob `pattern` under `directory`. This is an alternative to `logilab.common.shellutils.find`. @@ -209,21 +226,23 @@ def globfind(directory: str, pattern: str, blacklist: Tuple[str, str, str, str, for fname in fnmatch.filter(filenames, pattern): yield join(curdir, fname) + def unzip(archive, destdir): import zipfile + if not exists(destdir): os.mkdir(destdir) zfobj = zipfile.ZipFile(archive) for name in zfobj.namelist(): - if name.endswith('/'): + if name.endswith("/"): os.mkdir(join(destdir, name)) else: - outfile = open(join(destdir, name), 'wb') + outfile = open(join(destdir, name), "wb") outfile.write(zfobj.read(name)) outfile.close() -@deprecated('Use subprocess.Popen instead') +@deprecated("Use subprocess.Popen instead") class Execute: """This is a deadlock safe version of popen2 (no stdin), that returns an object with errorlevel, out and err. @@ -238,11 +257,13 @@ class Execute: class ProgressBar(object): """A simple text progression bar.""" - def __init__(self, nbops: int, size: int = 20, stream: StringIO = sys.stdout, title: str = '') -> None: + def __init__( + self, nbops: int, size: int = 20, stream: StringIO = sys.stdout, title: str = "" + ) -> None: if title: - self._fstr = '\r%s [%%-%ss]' % (title, int(size)) + self._fstr = "\r%s [%%-%ss]" % (title, int(size)) else: - self._fstr = '\r[%%-%ss]' % int(size) + self._fstr = "\r[%%-%ss]" % int(size) self._stream = stream self._total = nbops self._size = size @@ -280,42 +301,45 @@ class ProgressBar(object): else: self._current += offset - progress = int((float(self._current)/float(self._total))*self._size) + progress = int((float(self._current) / float(self._total)) * self._size) if progress > self._progress: self._progress = progress self.refresh() def refresh(self) -> None: """Refresh the progression bar display.""" - self._stream.write(self._fstr % ('=' * min(self._progress, self._size)) ) + self._stream.write(self._fstr % ("=" * min(self._progress, self._size))) if self._last_text_write_size or self._current_text: - template = ' %%-%is' % (self._last_text_write_size) + template = " %%-%is" % (self._last_text_write_size) text = self._current_text if text is None: - text = '' + text = "" self._stream.write(template % text) self._last_text_write_size = len(text.rstrip()) self._stream.flush() def finish(self): - self._stream.write('\n') + self._stream.write("\n") self._stream.flush() class DummyProgressBar(object): - __slots__ = ('text',) + __slots__ = ("text",) def refresh(self): pass + def update(self): pass + def finish(self): pass _MARKER = object() -class progress(object): + +class progress(object): def __init__(self, nbops=_MARKER, size=_MARKER, stream=_MARKER, title=_MARKER, enabled=True): self.nbops = nbops self.size = size @@ -326,26 +350,30 @@ class progress(object): def __enter__(self): if self.enabled: kwargs = {} - for attr in ('nbops', 'size', 'stream', 'title'): + for attr in ("nbops", "size", "stream", "title"): value = getattr(self, attr) if value is not _MARKER: kwargs[attr] = value self.pb = ProgressBar(**kwargs) else: - self.pb = DummyProgressBar() + self.pb = DummyProgressBar() return self.pb def __exit__(self, exc_type, exc_val, exc_tb): self.pb.finish() -class RawInput(object): - def __init__(self, input_function: Optional[Callable] = None, printer: Optional[Callable] = None, **kwargs: Any) -> None: - if 'input' in kwargs: - input_function = kwargs.pop('input') +class RawInput(object): + def __init__( + self, + input_function: Optional[Callable] = None, + printer: Optional[Callable] = None, + **kwargs: Any, + ) -> None: + if "input" in kwargs: + input_function = kwargs.pop("input") warnings.warn( - "'input' argument is deprecated," - "use 'input_function' instead", + "'input' argument is deprecated," "use 'input_function' instead", DeprecationWarning, ) self._input = input_function or input @@ -360,35 +388,36 @@ class RawInput(object): else: label = option[0].lower() if len(option) > 1: - label += '(%s)' % option[1:].lower() + label += "(%s)" % option[1:].lower() choices.append((option, label)) - prompt = "%s [%s]: " % (question, - '/'.join([opt[1] for opt in choices])) + prompt = "%s [%s]: " % (question, "/".join([opt[1] for opt in choices])) tries = 3 while tries > 0: answer = self._input(prompt).strip().lower() if not answer: return default - possible = [option for option, label in choices - if option.lower().startswith(answer)] + possible = [option for option, label in choices if option.lower().startswith(answer)] if len(possible) == 1: return possible[0] elif len(possible) == 0: - msg = '%s is not an option.' % answer + msg = "%s is not an option." % answer else: - msg = ('%s is an ambiguous answer, do you mean %s ?' % ( - answer, ' or '.join(possible))) + msg = "%s is an ambiguous answer, do you mean %s ?" % ( + answer, + " or ".join(possible), + ) if self._print: self._print(msg) else: print(msg) tries -= 1 - raise Exception('unable to get a sensible answer') + raise Exception("unable to get a sensible answer") def confirm(self, question: str, default_is_yes: bool = True) -> bool: - default = default_is_yes and 'y' or 'n' - answer = self.ask(question, ('y', 'n'), default) - return answer == 'y' + default = default_is_yes and "y" or "n" + answer = self.ask(question, ("y", "n"), default) + return answer == "y" + ASK = RawInput() @@ -398,15 +427,17 @@ def getlogin(): (man 3 getlogin) Another solution would be to use $LOGNAME, $USER or $USERNAME """ - if sys.platform != 'win32': - import pwd # Platforms: Unix + if sys.platform != "win32": + import pwd # Platforms: Unix + return pwd.getpwuid(os.getuid())[0] else: - return os.environ['USERNAME'] + return os.environ["USERNAME"] + def generate_password(length=8, vocab=string.ascii_letters + string.digits): """dumb password generation function""" - pwd = '' + pwd = "" for i in range(length): pwd += random.choice(vocab) return pwd diff --git a/logilab/common/sphinx_ext.py b/logilab/common/sphinx_ext.py index a24608c..4ca30f7 100644 --- a/logilab/common/sphinx_ext.py +++ b/logilab/common/sphinx_ext.py @@ -19,30 +19,41 @@ from logilab.common.decorators import monkeypatch from sphinx.ext import autodoc + class DocstringOnlyModuleDocumenter(autodoc.ModuleDocumenter): - objtype = 'docstring' + objtype = "docstring" + def format_signature(self): pass + def add_directive_header(self, sig): pass + def document_members(self, all_members=False): pass def resolve_name(self, modname, parents, path, base): if modname is not None: return modname, parents + [base] - return (path or '') + base, [] + return (path or "") + base, [] + +# autodoc.add_documenter(DocstringOnlyModuleDocumenter) -#autodoc.add_documenter(DocstringOnlyModuleDocumenter) def setup(app): app.add_autodocumenter(DocstringOnlyModuleDocumenter) +from sphinx.ext.autodoc import ( + ViewList, + Options, + AutodocReporter, + nodes, + assemble_option_dict, + nested_parse_with_titles, +) -from sphinx.ext.autodoc import (ViewList, Options, AutodocReporter, nodes, - assemble_option_dict, nested_parse_with_titles) @monkeypatch(autodoc.AutoDirective) def run(self): @@ -56,8 +67,7 @@ def run(self): objtype = self.name[4:] doc_class = self._registry[objtype] # process the options with the selected documenter's option_spec - self.genopt = Options(assemble_option_dict( - self.options.items(), doc_class.option_spec)) + self.genopt = Options(assemble_option_dict(self.options.items(), doc_class.option_spec)) # generate the output documenter = doc_class(self, self.arguments[0]) documenter.generate(more_content=self.content) @@ -72,9 +82,8 @@ def run(self): # use a custom reporter that correctly assigns lines to source # filename/description and lineno old_reporter = self.state.memo.reporter - self.state.memo.reporter = AutodocReporter(self.result, - self.state.memo.reporter) - if self.name in ('automodule', 'autodocstring'): + self.state.memo.reporter = AutodocReporter(self.result, self.state.memo.reporter) + if self.name in ("automodule", "autodocstring"): node = nodes.section() # necessary so that the child nodes get the right source/line set node.document = self.state.document diff --git a/logilab/common/sphinxutils.py b/logilab/common/sphinxutils.py index ab6e8a1..350188d 100644 --- a/logilab/common/sphinxutils.py +++ b/logilab/common/sphinxutils.py @@ -37,18 +37,24 @@ from logilab.common import STD_BLACKLIST from logilab.common.shellutils import globfind from logilab.common.modutils import load_module_from_file, modpath_from_file + def module_members(module): members = [] for name, value in inspect.getmembers(module): - if getattr(value, '__module__', None) == module.__name__: - members.append( (name, value) ) + if getattr(value, "__module__", None) == module.__name__: + members.append((name, value)) return sorted(members) def class_members(klass): - return sorted([name for name in vars(klass) - if name not in ('__doc__', '__module__', - '__dict__', '__weakref__')]) + return sorted( + [ + name + for name in vars(klass) + if name not in ("__doc__", "__module__", "__dict__", "__weakref__") + ] + ) + class ModuleGenerator: file_header = """.. -*- coding: utf-8 -*-\n\n%s\n""" @@ -72,7 +78,7 @@ class ModuleGenerator: def generate(self, dest_file, exclude_dirs=STD_BLACKLIST): """make the module file""" - self.fn = open(dest_file, 'w') + self.fn = open(dest_file, "w") num = len(self.title) + 6 title = "=" * num + "\n %s API\n" % self.title + "=" * num self.fn.write(self.file_header % title) @@ -88,35 +94,34 @@ class ModuleGenerator: for objname, obj in module_members(module): if inspect.isclass(obj): classmembers = class_members(obj) - classes.append( (objname, classmembers) ) + classes.append((objname, classmembers)) else: modmembers.append(objname) - self.fn.write(self.module_def % (modname, '=' * len(modname), - modname, - ', '.join(modmembers))) + self.fn.write( + self.module_def % (modname, "=" * len(modname), modname, ", ".join(modmembers)) + ) for klass, members in classes: - self.fn.write(self.class_def % (klass, ', '.join(members))) + self.fn.write(self.class_def % (klass, ", ".join(members))) def find_modules(self, exclude_dirs): basepath = osp.dirname(self.code_dir) basedir = osp.basename(basepath) + osp.sep if basedir not in sys.path: sys.path.insert(1, basedir) - for filepath in globfind(self.code_dir, '*.py', exclude_dirs): - if osp.basename(filepath) in ('setup.py', '__pkginfo__.py'): + for filepath in globfind(self.code_dir, "*.py", exclude_dirs): + if osp.basename(filepath) in ("setup.py", "__pkginfo__.py"): continue try: module = load_module_from_file(filepath) - except: # module might be broken or magic + except: # module might be broken or magic dotted_path = modpath_from_file(filepath) - module = type('.'.join(dotted_path), (), {}) # mock it + module = type(".".join(dotted_path), (), {}) # mock it yield module -if __name__ == '__main__': +if __name__ == "__main__": # example : title, code_dir, outfile = sys.argv[1:] generator = ModuleGenerator(title, code_dir) # XXX modnames = ['logilab'] - generator.generate(outfile, ('test', 'tests', 'examples', - 'data', 'doc', '.hg', 'migration')) + generator.generate(outfile, ("test", "tests", "examples", "data", "doc", ".hg", "migration")) diff --git a/logilab/common/table.py b/logilab/common/table.py index e7b9195..983708b 100644 --- a/logilab/common/table.py +++ b/logilab/common/table.py @@ -34,7 +34,12 @@ class Table(object): forall(self.data, lambda x: len(x) <= len(self.col_names)) """ - def __init__(self, default_value: int = 0, col_names: Optional[List[str]] = None, row_names: Optional[Any] = None) -> None: + def __init__( + self, + default_value: int = 0, + col_names: Optional[List[str]] = None, + row_names: Optional[Any] = None, + ) -> None: self.col_names: List = [] self.row_names: List = [] self.data: List = [] @@ -45,7 +50,7 @@ class Table(object): self.create_rows(row_names) def _next_row_name(self) -> str: - return 'row%s' % (len(self.row_names)+1) + return "row%s" % (len(self.row_names) + 1) def __iter__(self) -> Iterator: return iter(self.data) @@ -83,7 +88,7 @@ class Table(object): """ self.row_names.extend(row_names) for row_name in row_names: - self.data.append([self.default_value]*len(self.col_names)) + self.data.append([self.default_value] * len(self.col_names)) def create_columns(self, col_names: List[str]) -> None: """Appends col_names to the list of existing columns @@ -96,8 +101,7 @@ class Table(object): """ row_name = row_name or self._next_row_name() self.row_names.append(row_name) - self.data.append([self.default_value]*len(self.col_names)) - + self.data.append([self.default_value] * len(self.col_names)) def create_column(self, col_name: str) -> None: """Creates a colname to the col_names list @@ -107,7 +111,7 @@ class Table(object): row.append(self.default_value) ## Sort by column ########################################################## - def sort_by_column_id(self, col_id: str, method: str = 'asc') -> None: + def sort_by_column_id(self, col_id: str, method: str = "asc") -> None: """Sorts the table (in-place) according to data stored in col_id """ try: @@ -116,17 +120,17 @@ class Table(object): except ValueError: raise KeyError("Col (%s) not found in table" % (col_id)) - - def sort_by_column_index(self, col_index: int, method: str = 'asc') -> None: + def sort_by_column_index(self, col_index: int, method: str = "asc") -> None: """Sorts the table 'in-place' according to data stored in col_index method should be in ('asc', 'desc') """ - sort_list = sorted([(row[col_index], row, row_name) - for row, row_name in zip(self.data, self.row_names)]) + sort_list = sorted( + [(row[col_index], row, row_name) for row, row_name in zip(self.data, self.row_names)] + ) # Sorting sort_list will sort according to col_index # If we want reverse sort, then reverse list - if method.lower() == 'desc': + if method.lower() == "desc": sort_list.reverse() # Rebuild data / row names @@ -136,8 +140,9 @@ class Table(object): self.data.append(row) self.row_names.append(row_name) - def groupby(self, colname: str, *others: str) -> Union[Dict[str, Dict[str, 'Table']], - Dict[str, 'Table']]: + def groupby( + self, colname: str, *others: str + ) -> Union[Dict[str, Dict[str, "Table"]], Dict[str, "Table"]]: """builds indexes of data :returns: nested dictionaries pointing to actual rows """ @@ -148,13 +153,14 @@ class Table(object): ptr = groups for col_index in col_indexes[:-1]: ptr = ptr.setdefault(row[col_index], {}) - table = ptr.setdefault(row[col_indexes[-1]], - Table(default_value=self.default_value, - col_names=self.col_names)) + table = ptr.setdefault( + row[col_indexes[-1]], + Table(default_value=self.default_value, col_names=self.col_names), + ) table.append_row(tuple(row)) return groups - def select(self, colname: str, value: str) -> 'Table': + def select(self, colname: str, value: str) -> "Table": grouped = self.groupby(colname) try: # mypy: Incompatible return value type (got "Union[Dict[str, Table], Table]", @@ -170,14 +176,12 @@ class Table(object): if row[col_index] == value: self.data.remove(row) - ## The 'setter' part ####################################################### def set_cell(self, row_index: int, col_index: int, data: int) -> None: """sets value of cell 'row_indew', 'col_index' to data """ self.data[row_index][col_index] = data - def set_cell_by_ids(self, row_id: str, col_id: str, data: Union[int, str]) -> None: """sets value of cell mapped by row_id and col_id to data Raises a KeyError if row_id or col_id are not found in the table @@ -193,7 +197,6 @@ class Table(object): except ValueError: raise KeyError("Column (%s) not found in table" % (col_id)) - def set_row(self, row_index: int, row_data: Union[List[float], List[int], List[str]]) -> None: """sets the 'row_index' row pre:: @@ -203,7 +206,6 @@ class Table(object): """ self.data[row_index] = row_data - def set_row_by_id(self, row_id: str, row_data: List[str]) -> None: """sets the 'row_id' column pre:: @@ -217,10 +219,11 @@ class Table(object): row_index = self.row_names.index(row_id) self.set_row(row_index, row_data) except ValueError: - raise KeyError('Row (%s) not found in table' % (row_id)) - + raise KeyError("Row (%s) not found in table" % (row_id)) - def append_row(self, row_data: Union[List[Union[float, str]], List[int]], row_name: Optional[str] = None) -> int: + def append_row( + self, row_data: Union[List[Union[float, str]], List[int]], row_name: Optional[str] = None + ) -> int: """Appends a row to the table pre:: @@ -245,7 +248,6 @@ class Table(object): self.row_names.insert(index, row_name) self.data.insert(index, row_data) - def delete_row(self, index: int) -> List[str]: """Deletes the 'index' row in the table, and returns it. Raises an IndexError if index is out of range @@ -253,7 +255,6 @@ class Table(object): self.row_names.pop(index) return self.data.pop(index) - def delete_row_by_id(self, row_id: str) -> None: """Deletes the 'row_id' row in the table. Raises a KeyError if row_id was not found. @@ -262,8 +263,7 @@ class Table(object): row_index = self.row_names.index(row_id) self.delete_row(row_index) except ValueError: - raise KeyError('Row (%s) not found in table' % (row_id)) - + raise KeyError("Row (%s) not found in table" % (row_id)) def set_column(self, col_index: int, col_data: Union[List[int], range]) -> None: """sets the 'col_index' column @@ -276,7 +276,6 @@ class Table(object): for row_index, cell_data in enumerate(col_data): self.data[row_index][col_index] = cell_data - def set_column_by_id(self, col_id: str, col_data: Union[List[int], range]) -> None: """sets the 'col_id' column pre:: @@ -290,8 +289,7 @@ class Table(object): col_index = self.col_names.index(col_id) self.set_column(col_index, col_data) except ValueError: - raise KeyError('Column (%s) not found in table' % (col_id)) - + raise KeyError("Column (%s) not found in table" % (col_id)) def append_column(self, col_data: range, col_name: str) -> None: """Appends the 'col_index' column @@ -304,7 +302,6 @@ class Table(object): for row_index, cell_data in enumerate(col_data): self.data[row_index].append(cell_data) - def insert_column(self, index: int, col_data: range, col_name: str) -> None: """Appends col_data before 'index' in the table. To make 'insert' behave like 'list.insert', inserting in an out of range index will @@ -318,7 +315,6 @@ class Table(object): for row_index, cell_data in enumerate(col_data): self.data[row_index].insert(index, cell_data) - def delete_column(self, index: int) -> List[int]: """Deletes the 'index' column in the table, and returns it. Raises an IndexError if index is out of range @@ -326,7 +322,6 @@ class Table(object): self.col_names.pop(index) return [row.pop(index) for row in self.data] - def delete_column_by_id(self, col_id: str) -> None: """Deletes the 'col_id' col in the table. Raises a KeyError if col_id was not found. @@ -335,8 +330,7 @@ class Table(object): col_index = self.col_names.index(col_id) self.delete_column(col_index) except ValueError: - raise KeyError('Column (%s) not found in table' % (col_id)) - + raise KeyError("Column (%s) not found in table" % (col_id)) ## The 'getter' part ####################################################### @@ -344,9 +338,12 @@ class Table(object): """Returns a tuple which represents the table's shape """ return len(self.row_names), len(self.col_names) + shape = property(get_shape) - def __getitem__(self, indices: Union[Tuple[Union[int, slice, str], Union[int, str]], int, slice]) -> Any: + def __getitem__( + self, indices: Union[Tuple[Union[int, slice, str], Union[int, str]], int, slice] + ) -> Any: """provided for convenience""" multirows: bool = False multicols: bool = False @@ -402,7 +399,7 @@ class Table(object): for idx, row in enumerate(self.data[rows]): tab.set_row(idx, row[cols]) - if multirows : + if multirows: if multicols: return tab else: @@ -457,14 +454,13 @@ class Table(object): col = list(set(col)) return col - def apply_stylesheet(self, stylesheet: 'TableStyleSheet') -> None: + def apply_stylesheet(self, stylesheet: "TableStyleSheet") -> None: """Applies the stylesheet to this table """ for instruction in stylesheet.instructions: eval(instruction) - - def transpose(self) -> 'Table': + def transpose(self) -> "Table": """Keeps the self object intact, and returns the transposed (rotated) table. """ @@ -475,7 +471,6 @@ class Table(object): transposed.set_row(col_index, column) return transposed - def pprint(self) -> str: """returns a string representing the table in a pretty printed 'text' format. @@ -490,10 +485,10 @@ class Table(object): lines = [] # Build the 'first' line <=> the col_names one # The first cell <=> an empty one - col_names_line = [' '*col_start] + col_names_line = [" " * col_start] for col_name in self.col_names: - col_names_line.append(col_name + ' '*5) - lines.append('|' + '|'.join(col_names_line) + '|') + col_names_line.append(col_name + " " * 5) + lines.append("|" + "|".join(col_names_line) + "|") max_line_length = len(lines[0]) # Build the table @@ -501,22 +496,21 @@ class Table(object): line = [] # First, build the row_name's cell row_name = self.row_names[row_index] - line.append(row_name + ' '*(col_start-len(row_name))) + line.append(row_name + " " * (col_start - len(row_name))) # Then, build all the table's cell for this line. for col_index, cell in enumerate(row): col_name_length = len(self.col_names[col_index]) + 5 data = str(cell) - line.append(data + ' '*(col_name_length - len(data))) - lines.append('|' + '|'.join(line) + '|') + line.append(data + " " * (col_name_length - len(data))) + lines.append("|" + "|".join(line) + "|") if len(lines[-1]) > max_line_length: max_line_length = len(lines[-1]) # Wrap the table with '-' to make a frame - lines.insert(0, '-'*max_line_length) - lines.append('-'*max_line_length) - return '\n'.join(lines) - + lines.insert(0, "-" * max_line_length) + lines.append("-" * max_line_length) + return "\n".join(lines) def __repr__(self) -> str: return repr(self.data) @@ -526,9 +520,8 @@ class Table(object): # We must convert cells into strings before joining them for row in self.data: data.append([str(cell) for cell in row]) - lines = ['\t'.join(row) for row in data] - return '\n'.join(lines) - + lines = ["\t".join(row) for row in data] + return "\n".join(lines) class TableStyle: @@ -538,18 +531,17 @@ class TableStyle: def __init__(self, table: Table) -> None: self._table = table - self.size = dict([(col_name, '1*') for col_name in table.col_names]) + self.size = dict([(col_name, "1*") for col_name in table.col_names]) # __row_column__ is a special key to define the first column which # actually has no name (<=> left most column <=> row names column) - self.size['__row_column__'] = '1*' - self.alignment = dict([(col_name, 'right') - for col_name in table.col_names]) - self.alignment['__row_column__'] = 'right' + self.size["__row_column__"] = "1*" + self.alignment = dict([(col_name, "right") for col_name in table.col_names]) + self.alignment["__row_column__"] = "right" # We shouldn't have to create an entry for # the 1st col (the row_column one) - self.units = dict([(col_name, '') for col_name in table.col_names]) - self.units['__row_column__'] = '' + self.units = dict([(col_name, "") for col_name in table.col_names]) + self.units["__row_column__"] = "" # XXX FIXME : params order should be reversed for all set() methods def set_size(self, value: str, col_id: str) -> None: @@ -563,38 +555,34 @@ class TableStyle: BE CAREFUL : the '0' column is the '__row_column__' one ! """ if col_index == 0: - col_id = '__row_column__' + col_id = "__row_column__" else: - col_id = self._table.col_names[col_index-1] + col_id = self._table.col_names[col_index - 1] self.size[col_id] = value - def set_alignment(self, value: str, col_id: str) -> None: """sets the alignment of the specified col_id to value """ self.alignment[col_id] = value - def set_alignment_by_index(self, value: str, col_index: int) -> None: """Allows to set the alignment according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! """ if col_index == 0: - col_id = '__row_column__' + col_id = "__row_column__" else: - col_id = self._table.col_names[col_index-1] + col_id = self._table.col_names[col_index - 1] self.alignment[col_id] = value - def set_unit(self, value: str, col_id: str) -> None: """sets the unit of the specified col_id to value """ self.units[col_id] = value - def set_unit_by_index(self, value: str, col_index: int) -> None: """Allows to set the unit according to the column index rather than using the column's id. @@ -603,73 +591,69 @@ class TableStyle: for the 1st column (the __row__column__ one)) """ if col_index == 0: - col_id = '__row_column__' + col_id = "__row_column__" else: - col_id = self._table.col_names[col_index-1] + col_id = self._table.col_names[col_index - 1] self.units[col_id] = value - def get_size(self, col_id: str) -> str: """Returns the size of the specified col_id """ return self.size[col_id] - def get_size_by_index(self, col_index: int) -> str: """Allows to get the size according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! """ if col_index == 0: - col_id = '__row_column__' + col_id = "__row_column__" else: - col_id = self._table.col_names[col_index-1] + col_id = self._table.col_names[col_index - 1] return self.size[col_id] - def get_alignment(self, col_id: str) -> str: """Returns the alignment of the specified col_id """ return self.alignment[col_id] - def get_alignment_by_index(self, col_index: int) -> str: """Allors to get the alignment according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! """ if col_index == 0: - col_id = '__row_column__' + col_id = "__row_column__" else: - col_id = self._table.col_names[col_index-1] + col_id = self._table.col_names[col_index - 1] return self.alignment[col_id] - def get_unit(self, col_id: str) -> str: """Returns the unit of the specified col_id """ return self.units[col_id] - def get_unit_by_index(self, col_index: int) -> str: """Allors to get the unit according to the column index rather than using the column's id. BE CAREFUL : the '0' column is the '__row_column__' one ! """ if col_index == 0: - col_id = '__row_column__' + col_id = "__row_column__" else: - col_id = self._table.col_names[col_index-1] + col_id = self._table.col_names[col_index - 1] return self.units[col_id] import re + CELL_PROG = re.compile("([0-9]+)_([0-9]+)") + class TableStyleSheet: """A simple Table stylesheet Rules are expressions where cells are defined by the row_index @@ -694,21 +678,20 @@ class TableStyleSheet: for rule in rules: self.add_rule(rule) - def add_rule(self, rule: str) -> None: """Adds a rule to the stylesheet rules """ try: - source_code = ['from math import *'] - source_code.append(CELL_PROG.sub(r'self.data[\1][\2]', rule)) - self.instructions.append(compile('\n'.join(source_code), - 'table.py', 'exec')) + source_code = ["from math import *"] + source_code.append(CELL_PROG.sub(r"self.data[\1][\2]", rule)) + self.instructions.append(compile("\n".join(source_code), "table.py", "exec")) self.rules.append(rule) except SyntaxError: print("Bad Stylesheet Rule : %s [skipped]" % rule) - - def add_rowsum_rule(self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int) -> None: + def add_rowsum_rule( + self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int + ) -> None: """Creates and adds a rule to sum over the row at row_index from start_col to end_col. dest_cell is a tuple of two elements (x,y) of the destination cell @@ -718,13 +701,13 @@ class TableStyleSheet: start_col >= 0 end_col > start_col """ - cell_list = ['%d_%d'%(row_index, index) for index in range(start_col, - end_col + 1)] - rule = '%d_%d=' % dest_cell + '+'.join(cell_list) + cell_list = ["%d_%d" % (row_index, index) for index in range(start_col, end_col + 1)] + rule = "%d_%d=" % dest_cell + "+".join(cell_list) self.add_rule(rule) - - def add_rowavg_rule(self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int) -> None: + def add_rowavg_rule( + self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int + ) -> None: """Creates and adds a rule to make the row average (from start_col to end_col) dest_cell is a tuple of two elements (x,y) of the destination cell @@ -734,14 +717,14 @@ class TableStyleSheet: start_col >= 0 end_col > start_col """ - cell_list = ['%d_%d'%(row_index, index) for index in range(start_col, - end_col + 1)] - num = (end_col - start_col + 1) - rule = '%d_%d=' % dest_cell + '('+'+'.join(cell_list)+')/%f'%num + cell_list = ["%d_%d" % (row_index, index) for index in range(start_col, end_col + 1)] + num = end_col - start_col + 1 + rule = "%d_%d=" % dest_cell + "(" + "+".join(cell_list) + ")/%f" % num self.add_rule(rule) - - def add_colsum_rule(self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int) -> None: + def add_colsum_rule( + self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int + ) -> None: """Creates and adds a rule to sum over the col at col_index from start_row to end_row. dest_cell is a tuple of two elements (x,y) of the destination cell @@ -751,13 +734,13 @@ class TableStyleSheet: start_row >= 0 end_row > start_row """ - cell_list = ['%d_%d'%(index, col_index) for index in range(start_row, - end_row + 1)] - rule = '%d_%d=' % dest_cell + '+'.join(cell_list) + cell_list = ["%d_%d" % (index, col_index) for index in range(start_row, end_row + 1)] + rule = "%d_%d=" % dest_cell + "+".join(cell_list) self.add_rule(rule) - - def add_colavg_rule(self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int) -> None: + def add_colavg_rule( + self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int + ) -> None: """Creates and adds a rule to make the col average (from start_row to end_row) dest_cell is a tuple of two elements (x,y) of the destination cell @@ -767,14 +750,12 @@ class TableStyleSheet: start_row >= 0 end_row > start_row """ - cell_list = ['%d_%d'%(index, col_index) for index in range(start_row, - end_row + 1)] - num = (end_row - start_row + 1) - rule = '%d_%d=' % dest_cell + '('+'+'.join(cell_list)+')/%f'%num + cell_list = ["%d_%d" % (index, col_index) for index in range(start_row, end_row + 1)] + num = end_row - start_row + 1 + rule = "%d_%d=" % dest_cell + "(" + "+".join(cell_list) + ")/%f" % num self.add_rule(rule) - class TableCellRenderer: """Defines a simple text renderer """ @@ -789,35 +770,36 @@ class TableCellRenderer: """ self.properties = properties - - def render_cell(self, cell_coord: Tuple[int, int], table: Table, table_style: TableStyle) -> Union[str, int]: + def render_cell( + self, cell_coord: Tuple[int, int], table: Table, table_style: TableStyle + ) -> Union[str, int]: """Renders the cell at 'cell_coord' in the table, using table_style """ row_index, col_index = cell_coord cell_value = table.data[row_index][col_index] - final_content = self._make_cell_content(cell_value, - table_style, col_index +1) - return self._render_cell_content(final_content, - table_style, col_index + 1) - + final_content = self._make_cell_content(cell_value, table_style, col_index + 1) + return self._render_cell_content(final_content, table_style, col_index + 1) - def render_row_cell(self, row_name: str, table: Table, table_style: TableStyle) -> Union[str, int]: + def render_row_cell( + self, row_name: str, table: Table, table_style: TableStyle + ) -> Union[str, int]: """Renders the cell for 'row_id' row """ cell_value = row_name return self._render_cell_content(cell_value, table_style, 0) - - def render_col_cell(self, col_name: str, table: Table, table_style: TableStyle) -> Union[str, int]: + def render_col_cell( + self, col_name: str, table: Table, table_style: TableStyle + ) -> Union[str, int]: """Renders the cell for 'col_id' row """ cell_value = col_name col_index = table.col_names.index(col_name) - return self._render_cell_content(cell_value, table_style, col_index +1) + return self._render_cell_content(cell_value, table_style, col_index + 1) - - - def _render_cell_content(self, content: Union[str, int], table_style: TableStyle, col_index: int) -> Union[str, int]: + def _render_cell_content( + self, content: Union[str, int], table_style: TableStyle, col_index: int + ) -> Union[str, int]: """Makes the appropriate rendering for this cell content. Rendering properties will be searched using the *table_style.get_xxx_by_index(col_index)' methods @@ -826,31 +808,30 @@ class TableCellRenderer: """ return content - - def _make_cell_content(self, cell_content: int, table_style: TableStyle, col_index: int) -> Union[int, str]: + def _make_cell_content( + self, cell_content: int, table_style: TableStyle, col_index: int + ) -> Union[int, str]: """Makes the cell content (adds decoration data, like units for example) """ final_content: Union[int, str] = cell_content - if 'skip_zero' in self.properties: - replacement_char = self.properties['skip_zero'] + if "skip_zero" in self.properties: + replacement_char = self.properties["skip_zero"] else: replacement_char = 0 if replacement_char and final_content == 0: return replacement_char try: - units_on = self.properties['units'] + units_on = self.properties["units"] if units_on: - final_content = self._add_unit( - cell_content, table_style, col_index) + final_content = self._add_unit(cell_content, table_style, col_index) except KeyError: pass return final_content - def _add_unit(self, cell_content: int, table_style: TableStyle, col_index: int) -> str: """Adds unit to the cell_content if needed """ @@ -858,7 +839,6 @@ class TableCellRenderer: return str(cell_content) + " " + unit - class DocbookRenderer(TableCellRenderer): """Defines how to render a cell for a docboook table """ @@ -867,21 +847,20 @@ class DocbookRenderer(TableCellRenderer): """Computes the colspec element according to the style """ size = table_style.get_size_by_index(col_index) - return '<colspec colname="c%d" colwidth="%s"/>\n' % \ - (col_index, size) - + return '<colspec colname="c%d" colwidth="%s"/>\n' % (col_index, size) - def _render_cell_content(self, cell_content: Union[int, str], table_style: TableStyle, col_index: int) -> str: + def _render_cell_content( + self, cell_content: Union[int, str], table_style: TableStyle, col_index: int + ) -> str: """Makes the appropriate rendering for this cell content. Rendering properties will be searched using the table_style.get_xxx_by_index(col_index)' methods. """ try: - align_on = self.properties['alignment'] + align_on = self.properties["alignment"] alignment = table_style.get_alignment_by_index(col_index) if align_on: - return "<entry align='%s'>%s</entry>\n" % \ - (alignment, cell_content) + return "<entry align='%s'>%s</entry>\n" % (alignment, cell_content) except KeyError: # KeyError <=> Default alignment return "<entry>%s</entry>\n" % cell_content @@ -894,39 +873,36 @@ class TableWriter: """A class to write tables """ - def __init__(self, stream: StringIO, table: Table, style: Optional[Any], **properties: Any) -> None: + def __init__( + self, stream: StringIO, table: Table, style: Optional[Any], **properties: Any + ) -> None: self._stream = stream self.style = style or TableStyle(table) self._table = table self.properties = properties self.renderer: Optional[DocbookRenderer] = None - def set_style(self, style): """sets the table's associated style """ self.style = style - def set_renderer(self, renderer: DocbookRenderer) -> None: """sets the way to render cell """ self.renderer = renderer - def update_properties(self, **properties): """Updates writer's properties (for cell rendering) """ self.properties.update(properties) - def write_table(self, title: str = "") -> None: """Writes the table """ raise NotImplementedError("write_table must be implemented !") - class DocbookTableWriter(TableWriter): """Defines an implementation of TableWriter to write a table in Docbook """ @@ -937,56 +913,48 @@ class DocbookTableWriter(TableWriter): assert self.renderer is not None # Define col_headers (colstpec elements) - for col_index in range(len(self._table.col_names)+1): - self._stream.write(self.renderer.define_col_header(col_index, - self.style)) + for col_index in range(len(self._table.col_names) + 1): + self._stream.write(self.renderer.define_col_header(col_index, self.style)) self._stream.write("<thead>\n<row>\n") # XXX FIXME : write an empty entry <=> the first (__row_column) column - self._stream.write('<entry></entry>\n') + self._stream.write("<entry></entry>\n") for col_name in self._table.col_names: - self._stream.write(self.renderer.render_col_cell( - col_name, self._table, - self.style)) + self._stream.write(self.renderer.render_col_cell(col_name, self._table, self.style)) self._stream.write("</row>\n</thead>\n") - def _write_body(self) -> None: """Writes the table body """ assert self.renderer is not None - self._stream.write('<tbody>\n') + self._stream.write("<tbody>\n") for row_index, row in enumerate(self._table.data): - self._stream.write('<row>\n') + self._stream.write("<row>\n") row_name = self._table.row_names[row_index] # Write the first entry (row_name) - self._stream.write(self.renderer.render_row_cell(row_name, - self._table, - self.style)) + self._stream.write(self.renderer.render_row_cell(row_name, self._table, self.style)) for col_index, cell in enumerate(row): - self._stream.write(self.renderer.render_cell( - (row_index, col_index), - self._table, self.style)) + self._stream.write( + self.renderer.render_cell((row_index, col_index), self._table, self.style) + ) - self._stream.write('</row>\n') - - self._stream.write('</tbody>\n') + self._stream.write("</row>\n") + self._stream.write("</tbody>\n") def write_table(self, title: str = "") -> None: """Writes the table """ - self._stream.write('<table>\n<title>%s></title>\n'%(title)) + self._stream.write("<table>\n<title>%s></title>\n" % (title)) self._stream.write( - '<tgroup cols="%d" align="left" colsep="1" rowsep="1">\n'% - (len(self._table.col_names)+1)) + '<tgroup cols="%d" align="left" colsep="1" rowsep="1">\n' + % (len(self._table.col_names) + 1) + ) self._write_headers() self._write_body() - self._stream.write('</tgroup>\n</table>\n') - - + self._stream.write("</tgroup>\n</table>\n") diff --git a/logilab/common/tasksqueue.py b/logilab/common/tasksqueue.py index 4e3434e..0d4889d 100644 --- a/logilab/common/tasksqueue.py +++ b/logilab/common/tasksqueue.py @@ -29,22 +29,21 @@ MEDIUM = 10 HIGH = 100 PRIORITY = { - 'LOW': LOW, - 'MEDIUM': MEDIUM, - 'HIGH': HIGH, - } + "LOW": LOW, + "MEDIUM": MEDIUM, + "HIGH": HIGH, +} REVERSE_PRIORITY = dict((values, key) for key, values in PRIORITY.items()) class PrioritizedTasksQueue(queue.Queue): - def _init(self, maxsize: int) -> None: """Initialize the queue representation""" self.maxsize = maxsize # ordered list of task, from the lowest to the highest priority - self.queue: List['Task'] = [] # type: ignore + self.queue: List["Task"] = [] # type: ignore - def _put(self, item: 'Task') -> None: + def _put(self, item: "Task") -> None: """Put a new item in the queue""" for i, task in enumerate(self.queue): # equivalent task @@ -60,11 +59,11 @@ class PrioritizedTasksQueue(queue.Queue): return insort_left(self.queue, item) - def _get(self) -> 'Task': + def _get(self) -> "Task": """Get an item from the queue""" return self.queue.pop() - def __iter__(self) -> Iterator['Task']: + def __iter__(self) -> Iterator["Task"]: return iter(self.queue) def remove(self, tid: str) -> None: @@ -74,7 +73,8 @@ class PrioritizedTasksQueue(queue.Queue): if task.id == tid: self.queue.pop(i) return - raise ValueError('not task of id %s in queue' % tid) + raise ValueError("not task of id %s in queue" % tid) + class Task: def __init__(self, tid: str, priority: int = LOW) -> None: @@ -84,9 +84,9 @@ class Task: self.priority = priority def __repr__(self) -> str: - return '<Task %s @%#x>' % (self.id, id(self)) + return "<Task %s @%#x>" % (self.id, id(self)) - def __lt__(self, other: 'Task') -> bool: + def __lt__(self, other: "Task") -> bool: return self.priority < other.priority def __eq__(self, other: object) -> bool: @@ -94,5 +94,5 @@ class Task: __hash__ = object.__hash__ - def merge(self, other: 'Task') -> None: + def merge(self, other: "Task") -> None: pass diff --git a/logilab/common/testlib.py b/logilab/common/testlib.py index 8348900..f8401c4 100644 --- a/logilab/common/testlib.py +++ b/logilab/common/testlib.py @@ -64,6 +64,7 @@ import configparser from logilab.common.deprecation import class_deprecated, deprecated import unittest as unittest_legacy + if not getattr(unittest_legacy, "__package__", None): try: import unittest2 as unittest @@ -83,22 +84,22 @@ from logilab.common.decorators import cached, classproperty from logilab.common import textutils -__all__ = ['unittest_main', 'find_tests', 'nocoverage', 'pause_trace'] +__all__ = ["unittest_main", "find_tests", "nocoverage", "pause_trace"] -DEFAULT_PREFIXES = ('test', 'regrtest', 'smoketest', 'unittest', - 'func', 'validation') +DEFAULT_PREFIXES = ("test", "regrtest", "smoketest", "unittest", "func", "validation") -is_generator = deprecated('[lgc 0.63] use inspect.isgeneratorfunction')(isgeneratorfunction) +is_generator = deprecated("[lgc 0.63] use inspect.isgeneratorfunction")(isgeneratorfunction) # used by unittest to count the number of relevant levels in the traceback __unittest = 1 -@deprecated('with_tempdir is deprecated, use tempfile.TemporaryDirectory.') +@deprecated("with_tempdir is deprecated, use tempfile.TemporaryDirectory.") def with_tempdir(callable: Callable) -> Callable: """A decorator ensuring no temporary file left when the function return Work only for temporary file created with the tempfile module""" if isgeneratorfunction(callable): + def proxy(*args: Any, **kwargs: Any) -> Iterator[Union[Iterator, Iterator[str]]]: old_tmpdir = tempfile.gettempdir() new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-") @@ -111,9 +112,11 @@ def with_tempdir(callable: Callable) -> Callable: rmtree(new_tmpdir, ignore_errors=True) finally: tempfile.tempdir = old_tmpdir + return proxy else: + @wraps(callable) def proxy(*args: Any, **kargs: Any) -> Any: @@ -127,11 +130,14 @@ def with_tempdir(callable: Callable) -> Callable: rmtree(new_tmpdir, ignore_errors=True) finally: tempfile.tempdir = old_tmpdir + return proxy + def in_tempdir(callable): """A decorator moving the enclosed function inside the tempfile.tempfdir """ + @wraps(callable) def proxy(*args, **kargs): @@ -141,8 +147,10 @@ def in_tempdir(callable): return callable(*args, **kargs) finally: os.chdir(old_cwd) + return proxy + def within_tempdir(callable): """A decorator run the enclosed function inside a tmpdir removed after execution """ @@ -150,10 +158,8 @@ def within_tempdir(callable): proxy.__name__ = callable.__name__ return proxy -def find_tests(testdir, - prefixes=DEFAULT_PREFIXES, suffix=".py", - excludes=(), - remove_suffix=True): + +def find_tests(testdir, prefixes=DEFAULT_PREFIXES, suffix=".py", excludes=(), remove_suffix=True): """ Return a list of all applicable test modules. """ @@ -163,7 +169,7 @@ def find_tests(testdir, for prefix in prefixes: if name.startswith(prefix): if remove_suffix and name.endswith(suffix): - name = name[:-len(suffix)] + name = name[: -len(suffix)] if name not in excludes: tests.append(name) tests.sort() @@ -184,13 +190,12 @@ def start_interactive_mode(result): testindex = 0 print("Choose a test to debug:") # order debuggers in the same way than errors were printed - print("\n".join(['\t%s : %s' % (i, descr) for i, (_, descr) - in enumerate(descrs)])) + print("\n".join(["\t%s : %s" % (i, descr) for i, (_, descr) in enumerate(descrs)])) print("Type 'exit' (or ^D) to quit") print() try: - todebug = input('Enter a test name: ') - if todebug.strip().lower() == 'exit': + todebug = input("Enter a test name: ") + if todebug.strip().lower() == "exit": print() break else: @@ -198,7 +203,7 @@ def start_interactive_mode(result): testindex = int(todebug) debugger = debuggers[descrs[testindex][0]] except (ValueError, IndexError): - print("ERROR: invalid test number %r" % (todebug, )) + print("ERROR: invalid test number %r" % (todebug,)) else: debugger.start() except (EOFError, KeyboardInterrupt): @@ -208,6 +213,7 @@ def start_interactive_mode(result): # coverage pausing tools ##################################################### + @contextmanager def replace_trace(trace: Optional[Callable] = None) -> Iterator: """A context manager that temporary replaces the trace function""" @@ -218,8 +224,7 @@ def replace_trace(trace: Optional[Callable] = None) -> Iterator: finally: # specific hack to work around a bug in pycoverage, see # https://bitbucket.org/ned/coveragepy/issue/123 - if (oldtrace is not None and not callable(oldtrace) and - hasattr(oldtrace, 'pytrace')): + if oldtrace is not None and not callable(oldtrace) and hasattr(oldtrace, "pytrace"): oldtrace = oldtrace.pytrace sys.settrace(oldtrace) @@ -229,7 +234,7 @@ pause_trace = replace_trace def nocoverage(func: Callable) -> Callable: """Function decorator that pauses tracing functions""" - if hasattr(func, 'uncovered'): + if hasattr(func, "uncovered"): return func # mypy: "Callable[..., Any]" has no attribute "uncovered" # dynamic attribute for magic @@ -238,6 +243,7 @@ def nocoverage(func: Callable) -> Callable: def not_covered(*args: Any, **kwargs: Any) -> Any: with pause_trace(): return func(*args, **kwargs) + # mypy: "Callable[[VarArg(Any), KwArg(Any)], NoReturn]" has no attribute "uncovered" # dynamic attribute for magic not_covered.uncovered = True # type: ignore @@ -249,49 +255,56 @@ def nocoverage(func: Callable) -> Callable: # Add deprecation warnings about new api used by module level fixtures in unittest2 # http://www.voidspace.org.uk/python/articles/unittest2.shtml#setupmodule-and-teardownmodule -class _DebugResult(object): # simplify import statement among unittest flavors.. +class _DebugResult(object): # simplify import statement among unittest flavors.. "Used by the TestSuite to hold previous class when running in debug." _previousTestClass = None _moduleSetUpFailed = False shouldStop = False + # backward compatibility: TestSuite might be imported from lgc.testlib TestSuite = unittest.TestSuite + class keywords(dict): """Keyword args (**kwargs) support for generative tests.""" + class starargs(tuple): """Variable arguments (*args) for generative tests.""" + def __new__(cls, *args): return tuple.__new__(cls, args) + unittest_main = unittest.main class InnerTestSkipped(SkipTest): """raised when a test is skipped""" + pass + def parse_generative_args(params: Tuple[int, ...]) -> Tuple[Union[List[bool], List[int]], Dict]: args = [] varargs = () kwargs: Dict = {} - flags = 0 # 2 <=> starargs, 4 <=> kwargs + flags = 0 # 2 <=> starargs, 4 <=> kwargs for param in params: if isinstance(param, starargs): varargs = param if flags: - raise TypeError('found starargs after keywords !') + raise TypeError("found starargs after keywords !") flags |= 2 args += list(varargs) elif isinstance(param, keywords): kwargs = param if flags & 4: - raise TypeError('got multiple keywords parameters') + raise TypeError("got multiple keywords parameters") flags |= 4 elif flags & 2 or flags & 4: - raise TypeError('found parameters after kwargs or args') + raise TypeError("found parameters after kwargs or args") else: args.append(param) @@ -304,13 +317,14 @@ class InnerTest(tuple): instance.name = name return instance + class Tags(set): """A set of tag able validate an expression""" def __init__(self, *tags: str, **kwargs: Any) -> None: - self.inherit = kwargs.pop('inherit', True) + self.inherit = kwargs.pop("inherit", True) if kwargs: - raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys()) + raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys()) if len(tags) == 1 and not isinstance(tags[0], str): tags = tags[0] @@ -328,25 +342,26 @@ class Tags(set): # mypy: Argument 1 of "__or__" is incompatible with supertype "AbstractSet"; # mypy: supertype defines the argument type as "AbstractSet[_T]" # not sure how to fix this one - def __or__(self, other: 'Tags') -> 'Tags': # type: ignore + def __or__(self, other: "Tags") -> "Tags": # type: ignore return Tags(*super(Tags, self).__or__(other)) # duplicate definition from unittest2 of the _deprecate decorator def _deprecate(original_func): def deprecated_func(*args, **kwargs): - warnings.warn( - ('Please use %s instead.' % original_func.__name__), - DeprecationWarning, 2) + warnings.warn(("Please use %s instead." % original_func.__name__), DeprecationWarning, 2) return original_func(*args, **kwargs) + return deprecated_func + class TestCase(unittest.TestCase): """A unittest.TestCase extension with some additional methods.""" + maxDiff = None tags = Tags() - def __init__(self, methodName: str = 'runTest') -> None: + def __init__(self, methodName: str = "runTest") -> None: super(TestCase, self).__init__(methodName) self.__exc_info = sys.exc_info self.__testMethodName = self._testMethodName @@ -355,13 +370,14 @@ class TestCase(unittest.TestCase): @classproperty @cached - def datadir(cls) -> str: # pylint: disable=E0213 + def datadir(cls) -> str: # pylint: disable=E0213 """helper attribute holding the standard test's data directory NOTE: this is a logilab's standard """ mod = sys.modules[cls.__module__] - return osp.join(osp.dirname(osp.abspath(mod.__file__)), 'data') + return osp.join(osp.dirname(osp.abspath(mod.__file__)), "data") + # cache it (use a class method to cache on class since TestCase is # instantiated for each test run) @@ -392,11 +408,12 @@ class TestCase(unittest.TestCase): except (KeyboardInterrupt, SystemExit): raise except unittest.SkipTest as e: - if hasattr(result, 'addSkip'): + if hasattr(result, "addSkip"): result.addSkip(self, str(e)) else: - warnings.warn("TestResult has no addSkip method, skips not reported", - RuntimeWarning, 2) + warnings.warn( + "TestResult has no addSkip method, skips not reported", RuntimeWarning, 2 + ) result.addSuccess(self) return False except: @@ -423,23 +440,26 @@ class TestCase(unittest.TestCase): # if result.cvg: # result.cvg.start() testMethod = self._get_test_method() - if (getattr(self.__class__, "__unittest_skip__", False) or - getattr(testMethod, "__unittest_skip__", False)): + if getattr(self.__class__, "__unittest_skip__", False) or getattr( + testMethod, "__unittest_skip__", False + ): # If the class or method was skipped. try: - skip_why = (getattr(self.__class__, '__unittest_skip_why__', '') - or getattr(testMethod, '__unittest_skip_why__', '')) - if hasattr(result, 'addSkip'): + skip_why = getattr(self.__class__, "__unittest_skip_why__", "") or getattr( + testMethod, "__unittest_skip_why__", "" + ) + if hasattr(result, "addSkip"): result.addSkip(self, skip_why) else: - warnings.warn("TestResult has no addSkip method, skips not reported", - RuntimeWarning, 2) + warnings.warn( + "TestResult has no addSkip method, skips not reported", RuntimeWarning, 2 + ) result.addSuccess(self) finally: result.stopTest(self) return if runcondition and not runcondition(testMethod): - return # test is skipped + return # test is skipped result.startTest(self) try: if not self.quiet_run(result, self.setUp): @@ -447,11 +467,10 @@ class TestCase(unittest.TestCase): generative = isgeneratorfunction(testMethod) # generative tests if generative: - self._proceed_generative(result, testMethod, - runcondition) + self._proceed_generative(result, testMethod, runcondition) else: status = self._proceed(result, testMethod) - success = (status == 0) + success = status == 0 if not self.quiet_run(result, self.tearDown): return if not generative and success: @@ -461,19 +480,19 @@ class TestCase(unittest.TestCase): # result.cvg.stop() result.stopTest(self) - def _proceed_generative(self, result: Any, testfunc: Callable, runcondition: Callable = None) -> bool: + def _proceed_generative( + self, result: Any, testfunc: Callable, runcondition: Callable = None + ) -> bool: # cancel startTest()'s increment result.testsRun -= 1 success = True try: for params in testfunc(): - if runcondition and not runcondition(testfunc, - skipgenerator=False): - if not (isinstance(params, InnerTest) - and runcondition(params)): + if runcondition and not runcondition(testfunc, skipgenerator=False): + if not (isinstance(params, InnerTest) and runcondition(params)): continue if not isinstance(params, (tuple, list)): - params = (params, ) + params = (params,) func = params[0] args, kwargs = parse_generative_args(params[1:]) # increment test counter manually @@ -485,9 +504,9 @@ class TestCase(unittest.TestCase): else: success = False # XXX Don't stop anymore if an error occured - #if status == 2: + # if status == 2: # result.shouldStop = True - if result.shouldStop: # either on error or on exitfirst + error + if result.shouldStop: # either on error or on exitfirst + error break except self.failureException: result.addFailure(self, self.__exc_info()) @@ -500,7 +519,13 @@ class TestCase(unittest.TestCase): success = False return success - def _proceed(self, result: Any, testfunc: Callable, args: Union[List[bool], List[int], Tuple[()]] = (), kwargs: Optional[Dict] = None) -> int: + def _proceed( + self, + result: Any, + testfunc: Callable, + args: Union[List[bool], List[int], Tuple[()]] = (), + kwargs: Optional[Dict] = None, + ) -> int: """proceed the actual test returns 0 on success, 1 on failure, 2 on error @@ -529,39 +554,40 @@ class TestCase(unittest.TestCase): def innerSkip(self, msg: str = None) -> NoReturn: """mark a generative test as skipped for the <msg> reason""" - msg = msg or 'test was skipped' + msg = msg or "test was skipped" raise InnerTestSkipped(msg) - if sys.version_info >= (3,2): + if sys.version_info >= (3, 2): assertItemsEqual = unittest.TestCase.assertCountEqual else: assertCountEqual = unittest.TestCase.assertItemsEqual -TestCase.assertItemsEqual = deprecated('assertItemsEqual is deprecated, use assertCountEqual')( - TestCase.assertItemsEqual) + +TestCase.assertItemsEqual = deprecated("assertItemsEqual is deprecated, use assertCountEqual")( + TestCase.assertItemsEqual +) import doctest + class SkippedSuite(unittest.TestSuite): def test(self): """just there to trigger test execution""" - self.skipped_test('doctest module has no DocTestSuite class') + self.skipped_test("doctest module has no DocTestSuite class") class DocTestFinder(doctest.DocTestFinder): - def __init__(self, *args, **kwargs): - self.skipped = kwargs.pop('skipped', ()) + self.skipped = kwargs.pop("skipped", ()) doctest.DocTestFinder.__init__(self, *args, **kwargs) def _get_test(self, obj, name, module, globs, source_lines): """override default _get_test method to be able to skip tests according to skipped attribute's value """ - if getattr(obj, '__name__', '') in self.skipped: + if getattr(obj, "__name__", "") in self.skipped: return None - return doctest.DocTestFinder._get_test(self, obj, name, module, - globs, source_lines) + return doctest.DocTestFinder._get_test(self, obj, name, module, globs, source_lines) # mypy error: Invalid metaclass 'class_deprecated' @@ -571,10 +597,11 @@ class DocTest(TestCase, metaclass=class_deprecated): # type: ignore I don't know how to make unittest.main consider the DocTestSuite instance without this hack """ - __deprecation_warning__ = 'use stdlib doctest module with unittest API directly' + + __deprecation_warning__ = "use stdlib doctest module with unittest API directly" skipped = () - def __call__(self, result=None, runcondition=None, options=None):\ - # pylint: disable=W0613 + + def __call__(self, result=None, runcondition=None, options=None): # pylint: disable=W0613 try: finder = DocTestFinder(skipped=self.skipped) suite = doctest.DocTestSuite(self.module, test_finder=finder) @@ -590,6 +617,7 @@ class DocTest(TestCase, metaclass=class_deprecated): # type: ignore finally: builtins.__dict__.clear() builtins.__dict__.update(old_builtins) + run = __call__ def test(self): @@ -607,21 +635,27 @@ class MockConnection: def cursor(self): """Mock cursor method""" return self + def execute(self, query, args=None): """Mock execute method""" - self.received.append( (query, args) ) + self.received.append((query, args)) + def fetchone(self): """Mock fetchone method""" return self.results[0] + def fetchall(self): """Mock fetchall method""" return self.results + def commit(self): """Mock commiy method""" - self.states.append( ('commit', len(self.received)) ) + self.states.append(("commit", len(self.received))) + def rollback(self): """Mock rollback method""" - self.states.append( ('rollback', len(self.received)) ) + self.states.append(("rollback", len(self.received))) + def close(self): """Mock close method""" pass @@ -629,7 +663,7 @@ class MockConnection: # mypy error: Name 'Mock' is not defined # dynamic class created by this class -def mock_object(**params: Any) -> 'Mock': # type: ignore +def mock_object(**params: Any) -> "Mock": # type: ignore """creates an object using params to set attributes >>> option = mock_object(verbose=False, index=range(5)) >>> option.verbose @@ -637,7 +671,7 @@ def mock_object(**params: Any) -> 'Mock': # type: ignore >>> option.index [0, 1, 2, 3, 4] """ - return type('Mock', (), params)() + return type("Mock", (), params)() def create_files(paths: List[str], chroot: str) -> None: @@ -664,7 +698,7 @@ def create_files(paths: List[str], chroot: str) -> None: path = osp.join(chroot, path) filename = osp.basename(path) # path is a directory path - if filename == '': + if filename == "": dirs.add(path) # path is a filename path else: @@ -674,54 +708,69 @@ def create_files(paths: List[str], chroot: str) -> None: if not osp.isdir(dirpath): os.makedirs(dirpath) for filepath in files: - open(filepath, 'w').close() + open(filepath, "w").close() -class AttrObject: # XXX cf mock_object +class AttrObject: # XXX cf mock_object def __init__(self, **kwargs): self.__dict__.update(kwargs) + def tag(*args: str, **kwargs: Any) -> Callable: """descriptor adding tag to a function""" + def desc(func: Callable) -> Callable: - assert not hasattr(func, 'tags') + assert not hasattr(func, "tags") # mypy: "Callable[..., Any]" has no attribute "tags" # dynamic magic attribute func.tags = Tags(*args, **kwargs) # type: ignore return func + return desc + def require_version(version: str) -> Callable: """ Compare version of python interpreter to the given one. Skip the test if older. """ + def check_require_version(f: Callable) -> Callable: - version_elements = version.split('.') + version_elements = version.split(".") try: compare = tuple([int(v) for v in version_elements]) except ValueError: - raise ValueError('%s is not a correct version : should be X.Y[.Z].' % version) + raise ValueError("%s is not a correct version : should be X.Y[.Z]." % version) current = sys.version_info[:3] if current < compare: + def new_f(self, *args, **kwargs): - self.skipTest('Need at least %s version of python. Current version is %s.' % (version, '.'.join([str(element) for element in current]))) + self.skipTest( + "Need at least %s version of python. Current version is %s." + % (version, ".".join([str(element) for element in current])) + ) + new_f.__name__ = f.__name__ return new_f else: return f + return check_require_version + def require_module(module: str) -> Callable: """ Check if the given module is loaded. Skip the test if not. """ + def check_require_module(f: Callable) -> Callable: try: __import__(module) return f except ImportError: + def new_f(self, *args, **kwargs): - self.skipTest('%s can not be imported.' % module) + self.skipTest("%s can not be imported." % module) + new_f.__name__ = f.__name__ return new_f - return check_require_module + return check_require_module diff --git a/logilab/common/textutils.py b/logilab/common/textutils.py index 4b6ea98..b988c7a 100644 --- a/logilab/common/textutils.py +++ b/logilab/common/textutils.py @@ -50,33 +50,37 @@ from re import Pattern, Match from warnings import warn from unicodedata import normalize as _uninormalize from typing import Any, Optional, Tuple, List, Callable, Dict, Union + try: from os import linesep except ImportError: - linesep = '\n' # gae + linesep = "\n" # gae from logilab.common.deprecation import deprecated MANUAL_UNICODE_MAP = { - u'\xa1': u'!', # INVERTED EXCLAMATION MARK - u'\u0142': u'l', # LATIN SMALL LETTER L WITH STROKE - u'\u2044': u'/', # FRACTION SLASH - u'\xc6': u'AE', # LATIN CAPITAL LETTER AE - u'\xa9': u'(c)', # COPYRIGHT SIGN - u'\xab': u'"', # LEFT-POINTING DOUBLE ANGLE QUOTATION MARK - u'\xe6': u'ae', # LATIN SMALL LETTER AE - u'\xae': u'(r)', # REGISTERED SIGN - u'\u0153': u'oe', # LATIN SMALL LIGATURE OE - u'\u0152': u'OE', # LATIN CAPITAL LIGATURE OE - u'\xd8': u'O', # LATIN CAPITAL LETTER O WITH STROKE - u'\xf8': u'o', # LATIN SMALL LETTER O WITH STROKE - u'\xbb': u'"', # RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK - u'\xdf': u'ss', # LATIN SMALL LETTER SHARP S - u'\u2013': u'-', # HYPHEN - u'\u2019': u"'", # SIMPLE QUOTE - } - -def unormalize(ustring: str, ignorenonascii: Optional[Any] = None, substitute: Optional[str] = None) -> str: + "\xa1": "!", # INVERTED EXCLAMATION MARK + "\u0142": "l", # LATIN SMALL LETTER L WITH STROKE + "\u2044": "/", # FRACTION SLASH + "\xc6": "AE", # LATIN CAPITAL LETTER AE + "\xa9": "(c)", # COPYRIGHT SIGN + "\xab": '"', # LEFT-POINTING DOUBLE ANGLE QUOTATION MARK + "\xe6": "ae", # LATIN SMALL LETTER AE + "\xae": "(r)", # REGISTERED SIGN + "\u0153": "oe", # LATIN SMALL LIGATURE OE + "\u0152": "OE", # LATIN CAPITAL LIGATURE OE + "\xd8": "O", # LATIN CAPITAL LETTER O WITH STROKE + "\xf8": "o", # LATIN SMALL LETTER O WITH STROKE + "\xbb": '"', # RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK + "\xdf": "ss", # LATIN SMALL LETTER SHARP S + "\u2013": "-", # HYPHEN + "\u2019": "'", # SIMPLE QUOTE +} + + +def unormalize( + ustring: str, ignorenonascii: Optional[Any] = None, substitute: Optional[str] = None +) -> str: """replace diacritical characters with their corresponding ascii characters Convert the unicode string to its long normalized form (unicode character @@ -92,22 +96,26 @@ def unormalize(ustring: str, ignorenonascii: Optional[Any] = None, substitute: O """ # backward compatibility, ignorenonascii was a boolean if ignorenonascii is not None: - warn("ignorenonascii is deprecated, use substitute named parameter instead", - DeprecationWarning, stacklevel=2) + warn( + "ignorenonascii is deprecated, use substitute named parameter instead", + DeprecationWarning, + stacklevel=2, + ) if ignorenonascii: - substitute = '' + substitute = "" res = [] for letter in ustring[:]: try: replacement = MANUAL_UNICODE_MAP[letter] except KeyError: - replacement = _uninormalize('NFKD', letter)[0] + replacement = _uninormalize("NFKD", letter)[0] if ord(replacement) >= 2 ** 7: if substitute is None: raise ValueError("can't deal with non-ascii based characters") replacement = substitute res.append(replacement) - return u''.join(res) + return "".join(res) + def unquote(string: str) -> str: """remove optional quotes (simple or double) from the string @@ -120,17 +128,18 @@ def unquote(string: str) -> str: """ if not string: return string - if string[0] in '"\'': + if string[0] in "\"'": string = string[1:] - if string[-1] in '"\'': + if string[-1] in "\"'": string = string[:-1] return string -_BLANKLINES_RGX = re.compile('\r?\n\r?\n') -_NORM_SPACES_RGX = re.compile('\s+') +_BLANKLINES_RGX = re.compile("\r?\n\r?\n") +_NORM_SPACES_RGX = re.compile("\s+") + -def normalize_text(text: str, line_len: int = 80, indent: str = '', rest: bool = False) -> str: +def normalize_text(text: str, line_len: int = 80, indent: str = "", rest: bool = False) -> str: """normalize a text to display it with a maximum line size and optionally arbitrary indentation. Line jumps are normalized but blank lines are kept. The indentation string may be used to insert a @@ -158,10 +167,10 @@ def normalize_text(text: str, line_len: int = 80, indent: str = '', rest: bool = result = [] for text in _BLANKLINES_RGX.split(text): result.append(normp(text, line_len, indent)) - return ('%s%s%s' % (linesep, indent, linesep)).join(result) + return ("%s%s%s" % (linesep, indent, linesep)).join(result) -def normalize_paragraph(text: str, line_len: int = 80, indent: str = '') -> str: +def normalize_paragraph(text: str, line_len: int = 80, indent: str = "") -> str: """normalize a text to display it with a maximum line size and optionally arbitrary indentation. Line jumps are normalized. The indentation string may be used top insert a comment mark for @@ -182,7 +191,7 @@ def normalize_paragraph(text: str, line_len: int = 80, indent: str = '') -> str: inferior to `line_len`, and optionally prefixed by an indentation string """ - text = _NORM_SPACES_RGX.sub(' ', text) + text = _NORM_SPACES_RGX.sub(" ", text) line_len = line_len - len(indent) lines = [] while text: @@ -190,7 +199,8 @@ def normalize_paragraph(text: str, line_len: int = 80, indent: str = '') -> str: lines.append(indent + aline) return linesep.join(lines) -def normalize_rest_paragraph(text: str, line_len: int = 80, indent: str = '') -> str: + +def normalize_rest_paragraph(text: str, line_len: int = 80, indent: str = "") -> str: """normalize a ReST text to display it with a maximum line size and optionally arbitrary indentation. Line jumps are normalized. The indentation string may be used top insert a comment mark for @@ -211,21 +221,21 @@ def normalize_rest_paragraph(text: str, line_len: int = 80, indent: str = '') -> inferior to `line_len`, and optionally prefixed by an indentation string """ - toreport = '' + toreport = "" lines = [] line_len = line_len - len(indent) for line in text.splitlines(): - line = toreport + _NORM_SPACES_RGX.sub(' ', line.strip()) - toreport = '' + line = toreport + _NORM_SPACES_RGX.sub(" ", line.strip()) + toreport = "" while len(line) > line_len: # too long line, need split line, toreport = splittext(line, line_len) lines.append(indent + line) if toreport: - line = toreport + ' ' - toreport = '' + line = toreport + " " + toreport = "" else: - line = '' + line = "" if line: lines.append(indent + line.strip()) return linesep.join(lines) @@ -239,18 +249,18 @@ def splittext(text: str, line_len: int) -> Tuple[str, str]: * the rest of the text which has to be reported on another line """ if len(text) <= line_len: - return text, '' - pos = min(len(text)-1, line_len) - while pos > 0 and text[pos] != ' ': + return text, "" + pos = min(len(text) - 1, line_len) + while pos > 0 and text[pos] != " ": pos -= 1 if pos == 0: pos = min(len(text), line_len) - while len(text) > pos and text[pos] != ' ': + while len(text) > pos and text[pos] != " ": pos += 1 - return text[:pos], text[pos+1:].strip() + return text[:pos], text[pos + 1 :].strip() -def splitstrip(string: str, sep: str = ',') -> List[str]: +def splitstrip(string: str, sep: str = ",") -> List[str]: """return a list of stripped string by splitting the string given as argument on `sep` (',' by default). Empty string are discarded. @@ -271,15 +281,16 @@ def splitstrip(string: str, sep: str = ',') -> List[str]: """ return [word.strip() for word in string.split(sep) if word.strip()] -get_csv = deprecated('get_csv is deprecated, use splitstrip')(splitstrip) + +get_csv = deprecated("get_csv is deprecated, use splitstrip")(splitstrip) def split_url_or_path(url_or_path): """return the latest component of a string containing either an url of the form <scheme>://<path> or a local file system path """ - if '://' in url_or_path: - return url_or_path.rstrip('/').rsplit('/', 1) + if "://" in url_or_path: + return url_or_path.rstrip("/").rsplit("/", 1) return osp.split(url_or_path.rstrip(osp.sep)) @@ -303,8 +314,8 @@ def text_to_dict(text): return res for line in text.splitlines(): line = line.strip() - if line and not line.startswith('#'): - key, value = [w.strip() for w in line.split('=', 1)] + if line and not line.startswith("#"): + key, value = [w.strip() for w in line.split("=", 1)] if key in res: try: res[key].append(value) @@ -315,13 +326,12 @@ def text_to_dict(text): return res -_BLANK_URE = r'(\s|,)+' +_BLANK_URE = r"(\s|,)+" _BLANK_RE = re.compile(_BLANK_URE) -__VALUE_URE = r'-?(([0-9]+\.[0-9]*)|((0x?)?[0-9]+))' -__UNITS_URE = r'[a-zA-Z]+' -_VALUE_RE = re.compile(r'(?P<value>%s)(?P<unit>%s)?'%(__VALUE_URE, __UNITS_URE)) -_VALIDATION_RE = re.compile(r'^((%s)(%s))*(%s)?$' % (__VALUE_URE, __UNITS_URE, - __VALUE_URE)) +__VALUE_URE = r"-?(([0-9]+\.[0-9]*)|((0x?)?[0-9]+))" +__UNITS_URE = r"[a-zA-Z]+" +_VALUE_RE = re.compile(r"(?P<value>%s)(?P<unit>%s)?" % (__VALUE_URE, __UNITS_URE)) +_VALIDATION_RE = re.compile(r"^((%s)(%s))*(%s)?$" % (__VALUE_URE, __UNITS_URE, __VALUE_URE)) BYTE_UNITS = { "b": 1, @@ -336,11 +346,18 @@ TIME_UNITS = { "s": 1, "min": 60, "h": 60 * 60, - "d": 60 * 60 *24, + "d": 60 * 60 * 24, } -def apply_units(string: str, units: Dict[str, int], inter: Union[Callable, None, type] = None, final: type = float, blank_reg: Pattern = _BLANK_RE, - value_reg: Pattern = _VALUE_RE) -> Union[float, int]: + +def apply_units( + string: str, + units: Dict[str, int], + inter: Union[Callable, None, type] = None, + final: type = float, + blank_reg: Pattern = _BLANK_RE, + value_reg: Pattern = _VALUE_RE, +) -> Union[float, int]: """Parse the string applying the units defined in units (e.g.: "1.5m",{'m',60} -> 80). @@ -361,7 +378,7 @@ def apply_units(string: str, units: Dict[str, int], inter: Union[Callable, None, """ if inter is None: inter = final - fstring = _BLANK_RE.sub('', string) + fstring = _BLANK_RE.sub("", string) if not (fstring and _VALIDATION_RE.match(fstring)): raise ValueError("Invalid unit string: %r." % string) values = [] @@ -373,15 +390,15 @@ def apply_units(string: str, units: Dict[str, int], inter: Union[Callable, None, try: value *= units[unit.lower()] except KeyError: - raise ValueError('invalid unit %s. valid units are %s' % - (unit, list(units.keys()))) + raise ValueError("invalid unit %s. valid units are %s" % (unit, list(units.keys()))) values.append(value) return final(sum(values)) -_LINE_RGX = re.compile('\r\n|\r+|\n') +_LINE_RGX = re.compile("\r\n|\r+|\n") + -def pretty_match(match: Match, string: str, underline_char: str = '^') -> str: +def pretty_match(match: Match, string: str, underline_char: str = "^") -> str: """return a string with the match location underlined: >>> import re @@ -419,7 +436,7 @@ def pretty_match(match: Match, string: str, underline_char: str = '^') -> str: result = [string[:start_line_pos]] start_line_pos += len(linesep) offset = start - start_line_pos - underline = ' ' * offset + underline_char * (end - start) + underline = " " * offset + underline_char * (end - start) end_line_pos = string.find(linesep, end) if end_line_pos == -1: string = string[start_line_pos:] @@ -429,7 +446,7 @@ def pretty_match(match: Match, string: str, underline_char: str = '^') -> str: # mypy: Incompatible types in assignment (expression has type "str", # mypy: variable has type "int") # but it's a str :| - end = string[end_line_pos + len(linesep):] # type: ignore + end = string[end_line_pos + len(linesep) :] # type: ignore string = string[start_line_pos:end_line_pos] result.append(string) result.append(underline) @@ -439,30 +456,31 @@ def pretty_match(match: Match, string: str, underline_char: str = '^') -> str: # Ansi colorization ########################################################### -ANSI_PREFIX = '\033[' -ANSI_END = 'm' -ANSI_RESET = '\033[0m' +ANSI_PREFIX = "\033[" +ANSI_END = "m" +ANSI_RESET = "\033[0m" ANSI_STYLES = { - 'reset': "0", - 'bold': "1", - 'italic': "3", - 'underline': "4", - 'blink': "5", - 'inverse': "7", - 'strike': "9", + "reset": "0", + "bold": "1", + "italic": "3", + "underline": "4", + "blink": "5", + "inverse": "7", + "strike": "9", } ANSI_COLORS = { - 'reset': "0", - 'black': "30", - 'red': "31", - 'green': "32", - 'yellow': "33", - 'blue': "34", - 'magenta': "35", - 'cyan': "36", - 'white': "37", + "reset": "0", + "black": "30", + "red": "31", + "green": "32", + "yellow": "33", + "blue": "34", + "magenta": "35", + "cyan": "36", + "white": "37", } + def _get_ansi_code(color: Optional[str] = None, style: Optional[str] = None) -> str: """return ansi escape code corresponding to color and style @@ -488,13 +506,14 @@ def _get_ansi_code(color: Optional[str] = None, style: Optional[str] = None) -> ansi_code.append(ANSI_STYLES[effect]) if color: if color.isdigit(): - ansi_code.extend(['38', '5']) + ansi_code.extend(["38", "5"]) ansi_code.append(color) else: ansi_code.append(ANSI_COLORS[color]) if ansi_code: - return ANSI_PREFIX + ';'.join(ansi_code) + ANSI_END - return '' + return ANSI_PREFIX + ";".join(ansi_code) + ANSI_END + return "" + def colorize_ansi(msg: str, color: Optional[str] = None, style: Optional[str] = None) -> str: """colorize message by wrapping it with ansi escape codes @@ -522,23 +541,24 @@ def colorize_ansi(msg: str, color: Optional[str] = None, style: Optional[str] = escape_code = _get_ansi_code(color, style) # If invalid (or unknown) color, don't wrap msg with ansi codes if escape_code: - return '%s%s%s' % (escape_code, msg, ANSI_RESET) + return "%s%s%s" % (escape_code, msg, ANSI_RESET) return msg -DIFF_STYLE = {'separator': 'cyan', 'remove': 'red', 'add': 'green'} + +DIFF_STYLE = {"separator": "cyan", "remove": "red", "add": "green"} + def diff_colorize_ansi(lines, out=sys.stdout, style=DIFF_STYLE): for line in lines: - if line[:4] in ('--- ', '+++ '): - out.write(colorize_ansi(line, style['separator'])) - elif line[0] == '-': - out.write(colorize_ansi(line, style['remove'])) - elif line[0] == '+': - out.write(colorize_ansi(line, style['add'])) - elif line[:4] == '--- ': - out.write(colorize_ansi(line, style['separator'])) - elif line[:4] == '+++ ': - out.write(colorize_ansi(line, style['separator'])) + if line[:4] in ("--- ", "+++ "): + out.write(colorize_ansi(line, style["separator"])) + elif line[0] == "-": + out.write(colorize_ansi(line, style["remove"])) + elif line[0] == "+": + out.write(colorize_ansi(line, style["add"])) + elif line[:4] == "--- ": + out.write(colorize_ansi(line, style["separator"])) + elif line[:4] == "+++ ": + out.write(colorize_ansi(line, style["separator"])) else: out.write(line) - diff --git a/logilab/common/tree.py b/logilab/common/tree.py index 1fc5a21..dbde2eb 100644 --- a/logilab/common/tree.py +++ b/logilab/common/tree.py @@ -32,9 +32,11 @@ from typing import Optional, Any, Union, List, Callable, TypeVar ## Exceptions ################################################################# + class NodeNotFound(Exception): """raised when a node has not been found""" + EX_SIBLING_NOT_FOUND: str = "No such sibling as '%s'" EX_CHILD_NOT_FOUND: str = "No such child as '%s'" EX_NODE_NOT_FOUND: str = "No such node as '%s'" @@ -49,7 +51,7 @@ NodeType = Any class Node(object): """a basic tree node, characterized by an id""" - def __init__(self, nid: Optional[str] = None) -> None : + def __init__(self, nid: Optional[str] = None) -> None: self.id = nid # navigation # should be something like Optional[type(self)] for subclasses but that's not possible? @@ -61,14 +63,14 @@ class Node(object): return iter(self.children) def __str__(self, indent=0): - s = ['%s%s %s' % (' '*indent, self.__class__.__name__, self.id)] + s = ["%s%s %s" % (" " * indent, self.__class__.__name__, self.id)] indent += 2 for child in self.children: try: s.append(child.__str__(indent)) except TypeError: s.append(child.__str__()) - return '\n'.join(s) + return "\n".join(s) def is_leaf(self): return not self.children @@ -103,7 +105,7 @@ class Node(object): try: assert self.parent is not None return self.parent.get_child_by_id(nid) - except NodeNotFound : + except NodeNotFound: raise NodeNotFound(EX_SIBLING_NOT_FOUND % nid) def next_sibling(self): @@ -116,7 +118,7 @@ class Node(object): return None index = parent.children.index(self) try: - return parent.children[index+1] + return parent.children[index + 1] except IndexError: return None @@ -130,7 +132,7 @@ class Node(object): return None index = parent.children.index(self) if index > 0: - return parent.children[index-1] + return parent.children[index - 1] return None def get_node_by_id(self, nid: str) -> NodeType: @@ -140,7 +142,7 @@ class Node(object): root = self.root() try: return root.get_child_by_id(nid, 1) - except NodeNotFound : + except NodeNotFound: raise NodeNotFound(EX_NODE_NOT_FOUND % nid) def get_child_by_id(self, nid: str, recurse: Optional[bool] = None) -> NodeType: @@ -149,13 +151,13 @@ class Node(object): """ if self.id == nid: return self - for c in self.children : + for c in self.children: if recurse: try: return c.get_child_by_id(nid, 1) - except NodeNotFound : + except NodeNotFound: continue - if c.id == nid : + if c.id == nid: return c raise NodeNotFound(EX_CHILD_NOT_FOUND % nid) @@ -164,13 +166,13 @@ class Node(object): return child of given path (path is a list of ids) """ if len(path) > 0 and path[0] == self.id: - if len(path) == 1 : + if len(path) == 1: return self - else : - for c in self.children : + else: + for c in self.children: try: return c.get_child_by_path(path[1:]) - except NodeNotFound : + except NodeNotFound: pass raise NodeNotFound(EX_CHILD_NOT_FOUND % path) @@ -180,7 +182,7 @@ class Node(object): """ if self.parent is not None: return 1 + self.parent.depth() - else : + else: return 0 def depth_down(self) -> int: @@ -237,6 +239,7 @@ class Node(object): lst.extend(self.parent.lineage()) return lst + class VNode(Node, VisitedMixIn): # we should probably merge this VisitedMixIn here because it's only used here """a visitable node @@ -247,7 +250,8 @@ class VNode(Node, VisitedMixIn): class BinaryNode(VNode): """a binary node (i.e. only two children """ - def __init__(self, lhs=None, rhs=None) : + + def __init__(self, lhs=None, rhs=None): VNode.__init__(self) if lhs is not None or rhs is not None: assert lhs and rhs @@ -267,24 +271,29 @@ class BinaryNode(VNode): return self.children[0], self.children[1] - if sys.version_info[0:2] >= (2, 2): list_class = list else: from UserList import UserList + list_class = UserList + class ListNode(VNode, list_class): """Used to manipulate Nodes as Lists """ + def __init__(self): list_class.__init__(self) VNode.__init__(self) self.children = self def __str__(self, indent=0): - return '%s%s %s' % (indent*' ', self.__class__.__name__, - ', '.join([str(v) for v in self])) + return "%s%s %s" % ( + indent * " ", + self.__class__.__name__, + ", ".join([str(v) for v in self]), + ) def append(self, child): """add a node to children""" @@ -309,8 +318,10 @@ class ListNode(VNode, list_class): def __iter__(self): return list_class.__iter__(self) + # construct list from tree #################################################### + def post_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> List[Node]: """ create a list with tree nodes for which the <filter> function returned true @@ -339,6 +350,7 @@ def post_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> poped = 1 return l + def pre_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> List[Node]: """ create a list with tree nodes for which the <filter> function returned true @@ -368,15 +380,18 @@ def pre_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> L poped = 1 return l + class PostfixedDepthFirstIterator(FilteredIterator): """a postfixed depth first iterator, designed to be used with visitors """ + def __init__(self, node: Node, filter_func: Optional[Any] = None) -> None: FilteredIterator.__init__(self, node, post_order_list, filter_func) + class PrefixedDepthFirstIterator(FilteredIterator): """a prefixed depth first iterator, designed to be used with visitors """ + def __init__(self, node: Node, filter_func: Optional[Any] = None) -> None: FilteredIterator.__init__(self, node, pre_order_list, filter_func) - diff --git a/logilab/common/umessage.py b/logilab/common/umessage.py index 77a6272..a759003 100644 --- a/logilab/common/umessage.py +++ b/logilab/common/umessage.py @@ -40,14 +40,14 @@ import logilab.common as lgc def decode_QP(string: str) -> str: parts: List[str] = [] for maybe_decoded, charset in decode_header(string): - if not charset : - charset = 'iso-8859-15' + if not charset: + charset = "iso-8859-15" # python 3 sometimes returns str and sometimes bytes. # the 'official' fix is to use the new 'policy' APIs # https://bugs.python.org/issue24797 # let's just handle this bug ourselves for now if isinstance(maybe_decoded, bytes): - decoded = maybe_decoded.decode(charset, 'replace') + decoded = maybe_decoded.decode(charset, "replace") else: decoded = maybe_decoded @@ -57,21 +57,24 @@ def decode_QP(string: str) -> str: if sys.version_info < (3, 3): # decoding was non-RFC compliant wrt to whitespace handling # see http://bugs.python.org/issue1079 - return u' '.join(parts) + return " ".join(parts) + + return "".join(parts) - return u''.join(parts) def message_from_file(fd): try: return UMessage(email.message_from_file(fd)) except email.errors.MessageParseError: - return '' + return "" + -def message_from_string(string: str) -> Union['UMessage', str]: +def message_from_string(string: str) -> Union["UMessage", str]: try: return UMessage(email.message_from_string(string)) except email.errors.MessageParseError: - return '' + return "" + class UMessage: """Encapsulates an email.Message instance and returns only unicode objects. @@ -92,8 +95,7 @@ class UMessage: return self.get(header) def get_all(self, header: str, default: Tuple[()] = ()) -> List[str]: - return [decode_QP(val) for val in self.message.get_all(header, default) - if val is not None] + return [decode_QP(val) for val in self.message.get_all(header, default) if val is not None] def is_multipart(self): return self.message.is_multipart() @@ -105,7 +107,9 @@ class UMessage: for part in self.message.walk(): yield UMessage(part) - def get_payload(self, index: Optional[Any] = None, decode: bool = False) -> Union[str, 'UMessage', List['UMessage']]: + def get_payload( + self, index: Optional[Any] = None, decode: bool = False + ) -> Union[str, "UMessage", List["UMessage"]]: message = self.message if index is None: @@ -118,17 +122,17 @@ class UMessage: if isinstance(payload, list): return [UMessage(msg) for msg in payload] - if message.get_content_maintype() != 'text': + if message.get_content_maintype() != "text": return payload if isinstance(payload, str): return payload - charset = message.get_content_charset() or 'iso-8859-1' + charset = message.get_content_charset() or "iso-8859-1" if search_function(charset) is None: - charset = 'iso-8859-1' + charset = "iso-8859-1" - return str(payload or b'', charset, "replace") + return str(payload or b"", charset, "replace") else: payload = UMessage(message.get_payload(index, decode)) @@ -147,7 +151,7 @@ class UMessage: try: return str(value) except UnicodeDecodeError: - return u'error decoding filename' + return "error decoding filename" # other convenience methods ############################################### @@ -155,8 +159,8 @@ class UMessage: """return an unicode string containing all the message's headers""" values = [] for header in self.message.keys(): - values.append(u'%s: %s' % (header, self.get(header))) - return '\n'.join(values) + values.append("%s: %s" % (header, self.get(header))) + return "\n".join(values) def multi_addrs(self, header): """return a list of 2-uple (name, address) for the given address (which @@ -172,7 +176,7 @@ class UMessage: """return a datetime object for the email's date or None if no date is set or if it can't be parsed """ - value = self.get('date') + value = self.get("date") if value is None and alternative_source: unix_from = self.message.get_unixfrom() if unix_from is not None: diff --git a/logilab/common/ureports/__init__.py b/logilab/common/ureports/__init__.py index 9c0f1df..a539150 100644 --- a/logilab/common/ureports/__init__.py +++ b/logilab/common/ureports/__init__.py @@ -46,14 +46,14 @@ def layout_title(layout): """ for child in layout.children: if isinstance(child, Title): - return u' '.join([node.data for node in get_nodes(child, Text)]) + return " ".join([node.data for node in get_nodes(child, Text)]) def build_summary(layout, level=1): """make a summary for the report, including X level""" assert level > 0 level -= 1 - summary = List(klass=u'summary') + summary = List(klass="summary") for child in layout.children: if not isinstance(child, Section): continue @@ -61,8 +61,8 @@ def build_summary(layout, level=1): if not label and not child.id: continue if not child.id: - child.id = label.replace(' ', '-') - node = Link(u'#'+child.id, label=label or child.id) + child.id = label.replace(" ", "-") + node = Link("#" + child.id, label=label or child.id) # FIXME: Three following lines produce not very compliant # docbook: there are some useless <para><para>. They might be # replaced by the three commented lines but this then produces @@ -70,16 +70,21 @@ def build_summary(layout, level=1): if level and [n for n in child.children if isinstance(n, Section)]: node = Paragraph([node, build_summary(child, level)]) summary.append(node) -# summary.append(node) -# if level and [n for n in child.children if isinstance(n, Section)]: -# summary.append(build_summary(child, level)) + # summary.append(node) + # if level and [n for n in child.children if isinstance(n, Section)]: + # summary.append(build_summary(child, level)) return summary class BaseWriter(object): """base class for ureport writers""" - def format(self, layout: Any, stream: Optional[Union[StringIO, TextIO]] = None, encoding: Optional[Any] = None) -> None: + def format( + self, + layout: Any, + stream: Optional[Union[StringIO, TextIO]] = None, + encoding: Optional[Any] = None, + ) -> None: """format and write the given layout into the stream object unicode policy: unicode strings may be found in the layout; @@ -89,22 +94,22 @@ class BaseWriter(object): if stream is None: stream = sys.stdout if not encoding: - encoding = getattr(stream, 'encoding', 'UTF-8') - self.encoding = encoding or 'UTF-8' + encoding = getattr(stream, "encoding", "UTF-8") + self.encoding = encoding or "UTF-8" self.__compute_funcs: List[Tuple[Callable[[str], Any], Callable[[str], Any]]] = [] self.out = stream self.begin_format(layout) layout.accept(self) self.end_format(layout) - def format_children(self, layout: Union['Paragraph', 'Section', 'Title']) -> None: + def format_children(self, layout: Union["Paragraph", "Section", "Title"]) -> None: """recurse on the layout children and call their accept method (see the Visitor pattern) """ - for child in getattr(layout, 'children', ()): + for child in getattr(layout, "children", ()): child.accept(self) - def writeln(self, string: str = u'') -> None: + def writeln(self, string: str = "") -> None: """write a line in the output buffer""" self.write(string + linesep) @@ -146,7 +151,7 @@ class BaseWriter(object): # fill missing cells while len(result[-1]) < cols: - result[-1].append(u'') + result[-1].append("") return result @@ -166,13 +171,13 @@ class BaseWriter(object): # error from porting to python3? stream.write(data.encode(self.encoding)) # type: ignore - def writeln(data: str = u'') -> None: + def writeln(data: str = "") -> None: try: - stream.write(data+linesep) + stream.write(data + linesep) except UnicodeEncodeError: # mypy: Unsupported operand types for + ("bytes" and "str") # error from porting to python3? - stream.write(data.encode(self.encoding)+linesep) # type: ignore + stream.write(data.encode(self.encoding) + linesep) # type: ignore # mypy: Cannot assign to a method # this really looks like black dirty magic since self.write is reused elsewhere in the code @@ -202,6 +207,7 @@ class BaseWriter(object): del self.write del self.writeln + # mypy error: Incompatible import of "Table" (imported name has type # mypy error: "Type[logilab.common.ureports.nodes.Table]", local name has type # mypy error: "Type[logilab.common.table.Table]") diff --git a/logilab/common/ureports/docbook_writer.py b/logilab/common/ureports/docbook_writer.py index f28474e..7e7564f 100644 --- a/logilab/common/ureports/docbook_writer.py +++ b/logilab/common/ureports/docbook_writer.py @@ -29,15 +29,17 @@ class DocbookWriter(HTMLWriter): super(HTMLWriter, self).begin_format(layout) if self.snippet is None: self.writeln('<?xml version="1.0" encoding="ISO-8859-1"?>') - self.writeln(""" + self.writeln( + """ <book xmlns:xi='http://www.w3.org/2001/XInclude' lang='fr'> -""") +""" + ) def end_format(self, layout): """finished to format a layout""" if self.snippet is None: - self.writeln('</book>') + self.writeln("</book>") def visit_section(self, layout): """display a section (using <chapter> (level 0) or <section>)""" @@ -46,98 +48,95 @@ class DocbookWriter(HTMLWriter): else: tag = "section" self.section += 1 - self.writeln(self._indent('<%s%s>' % (tag, self.handle_attrs(layout)))) + self.writeln(self._indent("<%s%s>" % (tag, self.handle_attrs(layout)))) self.format_children(layout) - self.writeln(self._indent('</%s>' % tag)) + self.writeln(self._indent("</%s>" % tag)) self.section -= 1 def visit_title(self, layout): """display a title using <title>""" - self.write(self._indent(' <title%s>' % self.handle_attrs(layout))) + self.write(self._indent(" <title%s>" % self.handle_attrs(layout))) self.format_children(layout) - self.writeln('</title>') + self.writeln("</title>") def visit_table(self, layout): """display a table as html""" self.writeln( - self._indent(' <table%s><title>%s</title>' % ( - self.handle_attrs(layout), layout.title))) + self._indent(" <table%s><title>%s</title>" % (self.handle_attrs(layout), layout.title)) + ) self.writeln(self._indent(' <tgroup cols="%s">' % layout.cols)) for i in range(layout.cols): - self.writeln( - self._indent( - ' <colspec colname="c%s" colwidth="1*"/>' % i)) + self.writeln(self._indent(' <colspec colname="c%s" colwidth="1*"/>' % i)) table_content = self.get_table_content(layout) # write headers if layout.cheaders: - self.writeln(self._indent(' <thead>')) + self.writeln(self._indent(" <thead>")) self._write_row(table_content[0]) - self.writeln(self._indent(' </thead>')) + self.writeln(self._indent(" </thead>")) table_content = table_content[1:] elif layout.rcheaders: - self.writeln(self._indent(' <thead>')) + self.writeln(self._indent(" <thead>")) self._write_row(table_content[-1]) - self.writeln(self._indent(' </thead>')) + self.writeln(self._indent(" </thead>")) table_content = table_content[:-1] # write body - self.writeln(self._indent(' <tbody>')) + self.writeln(self._indent(" <tbody>")) for i in range(len(table_content)): row = table_content[i] - self.writeln(self._indent(' <row>')) + self.writeln(self._indent(" <row>")) for j in range(len(row)): - cell = row[j] or ' ' - self.writeln( - self._indent(' <entry>%s</entry>' % cell)) - self.writeln(self._indent(' </row>')) - self.writeln(self._indent(' </tbody>')) - self.writeln(self._indent(' </tgroup>')) - self.writeln(self._indent(' </table>')) + cell = row[j] or " " + self.writeln(self._indent(" <entry>%s</entry>" % cell)) + self.writeln(self._indent(" </row>")) + self.writeln(self._indent(" </tbody>")) + self.writeln(self._indent(" </tgroup>")) + self.writeln(self._indent(" </table>")) def _write_row(self, row): """write content of row (using <row> <entry>)""" - self.writeln(' <row>') + self.writeln(" <row>") for j in range(len(row)): - cell = row[j] or ' ' - self.writeln(' <entry>%s</entry>' % cell) - self.writeln(self._indent(' </row>')) + cell = row[j] or " " + self.writeln(" <entry>%s</entry>" % cell) + self.writeln(self._indent(" </row>")) def visit_list(self, layout): """display a list (using <itemizedlist>)""" - self.writeln(self._indent(' <itemizedlist%s>' - '' % self.handle_attrs(layout))) + self.writeln(self._indent(" <itemizedlist%s>" "" % self.handle_attrs(layout))) for row in list(self.compute_content(layout)): - self.writeln(' <listitem><para>%s</para></listitem>' % row) - self.writeln(self._indent(' </itemizedlist>')) + self.writeln(" <listitem><para>%s</para></listitem>" % row) + self.writeln(self._indent(" </itemizedlist>")) def visit_paragraph(self, layout): """display links (using <para>)""" - self.write(self._indent(' <para>')) + self.write(self._indent(" <para>")) self.format_children(layout) - self.writeln('</para>') + self.writeln("</para>") def visit_span(self, layout): """display links (using <p>)""" # TODO: translate in docbook - self.write('<literal %s>' % self.handle_attrs(layout)) + self.write("<literal %s>" % self.handle_attrs(layout)) self.format_children(layout) - self.write('</literal>') + self.write("</literal>") def visit_link(self, layout): """display links (using <ulink>)""" - self.write('<ulink url="%s"%s>%s</ulink>' % ( - layout.url, self.handle_attrs(layout), layout.label)) + self.write( + '<ulink url="%s"%s>%s</ulink>' % (layout.url, self.handle_attrs(layout), layout.label) + ) def visit_verbatimtext(self, layout): """display verbatim text (using <programlisting>)""" - self.writeln(self._indent(' <programlisting>')) - self.write(layout.data.replace('&', '&').replace('<', '<')) - self.writeln(self._indent(' </programlisting>')) + self.writeln(self._indent(" <programlisting>")) + self.write(layout.data.replace("&", "&").replace("<", "<")) + self.writeln(self._indent(" </programlisting>")) def visit_text(self, layout): """add some text""" - self.write(layout.data.replace('&', '&').replace('<', '<')) + self.write(layout.data.replace("&", "&").replace("<", "<")) def _indent(self, string): """correctly indent string according to section""" - return ' ' * 2*(self.section) + string + return " " * 2 * (self.section) + string diff --git a/logilab/common/ureports/html_writer.py b/logilab/common/ureports/html_writer.py index 0783075..23ff588 100644 --- a/logilab/common/ureports/html_writer.py +++ b/logilab/common/ureports/html_writer.py @@ -20,8 +20,16 @@ __docformat__ = "restructuredtext en" from logilab.common.ureports import BaseWriter -from logilab.common.ureports.nodes import (Section, Title, Table, List, - Paragraph, Link, VerbatimText, Text) +from logilab.common.ureports.nodes import ( + Section, + Title, + Table, + List, + Paragraph, + Link, + VerbatimText, + Text, +) from typing import Any @@ -34,100 +42,100 @@ class HTMLWriter(BaseWriter): def handle_attrs(self, layout: Any) -> str: """get an attribute string from layout member attributes""" - attrs = u'' - klass = getattr(layout, 'klass', None) + attrs = "" + klass = getattr(layout, "klass", None) if klass: - attrs += u' class="%s"' % klass - nid = getattr(layout, 'id', None) + attrs += ' class="%s"' % klass + nid = getattr(layout, "id", None) if nid: - attrs += u' id="%s"' % nid + attrs += ' id="%s"' % nid return attrs def begin_format(self, layout: Any) -> None: """begin to format a layout""" super(HTMLWriter, self).begin_format(layout) if self.snippet is None: - self.writeln(u'<html>') - self.writeln(u'<body>') + self.writeln("<html>") + self.writeln("<body>") def end_format(self, layout: Any) -> None: """finished to format a layout""" if self.snippet is None: - self.writeln(u'</body>') - self.writeln(u'</html>') + self.writeln("</body>") + self.writeln("</html>") def visit_section(self, layout: Section) -> None: """display a section as html, using div + h[section level]""" self.section += 1 - self.writeln(u'<div%s>' % self.handle_attrs(layout)) + self.writeln("<div%s>" % self.handle_attrs(layout)) self.format_children(layout) - self.writeln(u'</div>') + self.writeln("</div>") self.section -= 1 def visit_title(self, layout: Title) -> None: """display a title using <hX>""" - self.write(u'<h%s%s>' % (self.section, self.handle_attrs(layout))) + self.write("<h%s%s>" % (self.section, self.handle_attrs(layout))) self.format_children(layout) - self.writeln(u'</h%s>' % self.section) + self.writeln("</h%s>" % self.section) def visit_table(self, layout: Table) -> None: """display a table as html""" - self.writeln(u'<table%s>' % self.handle_attrs(layout)) + self.writeln("<table%s>" % self.handle_attrs(layout)) table_content = self.get_table_content(layout) for i in range(len(table_content)): row = table_content[i] if i == 0 and layout.rheaders: - self.writeln(u'<tr class="header">') - elif i+1 == len(table_content) and layout.rrheaders: - self.writeln(u'<tr class="header">') + self.writeln('<tr class="header">') + elif i + 1 == len(table_content) and layout.rrheaders: + self.writeln('<tr class="header">') else: - self.writeln(u'<tr class="%s">' % (i % 2 and 'even' or 'odd')) + self.writeln('<tr class="%s">' % (i % 2 and "even" or "odd")) for j in range(len(row)): - cell = row[j] or u' ' - if (layout.rheaders and i == 0) or \ - (layout.cheaders and j == 0) or \ - (layout.rrheaders and i+1 == len(table_content)) or \ - (layout.rcheaders and j+1 == len(row)): - self.writeln(u'<th>%s</th>' % cell) + cell = row[j] or " " + if ( + (layout.rheaders and i == 0) + or (layout.cheaders and j == 0) + or (layout.rrheaders and i + 1 == len(table_content)) + or (layout.rcheaders and j + 1 == len(row)) + ): + self.writeln("<th>%s</th>" % cell) else: - self.writeln(u'<td>%s</td>' % cell) - self.writeln(u'</tr>') - self.writeln(u'</table>') + self.writeln("<td>%s</td>" % cell) + self.writeln("</tr>") + self.writeln("</table>") def visit_list(self, layout: List) -> None: """display a list as html""" - self.writeln(u'<ul%s>' % self.handle_attrs(layout)) + self.writeln("<ul%s>" % self.handle_attrs(layout)) for row in list(self.compute_content(layout)): - self.writeln(u'<li>%s</li>' % row) - self.writeln(u'</ul>') + self.writeln("<li>%s</li>" % row) + self.writeln("</ul>") def visit_paragraph(self, layout: Paragraph) -> None: """display links (using <p>)""" - self.write(u'<p>') + self.write("<p>") self.format_children(layout) - self.write(u'</p>') + self.write("</p>") def visit_span(self, layout): """display links (using <p>)""" - self.write(u'<span%s>' % self.handle_attrs(layout)) + self.write("<span%s>" % self.handle_attrs(layout)) self.format_children(layout) - self.write(u'</span>') + self.write("</span>") def visit_link(self, layout: Link) -> None: """display links (using <a>)""" - self.write(u' <a href="%s"%s>%s</a>' % (layout.url, - self.handle_attrs(layout), - layout.label)) + self.write(' <a href="%s"%s>%s</a>' % (layout.url, self.handle_attrs(layout), layout.label)) def visit_verbatimtext(self, layout: VerbatimText) -> None: """display verbatim text (using <pre>)""" - self.write(u'<pre>') - self.write(layout.data.replace(u'&', u'&').replace(u'<', u'<')) - self.write(u'</pre>') + self.write("<pre>") + self.write(layout.data.replace("&", "&").replace("<", "<")) + self.write("</pre>") def visit_text(self, layout: Text) -> None: """add some text""" data = layout.data if layout.escaped: - data = data.replace(u'&', u'&').replace(u'<', u'<') + data = data.replace("&", "&").replace("<", "<") self.write(data) diff --git a/logilab/common/ureports/nodes.py b/logilab/common/ureports/nodes.py index d086faf..26c6715 100644 --- a/logilab/common/ureports/nodes.py +++ b/logilab/common/ureports/nodes.py @@ -23,6 +23,7 @@ __docformat__ = "restructuredtext en" from logilab.common.tree import VNode from typing import Optional + # from logilab.common.ureports.nodes import List # from logilab.common.ureports.nodes import Paragraph # from logilab.common.ureports.nodes import Text @@ -39,6 +40,7 @@ class BaseComponent(VNode): * id : the component's optional id * klass : the component's optional klass """ + def __init__(self, id: Optional[str] = None, klass: Optional[str] = None) -> None: VNode.__init__(self, id) self.klass = klass @@ -51,11 +53,16 @@ class BaseLayout(BaseComponent): * BaseComponent attributes * children : components in this table (i.e. the table's cells) """ - def __init__(self, - children: Union[TypingList['Text'], - Tuple[Union['Paragraph', str], - Union[TypingList, str]], Tuple[str, ...]] = (), - **kwargs: Any) -> None: + + def __init__( + self, + children: Union[ + TypingList["Text"], + Tuple[Union["Paragraph", str], Union[TypingList, str]], + Tuple[str, ...], + ] = (), + **kwargs: Any, + ) -> None: super(BaseLayout, self).__init__(**kwargs) @@ -87,6 +94,7 @@ class BaseLayout(BaseComponent): # non container nodes ######################################################### + class Text(BaseComponent): """a text portion @@ -94,6 +102,7 @@ class Text(BaseComponent): * BaseComponent attributes * data : the text value as an encoded or unicode string """ + def __init__(self, data: str, escaped: bool = True, **kwargs: Any) -> None: super(Text, self).__init__(**kwargs) # if isinstance(data, unicode): @@ -120,6 +129,7 @@ class Link(BaseComponent): * url : the link's target (REQUIRED) * label : the link's label as a string (use the url by default) """ + def __init__(self, url: str, label: str = None, **kwargs: Any) -> None: super(Link, self).__init__(**kwargs) assert url @@ -136,6 +146,7 @@ class Image(BaseComponent): * stream : the stream object containing the image data (REQUIRED) * title : the image's optional title """ + def __init__(self, filename, stream, title=None, **kwargs): super(Image, self).__init__(**kwargs) assert filename @@ -147,6 +158,7 @@ class Image(BaseComponent): # container nodes ############################################################# + class Section(BaseLayout): """a section @@ -158,6 +170,7 @@ class Section(BaseLayout): a description may also be given to the constructor, it'll be added as a first paragraph """ + def __init__(self, title: str = None, description: str = None, **kwargs: Any) -> None: super(Section, self).__init__(**kwargs) if description: @@ -206,9 +219,17 @@ class Table(BaseLayout): * cheaders : the first col's elements are table's header * title : the table's optional title """ - def __init__(self, cols: int, title: Optional[Any] = None, - rheaders: int = 0, cheaders: int = 0, rrheaders: int = 0, rcheaders: int = 0, - **kwargs: Any) -> None: + + def __init__( + self, + cols: int, + title: Optional[Any] = None, + rheaders: int = 0, + cheaders: int = 0, + rrheaders: int = 0, + rcheaders: int = 0, + **kwargs: Any, + ) -> None: super(Table, self).__init__(**kwargs) assert isinstance(cols, int) self.cols = cols diff --git a/logilab/common/ureports/text_writer.py b/logilab/common/ureports/text_writer.py index f75d7c9..efe85b7 100644 --- a/logilab/common/ureports/text_writer.py +++ b/logilab/common/ureports/text_writer.py @@ -24,18 +24,27 @@ __docformat__ = "restructuredtext en" from logilab.common.textutils import linesep from logilab.common.ureports import BaseWriter -from logilab.common.ureports.nodes import (Section, Title, Table, List as NodeList, - Paragraph, Link, VerbatimText, Text) +from logilab.common.ureports.nodes import ( + Section, + Title, + Table, + List as NodeList, + Paragraph, + Link, + VerbatimText, + Text, +) -TITLE_UNDERLINES = [u'', u'=', u'-', u'`', u'.', u'~', u'^'] -BULLETS = [u'*', u'-'] +TITLE_UNDERLINES = ["", "=", "-", "`", ".", "~", "^"] +BULLETS = ["*", "-"] class TextWriter(BaseWriter): """format layouts as text (ReStructured inspiration but not totally handled yet) """ + def begin_format(self, layout: Any) -> None: super(TextWriter, self).begin_format(layout) self.list_level = 0 @@ -50,20 +59,20 @@ class TextWriter(BaseWriter): if self.pending_urls: self.writeln() for label, url in self.pending_urls: - self.writeln(u'.. _`%s`: %s' % (label, url)) + self.writeln(".. _`%s`: %s" % (label, url)) self.pending_urls = [] self.section -= 1 self.writeln() def visit_title(self, layout: Title) -> None: - title = u''.join(list(self.compute_content(layout))) + title = "".join(list(self.compute_content(layout))) self.writeln(title) try: self.writeln(TITLE_UNDERLINES[self.section] * len(title)) except IndexError: print("FIXME TITLE TOO DEEP. TURNING TITLE INTO TEXT") - def visit_paragraph(self, layout: 'Paragraph') -> None: + def visit_paragraph(self, layout: "Paragraph") -> None: """enter a paragraph""" self.format_children(layout) self.writeln() @@ -76,64 +85,67 @@ class TextWriter(BaseWriter): """display a table as text""" table_content = self.get_table_content(layout) # get columns width - cols_width = [0]*len(table_content[0]) + cols_width = [0] * len(table_content[0]) for row in table_content: for index in range(len(row)): col = row[index] cols_width[index] = max(cols_width[index], len(col)) - if layout.klass == 'field': + if layout.klass == "field": self.field_table(layout, table_content, cols_width) else: self.default_table(layout, table_content, cols_width) self.writeln() - def default_table(self, layout: Table, table_content: List[List[str]], cols_width: List[int]) -> None: + def default_table( + self, layout: Table, table_content: List[List[str]], cols_width: List[int] + ) -> None: """format a table""" - cols_width = [size+1 for size in cols_width] + cols_width = [size + 1 for size in cols_width] - format_strings = u' '.join([u'%%-%ss'] * len(cols_width)) + format_strings = " ".join(["%%-%ss"] * len(cols_width)) format_strings = format_strings % tuple(cols_width) - format_strings_list = format_strings.split(' ') + format_strings_list = format_strings.split(" ") - table_linesep = ( - u'\n+' + u'+'.join([u'-'*w for w in cols_width]) + u'+\n') - headsep = u'\n+' + u'+'.join([u'='*w for w in cols_width]) + u'+\n' + table_linesep = "\n+" + "+".join(["-" * w for w in cols_width]) + "+\n" + headsep = "\n+" + "+".join(["=" * w for w in cols_width]) + "+\n" # FIXME: layout.cheaders self.write(table_linesep) for i in range(len(table_content)): - self.write(u'|') + self.write("|") line = table_content[i] for j in range(len(line)): self.write(format_strings_list[j] % line[j]) - self.write(u'|') + self.write("|") if i == 0 and layout.rheaders: self.write(headsep) else: self.write(table_linesep) - def field_table(self, layout: Table, table_content: List[List[str]], cols_width: List[int]) -> None: + def field_table( + self, layout: Table, table_content: List[List[str]], cols_width: List[int] + ) -> None: """special case for field table""" assert layout.cols == 2 - format_string = u'%s%%-%ss: %%s' % (linesep, cols_width[0]) + format_string = "%s%%-%ss: %%s" % (linesep, cols_width[0]) for field, value in table_content: self.write(format_string % (field, value)) def visit_list(self, layout: NodeList) -> None: """display a list layout as text""" bullet = BULLETS[self.list_level % len(BULLETS)] - indent = ' ' * self.list_level + indent = " " * self.list_level self.list_level += 1 for child in layout.children: - self.write(u'%s%s%s ' % (linesep, indent, bullet)) + self.write("%s%s%s " % (linesep, indent, bullet)) child.accept(self) self.list_level -= 1 def visit_link(self, layout: Link) -> None: """add a hyperlink""" if layout.label != layout.url: - self.write(u'`%s`_' % layout.label) + self.write("`%s`_" % layout.label) self.pending_urls.append((layout.label, layout.url)) else: self.write(layout.url) @@ -141,11 +153,11 @@ class TextWriter(BaseWriter): def visit_verbatimtext(self, layout: VerbatimText) -> None: """display a verbatim layout as text (so difficult ;) """ - self.writeln(u'::\n') + self.writeln("::\n") for line in layout.data.splitlines(): - self.writeln(u' ' + line) + self.writeln(" " + line) self.writeln() def visit_text(self, layout: Text) -> None: """add some text""" - self.write(u'%s' % layout.data) + self.write("%s" % layout.data) diff --git a/logilab/common/urllib2ext.py b/logilab/common/urllib2ext.py index 339aec0..dfbafc1 100644 --- a/logilab/common/urllib2ext.py +++ b/logilab/common/urllib2ext.py @@ -5,22 +5,27 @@ import urllib2 import kerberos as krb + class GssapiAuthError(Exception): """raised on error during authentication process""" + import re -RGX = re.compile('(?:.*,)*\s*Negotiate\s*([^,]*),?', re.I) + +RGX = re.compile("(?:.*,)*\s*Negotiate\s*([^,]*),?", re.I) + def get_negociate_value(headers): - for authreq in headers.getheaders('www-authenticate'): + for authreq in headers.getheaders("www-authenticate"): match = RGX.search(authreq) if match: return match.group(1) + class HTTPGssapiAuthHandler(urllib2.BaseHandler): """Negotiate HTTP authentication using context from GSSAPI""" - handler_order = 400 # before Digest Auth + handler_order = 400 # before Digest Auth def __init__(self): self._reset() @@ -36,15 +41,16 @@ class HTTPGssapiAuthHandler(urllib2.BaseHandler): def http_error_401(self, req, fp, code, msg, headers): try: if self._retried > 5: - raise urllib2.HTTPError(req.get_full_url(), 401, - "negotiate auth failed", headers, None) + raise urllib2.HTTPError( + req.get_full_url(), 401, "negotiate auth failed", headers, None + ) self._retried += 1 - logging.debug('gssapi handler, try %s' % self._retried) + logging.debug("gssapi handler, try %s" % self._retried) negotiate = get_negociate_value(headers) if negotiate is None: - logging.debug('no negociate found in a www-authenticate header') + logging.debug("no negociate found in a www-authenticate header") return None - logging.debug('HTTPGssapiAuthHandler: negotiate 1 is %r' % negotiate) + logging.debug("HTTPGssapiAuthHandler: negotiate 1 is %r" % negotiate) result, self._context = krb.authGSSClientInit("HTTP@%s" % req.get_host()) if result < 1: raise GssapiAuthError("HTTPGssapiAuthHandler: init failed with %d" % result) @@ -52,14 +58,14 @@ class HTTPGssapiAuthHandler(urllib2.BaseHandler): if result < 0: raise GssapiAuthError("HTTPGssapiAuthHandler: step 1 failed with %d" % result) client_response = krb.authGSSClientResponse(self._context) - logging.debug('HTTPGssapiAuthHandler: client response is %s...' % client_response[:10]) - req.add_unredirected_header('Authorization', "Negotiate %s" % client_response) + logging.debug("HTTPGssapiAuthHandler: client response is %s..." % client_response[:10]) + req.add_unredirected_header("Authorization", "Negotiate %s" % client_response) server_response = self.parent.open(req) negotiate = get_negociate_value(server_response.info()) if negotiate is None: - logging.warning('HTTPGssapiAuthHandler: failed to authenticate server') + logging.warning("HTTPGssapiAuthHandler: failed to authenticate server") else: - logging.debug('HTTPGssapiAuthHandler negotiate 2: %s' % negotiate) + logging.debug("HTTPGssapiAuthHandler negotiate 2: %s" % negotiate) result = krb.authGSSClientStep(self._context, negotiate) if result < 1: raise GssapiAuthError("HTTPGssapiAuthHandler: step 2 failed with %d" % result) @@ -70,20 +76,25 @@ class HTTPGssapiAuthHandler(urllib2.BaseHandler): self.clean_context() self._reset() -if __name__ == '__main__': + +if __name__ == "__main__": import sys + # debug import httplib + httplib.HTTPConnection.debuglevel = 1 httplib.HTTPSConnection.debuglevel = 1 # debug import logging + logging.basicConfig(level=logging.DEBUG) # handle cookies import cookielib + cj = cookielib.CookieJar() ch = urllib2.HTTPCookieProcessor(cj) # test with url sys.argv[1] h = HTTPGssapiAuthHandler() response = urllib2.build_opener(h, ch).open(sys.argv[1]) - print('\nresponse: %s\n--------------\n' % response.code, response.info()) + print("\nresponse: %s\n--------------\n" % response.code, response.info()) diff --git a/logilab/common/vcgutils.py b/logilab/common/vcgutils.py index 9cd2acd..cd2b73a 100644 --- a/logilab/common/vcgutils.py +++ b/logilab/common/vcgutils.py @@ -33,101 +33,141 @@ __docformat__ = "restructuredtext en" import string ATTRS_VAL = { - 'algos': ('dfs', 'tree', 'minbackward', - 'left_to_right', 'right_to_left', - 'top_to_bottom', 'bottom_to_top', - 'maxdepth', 'maxdepthslow', 'mindepth', 'mindepthslow', - 'mindegree', 'minindegree', 'minoutdegree', - 'maxdegree', 'maxindegree', 'maxoutdegree'), - 'booleans': ('yes', 'no'), - 'colors': ('black', 'white', 'blue', 'red', 'green', 'yellow', - 'magenta', 'lightgrey', - 'cyan', 'darkgrey', 'darkblue', 'darkred', 'darkgreen', - 'darkyellow', 'darkmagenta', 'darkcyan', 'gold', - 'lightblue', 'lightred', 'lightgreen', 'lightyellow', - 'lightmagenta', 'lightcyan', 'lilac', 'turquoise', - 'aquamarine', 'khaki', 'purple', 'yellowgreen', 'pink', - 'orange', 'orchid'), - 'shapes': ('box', 'ellipse', 'rhomb', 'triangle'), - 'textmodes': ('center', 'left_justify', 'right_justify'), - 'arrowstyles': ('solid', 'line', 'none'), - 'linestyles': ('continuous', 'dashed', 'dotted', 'invisible'), - } + "algos": ( + "dfs", + "tree", + "minbackward", + "left_to_right", + "right_to_left", + "top_to_bottom", + "bottom_to_top", + "maxdepth", + "maxdepthslow", + "mindepth", + "mindepthslow", + "mindegree", + "minindegree", + "minoutdegree", + "maxdegree", + "maxindegree", + "maxoutdegree", + ), + "booleans": ("yes", "no"), + "colors": ( + "black", + "white", + "blue", + "red", + "green", + "yellow", + "magenta", + "lightgrey", + "cyan", + "darkgrey", + "darkblue", + "darkred", + "darkgreen", + "darkyellow", + "darkmagenta", + "darkcyan", + "gold", + "lightblue", + "lightred", + "lightgreen", + "lightyellow", + "lightmagenta", + "lightcyan", + "lilac", + "turquoise", + "aquamarine", + "khaki", + "purple", + "yellowgreen", + "pink", + "orange", + "orchid", + ), + "shapes": ("box", "ellipse", "rhomb", "triangle"), + "textmodes": ("center", "left_justify", "right_justify"), + "arrowstyles": ("solid", "line", "none"), + "linestyles": ("continuous", "dashed", "dotted", "invisible"), +} # meaning of possible values: # O -> string # 1 -> int # list -> value in list GRAPH_ATTRS = { - 'title': 0, - 'label': 0, - 'color': ATTRS_VAL['colors'], - 'textcolor': ATTRS_VAL['colors'], - 'bordercolor': ATTRS_VAL['colors'], - 'width': 1, - 'height': 1, - 'borderwidth': 1, - 'textmode': ATTRS_VAL['textmodes'], - 'shape': ATTRS_VAL['shapes'], - 'shrink': 1, - 'stretch': 1, - 'orientation': ATTRS_VAL['algos'], - 'vertical_order': 1, - 'horizontal_order': 1, - 'xspace': 1, - 'yspace': 1, - 'layoutalgorithm': ATTRS_VAL['algos'], - 'late_edge_labels': ATTRS_VAL['booleans'], - 'display_edge_labels': ATTRS_VAL['booleans'], - 'dirty_edge_labels': ATTRS_VAL['booleans'], - 'finetuning': ATTRS_VAL['booleans'], - 'manhattan_edges': ATTRS_VAL['booleans'], - 'smanhattan_edges': ATTRS_VAL['booleans'], - 'port_sharing': ATTRS_VAL['booleans'], - 'edges': ATTRS_VAL['booleans'], - 'nodes': ATTRS_VAL['booleans'], - 'splines': ATTRS_VAL['booleans'], - } + "title": 0, + "label": 0, + "color": ATTRS_VAL["colors"], + "textcolor": ATTRS_VAL["colors"], + "bordercolor": ATTRS_VAL["colors"], + "width": 1, + "height": 1, + "borderwidth": 1, + "textmode": ATTRS_VAL["textmodes"], + "shape": ATTRS_VAL["shapes"], + "shrink": 1, + "stretch": 1, + "orientation": ATTRS_VAL["algos"], + "vertical_order": 1, + "horizontal_order": 1, + "xspace": 1, + "yspace": 1, + "layoutalgorithm": ATTRS_VAL["algos"], + "late_edge_labels": ATTRS_VAL["booleans"], + "display_edge_labels": ATTRS_VAL["booleans"], + "dirty_edge_labels": ATTRS_VAL["booleans"], + "finetuning": ATTRS_VAL["booleans"], + "manhattan_edges": ATTRS_VAL["booleans"], + "smanhattan_edges": ATTRS_VAL["booleans"], + "port_sharing": ATTRS_VAL["booleans"], + "edges": ATTRS_VAL["booleans"], + "nodes": ATTRS_VAL["booleans"], + "splines": ATTRS_VAL["booleans"], +} NODE_ATTRS = { - 'title': 0, - 'label': 0, - 'color': ATTRS_VAL['colors'], - 'textcolor': ATTRS_VAL['colors'], - 'bordercolor': ATTRS_VAL['colors'], - 'width': 1, - 'height': 1, - 'borderwidth': 1, - 'textmode': ATTRS_VAL['textmodes'], - 'shape': ATTRS_VAL['shapes'], - 'shrink': 1, - 'stretch': 1, - 'vertical_order': 1, - 'horizontal_order': 1, - } + "title": 0, + "label": 0, + "color": ATTRS_VAL["colors"], + "textcolor": ATTRS_VAL["colors"], + "bordercolor": ATTRS_VAL["colors"], + "width": 1, + "height": 1, + "borderwidth": 1, + "textmode": ATTRS_VAL["textmodes"], + "shape": ATTRS_VAL["shapes"], + "shrink": 1, + "stretch": 1, + "vertical_order": 1, + "horizontal_order": 1, +} EDGE_ATTRS = { - 'sourcename': 0, - 'targetname': 0, - 'label': 0, - 'linestyle': ATTRS_VAL['linestyles'], - 'class': 1, - 'thickness': 0, - 'color': ATTRS_VAL['colors'], - 'textcolor': ATTRS_VAL['colors'], - 'arrowcolor': ATTRS_VAL['colors'], - 'backarrowcolor': ATTRS_VAL['colors'], - 'arrowsize': 1, - 'backarrowsize': 1, - 'arrowstyle': ATTRS_VAL['arrowstyles'], - 'backarrowstyle': ATTRS_VAL['arrowstyles'], - 'textmode': ATTRS_VAL['textmodes'], - 'priority': 1, - 'anchor': 1, - 'horizontal_order': 1, - } + "sourcename": 0, + "targetname": 0, + "label": 0, + "linestyle": ATTRS_VAL["linestyles"], + "class": 1, + "thickness": 0, + "color": ATTRS_VAL["colors"], + "textcolor": ATTRS_VAL["colors"], + "arrowcolor": ATTRS_VAL["colors"], + "backarrowcolor": ATTRS_VAL["colors"], + "arrowsize": 1, + "backarrowsize": 1, + "arrowstyle": ATTRS_VAL["arrowstyles"], + "backarrowstyle": ATTRS_VAL["arrowstyles"], + "textmode": ATTRS_VAL["textmodes"], + "priority": 1, + "anchor": 1, + "horizontal_order": 1, +} # Misc utilities ############################################################### + def latin_to_vcg(st): """Convert latin characters using vcg escape sequence. """ @@ -136,7 +176,7 @@ def latin_to_vcg(st): try: num = ord(char) if num >= 192: - st = st.replace(char, r'\fi%d'%ord(char)) + st = st.replace(char, r"\fi%d" % ord(char)) except: pass return st @@ -148,12 +188,12 @@ class VCGPrinter: def __init__(self, output_stream): self._stream = output_stream - self._indent = '' + self._indent = "" def open_graph(self, **args): """open a vcg graph """ - self._stream.write('%sgraph:{\n'%self._indent) + self._stream.write("%sgraph:{\n" % self._indent) self._inc_indent() self._write_attributes(GRAPH_ATTRS, **args) @@ -161,26 +201,24 @@ class VCGPrinter: """close a vcg graph """ self._dec_indent() - self._stream.write('%s}\n'%self._indent) - + self._stream.write("%s}\n" % self._indent) def node(self, title, **args): """draw a node """ self._stream.write('%snode: {title:"%s"' % (self._indent, title)) self._write_attributes(NODE_ATTRS, **args) - self._stream.write('}\n') + self._stream.write("}\n") - - def edge(self, from_node, to_node, edge_type='', **args): + def edge(self, from_node, to_node, edge_type="", **args): """draw an edge from a node to another. """ self._stream.write( - '%s%sedge: {sourcename:"%s" targetname:"%s"' % ( - self._indent, edge_type, from_node, to_node)) + '%s%sedge: {sourcename:"%s" targetname:"%s"' + % (self._indent, edge_type, from_node, to_node) + ) self._write_attributes(EDGE_ATTRS, **args) - self._stream.write('}\n') - + self._stream.write("}\n") # private ################################################################## @@ -189,26 +227,31 @@ class VCGPrinter: """ for key, value in args.items(): try: - _type = attributes_dict[key] + _type = attributes_dict[key] except KeyError: - raise Exception('''no such attribute %s -possible attributes are %s''' % (key, attributes_dict.keys())) + raise Exception( + """no such attribute %s +possible attributes are %s""" + % (key, attributes_dict.keys()) + ) if not _type: self._stream.write('%s%s:"%s"\n' % (self._indent, key, value)) elif _type == 1: - self._stream.write('%s%s:%s\n' % (self._indent, key, - int(value))) + self._stream.write("%s%s:%s\n" % (self._indent, key, int(value))) elif value in _type: - self._stream.write('%s%s:%s\n' % (self._indent, key, value)) + self._stream.write("%s%s:%s\n" % (self._indent, key, value)) else: - raise Exception('''value %s isn\'t correct for attribute %s -correct values are %s''' % (value, key, _type)) + raise Exception( + """value %s isn\'t correct for attribute %s +correct values are %s""" + % (value, key, _type) + ) def _inc_indent(self): """increment indentation """ - self._indent = ' %s' % self._indent + self._indent = " %s" % self._indent def _dec_indent(self): """decrement indentation diff --git a/logilab/common/visitor.py b/logilab/common/visitor.py index 0698bae..8d80d54 100644 --- a/logilab/common/visitor.py +++ b/logilab/common/visitor.py @@ -23,15 +23,16 @@ """ from typing import Any, Callable, Optional, Union from logilab.common.types import Node, HTMLWriter, TextWriter + __docformat__ = "restructuredtext en" def no_filter(_: Node) -> int: return 1 + # Iterators ################################################################### class FilteredIterator(object): - def __init__(self, node: Node, list_func: Callable, filter_func: Optional[Any] = None) -> None: self._next = [(node, 0)] if filter_func is None: @@ -41,14 +42,14 @@ class FilteredIterator(object): def __next__(self) -> Optional[Node]: try: return self._list.pop(0) - except : + except: return None next = __next__ + # Base Visitor ################################################################ class Visitor(object): - def __init__(self, iterator_class, filter_func=None): self._iter_class = iterator_class self.filter = filter_func @@ -87,11 +88,13 @@ class Visitor(object): """ return result + # standard visited mixin ###################################################### class VisitedMixIn(object): """ Visited interface allow node visitors to use the node """ + def get_visit_name(self) -> str: """ return the visit name for the mixed class. When calling 'accept', the @@ -101,14 +104,16 @@ class VisitedMixIn(object): try: # mypy: "VisitedMixIn" has no attribute "TYPE" # dynamic attribute - return self.TYPE.replace('-', '_') # type: ignore + return self.TYPE.replace("-", "_") # type: ignore except: return self.__class__.__name__.lower() - def accept(self, visitor: Union[HTMLWriter, TextWriter], *args: Any, **kwargs: Any) -> Optional[Any]: - func = getattr(visitor, 'visit_%s' % self.get_visit_name()) + def accept( + self, visitor: Union[HTMLWriter, TextWriter], *args: Any, **kwargs: Any + ) -> Optional[Any]: + func = getattr(visitor, "visit_%s" % self.get_visit_name()) return func(self, *args, **kwargs) def leave(self, visitor, *args, **kwargs): - func = getattr(visitor, 'leave_%s' % self.get_visit_name()) + func = getattr(visitor, "leave_%s" % self.get_visit_name()) return func(self, *args, **kwargs) diff --git a/logilab/common/xmlutils.py b/logilab/common/xmlutils.py index 7b12c45..14e3762 100644 --- a/logilab/common/xmlutils.py +++ b/logilab/common/xmlutils.py @@ -34,6 +34,7 @@ from typing import Dict, Optional, Union RE_DOUBLE_QUOTE = re.compile('([\w\-\.]+)="([^"]+)"') RE_SIMPLE_QUOTE = re.compile("([\w\-\.]+)='([^']+)'") + def parse_pi_data(pi_data: str) -> Dict[str, Optional[str]]: """ Utility function that parses the data contained in an XML |