summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLaurent Peuch <cortex@worlddomination.be>2020-04-01 00:11:10 +0200
committerLaurent Peuch <cortex@worlddomination.be>2020-04-01 00:11:10 +0200
commitb8899451fa861b04568e2a0bb4e3fe4acc0daee3 (patch)
tree1446c809e19b5571b31b1999246aa0e50b19f5c8
parent32cd73810056594f55eff0ffafebbdeb50c7a860 (diff)
downloadlogilab-common-b8899451fa861b04568e2a0bb4e3fe4acc0daee3.tar.gz
Black the whole code base
-rw-r--r--__pkginfo__.py40
-rw-r--r--logilab/__init__.py2
-rw-r--r--logilab/common/__init__.py29
-rw-r--r--logilab/common/cache.py13
-rw-r--r--logilab/common/changelog.py74
-rw-r--r--logilab/common/clcommands.py126
-rw-r--r--logilab/common/compat.py8
-rw-r--r--logilab/common/configuration.py503
-rw-r--r--logilab/common/daemon.py23
-rw-r--r--logilab/common/date.py180
-rw-r--r--logilab/common/debugger.py34
-rw-r--r--logilab/common/decorators.py61
-rw-r--r--logilab/common/deprecation.py2
-rw-r--r--logilab/common/fileutils.py49
-rw-r--r--logilab/common/graph.py108
-rw-r--r--logilab/common/interface.py5
-rw-r--r--logilab/common/logging_ext.py91
-rw-r--r--logilab/common/modutils.py182
-rw-r--r--logilab/common/optik_ext.py245
-rw-r--r--logilab/common/optparser.py33
-rw-r--r--logilab/common/proc.py43
-rw-r--r--logilab/common/pytest.py630
-rw-r--r--logilab/common/registry.py269
-rw-r--r--logilab/common/shellutils.py119
-rw-r--r--logilab/common/sphinx_ext.py29
-rw-r--r--logilab/common/sphinxutils.py41
-rw-r--r--logilab/common/table.py314
-rw-r--r--logilab/common/tasksqueue.py26
-rw-r--r--logilab/common/testlib.py213
-rw-r--r--logilab/common/textutils.py224
-rw-r--r--logilab/common/tree.py55
-rw-r--r--logilab/common/umessage.py42
-rw-r--r--logilab/common/ureports/__init__.py40
-rw-r--r--logilab/common/ureports/docbook_writer.py87
-rw-r--r--logilab/common/ureports/html_writer.py94
-rw-r--r--logilab/common/ureports/nodes.py37
-rw-r--r--logilab/common/ureports/text_writer.py64
-rw-r--r--logilab/common/urllib2ext.py39
-rw-r--r--logilab/common/vcgutils.py251
-rw-r--r--logilab/common/visitor.py19
-rw-r--r--logilab/common/xmlutils.py1
-rw-r--r--setup.py30
-rw-r--r--test/data/__pkginfo__.py36
-rw-r--r--test/data/deprecation.py1
-rw-r--r--test/data/lmfp/foo.py3
-rw-r--r--test/data/module.py18
-rw-r--r--test/data/module2.py67
-rw-r--r--test/data/noendingnewline.py11
-rw-r--r--test/data/nonregr.py3
-rw-r--r--test/data/regobjects.py15
-rw-r--r--test/data/regobjects2.py6
-rw-r--r--test/data/sub/momo.py2
-rw-r--r--test/test_cache.py125
-rw-r--r--test/test_changelog.py4
-rw-r--r--test/test_configuration.py373
-rw-r--r--test/test_date.py64
-rw-r--r--test/test_decorators.py114
-rw-r--r--test/test_fileutils.py76
-rw-r--r--test/test_graph.py50
-rw-r--r--test/test_interface.py32
-rw-r--r--test/test_pytest.py3
-rw-r--r--test/test_shellutils.py244
-rw-r--r--test/test_table.py287
-rw-r--r--test/test_taskqueue.py62
-rw-r--r--test/test_testlib.py503
-rw-r--r--test/test_textutils.py273
-rw-r--r--test/test_tree.py131
-rw-r--r--test/test_umessage.py39
-rw-r--r--test/test_ureports_html.py30
-rw-r--r--test/test_ureports_text.py41
-rw-r--r--test/test_xmlutils.py22
-rw-r--r--test/utils.py61
72 files changed, 4203 insertions, 2968 deletions
diff --git a/__pkginfo__.py b/__pkginfo__.py
index 6ad6cb6..6005ae7 100644
--- a/__pkginfo__.py
+++ b/__pkginfo__.py
@@ -23,41 +23,41 @@ __docformat__ = "restructuredtext en"
import os
from os.path import join
-distname = 'logilab-common'
-modname = 'common'
-subpackage_of = 'logilab'
+distname = "logilab-common"
+modname = "common"
+subpackage_of = "logilab"
subpackage_master = True
numversion = (1, 6, 1)
version = '.'.join([str(num) for num in numversion])
-license = 'LGPL' # 2.1 or later
-description = ("collection of low-level Python packages and modules"
- " used by Logilab projects")
+license = "LGPL" # 2.1 or later
+description = "collection of low-level Python packages and modules" " used by Logilab projects"
web = "http://www.logilab.org/project/%s" % distname
mailinglist = "mailto://python-projects@lists.logilab.org"
author = "Logilab"
author_email = "contact@logilab.fr"
-scripts = [join('bin', 'logilab-pytest')]
-include_dirs = [join('test', 'data')]
+scripts = [join("bin", "logilab-pytest")]
+include_dirs = [join("test", "data")]
install_requires = [
- 'setuptools',
- 'mypy-extensions',
- 'typing_extensions',
+ "setuptools",
+ "mypy-extensions",
+ "typing_extensions",
]
tests_require = [
- 'pytz',
- 'egenix-mx-base',
+ "pytz",
+ "egenix-mx-base",
]
-if os.name == 'nt':
- install_requires.append('colorama')
+if os.name == "nt":
+ install_requires.append("colorama")
-classifiers = ["Topic :: Utilities",
- "Programming Language :: Python",
- "Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3 :: Only",
- ]
+classifiers = [
+ "Topic :: Utilities",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3 :: Only",
+]
diff --git a/logilab/__init__.py b/logilab/__init__.py
index de40ea7..5284146 100644
--- a/logilab/__init__.py
+++ b/logilab/__init__.py
@@ -1 +1 @@
-__import__('pkg_resources').declare_namespace(__name__)
+__import__("pkg_resources").declare_namespace(__name__)
diff --git a/logilab/common/__init__.py b/logilab/common/__init__.py
index 0d7f183..34e6c1b 100644
--- a/logilab/common/__init__.py
+++ b/logilab/common/__init__.py
@@ -31,19 +31,19 @@ import types
import pkg_resources
from typing import List, Sequence
-__version__ = pkg_resources.get_distribution('logilab-common').version
+__version__ = pkg_resources.get_distribution("logilab-common").version
# deprecated, but keep compatibility with pylint < 1.4.4
-__pkginfo__ = types.ModuleType('__pkginfo__')
+__pkginfo__ = types.ModuleType("__pkginfo__")
__pkginfo__.__package__ = __name__
# mypy output: Module has no attribute "version"
# logilab's magic
__pkginfo__.version = __version__ # type: ignore
-sys.modules['logilab.common.__pkginfo__'] = __pkginfo__
+sys.modules["logilab.common.__pkginfo__"] = __pkginfo__
-STD_BLACKLIST = ('CVS', '.svn', '.hg', '.git', '.tox', 'debian', 'dist', 'build')
+STD_BLACKLIST = ("CVS", ".svn", ".hg", ".git", ".tox", "debian", "dist", "build")
-IGNORED_EXTENSIONS = ('.pyc', '.pyo', '.elc', '~', '.swp', '.orig')
+IGNORED_EXTENSIONS = (".pyc", ".pyo", ".elc", "~", ".swp", ".orig")
# set this to False if you've mx DateTime installed but you don't want your db
# adapter to use it (should be set before you got a connection)
@@ -52,12 +52,14 @@ USE_MX_DATETIME = True
class attrdict(dict):
"""A dictionary for which keys are also accessible as attributes."""
+
def __getattr__(self, attr: str) -> str:
try:
return self[attr]
except KeyError:
raise AttributeError(attr)
+
class dictattr(dict):
def __init__(self, proxy):
self.__proxy = proxy
@@ -68,13 +70,17 @@ class dictattr(dict):
except AttributeError:
raise KeyError(attr)
+
class nullobject(object):
def __repr__(self):
- return '<nullobject>'
+ return "<nullobject>"
+
def __bool__(self):
return False
+
__nonzero__ = __bool__
+
class tempattr(object):
def __init__(self, obj, attr, value):
self.obj = obj
@@ -90,7 +96,6 @@ class tempattr(object):
setattr(self.obj, self.attr, self.oldvalue)
-
# flatten -----
# XXX move in a specific module and use yield instead
# do not mix flatten and translate
@@ -105,10 +110,10 @@ class tempattr(object):
# except (TypeError, ValueError): return False
# return True
#
-#def is_scalar(obj):
+# def is_scalar(obj):
# return is_string_like(obj) or not iterable(obj)
#
-#def flatten(seq):
+# def flatten(seq):
# for item in seq:
# if is_scalar(item):
# yield item
@@ -116,6 +121,7 @@ class tempattr(object):
# for subitem in flatten(item):
# yield subitem
+
def flatten(iterable, tr_func=None, results=None):
"""Flatten a list of list with any level.
@@ -141,6 +147,7 @@ def flatten(iterable, tr_func=None, results=None):
# XXX is function below still used ?
+
def make_domains(lists):
"""
Given a list of lists, return a list of domain for each list to produce all
@@ -157,7 +164,7 @@ def make_domains(lists):
for iterable in lists:
new_domain = iterable[:]
for i in range(len(domains)):
- domains[i] = domains[i]*len(iterable)
+ domains[i] = domains[i] * len(iterable)
if domains:
missing = (len(domains[0]) - len(iterable)) / len(iterable)
i = 0
@@ -173,6 +180,7 @@ def make_domains(lists):
# private stuff ################################################################
+
def _handle_blacklist(blacklist: Sequence[str], dirnames: List[str], filenames: List[str]) -> None:
"""remove files/directories in the black list
@@ -183,4 +191,3 @@ def _handle_blacklist(blacklist: Sequence[str], dirnames: List[str], filenames:
dirnames.remove(norecurs)
elif norecurs in filenames:
filenames.remove(norecurs)
-
diff --git a/logilab/common/cache.py b/logilab/common/cache.py
index c47f481..6b673a2 100644
--- a/logilab/common/cache.py
+++ b/logilab/common/cache.py
@@ -47,7 +47,7 @@ class Cache(dict):
""" Warning : Cache.__init__() != dict.__init__().
Constructor does not take any arguments beside size.
"""
- assert size >= 0, 'cache size must be >= 0 (0 meaning no caching)'
+ assert size >= 0, "cache size must be >= 0 (0 meaning no caching)"
self.size = size
self._usage: List = []
self._lock = Lock()
@@ -74,12 +74,13 @@ class Cache(dict):
del self._usage[0]
self._usage.append(key)
else:
- pass # key is already the most recently used key
+ pass # key is already the most recently used key
def __getitem__(self, key: _KeyType):
value = super(Cache, self).__getitem__(key)
self._update_usage(key)
return value
+
__getitem__ = locked(_acquire, _release)(__getitem__)
def __setitem__(self, key: _KeyType, item):
@@ -87,24 +88,28 @@ class Cache(dict):
if self.size > 0:
super(Cache, self).__setitem__(key, item)
self._update_usage(key)
+
__setitem__ = locked(_acquire, _release)(__setitem__)
def __delitem__(self, key: _KeyType):
super(Cache, self).__delitem__(key)
self._usage.remove(key)
+
__delitem__ = locked(_acquire, _release)(__delitem__)
def clear(self):
super(Cache, self).clear()
self._usage = []
+
clear = locked(_acquire, _release)(clear)
def pop(self, key: _KeyType, default=_marker):
if key in self:
self._usage.remove(key)
- #if default is _marker:
+ # if default is _marker:
# return super(Cache, self).pop(key)
return super(Cache, self).pop(key, default)
+
pop = locked(_acquire, _release)(pop)
def popitem(self):
@@ -115,5 +120,3 @@ class Cache(dict):
def update(self, other):
raise NotImplementedError()
-
-
diff --git a/logilab/common/changelog.py b/logilab/common/changelog.py
index c128eb7..cec1b5e 100644
--- a/logilab/common/changelog.py
+++ b/logilab/common/changelog.py
@@ -52,9 +52,9 @@ import codecs
from typing import List, Any, Optional, Tuple
from _io import StringIO
-BULLET = '*'
-SUBBULLET = '-'
-INDENT = ' ' * 4
+BULLET = "*"
+SUBBULLET = "-"
+INDENT = " " * 4
class NoEntry(Exception):
@@ -69,9 +69,10 @@ class Version(tuple):
"""simple class to handle soft version number has a tuple while
correctly printing it as X.Y.Z
"""
+
def __new__(cls, versionstr):
if isinstance(versionstr, str):
- versionstr = versionstr.strip(' :') # XXX (syt) duh?
+ versionstr = versionstr.strip(" :") # XXX (syt) duh?
parsed = cls.parse(versionstr)
else:
parsed = versionstr
@@ -79,26 +80,29 @@ class Version(tuple):
@classmethod
def parse(cls, versionstr: str) -> List[int]:
- versionstr = versionstr.strip(' :')
+ versionstr = versionstr.strip(" :")
try:
- return [int(i) for i in versionstr.split('.')]
+ return [int(i) for i in versionstr.split(".")]
except ValueError as ex:
- raise ValueError("invalid literal for version '%s' (%s)" %
- (versionstr, ex))
+ raise ValueError("invalid literal for version '%s' (%s)" % (versionstr, ex))
def __str__(self) -> str:
- return '.'.join([str(i) for i in self])
+ return ".".join([str(i) for i in self])
# upstream change log #########################################################
+
class ChangeLogEntry(object):
"""a change log entry, i.e. a set of messages associated to a version and
its release date
"""
+
version_class = Version
- def __init__(self, date: Optional[str] = None, version: Optional[str] = None, **kwargs: Any) -> None:
+ def __init__(
+ self, date: Optional[str] = None, version: Optional[str] = None, **kwargs: Any
+ ) -> None:
self.__dict__.update(kwargs)
self.version: Optional[Version]
if version:
@@ -116,8 +120,7 @@ class ChangeLogEntry(object):
"""complete the latest added message
"""
if not self.messages:
- raise ValueError('unable to complete last message as '
- 'there is no previous message)')
+ raise ValueError("unable to complete last message as " "there is no previous message)")
if self.messages[-1][1]: # sub messages
self.messages[-1][1][-1].append(msg_suite)
else: # message
@@ -125,29 +128,26 @@ class ChangeLogEntry(object):
def add_sub_message(self, sub_msg: str, key: Optional[Any] = None) -> None:
if not self.messages:
- raise ValueError('unable to complete last message as '
- 'there is no previous message)')
+ raise ValueError("unable to complete last message as " "there is no previous message)")
if key is None:
self.messages[-1][1].append([sub_msg])
else:
- raise NotImplementedError('sub message to specific key '
- 'are not implemented yet')
+ raise NotImplementedError("sub message to specific key " "are not implemented yet")
def write(self, stream: StringIO = sys.stdout) -> None:
"""write the entry to file """
- stream.write(u'%s -- %s\n' % (self.date or '', self.version or ''))
+ stream.write("%s -- %s\n" % (self.date or "", self.version or ""))
for msg, sub_msgs in self.messages:
- stream.write(u'%s%s %s\n' % (INDENT, BULLET, msg[0]))
- stream.write(u''.join(msg[1:]))
+ stream.write("%s%s %s\n" % (INDENT, BULLET, msg[0]))
+ stream.write("".join(msg[1:]))
if sub_msgs:
- stream.write(u'\n')
+ stream.write("\n")
for sub_msg in sub_msgs:
- stream.write(u'%s%s %s\n' %
- (INDENT * 2, SUBBULLET, sub_msg[0]))
- stream.write(u''.join(sub_msg[1:]))
- stream.write(u'\n')
+ stream.write("%s%s %s\n" % (INDENT * 2, SUBBULLET, sub_msg[0]))
+ stream.write("".join(sub_msg[1:]))
+ stream.write("\n")
- stream.write(u'\n\n')
+ stream.write("\n\n")
class ChangeLog(object):
@@ -155,23 +155,22 @@ class ChangeLog(object):
entry_class = ChangeLogEntry
- def __init__(self, changelog_file: str, title: str = u'') -> None:
+ def __init__(self, changelog_file: str, title: str = "") -> None:
self.file = changelog_file
- assert isinstance(title, type(u'')), 'title must be a unicode object'
+ assert isinstance(title, type("")), "title must be a unicode object"
self.title = title
- self.additional_content = u''
+ self.additional_content = ""
self.entries: List[ChangeLogEntry] = []
self.load()
def __repr__(self):
- return '<ChangeLog %s at %s (%s entries)>' % (self.file, id(self),
- len(self.entries))
+ return "<ChangeLog %s at %s (%s entries)>" % (self.file, id(self), len(self.entries))
def add_entry(self, entry: ChangeLogEntry) -> None:
"""add a new entry to the change log"""
self.entries.append(entry)
- def get_entry(self, version='', create=None):
+ def get_entry(self, version="", create=None):
""" return a given changelog entry
if version is omitted, return the current entry
"""
@@ -197,7 +196,7 @@ class ChangeLog(object):
def load(self) -> None:
""" read a logilab's ChangeLog from file """
try:
- stream = codecs.open(self.file, encoding='utf-8')
+ stream = codecs.open(self.file, encoding="utf-8")
except IOError:
return
@@ -209,20 +208,20 @@ class ChangeLog(object):
words = sline.split()
# if new entry
- if len(words) == 1 and words[0] == '--':
+ if len(words) == 1 and words[0] == "--":
expect_sub = False
last = self.entry_class()
self.add_entry(last)
# if old entry
- elif len(words) == 3 and words[1] == '--':
+ elif len(words) == 3 and words[1] == "--":
expect_sub = False
last = self.entry_class(words[0], words[2])
self.add_entry(last)
# if title
elif sline and last is None:
- self.title = '%s%s' % (self.title, line)
+ self.title = "%s%s" % (self.title, line)
# if new entry
elif sline and sline[0] == BULLET:
expect_sub = False
@@ -243,14 +242,15 @@ class ChangeLog(object):
stream.close()
def format_title(self) -> str:
- return u'%s\n\n' % self.title.strip()
+ return "%s\n\n" % self.title.strip()
def save(self):
"""write back change log"""
# filetutils isn't importable in appengine, so import locally
from logilab.common.fileutils import ensure_fs_mode
+
ensure_fs_mode(self.file, S_IWRITE)
- self.write(codecs.open(self.file, 'w', encoding='utf-8'))
+ self.write(codecs.open(self.file, "w", encoding="utf-8"))
def write(self, stream: StringIO = sys.stdout) -> None:
"""write changelog to stream"""
diff --git a/logilab/common/clcommands.py b/logilab/common/clcommands.py
index 4778b99..f89a4b4 100644
--- a/logilab/common/clcommands.py
+++ b/logilab/common/clcommands.py
@@ -42,6 +42,7 @@ class BadCommandUsage(Exception):
Trigger display of command usage.
"""
+
class CommandError(Exception):
"""Raised when a command can't be processed and we want to display it and
exit, without traceback nor usage displayed.
@@ -50,6 +51,7 @@ class CommandError(Exception):
# command line access point ####################################################
+
class CommandLine(dict):
"""Usage:
@@ -77,9 +79,17 @@ class CommandLine(dict):
* `logger`, logger to propagate to commands, default to
`logging.getLogger(self.pgm))`
"""
- def __init__(self, pgm=None, doc=None, copyright=None, version=None,
- rcfile=None, logthreshold=logging.ERROR,
- check_duplicated_command=True):
+
+ def __init__(
+ self,
+ pgm=None,
+ doc=None,
+ copyright=None,
+ version=None,
+ rcfile=None,
+ logthreshold=logging.ERROR,
+ check_duplicated_command=True,
+ ):
if pgm is None:
pgm = basename(sys.argv[0])
self.pgm = pgm
@@ -93,8 +103,9 @@ class CommandLine(dict):
def register(self, cls, force=False):
"""register the given :class:`Command` subclass"""
- assert not self.check_duplicated_command or force or not cls.name in self, \
- 'a command %s is already defined' % cls.name
+ assert not self.check_duplicated_command or force or not cls.name in self, (
+ "a command %s is already defined" % cls.name
+ )
self[cls.name] = cls
return cls
@@ -107,20 +118,22 @@ class CommandLine(dict):
Terminate by :exc:`SystemExit`
"""
- init_log(debug=True, # so that we use StreamHandler
- logthreshold=self.logthreshold,
- logformat='%(levelname)s: %(message)s')
+ init_log(
+ debug=True, # so that we use StreamHandler
+ logthreshold=self.logthreshold,
+ logformat="%(levelname)s: %(message)s",
+ )
try:
arg = args.pop(0)
except IndexError:
self.usage_and_exit(1)
- if arg in ('-h', '--help'):
+ if arg in ("-h", "--help"):
self.usage_and_exit(0)
- if self.version is not None and arg in ('--version'):
+ if self.version is not None and arg in ("--version"):
print(self.version)
sys.exit(0)
rcfile = self.rcfile
- if rcfile is not None and arg in ('-C', '--rc-file'):
+ if rcfile is not None and arg in ("-C", "--rc-file"):
try:
rcfile = args.pop(0)
arg = args.pop(0)
@@ -129,19 +142,19 @@ class CommandLine(dict):
try:
command = self.get_command(arg)
except KeyError:
- print('ERROR: no %s command' % arg)
+ print("ERROR: no %s command" % arg)
print()
self.usage_and_exit(1)
try:
sys.exit(command.main_run(args, rcfile))
except KeyboardInterrupt as exc:
- print('Interrupted', end=' ')
+ print("Interrupted", end=" ")
if str(exc):
- print(': %s' % exc, end=' ')
+ print(": %s" % exc, end=" ")
print()
sys.exit(4)
except BadCommandUsage as err:
- print('ERROR:', err)
+ print("ERROR:", err)
print()
print(command.help())
sys.exit(1)
@@ -166,32 +179,44 @@ class CommandLine(dict):
"""display usage for the main program (i.e. when no command supplied)
and exit
"""
- print('usage:', self.pgm, end=' ')
+ print("usage:", self.pgm, end=" ")
if self.rcfile:
- print('[--rc-file=<configuration file>]', end=' ')
- print('<command> [options] <command argument>...')
+ print("[--rc-file=<configuration file>]", end=" ")
+ print("<command> [options] <command argument>...")
if self.doc:
- print('\n%s' % self.doc)
- print('''
+ print("\n%s" % self.doc)
+ print(
+ """
Type "%(pgm)s <command> --help" for more information about a specific
-command. Available commands are :\n''' % self.__dict__)
+command. Available commands are :\n"""
+ % self.__dict__
+ )
max_len = max([len(cmd) for cmd in self])
- padding = ' ' * max_len
+ padding = " " * max_len
for cmdname, cmd in sorted(self.items()):
if not cmd.hidden:
- print(' ', (cmdname + padding)[:max_len], cmd.short_description())
+ print(" ", (cmdname + padding)[:max_len], cmd.short_description())
if self.rcfile:
- print('''
+ print(
+ """
Use --rc-file=<configuration file> / -C <configuration file> before the command
to specify a configuration file. Default to %s.
-''' % self.rcfile)
- print('''%(pgm)s -h/--help
- display this usage information and exit''' % self.__dict__)
+"""
+ % self.rcfile
+ )
+ print(
+ """%(pgm)s -h/--help
+ display this usage information and exit"""
+ % self.__dict__
+ )
if self.version:
- print('''%(pgm)s -v/--version
- display version configuration and exit''' % self.__dict__)
+ print(
+ """%(pgm)s -v/--version
+ display version configuration and exit"""
+ % self.__dict__
+ )
if self.copyright:
- print('\n', self.copyright)
+ print("\n", self.copyright)
def usage_and_exit(self, status):
self.usage()
@@ -200,6 +225,7 @@ to specify a configuration file. Default to %s.
# base command classes #########################################################
+
class Command(Configuration):
"""Base class for command line commands.
@@ -219,8 +245,8 @@ class Command(Configuration):
* `options`, options list, as allowed by :mod:configuration
"""
- arguments = ''
- name = ''
+ arguments = ""
+ name = ""
# hidden from help ?
hidden = False
# max/min args, None meaning unspecified
@@ -229,24 +255,23 @@ class Command(Configuration):
@classmethod
def description(cls):
- return cls.__doc__.replace(' ', '')
+ return cls.__doc__.replace(" ", "")
@classmethod
def short_description(cls):
- return cls.description().split('.')[0]
+ return cls.description().split(".")[0]
def __init__(self, logger):
- usage = '%%prog %s %s\n\n%s' % (self.name, self.arguments,
- self.description())
+ usage = "%%prog %s %s\n\n%s" % (self.name, self.arguments, self.description())
Configuration.__init__(self, usage=usage)
self.logger = logger
def check_args(self, args):
"""check command's arguments are provided"""
if self.min_args is not None and len(args) < self.min_args:
- raise BadCommandUsage('missing argument')
+ raise BadCommandUsage("missing argument")
if self.max_args is not None and len(args) > self.max_args:
- raise BadCommandUsage('too many arguments')
+ raise BadCommandUsage("too many arguments")
def main_run(self, args, rcfile=None):
"""Run the command and return status 0 if everything went fine.
@@ -275,8 +300,9 @@ class Command(Configuration):
class ListCommandsCommand(Command):
"""list available commands, useful for bash completion."""
- name = 'listcommands'
- arguments = '[command]'
+
+ name = "listcommands"
+ arguments = "[command]"
hidden = True
def run(self, args):
@@ -285,8 +311,8 @@ class ListCommandsCommand(Command):
command = args.pop()
cmd = _COMMANDS[command]
for optname, optdict in cmd.options:
- print('--help')
- print('--' + optname)
+ print("--help")
+ print("--" + optname)
else:
commands = sorted(_COMMANDS.keys())
for command in commands:
@@ -299,17 +325,19 @@ class ListCommandsCommand(Command):
_COMMANDS = CommandLine()
-DEFAULT_COPYRIGHT = '''\
+DEFAULT_COPYRIGHT = """\
Copyright (c) 2004-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
-http://www.logilab.fr/ -- mailto:contact@logilab.fr'''
+http://www.logilab.fr/ -- mailto:contact@logilab.fr"""
+
-@deprecated('use cls.register(cli)')
+@deprecated("use cls.register(cli)")
def register_commands(commands):
"""register existing commands"""
for command_klass in commands:
_COMMANDS.register(command_klass)
-@deprecated('use args.pop(0)')
+
+@deprecated("use args.pop(0)")
def main_run(args, doc=None, copyright=None, version=None):
"""command line tool: run command specified by argument list (without the
program name). Raise SystemExit with status 0 if everything went fine.
@@ -321,7 +349,8 @@ def main_run(args, doc=None, copyright=None, version=None):
_COMMANDS.version = version
_COMMANDS.run(args)
-@deprecated('use args.pop(0)')
+
+@deprecated("use args.pop(0)")
def pop_arg(args_list, expected_size_after=None, msg="Missing argument"):
"""helper function to get and check command line arguments"""
try:
@@ -329,6 +358,5 @@ def pop_arg(args_list, expected_size_after=None, msg="Missing argument"):
except IndexError:
raise BadCommandUsage(msg)
if expected_size_after is not None and len(args_list) > expected_size_after:
- raise BadCommandUsage('too many arguments')
+ raise BadCommandUsage("too many arguments")
return value
-
diff --git a/logilab/common/compat.py b/logilab/common/compat.py
index 4ca540b..e601b26 100644
--- a/logilab/common/compat.py
+++ b/logilab/common/compat.py
@@ -38,18 +38,25 @@ from typing import Union
# not used here, but imported to preserve API
import builtins
+
def str_to_bytes(string):
return str.encode(string)
+
+
# we have to ignore the encoding in py3k to be able to write a string into a
# TextIOWrapper or like object (which expect an unicode string)
def str_encode(string: Union[int, str], encoding: str) -> str:
return str(string)
+
# See also http://bugs.python.org/issue11776
if sys.version_info[0] == 3:
+
def method_type(callable, instance, klass):
# api change. klass is no more considered
return types.MethodType(callable, instance)
+
+
else:
# alias types otherwise
method_type = types.MethodType
@@ -57,6 +64,7 @@ else:
# Pythons 2 and 3 differ on where to get StringIO
if sys.version_info < (3, 0):
from cStringIO import StringIO
+
FileIO = file
BytesIO = StringIO
reload = reload
diff --git a/logilab/common/configuration.py b/logilab/common/configuration.py
index 61c2e97..4c83030 100644
--- a/logilab/common/configuration.py
+++ b/logilab/common/configuration.py
@@ -111,9 +111,13 @@ from __future__ import print_function
__docformat__ = "restructuredtext en"
-__all__ = ('OptionsManagerMixIn', 'OptionsProviderMixIn',
- 'ConfigurationMixIn', 'Configuration',
- 'OptionsManager2ConfigurationAdapter')
+__all__ = (
+ "OptionsManagerMixIn",
+ "OptionsProviderMixIn",
+ "ConfigurationMixIn",
+ "Configuration",
+ "OptionsManager2ConfigurationAdapter",
+)
import os
import sys
@@ -139,14 +143,16 @@ OptionError = optik_ext.OptionError
REQUIRED: List = []
+
class UnsupportedAction(Exception):
"""raised by set_option when it doesn't know what to do for an action"""
def _get_encoding(encoding: Optional[str], stream: Union[StringIO, TextIOWrapper]) -> str:
- encoding = encoding or getattr(stream, 'encoding', None)
+ encoding = encoding or getattr(stream, "encoding", None)
if not encoding:
import locale
+
encoding = locale.getpreferredencoding()
return encoding
@@ -158,19 +164,20 @@ _ValueType = Union[List[str], Tuple[str, ...], str]
# validators will return the validated value or raise optparse.OptionValueError
# XXX add to documentation
+
def choice_validator(optdict: Dict[str, Any], name: str, value: str) -> str:
"""validate and return a converted value for option of type 'choice'
"""
- if not value in optdict['choices']:
+ if not value in optdict["choices"]:
msg = "option %s: invalid value: %r, should be in %s"
- raise optik_ext.OptionValueError(msg % (name, value, optdict['choices']))
+ raise optik_ext.OptionValueError(msg % (name, value, optdict["choices"]))
return value
def multiple_choice_validator(optdict: Dict[str, Any], name: str, value: _ValueType) -> _ValueType:
"""validate and return a converted value for option of type 'choice'
"""
- choices = optdict['choices']
+ choices = optdict["choices"]
values = optik_ext.check_csv(None, name, value)
for value in values:
if not value in choices:
@@ -178,67 +185,81 @@ def multiple_choice_validator(optdict: Dict[str, Any], name: str, value: _ValueT
raise optik_ext.OptionValueError(msg % (name, value, choices))
return values
+
def csv_validator(optdict: Dict[str, Any], name: str, value: _ValueType) -> _ValueType:
"""validate and return a converted value for option of type 'csv'
"""
return optik_ext.check_csv(None, name, value)
+
def yn_validator(optdict: Dict[str, Any], name: str, value: Union[bool, str]) -> bool:
"""validate and return a converted value for option of type 'yn'
"""
return optik_ext.check_yn(None, name, value)
-def named_validator(optdict: Dict[str, Any], name: str, value: Union[Dict[str, str], str]) -> Dict[str, str]:
+
+def named_validator(
+ optdict: Dict[str, Any], name: str, value: Union[Dict[str, str], str]
+) -> Dict[str, str]:
"""validate and return a converted value for option of type 'named'
"""
return optik_ext.check_named(None, name, value)
+
def file_validator(optdict, name, value):
"""validate and return a filepath for option of type 'file'"""
return optik_ext.check_file(None, name, value)
+
def color_validator(optdict, name, value):
"""validate and return a valid color for option of type 'color'"""
return optik_ext.check_color(None, name, value)
+
def password_validator(optdict, name, value):
"""validate and return a string for option of type 'password'"""
return optik_ext.check_password(None, name, value)
+
def date_validator(optdict, name, value):
"""validate and return a mx DateTime object for option of type 'date'"""
return optik_ext.check_date(None, name, value)
+
def time_validator(optdict, name, value):
"""validate and return a time object for option of type 'time'"""
return optik_ext.check_time(None, name, value)
+
def bytes_validator(optdict: Dict[str, str], name: str, value: Union[int, str]) -> int:
"""validate and return an integer for option of type 'bytes'"""
return optik_ext.check_bytes(None, name, value)
VALIDATORS: Dict[str, Callable] = {
- 'string': unquote,
- 'int': int,
- 'float': float,
- 'file': file_validator,
- 'font': unquote,
- 'color': color_validator,
- 'regexp': re.compile,
- 'csv': csv_validator,
- 'yn': yn_validator,
- 'bool': yn_validator,
- 'named': named_validator,
- 'password': password_validator,
- 'date': date_validator,
- 'time': time_validator,
- 'bytes': bytes_validator,
- 'choice': choice_validator,
- 'multiple_choice': multiple_choice_validator,
+ "string": unquote,
+ "int": int,
+ "float": float,
+ "file": file_validator,
+ "font": unquote,
+ "color": color_validator,
+ "regexp": re.compile,
+ "csv": csv_validator,
+ "yn": yn_validator,
+ "bool": yn_validator,
+ "named": named_validator,
+ "password": password_validator,
+ "date": date_validator,
+ "time": time_validator,
+ "bytes": bytes_validator,
+ "choice": choice_validator,
+ "multiple_choice": multiple_choice_validator,
}
-def _call_validator(opttype: str, optdict: Dict[str, Any], option: str, value: Union[List[str], int, str]) -> Union[List[str], int, str]:
+
+def _call_validator(
+ opttype: str, optdict: Dict[str, Any], option: str, value: Union[List[str], int, str]
+) -> Union[List[str], int, str]:
if opttype not in VALIDATORS:
raise Exception('Unsupported type "%s"' % opttype)
try:
@@ -249,8 +270,10 @@ def _call_validator(opttype: str, optdict: Dict[str, Any], option: str, value: U
except optik_ext.OptionValueError:
raise
except:
- raise optik_ext.OptionValueError('%s value (%r) should be of type %s' %
- (option, value, opttype))
+ raise optik_ext.OptionValueError(
+ "%s value (%r) should be of type %s" % (option, value, opttype)
+ )
+
# user input functions ########################################################
@@ -258,19 +281,23 @@ def _call_validator(opttype: str, optdict: Dict[str, Any], option: str, value: U
# the result and return the validated value or raise optparse.OptionValueError
# XXX add to documentation
-def input_password(optdict, question='password:'):
+
+def input_password(optdict, question="password:"):
from getpass import getpass
+
while True:
value = getpass(question)
- value2 = getpass('confirm: ')
+ value2 = getpass("confirm: ")
if value == value2:
return value
- print('password mismatch, try again')
+ print("password mismatch, try again")
+
def input_string(optdict, question):
value = input(question).strip()
return value or None
+
def _make_input_function(opttype):
def input_validator(optdict, question):
while True:
@@ -280,14 +307,15 @@ def _make_input_function(opttype):
try:
return _call_validator(opttype, optdict, None, value)
except optik_ext.OptionValueError as ex:
- msg = str(ex).split(':', 1)[-1].strip()
- print('bad value: %s' % msg)
+ msg = str(ex).split(":", 1)[-1].strip()
+ print("bad value: %s" % msg)
+
return input_validator
INPUT_FUNCTIONS: Dict[str, Callable] = {
- 'string': input_string,
- 'password': input_password,
+ "string": input_string,
+ "password": input_password,
}
for opttype in VALIDATORS.keys():
@@ -295,6 +323,7 @@ for opttype in VALIDATORS.keys():
# utility functions ############################################################
+
def expand_default(self, option):
"""monkey patch OptionParser.expand_default since we have a particular
way to handle defaults to avoid overriding values in the configuration
@@ -317,125 +346,144 @@ def expand_default(self, option):
return option.help.replace(self.default_tag, str(value))
-def _validate(value: Union[List[str], int, str], optdict: Dict[str, Any], name: str = '') -> Union[List[str], int, str]:
+def _validate(
+ value: Union[List[str], int, str], optdict: Dict[str, Any], name: str = ""
+) -> Union[List[str], int, str]:
"""return a validated value for an option according to its type
optional argument name is only used for error message formatting
"""
try:
- _type = optdict['type']
+ _type = optdict["type"]
except KeyError:
# FIXME
return value
return _call_validator(_type, optdict, name, value)
-convert = deprecated('[0.60] convert() was renamed _validate()')(_validate)
+
+
+convert = deprecated("[0.60] convert() was renamed _validate()")(_validate)
# format and output functions ##################################################
+
def comment(string):
"""return string as a comment"""
lines = [line.strip() for line in string.splitlines()]
- return '# ' + ('%s# ' % os.linesep).join(lines)
+ return "# " + ("%s# " % os.linesep).join(lines)
+
def format_time(value):
if not value:
- return '0'
+ return "0"
if value != int(value):
- return '%.2fs' % value
+ return "%.2fs" % value
value = int(value)
nbmin, nbsec = divmod(value, 60)
if nbsec:
- return '%ss' % value
+ return "%ss" % value
nbhour, nbmin_ = divmod(nbmin, 60)
if nbmin_:
- return '%smin' % nbmin
+ return "%smin" % nbmin
nbday, nbhour_ = divmod(nbhour, 24)
if nbhour_:
- return '%sh' % nbhour
- return '%sd' % nbday
+ return "%sh" % nbhour
+ return "%sd" % nbday
+
def format_bytes(value: int) -> str:
if not value:
- return '0'
+ return "0"
if value != int(value):
- return '%.2fB' % value
+ return "%.2fB" % value
value = int(value)
- prevunit = 'B'
- for unit in ('KB', 'MB', 'GB', 'TB'):
+ prevunit = "B"
+ for unit in ("KB", "MB", "GB", "TB"):
next, remain = divmod(value, 1024)
if remain:
- return '%s%s' % (value, prevunit)
+ return "%s%s" % (value, prevunit)
prevunit = unit
value = next
- return '%s%s' % (value, unit)
+ return "%s%s" % (value, unit)
+
def format_option_value(optdict: Dict[str, Any], value: Any) -> Union[None, int, str]:
"""return the user input's value from a 'compiled' value"""
if isinstance(value, (list, tuple)):
- value = ','.join(value)
+ value = ",".join(value)
elif isinstance(value, dict):
- value = ','.join(['%s:%s' % (k, v) for k, v in value.items()])
- elif hasattr(value, 'match'): # optdict.get('type') == 'regexp'
+ value = ",".join(["%s:%s" % (k, v) for k, v in value.items()])
+ elif hasattr(value, "match"): # optdict.get('type') == 'regexp'
# compiled regexp
value = value.pattern
- elif optdict.get('type') == 'yn':
- value = value and 'yes' or 'no'
+ elif optdict.get("type") == "yn":
+ value = value and "yes" or "no"
elif isinstance(value, str) and value.isspace():
value = "'%s'" % value
- elif optdict.get('type') == 'time' and isinstance(value, (float, int)):
+ elif optdict.get("type") == "time" and isinstance(value, (float, int)):
value = format_time(value)
- elif optdict.get('type') == 'bytes' and hasattr(value, '__int__'):
+ elif optdict.get("type") == "bytes" and hasattr(value, "__int__"):
value = format_bytes(value)
return value
-def ini_format_section(stream: Union[StringIO, TextIOWrapper], section: str, options: Any, encoding: str = None, doc: Optional[Any] = None) -> None:
+
+def ini_format_section(
+ stream: Union[StringIO, TextIOWrapper],
+ section: str,
+ options: Any,
+ encoding: str = None,
+ doc: Optional[Any] = None,
+) -> None:
"""format an options section using the INI format"""
encoding = _get_encoding(encoding, stream)
if doc:
print(_encode(comment(doc), encoding), file=stream)
- print('[%s]' % section, file=stream)
+ print("[%s]" % section, file=stream)
ini_format(stream, options, encoding)
+
def ini_format(stream: Union[StringIO, TextIOWrapper], options: Any, encoding: str) -> None:
"""format options using the INI format"""
for optname, optdict, value in options:
value = format_option_value(optdict, value)
- help = optdict.get('help')
+ help = optdict.get("help")
if help:
- help = normalize_text(help, line_len=79, indent='# ')
+ help = normalize_text(help, line_len=79, indent="# ")
print(file=stream)
print(_encode(help, encoding), file=stream)
else:
print(file=stream)
if value is None:
- print('#%s=' % optname, file=stream)
+ print("#%s=" % optname, file=stream)
else:
value = _encode(value, encoding).strip()
- if optdict.get('type') == 'string' and '\n' in value:
- prefix = '\n '
- value = prefix + prefix.join(value.split('\n'))
- print('%s=%s' % (optname, value), file=stream)
+ if optdict.get("type") == "string" and "\n" in value:
+ prefix = "\n "
+ value = prefix + prefix.join(value.split("\n"))
+ print("%s=%s" % (optname, value), file=stream)
+
format_section = ini_format_section
+
def rest_format_section(stream, section, options, encoding=None, doc=None):
"""format an options section using as ReST formatted output"""
encoding = _get_encoding(encoding, stream)
if section:
- print('%s\n%s' % (section, "'"*len(section)), file=stream)
+ print("%s\n%s" % (section, "'" * len(section)), file=stream)
if doc:
- print(_encode(normalize_text(doc, line_len=79, indent=''), encoding), file=stream)
+ print(_encode(normalize_text(doc, line_len=79, indent=""), encoding), file=stream)
print(file=stream)
for optname, optdict, value in options:
- help = optdict.get('help')
- print(':%s:' % optname, file=stream)
+ help = optdict.get("help")
+ print(":%s:" % optname, file=stream)
if help:
- help = normalize_text(help, line_len=79, indent=' ')
+ help = normalize_text(help, line_len=79, indent=" ")
print(_encode(help, encoding), file=stream)
if value:
value = _encode(format_option_value(optdict, value), encoding)
print(file=stream)
- print(' Default: ``%s``' % value.replace("`` ", "```` ``"), file=stream)
+ print(" Default: ``%s``" % value.replace("`` ", "```` ``"), file=stream)
+
# Options Manager ##############################################################
@@ -445,7 +493,13 @@ class OptionsManagerMixIn(object):
command line options
"""
- def __init__(self, usage: Optional[str], config_file: Optional[Any] = None, version: Optional[Any] = None, quiet: int = 0) -> None:
+ def __init__(
+ self,
+ usage: Optional[str],
+ config_file: Optional[Any] = None,
+ version: Optional[Any] = None,
+ quiet: int = 0,
+ ) -> None:
self.config_file = config_file
self.reset_parsers(usage, version=version)
# list of registered options providers
@@ -459,7 +513,7 @@ class OptionsManagerMixIn(object):
self.quiet = quiet
self._maxlevel = 0
- def reset_parsers(self, usage: Optional[str] = '', version: Optional[Any] = None) -> None:
+ def reset_parsers(self, usage: Optional[str] = "", version: Optional[Any] = None) -> None:
# configuration file parser
self.cfgfile_parser = cp.ConfigParser()
# command line parser
@@ -469,7 +523,9 @@ class OptionsManagerMixIn(object):
self.cmdline_parser.options_manager = self # type: ignore
self._optik_option_attrs = set(self.cmdline_parser.option_class.ATTRS)
- def register_options_provider(self, provider: 'ConfigurationMixIn', own_group: bool = True) -> None:
+ def register_options_provider(
+ self, provider: "ConfigurationMixIn", own_group: bool = True
+ ) -> None:
"""register an options provider"""
assert provider.priority <= 0, "provider's priority can't be >= 0"
for i in range(len(self.options_providers)):
@@ -481,13 +537,17 @@ class OptionsManagerMixIn(object):
# mypy: Need type annotation for 'option'
# you can't type variable of a list comprehension, right?
- non_group_spec_options: List = [option for option in provider.options # type: ignore
- if 'group' not in option[1]] # type: ignore
+ non_group_spec_options: List = [
+ option
+ for option in provider.options # type: ignore
+ if "group" not in option[1]
+ ] # type: ignore
- groups = getattr(provider, 'option_groups', ())
+ groups = getattr(provider, "option_groups", ())
if own_group and non_group_spec_options:
- self.add_option_group(provider.name.upper(), provider.__doc__,
- non_group_spec_options, provider)
+ self.add_option_group(
+ provider.name.upper(), provider.__doc__, non_group_spec_options, provider
+ )
else:
for opt, optdict in non_group_spec_options:
self.add_optik_option(provider, self.cmdline_parser, opt, optdict)
@@ -496,11 +556,20 @@ class OptionsManagerMixIn(object):
# mypy: Need type annotation for 'option'
# you can't type variable of a list comprehension, right?
- goptions: List = [option for option in provider.options # type: ignore
- if option[1].get('group', '').upper() == gname] # type: ignore
+ goptions: List = [
+ option
+ for option in provider.options # type: ignore
+ if option[1].get("group", "").upper() == gname
+ ] # type: ignore
self.add_option_group(gname, gdoc, goptions, provider)
- def add_option_group(self, group_name: str, doc: Optional[str], options: Union[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, Dict[str, str]]]], provider: 'ConfigurationMixIn') -> None:
+ def add_option_group(
+ self,
+ group_name: str,
+ doc: Optional[str],
+ options: Union[List[Tuple[str, Dict[str, Any]]], List[Tuple[str, Dict[str, str]]]],
+ provider: "ConfigurationMixIn",
+ ) -> None:
"""add an option group including the listed options
"""
assert options
@@ -508,8 +577,7 @@ class OptionsManagerMixIn(object):
if group_name in self._mygroups:
group = self._mygroups[group_name]
else:
- group = optik_ext.OptionGroup(self.cmdline_parser,
- title=group_name.capitalize())
+ group = optik_ext.OptionGroup(self.cmdline_parser, title=group_name.capitalize())
self.cmdline_parser.add_option_group(group)
# mypy: "OptionGroup" has no attribute "level"
# dynamic attribute
@@ -522,48 +590,63 @@ class OptionsManagerMixIn(object):
for opt, optdict in options:
self.add_optik_option(provider, group, opt, optdict)
- def add_optik_option(self, provider: 'ConfigurationMixIn', optikcontainer: Union[OptionParser, OptionGroup], opt: str, optdict: Dict[str, Any]) -> None:
- if 'inputlevel' in optdict:
- warn('[0.50] "inputlevel" in option dictionary for %s is deprecated,'
- ' use "level"' % opt, DeprecationWarning)
- optdict['level'] = optdict.pop('inputlevel')
+ def add_optik_option(
+ self,
+ provider: "ConfigurationMixIn",
+ optikcontainer: Union[OptionParser, OptionGroup],
+ opt: str,
+ optdict: Dict[str, Any],
+ ) -> None:
+ if "inputlevel" in optdict:
+ warn(
+ '[0.50] "inputlevel" in option dictionary for %s is deprecated,'
+ ' use "level"' % opt,
+ DeprecationWarning,
+ )
+ optdict["level"] = optdict.pop("inputlevel")
args, optdict = self.optik_option(provider, opt, optdict)
option = optikcontainer.add_option(*args, **optdict)
self._all_options[opt] = provider
self._maxlevel = max(self._maxlevel, option.level or 0)
- def optik_option(self, provider: 'ConfigurationMixIn', opt: str, optdict: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any]]:
+ def optik_option(
+ self, provider: "ConfigurationMixIn", opt: str, optdict: Dict[str, Any]
+ ) -> Tuple[List[str], Dict[str, Any]]:
"""get our personal option definition and return a suitable form for
use with optik/optparse
"""
optdict = copy(optdict)
- if 'action' in optdict:
+ if "action" in optdict:
self._nocallback_options[provider] = opt
else:
- optdict['action'] = 'callback'
- optdict['callback'] = self.cb_set_provider_option
+ optdict["action"] = "callback"
+ optdict["callback"] = self.cb_set_provider_option
# default is handled here and *must not* be given to optik if you
# want the whole machinery to work
- if 'default' in optdict:
- if ('help' in optdict
- and optdict.get('default') is not None
- and not optdict['action'] in ('store_true', 'store_false')):
- optdict['help'] += ' [current: %default]'
- del optdict['default']
- args = ['--' + str(opt)]
- if 'short' in optdict:
- self._short_options[optdict['short']] = opt
- args.append('-' + optdict['short'])
- del optdict['short']
+ if "default" in optdict:
+ if (
+ "help" in optdict
+ and optdict.get("default") is not None
+ and not optdict["action"] in ("store_true", "store_false")
+ ):
+ optdict["help"] += " [current: %default]"
+ del optdict["default"]
+ args = ["--" + str(opt)]
+ if "short" in optdict:
+ self._short_options[optdict["short"]] = opt
+ args.append("-" + optdict["short"])
+ del optdict["short"]
# cleanup option definition dict before giving it to optik
for key in list(optdict.keys()):
if not key in self._optik_option_attrs:
optdict.pop(key)
return args, optdict
- def cb_set_provider_option(self, option: 'Option', opt: str, value: Union[List[str], int, str], parser: 'OptionParser') -> None:
+ def cb_set_provider_option(
+ self, option: "Option", opt: str, value: Union[List[str], int, str], parser: "OptionParser"
+ ) -> None:
"""optik callback for option setting"""
- if opt.startswith('--'):
+ if opt.startswith("--"):
# remove -- on long option
opt = opt[2:]
else:
@@ -578,7 +661,12 @@ class OptionsManagerMixIn(object):
"""set option on the correct option provider"""
self._all_options[opt].set_option(opt, value)
- def generate_config(self, stream: Union[StringIO, TextIOWrapper] = None, skipsections: Tuple[()] = (), encoding: Optional[Any] = None) -> None:
+ def generate_config(
+ self,
+ stream: Union[StringIO, TextIOWrapper] = None,
+ skipsections: Tuple[()] = (),
+ encoding: Optional[Any] = None,
+ ) -> None:
"""write a configuration file according to the current configuration
into the given stream or stdout
"""
@@ -591,8 +679,7 @@ class OptionsManagerMixIn(object):
section = provider.name
if section in skipsections:
continue
- options = [(n, d, v) for (n, d, v) in options
- if d.get('type') is not None]
+ options = [(n, d, v) for (n, d, v) in options if d.get("type") is not None]
if not options:
continue
if not section in sections:
@@ -604,20 +691,25 @@ class OptionsManagerMixIn(object):
printed = False
for section in sections:
if printed:
- print('\n', file=stream)
- format_section(stream, section.upper(), options_by_section[section],
- encoding)
+ print("\n", file=stream)
+ format_section(stream, section.upper(), options_by_section[section], encoding)
printed = True
- def generate_manpage(self, pkginfo: attrdict, section: int = 1, stream: StringIO = None) -> None:
+ def generate_manpage(
+ self, pkginfo: attrdict, section: int = 1, stream: StringIO = None
+ ) -> None:
"""write a man page for the current configuration into the given
stream or stdout
"""
self._monkeypatch_expand_default()
try:
- optik_ext.generate_manpage(self.cmdline_parser, pkginfo,
- section, stream=stream or sys.stdout,
- level=self._maxlevel)
+ optik_ext.generate_manpage(
+ self.cmdline_parser,
+ pkginfo,
+ section,
+ stream=stream or sys.stdout,
+ level=self._maxlevel,
+ )
finally:
self._unmonkeypatch_expand_default()
@@ -639,18 +731,19 @@ class OptionsManagerMixIn(object):
"""
helplevel = 1
while helplevel <= self._maxlevel:
- opt = '-'.join(['long'] * helplevel) + '-help'
+ opt = "-".join(["long"] * helplevel) + "-help"
if opt in self._all_options:
- break # already processed
+ break # already processed
+
def helpfunc(option, opt, val, p, level=helplevel):
print(self.help(level))
sys.exit(0)
- helpmsg = '%s verbose help.' % ' '.join(['more'] * helplevel)
- optdict = {'action' : 'callback', 'callback' : helpfunc,
- 'help' : helpmsg}
+
+ helpmsg = "%s verbose help." % " ".join(["more"] * helplevel)
+ optdict = {"action": "callback", "callback": helpfunc, "help": helpmsg}
provider = self.options_providers[0]
self.add_optik_option(provider, self.cmdline_parser, opt, optdict)
- provider.options += ( (opt, optdict), )
+ provider.options += ((opt, optdict),)
helplevel += 1
if config_file is None:
config_file = self.config_file
@@ -666,7 +759,7 @@ class OptionsManagerMixIn(object):
if not sect.isupper() and values:
parser._sections[sect.upper()] = values # type: ignore
elif not self.quiet:
- msg = 'No config file found, using default configuration'
+ msg = "No config file found, using default configuration"
print(msg, file=sys.stderr)
return
@@ -680,7 +773,7 @@ class OptionsManagerMixIn(object):
for section, option, optdict in provider.all_options():
if onlysection is not None and section != onlysection:
continue
- if not 'type' in optdict:
+ if not "type" in optdict:
# ignore action without type (callback, store_true...)
continue
provider.input_option(option, optdict, inputlevel)
@@ -694,18 +787,18 @@ class OptionsManagerMixIn(object):
"""
parser = self.cfgfile_parser
for section in parser.sections():
- for option, value in parser.items(section):
- try:
- self.global_set_option(option, value)
- except (KeyError, OptionError):
- # TODO handle here undeclared options appearing in the config file
- continue
+ for option, value in parser.items(section):
+ try:
+ self.global_set_option(option, value)
+ except (KeyError, OptionError):
+ # TODO handle here undeclared options appearing in the config file
+ continue
def load_configuration(self, **kwargs: Any) -> None:
"""override configuration according to given parameters
"""
for opt, opt_value in kwargs.items():
- opt = opt.replace('_', '-')
+ opt = opt.replace("_", "-")
provider = self._all_options[opt]
provider.set_option(opt, opt_value)
@@ -733,14 +826,13 @@ class OptionsManagerMixIn(object):
finally:
self._unmonkeypatch_expand_default()
-
# help methods ############################################################
def add_help_section(self, title: str, description: str, level: int = 0) -> None:
"""add a dummy option section for help purpose """
- group = optik_ext.OptionGroup(self.cmdline_parser,
- title=title.capitalize(),
- description=description)
+ group = optik_ext.OptionGroup(
+ self.cmdline_parser, title=title.capitalize(), description=description
+ )
# mypy: "OptionGroup" has no attribute "level"
# it does, it is set in the optik_ext module
group.level = level # type: ignore
@@ -757,9 +849,10 @@ class OptionsManagerMixIn(object):
except AttributeError:
# python < 2.4: nothing to be done
pass
+
def _unmonkeypatch_expand_default(self) -> None:
# remove monkey patch
- if hasattr(optik_ext.HelpFormatter, 'expand_default'):
+ if hasattr(optik_ext.HelpFormatter, "expand_default"):
# mypy: Cannot assign to a method
# it's dirty but you can
@@ -782,27 +875,30 @@ class Method(object):
"""used to ease late binding of default method (so you can define options
on the class using default methods on the configuration instance)
"""
+
def __init__(self, methname):
self.method = methname
self._inst = None
- def bind(self, instance: 'Configuration') -> None:
+ def bind(self, instance: "Configuration") -> None:
"""bind the method to its instance"""
if self._inst is None:
self._inst = instance
def __call__(self, *args: Any, **kwargs: Any) -> Dict[str, str]:
- assert self._inst, 'unbound method'
+ assert self._inst, "unbound method"
return getattr(self._inst, self.method)(*args, **kwargs)
+
# Options Provider #############################################################
+
class OptionsProviderMixIn(object):
"""Mixin to provide options to an OptionsManager"""
# those attributes should be overridden
priority = -1
- name = 'default'
+ name = "default"
options: Tuple = ()
level = 0
@@ -812,18 +908,18 @@ class OptionsProviderMixIn(object):
try:
option, optdict = option_tuple
except ValueError:
- raise Exception('Bad option: %s' % str(option_tuple))
- if isinstance(optdict.get('default'), Method):
- optdict['default'].bind(self)
- elif isinstance(optdict.get('callback'), Method):
- optdict['callback'].bind(self)
+ raise Exception("Bad option: %s" % str(option_tuple))
+ if isinstance(optdict.get("default"), Method):
+ optdict["default"].bind(self)
+ elif isinstance(optdict.get("callback"), Method):
+ optdict["callback"].bind(self)
self.load_defaults()
def load_defaults(self) -> None:
"""initialize the provider using default values"""
for opt, optdict in self.options:
- action = optdict.get('action')
- if action != 'callback':
+ action = optdict.get("action")
+ if action != "callback":
# callback action have no default
default = self.option_default(opt, optdict)
if default is REQUIRED:
@@ -834,7 +930,7 @@ class OptionsProviderMixIn(object):
"""return the default value for an option"""
if optdict is None:
optdict = self.get_option_def(opt)
- default = optdict.get('default')
+ default = optdict.get("default")
if callable(default):
default = default()
return default
@@ -844,8 +940,11 @@ class OptionsProviderMixIn(object):
"""
if optdict is None:
optdict = self.get_option_def(opt)
- return optdict.get('dest', opt.replace('-', '_'))
- option_name = deprecated('[0.60] OptionsProviderMixIn.option_name() was renamed to option_attrname()')(option_attrname)
+ return optdict.get("dest", opt.replace("-", "_"))
+
+ option_name = deprecated(
+ "[0.60] OptionsProviderMixIn.option_name() was renamed to option_attrname()"
+ )(option_attrname)
def option_value(self, opt):
"""get the current value for the given option"""
@@ -859,20 +958,20 @@ class OptionsProviderMixIn(object):
if value is not None:
value = _validate(value, optdict, opt)
if action is None:
- action = optdict.get('action', 'store')
- if optdict.get('type') == 'named': # XXX need specific handling
+ action = optdict.get("action", "store")
+ if optdict.get("type") == "named": # XXX need specific handling
optname = self.option_attrname(opt, optdict)
currentvalue = getattr(self.config, optname, None)
if currentvalue:
currentvalue.update(value)
value = currentvalue
- if action == 'store':
+ if action == "store":
setattr(self.config, self.option_attrname(opt, optdict), value)
- elif action in ('store_true', 'count'):
+ elif action in ("store_true", "count"):
setattr(self.config, self.option_attrname(opt, optdict), 0)
- elif action == 'store_false':
+ elif action == "store_false":
setattr(self.config, self.option_attrname(opt, optdict), 1)
- elif action == 'append':
+ elif action == "append":
opt = self.option_attrname(opt, optdict)
_list = getattr(self.config, opt, None)
if _list is None:
@@ -886,28 +985,28 @@ class OptionsProviderMixIn(object):
setattr(self.config, opt, _list + (value,))
else:
_list.append(value)
- elif action == 'callback':
- optdict['callback'](None, opt, value, None)
+ elif action == "callback":
+ optdict["callback"](None, opt, value, None)
else:
raise UnsupportedAction(action)
def input_option(self, option, optdict, inputlevel=99):
default = self.option_default(option, optdict)
if default is REQUIRED:
- defaultstr = '(required): '
- elif optdict.get('level', 0) > inputlevel:
+ defaultstr = "(required): "
+ elif optdict.get("level", 0) > inputlevel:
return
- elif optdict['type'] == 'password' or default is None:
- defaultstr = ': '
+ elif optdict["type"] == "password" or default is None:
+ defaultstr = ": "
else:
- defaultstr = '(default: %s): ' % format_option_value(optdict, default)
- print(':%s:' % option)
- print(optdict.get('help') or option)
- inputfunc = INPUT_FUNCTIONS[optdict['type']]
+ defaultstr = "(default: %s): " % format_option_value(optdict, default)
+ print(":%s:" % option)
+ print(optdict.get("help") or option)
+ inputfunc = INPUT_FUNCTIONS[optdict["type"]]
value = inputfunc(optdict, defaultstr)
while default is REQUIRED and not value:
- print('please specify a value')
- value = inputfunc(optdict, '%s: ' % option)
+ print("please specify a value")
+ value = inputfunc(optdict, "%s: " % option)
if value is None and default is not None:
value = default
self.set_option(option, value, optdict=optdict)
@@ -920,9 +1019,7 @@ class OptionsProviderMixIn(object):
return option[1]
# mypy: Argument 2 to "OptionError" has incompatible type "str"; expected "Option"
# seems to be working?
- raise OptionError('no such option %s in section %r'
- % (opt, self.name), opt) # type: ignore
-
+ raise OptionError("no such option %s in section %r" % (opt, self.name), opt) # type: ignore
def all_options(self):
"""return an iterator on available options for this provider
@@ -944,8 +1041,9 @@ class OptionsProviderMixIn(object):
"""
sections: Dict[str, List[Tuple[str, Dict[str, Any], Any]]] = {}
for optname, optdict in self.options:
- sections.setdefault(optdict.get('group'), []).append(
- (optname, optdict, self.option_value(optname)))
+ sections.setdefault(optdict.get("group"), []).append(
+ (optname, optdict, self.option_value(optname))
+ )
if None in sections:
# mypy: No overload variant of "pop" of "MutableMapping" matches argument type "None"
# it actually works
@@ -959,23 +1057,26 @@ class OptionsProviderMixIn(object):
for optname, optdict in options:
yield (optname, optdict, self.option_value(optname))
+
# configuration ################################################################
+
class ConfigurationMixIn(OptionsManagerMixIn, OptionsProviderMixIn):
"""basic mixin for simple configurations which don't need the
manager / providers model
"""
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
if not args:
- kwargs.setdefault('usage', '')
- kwargs.setdefault('quiet', 1)
+ kwargs.setdefault("usage", "")
+ kwargs.setdefault("quiet", 1)
OptionsManagerMixIn.__init__(self, *args, **kwargs)
OptionsProviderMixIn.__init__(self)
- if not getattr(self, 'option_groups', None):
+ if not getattr(self, "option_groups", None):
self.option_groups: List[Tuple[Any, str]] = []
for option, optdict in self.options:
try:
- gdef = (optdict['group'].upper(), '')
+ gdef = (optdict["group"].upper(), "")
except KeyError:
continue
if not gdef in self.option_groups:
@@ -986,7 +1087,9 @@ class ConfigurationMixIn(OptionsManagerMixIn, OptionsProviderMixIn):
"""add some options to the configuration"""
options_by_group = {}
for optname, optdict in options:
- options_by_group.setdefault(optdict.get('group', self.name.upper()), []).append((optname, optdict))
+ options_by_group.setdefault(optdict.get("group", self.name.upper()), []).append(
+ (optname, optdict)
+ )
for group, group_options in options_by_group.items():
self.add_option_group(group, None, group_options, self)
self.options += tuple(options)
@@ -1020,8 +1123,9 @@ class Configuration(ConfigurationMixIn):
configuration values are accessible through a dict like interface
"""
- def __init__(self, config_file=None, options=None, name=None,
- usage=None, doc=None, version=None):
+ def __init__(
+ self, config_file=None, options=None, name=None, usage=None, doc=None, version=None
+ ):
if options is not None:
self.options = options
if name is not None:
@@ -1035,6 +1139,7 @@ class OptionsManager2ConfigurationAdapter(object):
"""Adapt an option manager to behave like a
`logilab.common.configuration.Configuration` instance
"""
+
def __init__(self, provider):
self.config = provider
@@ -1057,8 +1162,10 @@ class OptionsManager2ConfigurationAdapter(object):
except KeyError:
return default
+
# other functions ##############################################################
+
def read_old_config(newconfig, changes, configfile):
"""initialize newconfig from a deprecated configuration file
@@ -1070,38 +1177,38 @@ def read_old_config(newconfig, changes, configfile):
# build an index of changes
changesindex = {}
for action in changes:
- if action[0] == 'moved':
+ if action[0] == "moved":
option, oldgroup, newgroup = action[1:]
changesindex.setdefault(option, []).append((action[0], oldgroup, newgroup))
continue
- if action[0] == 'renamed':
+ if action[0] == "renamed":
oldname, newname = action[1:]
changesindex.setdefault(newname, []).append((action[0], oldname))
continue
- if action[0] == 'typechanged':
+ if action[0] == "typechanged":
option, oldtype, newvalue = action[1:]
changesindex.setdefault(option, []).append((action[0], oldtype, newvalue))
continue
- if action[0] in ('added', 'removed'):
- continue # nothing to do here
- raise Exception('unknown change %s' % action[0])
+ if action[0] in ("added", "removed"):
+ continue # nothing to do here
+ raise Exception("unknown change %s" % action[0])
# build a config object able to read the old config
options = []
for optname, optdef in newconfig.options:
for action in changesindex.pop(optname, ()):
- if action[0] == 'moved':
+ if action[0] == "moved":
oldgroup, newgroup = action[1:]
optdef = optdef.copy()
- optdef['group'] = oldgroup
- elif action[0] == 'renamed':
+ optdef["group"] = oldgroup
+ elif action[0] == "renamed":
optname = action[1]
- elif action[0] == 'typechanged':
+ elif action[0] == "typechanged":
oldtype = action[1]
optdef = optdef.copy()
- optdef['type'] = oldtype
+ optdef["type"] = oldtype
options.append((optname, optdef))
if changesindex:
- raise Exception('unapplied changes: %s' % changesindex)
+ raise Exception("unapplied changes: %s" % changesindex)
oldconfig = Configuration(options=options, name=newconfig.name)
# read the old config
oldconfig.load_file_configuration(configfile)
@@ -1109,16 +1216,16 @@ def read_old_config(newconfig, changes, configfile):
changes.reverse()
done = set()
for action in changes:
- if action[0] == 'renamed':
+ if action[0] == "renamed":
oldname, newname = action[1:]
newconfig[newname] = oldconfig[oldname]
done.add(newname)
- elif action[0] == 'typechanged':
+ elif action[0] == "typechanged":
optname, oldtype, newvalue = action[1:]
newconfig[optname] = newvalue
done.add(optname)
for optname, optdef in newconfig.options:
- if optdef.get('type') and not optname in done:
+ if optdef.get("type") and not optname in done:
newconfig.set_option(optname, oldconfig[optname], optdict=optdef)
@@ -1131,7 +1238,7 @@ def merge_options(options, optgroup=None):
"""
alloptions = {}
options = list(options)
- for i in range(len(options)-1, -1, -1):
+ for i in range(len(options) - 1, -1, -1):
optname, optdict = options[i]
if optname in alloptions:
options.pop(i)
@@ -1141,5 +1248,5 @@ def merge_options(options, optgroup=None):
options[i] = (optname, optdict)
alloptions[optname] = optdict
if optgroup is not None:
- alloptions[optname]['group'] = optgroup
+ alloptions[optname]["group"] = optgroup
return tuple(options)
diff --git a/logilab/common/daemon.py b/logilab/common/daemon.py
index 78e4743..c4c8d93 100644
--- a/logilab/common/daemon.py
+++ b/logilab/common/daemon.py
@@ -33,21 +33,24 @@ def setugid(user):
Argument is a numeric user id or a user name"""
try:
from pwd import getpwuid
+
passwd = getpwuid(int(user))
except ValueError:
from pwd import getpwnam
+
passwd = getpwnam(user)
- if hasattr(os, 'initgroups'): # python >= 2.7
+ if hasattr(os, "initgroups"): # python >= 2.7
os.initgroups(passwd.pw_name, passwd.pw_gid)
else:
import ctypes
+
if ctypes.CDLL(None).initgroups(passwd.pw_name, passwd.pw_gid) < 0:
- err = ctypes.c_int.in_dll(ctypes.pythonapi,"errno").value
- raise OSError(err, os.strerror(err), 'initgroups')
+ err = ctypes.c_int.in_dll(ctypes.pythonapi, "errno").value
+ raise OSError(err, os.strerror(err), "initgroups")
os.setgid(passwd.pw_gid)
os.setuid(passwd.pw_uid)
- os.environ['HOME'] = passwd.pw_dir
+ os.environ["HOME"] = passwd.pw_dir
def daemonize(pidfile=None, uid=None, umask=0o77):
@@ -59,19 +62,19 @@ def daemonize(pidfile=None, uid=None, umask=0o77):
# http://www.faqs.org/faqs/unix-faq/programmer/faq/
#
# fork so the parent can exit
- if os.fork(): # launch child and...
+ if os.fork(): # launch child and...
return 1
# disconnect from tty and create a new session
os.setsid()
# fork again so the parent, (the session group leader), can exit.
# as a non-session group leader, we can never regain a controlling
# terminal.
- if os.fork(): # launch child again.
+ if os.fork(): # launch child again.
return 2
# move to the root to avoit mount pb
- os.chdir('/')
+ os.chdir("/")
# redirect standard descriptors
- null = os.open('/dev/null', os.O_RDWR)
+ null = os.open("/dev/null", os.O_RDWR)
for i in range(3):
try:
os.dup2(null, i)
@@ -80,7 +83,7 @@ def daemonize(pidfile=None, uid=None, umask=0o77):
raise
os.close(null)
# filter warnings
- warnings.filterwarnings('ignore')
+ warnings.filterwarnings("ignore")
# write pid in a file
if pidfile:
# ensure the directory where the pid-file should be set exists (for
@@ -88,7 +91,7 @@ def daemonize(pidfile=None, uid=None, umask=0o77):
piddir = os.path.dirname(pidfile)
if not os.path.exists(piddir):
os.makedirs(piddir)
- f = file(pidfile, 'w')
+ f = file(pidfile, "w")
f.write(str(os.getpid()))
f.close()
# set umask if specified
diff --git a/logilab/common/date.py b/logilab/common/date.py
index 2d2ed22..5f43d3e 100644
--- a/logilab/common/date.py
+++ b/logilab/common/date.py
@@ -42,63 +42,59 @@ else:
# as we have in lgc.db ?
FRENCH_FIXED_HOLIDAYS = {
- 'jour_an': '%s-01-01',
- 'fete_travail': '%s-05-01',
- 'armistice1945': '%s-05-08',
- 'fete_nat': '%s-07-14',
- 'assomption': '%s-08-15',
- 'toussaint': '%s-11-01',
- 'armistice1918': '%s-11-11',
- 'noel': '%s-12-25',
- }
+ "jour_an": "%s-01-01",
+ "fete_travail": "%s-05-01",
+ "armistice1945": "%s-05-08",
+ "fete_nat": "%s-07-14",
+ "assomption": "%s-08-15",
+ "toussaint": "%s-11-01",
+ "armistice1918": "%s-11-11",
+ "noel": "%s-12-25",
+}
FRENCH_MOBILE_HOLIDAYS = {
- 'paques2004': '2004-04-12',
- 'ascension2004': '2004-05-20',
- 'pentecote2004': '2004-05-31',
-
- 'paques2005': '2005-03-28',
- 'ascension2005': '2005-05-05',
- 'pentecote2005': '2005-05-16',
-
- 'paques2006': '2006-04-17',
- 'ascension2006': '2006-05-25',
- 'pentecote2006': '2006-06-05',
-
- 'paques2007': '2007-04-09',
- 'ascension2007': '2007-05-17',
- 'pentecote2007': '2007-05-28',
-
- 'paques2008': '2008-03-24',
- 'ascension2008': '2008-05-01',
- 'pentecote2008': '2008-05-12',
-
- 'paques2009': '2009-04-13',
- 'ascension2009': '2009-05-21',
- 'pentecote2009': '2009-06-01',
-
- 'paques2010': '2010-04-05',
- 'ascension2010': '2010-05-13',
- 'pentecote2010': '2010-05-24',
-
- 'paques2011': '2011-04-25',
- 'ascension2011': '2011-06-02',
- 'pentecote2011': '2011-06-13',
-
- 'paques2012': '2012-04-09',
- 'ascension2012': '2012-05-17',
- 'pentecote2012': '2012-05-28',
- }
+ "paques2004": "2004-04-12",
+ "ascension2004": "2004-05-20",
+ "pentecote2004": "2004-05-31",
+ "paques2005": "2005-03-28",
+ "ascension2005": "2005-05-05",
+ "pentecote2005": "2005-05-16",
+ "paques2006": "2006-04-17",
+ "ascension2006": "2006-05-25",
+ "pentecote2006": "2006-06-05",
+ "paques2007": "2007-04-09",
+ "ascension2007": "2007-05-17",
+ "pentecote2007": "2007-05-28",
+ "paques2008": "2008-03-24",
+ "ascension2008": "2008-05-01",
+ "pentecote2008": "2008-05-12",
+ "paques2009": "2009-04-13",
+ "ascension2009": "2009-05-21",
+ "pentecote2009": "2009-06-01",
+ "paques2010": "2010-04-05",
+ "ascension2010": "2010-05-13",
+ "pentecote2010": "2010-05-24",
+ "paques2011": "2011-04-25",
+ "ascension2011": "2011-06-02",
+ "pentecote2011": "2011-06-13",
+ "paques2012": "2012-04-09",
+ "ascension2012": "2012-05-17",
+ "pentecote2012": "2012-05-28",
+}
# XXX this implementation cries for multimethod dispatching
+
def get_step(dateobj: Union[date, datetime], nbdays: int = 1) -> timedelta:
# assume date is either a python datetime or a mx.DateTime object
if isinstance(dateobj, date):
return ONEDAY * nbdays
- return nbdays # mx.DateTime is ok with integers
+ return nbdays # mx.DateTime is ok with integers
-def datefactory(year: int, month: int, day: int, sampledate: Union[date, datetime]) -> Union[date, datetime]:
+
+def datefactory(
+ year: int, month: int, day: int, sampledate: Union[date, datetime]
+) -> Union[date, datetime]:
# assume date is either a python datetime or a mx.DateTime object
if isinstance(sampledate, datetime):
return datetime(year, month, day)
@@ -106,17 +102,20 @@ def datefactory(year: int, month: int, day: int, sampledate: Union[date, datetim
return date(year, month, day)
return Date(year, month, day)
+
def weekday(dateobj: Union[date, datetime]) -> int:
# assume date is either a python datetime or a mx.DateTime object
if isinstance(dateobj, date):
return dateobj.weekday()
return dateobj.day_of_week
+
def str2date(datestr: str, sampledate: Union[date, datetime]) -> Union[date, datetime]:
# NOTE: datetime.strptime is not an option until we drop py2.4 compat
- year, month, day = [int(chunk) for chunk in datestr.split('-')]
+ year, month, day = [int(chunk) for chunk in datestr.split("-")]
return datefactory(year, month, day, sampledate)
+
def days_between(start: Union[date, datetime], end: Union[date, datetime]) -> int:
if isinstance(start, date):
# mypy: No overload variant of "__sub__" of "datetime" matches argument type "date"
@@ -130,32 +129,35 @@ def days_between(start: Union[date, datetime], end: Union[date, datetime]) -> in
else:
return int(math.ceil((end - start).days))
-def get_national_holidays(begin: Union[date, datetime], end: Union[date, datetime]) -> Union[List[date], List[datetime]]:
+
+def get_national_holidays(
+ begin: Union[date, datetime], end: Union[date, datetime]
+) -> Union[List[date], List[datetime]]:
"""return french national days off between begin and end"""
begin = datefactory(begin.year, begin.month, begin.day, begin)
end = datefactory(end.year, end.month, end.day, end)
- holidays = [str2date(datestr, begin)
- for datestr in FRENCH_MOBILE_HOLIDAYS.values()]
- for year in range(begin.year, end.year+1):
+ holidays = [str2date(datestr, begin) for datestr in FRENCH_MOBILE_HOLIDAYS.values()]
+ for year in range(begin.year, end.year + 1):
for datestr in FRENCH_FIXED_HOLIDAYS.values():
date = str2date(datestr % year, begin)
if date not in holidays:
holidays.append(date)
return [day for day in holidays if begin <= day < end]
+
def add_days_worked(start: date, days: int) -> date:
"""adds date but try to only take days worked into account"""
step = get_step(start)
weeks, plus = divmod(days, 5)
end = start + ((weeks * 7) + plus) * step
- if weekday(end) >= 5: # saturday or sunday
- end += (2 * step)
- end += len([x for x in get_national_holidays(start, end + step)
- if weekday(x) < 5]) * step
- if weekday(end) >= 5: # saturday or sunday
- end += (2 * step)
+ if weekday(end) >= 5: # saturday or sunday
+ end += 2 * step
+ end += len([x for x in get_national_holidays(start, end + step) if weekday(x) < 5]) * step
+ if weekday(end) >= 5: # saturday or sunday
+ end += 2 * step
return end
+
def nb_open_days(start: Union[date, datetime], end: Union[date, datetime]) -> int:
assert start <= end
step = get_step(start)
@@ -166,15 +168,18 @@ def nb_open_days(start: Union[date, datetime], end: Union[date, datetime]) -> in
elif weekday(end) == 6:
plus -= 1
open_days = weeks * 5 + plus
- nb_week_holidays = len([x for x in get_national_holidays(start, end+step)
- if weekday(x) < 5 and x < end])
+ nb_week_holidays = len(
+ [x for x in get_national_holidays(start, end + step) if weekday(x) < 5 and x < end]
+ )
open_days -= nb_week_holidays
if open_days < 0:
return 0
return open_days
-def date_range(begin: date, end: date, incday: Optional[Any] = None, incmonth: Optional[bool] = None) -> Generator[date, Any, None]:
+def date_range(
+ begin: date, end: date, incday: Optional[Any] = None, incmonth: Optional[bool] = None
+) -> Generator[date, Any, None]:
"""yields each date between begin and end
:param begin: the start date
@@ -202,6 +207,7 @@ def date_range(begin: date, end: date, incday: Optional[Any] = None, incmonth: O
yield begin
begin += incr
+
# makes py datetime usable #####################################################
ONEDAY: timedelta = timedelta(days=1)
@@ -209,14 +215,17 @@ ONEWEEK: timedelta = timedelta(days=7)
try:
strptime = datetime.strptime
-except AttributeError: # py < 2.5
+except AttributeError: # py < 2.5
from time import strptime as time_strptime
+
def strptime(value, format):
return datetime(*time_strptime(value, format)[:6])
-def strptime_time(value, format='%H:%M'):
+
+def strptime_time(value, format="%H:%M"):
return time(*time_strptime(value, format)[3:6])
+
def todate(somedate: date) -> date:
"""return a date from a date (leaving unchanged) or a datetime"""
if isinstance(somedate, datetime):
@@ -224,6 +233,7 @@ def todate(somedate: date) -> date:
assert isinstance(somedate, (date, DateTimeType)), repr(somedate)
return somedate
+
def totime(somedate):
"""return a time from a time (leaving unchanged), date or datetime"""
# XXX mx compat
@@ -232,6 +242,7 @@ def totime(somedate):
assert isinstance(somedate, (time)), repr(somedate)
return somedate
+
def todatetime(somedate):
"""return a date from a date (leaving unchanged) or a datetime"""
# take care, datetime is a subclass of date
@@ -240,8 +251,10 @@ def todatetime(somedate):
assert isinstance(somedate, (date, DateTimeType)), repr(somedate)
return datetime(somedate.year, somedate.month, somedate.day)
+
def datetime2ticks(somedate: Union[date, datetime]) -> int:
- return timegm(somedate.timetuple()) * 1000 + int(getattr(somedate, 'microsecond', 0) / 1000)
+ return timegm(somedate.timetuple()) * 1000 + int(getattr(somedate, "microsecond", 0) / 1000)
+
def ticks2datetime(ticks: int) -> datetime:
miliseconds, microseconds = divmod(ticks, 1000)
@@ -256,9 +269,11 @@ def ticks2datetime(ticks: int) -> datetime:
except (ValueError, OverflowError):
raise
+
def days_in_month(somedate: date) -> int:
return monthrange(somedate.year, somedate.month)[1]
+
def days_in_year(somedate):
feb = date(somedate.year, 2, 1)
if days_in_month(feb) == 29:
@@ -266,25 +281,30 @@ def days_in_year(somedate):
else:
return 365
+
def previous_month(somedate, nbmonth=1):
while nbmonth:
somedate = first_day(somedate) - ONEDAY
nbmonth -= 1
return somedate
+
def next_month(somedate: date, nbmonth: int = 1) -> date:
while nbmonth:
somedate = last_day(somedate) + ONEDAY
nbmonth -= 1
return somedate
+
def first_day(somedate):
return date(somedate.year, somedate.month, 1)
+
def last_day(somedate: date) -> date:
return date(somedate.year, somedate.month, days_in_month(somedate))
-def ustrftime(somedate: datetime, fmt: str = '%Y-%m-%d') -> str:
+
+def ustrftime(somedate: datetime, fmt: str = "%Y-%m-%d") -> str:
"""like strftime, but returns a unicode string instead of an encoded
string which may be problematic with localized date.
"""
@@ -294,7 +314,7 @@ def ustrftime(somedate: datetime, fmt: str = '%Y-%m-%d') -> str:
else:
try:
if sys.version_info < (3, 0):
- encoding = getlocale(LC_TIME)[1] or 'ascii'
+ encoding = getlocale(LC_TIME)[1] or "ascii"
return unicode(somedate.strftime(str(fmt)), encoding)
else:
return somedate.strftime(fmt)
@@ -304,37 +324,41 @@ def ustrftime(somedate: datetime, fmt: str = '%Y-%m-%d') -> str:
# datetime is not happy with dates before 1900
# we try to work around this, assuming a simple
# format string
- fields = {'Y': somedate.year,
- 'm': somedate.month,
- 'd': somedate.day,
- }
+ fields = {
+ "Y": somedate.year,
+ "m": somedate.month,
+ "d": somedate.day,
+ }
if isinstance(somedate, datetime):
- fields.update({'H': somedate.hour,
- 'M': somedate.minute,
- 'S': somedate.second})
- fmt = re.sub('%([YmdHMS])', r'%(\1)02d', fmt)
+ fields.update({"H": somedate.hour, "M": somedate.minute, "S": somedate.second})
+ fmt = re.sub("%([YmdHMS])", r"%(\1)02d", fmt)
return unicode(fmt) % fields
+
def utcdatetime(dt: datetime) -> datetime:
if dt.tzinfo is None:
return dt
# mypy: No overload variant of "__sub__" of "datetime" matches argument type "None"
- return (dt.replace(tzinfo=None) - dt.utcoffset()) # type: ignore
+ return dt.replace(tzinfo=None) - dt.utcoffset() # type: ignore
+
def utctime(dt):
if dt.tzinfo is None:
return dt
return (dt + dt.utcoffset() + dt.dst()).replace(tzinfo=None)
+
def datetime_to_seconds(date):
"""return the number of seconds since the begining of the day for that date
"""
- return date.second+60*date.minute + 3600*date.hour
+ return date.second + 60 * date.minute + 3600 * date.hour
+
def timedelta_to_days(delta):
"""return the time delta as a number of seconds"""
- return delta.days + delta.seconds / (3600*24)
+ return delta.days + delta.seconds / (3600 * 24)
+
def timedelta_to_seconds(delta):
"""return the time delta as a fraction of days"""
- return delta.days*(3600*24) + delta.seconds
+ return delta.days * (3600 * 24) + delta.seconds
diff --git a/logilab/common/debugger.py b/logilab/common/debugger.py
index 2df84ad..6553557 100644
--- a/logilab/common/debugger.py
+++ b/logilab/common/debugger.py
@@ -49,12 +49,17 @@ from logilab.common.compat import StringIO
try:
from IPython import PyColorize
except ImportError:
+
def colorize(source, start_lineno, curlineno):
"""fallback colorize function"""
return source
+
def colorize_source(source):
return source
+
+
else:
+
def colorize(source, start_lineno, curlineno):
"""colorize and annotate source with linenos
(as in pdb's list command)
@@ -66,10 +71,10 @@ else:
for index, line in enumerate(output.getvalue().splitlines()):
lineno = index + start_lineno
if lineno == curlineno:
- annotated.append('%4s\t->\t%s' % (lineno, line))
+ annotated.append("%4s\t->\t%s" % (lineno, line))
else:
- annotated.append('%4s\t\t%s' % (lineno, line))
- return '\n'.join(annotated)
+ annotated.append("%4s\t\t%s" % (lineno, line))
+ return "\n".join(annotated)
def colorize_source(source):
"""colorize given source"""
@@ -86,7 +91,7 @@ def getsource(obj):
or code object. The source code is returned as a single string. An
IOError is raised if the source code cannot be retrieved."""
lines, lnum = inspect.getsourcelines(obj)
- return ''.join(lines), lnum
+ return "".join(lines), lnum
################################################################
@@ -98,6 +103,7 @@ class Debugger(Pdb):
- overrides list command to search for current block instead
of using 5 lines of context
"""
+
def __init__(self, tcbk=None):
Pdb.__init__(self)
self.reset()
@@ -137,11 +143,10 @@ class Debugger(Pdb):
"""provide variable names completion for the ``p`` command"""
namespace = dict(self.curframe.f_globals)
namespace.update(self.curframe.f_locals)
- if '.' in text:
+ if "." in text:
return self.attr_matches(text, namespace)
return [varname for varname in namespace if varname.startswith(text)]
-
def attr_matches(self, text, namespace):
"""implementation coming from rlcompleter.Completer.attr_matches
Compute matches when text contains a dot.
@@ -156,14 +161,15 @@ class Debugger(Pdb):
"""
import re
+
m = re.match(r"(\w+(\.\w+)*)\.(\w*)", text)
if not m:
return
expr, attr = m.group(1, 3)
object = eval(expr, namespace)
words = dir(object)
- if hasattr(object, '__class__'):
- words.append('__class__')
+ if hasattr(object, "__class__"):
+ words.append("__class__")
words = words + self.get_class_members(object.__class__)
matches = []
n = len(attr)
@@ -175,7 +181,7 @@ class Debugger(Pdb):
def get_class_members(self, klass):
"""implementation coming from rlcompleter.get_class_members"""
ret = dir(klass)
- if hasattr(klass, '__bases__'):
+ if hasattr(klass, "__bases__"):
for base in klass.__bases__:
ret = ret + self.get_class_members(base)
return ret
@@ -185,33 +191,35 @@ class Debugger(Pdb):
"""overrides default list command to display the surrounding block
instead of 5 lines of context
"""
- self.lastcmd = 'list'
+ self.lastcmd = "list"
if not arg:
try:
source, start_lineno = getsource(self.curframe)
- print(colorize(''.join(source), start_lineno,
- self.curframe.f_lineno))
+ print(colorize("".join(source), start_lineno, self.curframe.f_lineno))
except KeyboardInterrupt:
pass
except IOError:
Pdb.do_list(self, arg)
else:
Pdb.do_list(self, arg)
+
do_l = do_list
def do_open(self, arg):
"""opens source file corresponding to the current stack level"""
filename = self.curframe.f_code.co_filename
lineno = self.curframe.f_lineno
- cmd = 'emacsclient --no-wait +%s %s' % (lineno, filename)
+ cmd = "emacsclient --no-wait +%s %s" % (lineno, filename)
os.system(cmd)
do_o = do_open
+
def pm():
"""use our custom debugger"""
dbg = Debugger(sys.last_traceback)
dbg.start()
+
def set_trace():
Debugger().set_trace(sys._getframe().f_back)
diff --git a/logilab/common/decorators.py b/logilab/common/decorators.py
index 27ed7ee..a471353 100644
--- a/logilab/common/decorators.py
+++ b/logilab/common/decorators.py
@@ -34,13 +34,16 @@ from logilab.common.compat import method_type
# XXX rewrite so we can use the decorator syntax when keyarg has to be specified
+
class cached_decorator(object):
def __init__(self, cacheattr: Optional[str] = None, keyarg: Optional[int] = None) -> None:
self.cacheattr = cacheattr
self.keyarg = keyarg
+
def __call__(self, callableobj: Optional[Callable] = None) -> Callable:
- assert not isgeneratorfunction(callableobj), \
- 'cannot cache generator function: %s' % callableobj
+ assert not isgeneratorfunction(callableobj), (
+ "cannot cache generator function: %s" % callableobj
+ )
assert callableobj is not None
if len(getfullargspec(callableobj).args) == 1 or self.keyarg == 0:
cache = _SingleValueCache(callableobj, self.cacheattr)
@@ -50,11 +53,12 @@ class cached_decorator(object):
cache = _MultiValuesCache(callableobj, self.cacheattr)
return cache.closure()
+
class _SingleValueCache(object):
def __init__(self, callableobj: Callable, cacheattr: Optional[str] = None) -> None:
self.callable = callableobj
if cacheattr is None:
- self.cacheattr = '_%s_cache_' % callableobj.__name__
+ self.cacheattr = "_%s_cache_" % callableobj.__name__
else:
assert cacheattr != callableobj.__name__
self.cacheattr = cacheattr
@@ -70,6 +74,7 @@ class _SingleValueCache(object):
def closure(self) -> Callable:
def wrapped(*args, **kwargs):
return self.__call__(*args, **kwargs)
+
# mypy: "Callable[[VarArg(Any), KwArg(Any)], Any]" has no attribute "cache_obj"
# dynamic attribute for magic
wrapped.cache_obj = self # type: ignore
@@ -101,6 +106,7 @@ class _MultiValuesCache(_SingleValueCache):
_cache[args] = __me.callable(self, *args)
return _cache[args]
+
class _MultiValuesKeyArgCache(_MultiValuesCache):
def __init__(self, callableobj: Callable, keyarg: int, cacheattr: Optional[str] = None) -> None:
super(_MultiValuesKeyArgCache, self).__init__(callableobj, cacheattr)
@@ -108,7 +114,7 @@ class _MultiValuesKeyArgCache(_MultiValuesCache):
def __call__(__me, self, *args, **kwargs):
_cache = __me._get_cache(self)
- key = args[__me.keyarg-1]
+ key = args[__me.keyarg - 1]
try:
return _cache[key]
except KeyError:
@@ -116,9 +122,11 @@ class _MultiValuesKeyArgCache(_MultiValuesCache):
return _cache[key]
-def cached(callableobj: Optional[Callable] = None, keyarg: Optional[int] = None, **kwargs: Any) -> Union[Callable, cached_decorator]:
+def cached(
+ callableobj: Optional[Callable] = None, keyarg: Optional[int] = None, **kwargs: Any
+) -> Union[Callable, cached_decorator]:
"""Simple decorator to cache result of method call."""
- kwargs['keyarg'] = keyarg
+ kwargs["keyarg"] = keyarg
decorator = cached_decorator(**kwargs)
if callableobj is None:
return decorator
@@ -140,23 +148,22 @@ class cachedproperty(object):
.. _pyramid: http://pypi.python.org/pypi/pyramid
.. _mercurial: http://pypi.python.org/pypi/Mercurial
"""
- __slots__ = ('wrapped',)
+
+ __slots__ = ("wrapped",)
def __init__(self, wrapped):
try:
wrapped.__name__
except AttributeError:
- raise TypeError('%s must have a __name__ attribute' %
- wrapped)
+ raise TypeError("%s must have a __name__ attribute" % wrapped)
self.wrapped = wrapped
# mypy: Signature of "__doc__" incompatible with supertype "object"
# but this works?
@property
def __doc__(self) -> str: # type: ignore
- doc = getattr(self.wrapped, '__doc__', None)
- return ('<wrapped by the cachedproperty decorator>%s'
- % ('\n%s' % doc if doc else ''))
+ doc = getattr(self.wrapped, "__doc__", None)
+ return "<wrapped by the cachedproperty decorator>%s" % ("\n%s" % doc if doc else "")
def __get__(self, inst, objtype=None):
if inst is None:
@@ -173,6 +180,7 @@ def get_cache_impl(obj, funcname):
member = member.fget
return member.cache_obj
+
def clear_cache(obj, funcname):
"""Clear a cache handled by the :func:`cached` decorator. If 'x' class has
@cached on its method `foo`, type
@@ -183,6 +191,7 @@ def clear_cache(obj, funcname):
"""
get_cache_impl(obj, funcname).clear(obj)
+
def copy_cache(obj, funcname, cacheobj):
"""Copy cache for <funcname> from cacheobj to obj."""
cacheattr = get_cache_impl(obj, funcname).cacheattr
@@ -196,9 +205,10 @@ class wproperty(object):
"""Simple descriptor expecting to take a modifier function as first argument
and looking for a _<function name> to retrieve the attribute.
"""
+
def __init__(self, setfunc):
self.setfunc = setfunc
- self.attrname = '_%s' % setfunc.__name__
+ self.attrname = "_%s" % setfunc.__name__
def __set__(self, obj, value):
self.setfunc(obj, value)
@@ -211,22 +221,27 @@ class wproperty(object):
class classproperty(object):
"""this is a simple property-like class but for class attributes.
"""
+
def __init__(self, get):
self.get = get
+
def __get__(self, inst, cls):
return self.get(cls)
class iclassmethod(object):
- '''Descriptor for method which should be available as class method if called
+ """Descriptor for method which should be available as class method if called
on the class or instance method if called on an instance.
- '''
+ """
+
def __init__(self, func):
self.func = func
+
def __get__(self, instance, objtype):
if instance is None:
return method_type(self.func, objtype, objtype.__class__)
return method_type(self.func, instance, objtype)
+
def __set__(self, instance, value):
raise AttributeError("can't set attribute")
@@ -236,9 +251,9 @@ def timed(f):
t = time()
c = process_time()
res = f(*args, **kwargs)
- print('%s clock: %.9f / time: %.9f' % (f.__name__,
- process_time() - c, time() - t))
+ print("%s clock: %.9f / time: %.9f" % (f.__name__, process_time() - c, time() - t))
return res
+
return wrap
@@ -247,6 +262,7 @@ def locked(acquire, release):
returning a decorator function which will call the inner method after
having called acquire(self) et will call release(self) afterwards.
"""
+
def decorator(f):
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
acquire(self)
@@ -254,7 +270,9 @@ def locked(acquire, release):
return f(self, *args, **kwargs)
finally:
release(self)
+
return wrapper
+
return decorator
@@ -278,13 +296,16 @@ def monkeypatch(klass: type, methodname: Optional[str] = None) -> Callable:
>>> a.foo()
12
"""
+
def decorator(func):
try:
name = methodname or func.__name__
except AttributeError:
- raise AttributeError('%s has no __name__ attribute: '
- 'you should provide an explicit `methodname`'
- % func)
+ raise AttributeError(
+ "%s has no __name__ attribute: "
+ "you should provide an explicit `methodname`" % func
+ )
setattr(klass, name, func)
return func
+
return decorator
diff --git a/logilab/common/deprecation.py b/logilab/common/deprecation.py
index b147b43..15f8087 100644
--- a/logilab/common/deprecation.py
+++ b/logilab/common/deprecation.py
@@ -38,7 +38,7 @@ class DeprecationWrapper(object):
return getattr(self._proxied, attr)
def __setattr__(self, attr, value):
- if attr in ('_proxied', '_msg'):
+ if attr in ("_proxied", "_msg"):
self.__dict__[attr] = value
else:
send_warning(self._msg, stacklevel=3, version=self.version)
diff --git a/logilab/common/fileutils.py b/logilab/common/fileutils.py
index 102cd7c..1b1ed5b 100644
--- a/logilab/common/fileutils.py
+++ b/logilab/common/fileutils.py
@@ -44,6 +44,7 @@ from logilab.common.shellutils import find
from logilab.common.deprecation import deprecated
from logilab.common.compat import FileIO
+
def first_level_directory(path: str) -> str:
"""Return the first level directory of a path.
@@ -69,6 +70,7 @@ def first_level_directory(path: str) -> str:
# path was absolute, head is the fs root
return head
+
def abspath_listdir(path):
"""Lists path's content using absolute paths."""
path = abspath(path)
@@ -90,7 +92,7 @@ def is_binary(filename: str) -> int:
try:
# mypy: Item "None" of "Optional[str]" has no attribute "startswith"
# it's handle by the exception
- return not mimetypes.guess_type(filename)[0].startswith('text') # type: ignore
+ return not mimetypes.guess_type(filename)[0].startswith("text") # type: ignore
except AttributeError:
return 1
@@ -105,8 +107,8 @@ def write_open_mode(filename: str) -> str:
:return: the mode that should be use to open the file ('w' or 'wb')
"""
if is_binary(filename):
- return 'wb'
- return 'w'
+ return "wb"
+ return "w"
def ensure_fs_mode(filepath, desired_mode=S_IWRITE):
@@ -147,10 +149,11 @@ class ProtectedFile(FileIO):
- on close()/del(), write/append the StringIO content to the file and
do the chmod only once
"""
+
def __init__(self, filepath: str, mode: str) -> None:
self.original_mode = stat(filepath)[ST_MODE]
self.mode_changed = False
- if mode in ('w', 'a', 'wb', 'ab'):
+ if mode in ("w", "a", "wb", "ab"):
if not self.original_mode & S_IWRITE:
chmod(filepath, self.original_mode | S_IWRITE)
self.mode_changed = True
@@ -178,6 +181,7 @@ class UnresolvableError(Exception):
path between two paths.
"""
+
def relative_path(from_file, to_file):
"""Try to get a relative path from `from_file` to `to_file`
(path will be absolute if to_file is an absolute file). This function
@@ -224,7 +228,7 @@ def relative_path(from_file, to_file):
from_file = normpath(from_file)
to_file = normpath(to_file)
if from_file == to_file:
- return ''
+ return ""
if isabs(to_file):
if not isabs(from_file):
return to_file
@@ -240,7 +244,7 @@ def relative_path(from_file, to_file):
to_parts.pop(0)
else:
idem = 0
- result.append('..')
+ result.append("..")
result += to_parts
return sep.join(result)
@@ -254,9 +258,12 @@ def norm_read(path):
:rtype: str
:return: the content of the file with normalized line feeds
"""
- return open(path, 'U').read()
+ return open(path, "U").read()
+
+
norm_read = deprecated("use \"open(path, 'U').read()\"")(norm_read)
+
def norm_open(path):
"""Return a stream for a file with content with normalized line feeds.
@@ -266,9 +273,12 @@ def norm_open(path):
:rtype: file or StringIO
:return: the opened file with normalized line feeds
"""
- return open(path, 'U')
+ return open(path, "U")
+
+
norm_open = deprecated("use \"open(path, 'U')\"")(norm_open)
+
def lines(path: str, comments: Optional[str] = None) -> List[str]:
"""Return a list of non empty lines in the file located at `path`.
@@ -321,9 +331,13 @@ def stream_lines(stream: TextIOWrapper, comments: Optional[str] = None) -> List[
return result
-def export(from_dir: str, to_dir: str,
- blacklist: Tuple[str, str, str, str, str, str, str, str] = BASE_BLACKLIST, ignore_ext: Tuple[str, str, str, str, str, str] = IGNORED_EXTENSIONS,
- verbose: int = 0) -> None:
+def export(
+ from_dir: str,
+ to_dir: str,
+ blacklist: Tuple[str, str, str, str, str, str, str, str] = BASE_BLACKLIST,
+ ignore_ext: Tuple[str, str, str, str, str, str] = IGNORED_EXTENSIONS,
+ verbose: int = 0,
+) -> None:
"""Make a mirror of `from_dir` in `to_dir`, omitting directories and
files listed in the black list or ending with one of the given
extensions.
@@ -352,8 +366,8 @@ def export(from_dir: str, to_dir: str,
try:
mkdir(to_dir)
except OSError:
- pass # FIXME we should use "exists" if the point is about existing dir
- # else (permission problems?) shouldn't return / raise ?
+ pass # FIXME we should use "exists" if the point is about existing dir
+ # else (permission problems?) shouldn't return / raise ?
for directory, dirnames, filenames in walk(from_dir):
for norecurs in blacklist:
try:
@@ -362,7 +376,7 @@ def export(from_dir: str, to_dir: str,
continue
for dirname in dirnames:
src = join(directory, dirname)
- dest = to_dir + src[len(from_dir):]
+ dest = to_dir + src[len(from_dir) :]
if isdir(src):
if not exists(dest):
mkdir(dest)
@@ -372,9 +386,9 @@ def export(from_dir: str, to_dir: str,
if any([filename.endswith(ext) for ext in ignore_ext]):
continue
src = join(directory, filename)
- dest = to_dir + src[len(from_dir):]
+ dest = to_dir + src[len(from_dir) :]
if verbose:
- print(src, '->', dest, file=sys.stderr)
+ print(src, "->", dest, file=sys.stderr)
if exists(dest):
remove(dest)
shutil.copy2(src, dest)
@@ -396,6 +410,5 @@ def remove_dead_links(directory, verbose=0):
src = join(dirpath, filename)
if islink(src) and not exists(src):
if verbose:
- print('remove dead link', src)
+ print("remove dead link", src)
remove(src)
-
diff --git a/logilab/common/graph.py b/logilab/common/graph.py
index fffa172..82c3a32 100644
--- a/logilab/common/graph.py
+++ b/logilab/common/graph.py
@@ -32,47 +32,59 @@ import codecs
import errno
from typing import Dict, List, Tuple, Union, Any, Optional, Set, TypeVar, Iterable
+
def escape(value):
"""Make <value> usable in a dot file."""
- lines = [line.replace('"', '\\"') for line in value.split('\n')]
- data = '\\l'.join(lines)
- return '\\n' + data
+ lines = [line.replace('"', '\\"') for line in value.split("\n")]
+ data = "\\l".join(lines)
+ return "\\n" + data
+
def target_info_from_filename(filename):
"""Transforms /some/path/foo.png into ('/some/path', 'foo.png', 'png')."""
basename = osp.basename(filename)
storedir = osp.dirname(osp.abspath(filename))
- target = filename.split('.')[-1]
+ target = filename.split(".")[-1]
return storedir, basename, target
class DotBackend:
"""Dot File backend."""
- def __init__(self, graphname, rankdir=None, size=None, ratio=None,
- charset='utf-8', renderer='dot', additionnal_param={}):
+
+ def __init__(
+ self,
+ graphname,
+ rankdir=None,
+ size=None,
+ ratio=None,
+ charset="utf-8",
+ renderer="dot",
+ additionnal_param={},
+ ):
self.graphname = graphname
self.renderer = renderer
self.lines = []
self._source = None
self.emit("digraph %s {" % normalize_node_id(graphname))
if rankdir:
- self.emit('rankdir=%s' % rankdir)
+ self.emit("rankdir=%s" % rankdir)
if ratio:
- self.emit('ratio=%s' % ratio)
+ self.emit("ratio=%s" % ratio)
if size:
self.emit('size="%s"' % size)
if charset:
- assert charset.lower() in ('utf-8', 'iso-8859-1', 'latin1'), \
- 'unsupported charset %s' % charset
+ assert charset.lower() in ("utf-8", "iso-8859-1", "latin1"), (
+ "unsupported charset %s" % charset
+ )
self.emit('charset="%s"' % charset)
for param in sorted(additionnal_param.items()):
- self.emit('='.join(param))
+ self.emit("=".join(param))
def get_source(self):
"""returns self._source"""
if self._source is None:
self.emit("}\n")
- self._source = '\n'.join(self.lines)
+ self._source = "\n".join(self.lines)
del self.lines
return self._source
@@ -87,14 +99,15 @@ class DotBackend:
:rtype: str
:return: a path to the generated file
"""
- import subprocess # introduced in py 2.4
+ import subprocess # introduced in py 2.4
+
name = self.graphname
if not dotfile:
# if 'outputfile' is a dot file use it as 'dotfile'
if outputfile and outputfile.endswith(".dot"):
dotfile = outputfile
else:
- dotfile = '%s.dot' % name
+ dotfile = "%s.dot" % name
if outputfile is not None:
storedir, basename, target = target_info_from_filename(outputfile)
if target != "dot":
@@ -103,30 +116,43 @@ class DotBackend:
else:
dot_sourcepath = osp.join(storedir, dotfile)
else:
- target = 'png'
+ target = "png"
pdot, dot_sourcepath = tempfile.mkstemp(".dot", name)
ppng, outputfile = tempfile.mkstemp(".png", name)
os.close(pdot)
os.close(ppng)
- pdot = codecs.open(dot_sourcepath, 'w', encoding='utf8')
+ pdot = codecs.open(dot_sourcepath, "w", encoding="utf8")
pdot.write(self.source)
pdot.close()
- if target != 'dot':
- if sys.platform == 'win32':
+ if target != "dot":
+ if sys.platform == "win32":
use_shell = True
else:
use_shell = False
try:
if mapfile:
- subprocess.call([self.renderer, '-Tcmapx', '-o', mapfile, '-T', target, dot_sourcepath, '-o', outputfile],
- shell=use_shell)
+ subprocess.call(
+ [
+ self.renderer,
+ "-Tcmapx",
+ "-o",
+ mapfile,
+ "-T",
+ target,
+ dot_sourcepath,
+ "-o",
+ outputfile,
+ ],
+ shell=use_shell,
+ )
else:
- subprocess.call([self.renderer, '-T', target,
- dot_sourcepath, '-o', outputfile],
- shell=use_shell)
+ subprocess.call(
+ [self.renderer, "-T", target, dot_sourcepath, "-o", outputfile],
+ shell=use_shell,
+ )
except OSError as e:
if e.errno == errno.ENOENT:
- e.strerror = 'File not found: {0}'.format(self.renderer)
+ e.strerror = "File not found: {0}".format(self.renderer)
raise
os.unlink(dot_sourcepath)
return outputfile
@@ -141,19 +167,21 @@ class DotBackend:
"""
attrs = ['%s="%s"' % (prop, value) for prop, value in props.items()]
n_from, n_to = normalize_node_id(name1), normalize_node_id(name2)
- self.emit('%s -> %s [%s];' % (n_from, n_to, ', '.join(sorted(attrs))) )
+ self.emit("%s -> %s [%s];" % (n_from, n_to, ", ".join(sorted(attrs))))
def emit_node(self, name, **props):
"""emit a node with given properties.
node properties: see http://www.graphviz.org/doc/info/attrs.html
"""
attrs = ['%s="%s"' % (prop, value) for prop, value in props.items()]
- self.emit('%s [%s];' % (normalize_node_id(name), ', '.join(sorted(attrs))))
+ self.emit("%s [%s];" % (normalize_node_id(name), ", ".join(sorted(attrs))))
+
def normalize_node_id(nid):
"""Returns a suitable DOT node id for `nid`."""
return '"%s"' % nid
+
class GraphGenerator:
def __init__(self, backend):
# the backend is responsible to output the graph in a particular format
@@ -194,8 +222,8 @@ def ordered_nodes(graph: _Graph) -> Tuple[V, ...]:
cycles: List[List[V]] = get_cycles(graph)
if cycles:
- bad_cycles = '\n'.join([' -> '.join(map(str, cycle)) for cycle in cycles])
- raise UnorderableGraph('cycles in graph: %s' % bad_cycles)
+ bad_cycles = "\n".join([" -> ".join(map(str, cycle)) for cycle in cycles])
+ raise UnorderableGraph("cycles in graph: %s" % bad_cycles)
vertices = set(graph)
to_vertices = set()
@@ -205,7 +233,7 @@ def ordered_nodes(graph: _Graph) -> Tuple[V, ...]:
missing_vertices = to_vertices - vertices
if missing_vertices:
- raise UnorderableGraph('missing vertices: %s' % ', '.join(missing_vertices))
+ raise UnorderableGraph("missing vertices: %s" % ", ".join(missing_vertices))
# order vertices
order = []
@@ -214,7 +242,7 @@ def ordered_nodes(graph: _Graph) -> Tuple[V, ...]:
while graph:
if old_len == len(graph):
- raise UnorderableGraph('unknown problem with %s' % graph)
+ raise UnorderableGraph("unknown problem with %s" % graph)
old_len = len(graph)
deps_ok = []
@@ -240,12 +268,11 @@ def ordered_nodes(graph: _Graph) -> Tuple[V, ...]:
return tuple(result)
-def get_cycles(graph_dict: _Graph,
- vertices: Optional[Iterable] = None) -> List[List]:
- '''given a dictionary representing an ordered graph (i.e. key are vertices
+def get_cycles(graph_dict: _Graph, vertices: Optional[Iterable] = None) -> List[List]:
+ """given a dictionary representing an ordered graph (i.e. key are vertices
and values is a list of destination vertices representing edges), return a
list of detected cycles
- '''
+ """
if not graph_dict:
return []
@@ -259,11 +286,9 @@ def get_cycles(graph_dict: _Graph,
return result
-def _get_cycles(graph_dict: _Graph,
- path: List,
- visited: Set,
- result: List[List],
- vertice: V) -> None:
+def _get_cycles(
+ graph_dict: _Graph, path: List, visited: Set, result: List[List], vertice: V
+) -> None:
"""recursive function doing the real work for get_cycles"""
if vertice in path:
cycle = [vertice]
@@ -299,7 +324,9 @@ def _get_cycles(graph_dict: _Graph,
path.pop()
-def has_path(graph_dict: Dict[str, List[str]], fromnode: str, tonode: str, path: Optional[List[str]] = None) -> Optional[List[str]]:
+def has_path(
+ graph_dict: Dict[str, List[str]], fromnode: str, tonode: str, path: Optional[List[str]] = None
+) -> Optional[List[str]]:
"""generic function taking a simple graph definition as a dictionary, with
node has key associated to a list of nodes directly reachable from it.
@@ -316,4 +343,3 @@ def has_path(graph_dict: Dict[str, List[str]], fromnode: str, tonode: str, path:
return path[1:] + [tonode]
path.pop()
return None
-
diff --git a/logilab/common/interface.py b/logilab/common/interface.py
index 8248a27..4d4b92d 100644
--- a/logilab/common/interface.py
+++ b/logilab/common/interface.py
@@ -28,6 +28,7 @@ __docformat__ = "restructuredtext en"
class Interface(object):
"""Base class for interfaces."""
+
@classmethod
def is_implemented_by(cls, instance: type) -> bool:
return implements(instance, cls)
@@ -37,7 +38,7 @@ def implements(obj: type, interface: type) -> bool:
"""Return true if the give object (maybe an instance or class) implements
the interface.
"""
- kimplements = getattr(obj, '__implements__', ())
+ kimplements = getattr(obj, "__implements__", ())
if not isinstance(kimplements, (list, tuple)):
kimplements = (kimplements,)
for implementedinterface in kimplements:
@@ -62,7 +63,7 @@ def extend(klass: type, interface: type, _recurs: bool = False) -> None:
kimplementsklass = tuple
kimplements = []
kimplements.append(interface)
- klass.__implements__ = kimplementsklass(kimplements) #type: ignore
+ klass.__implements__ = kimplementsklass(kimplements) # type: ignore
for subklass in klass.__subclasses__():
extend(subklass, interface, _recurs=True)
elif _recurs:
diff --git a/logilab/common/logging_ext.py b/logilab/common/logging_ext.py
index 9657581..e1df45d 100644
--- a/logilab/common/logging_ext.py
+++ b/logilab/common/logging_ext.py
@@ -30,13 +30,14 @@ from logilab.common.textutils import colorize_ansi
def set_log_methods(cls, logger):
"""bind standard logger's methods as methods on the class"""
cls.__logger = logger
- for attr in ('debug', 'info', 'warning', 'error', 'critical', 'exception'):
+ for attr in ("debug", "info", "warning", "error", "critical", "exception"):
setattr(cls, attr, getattr(logger, attr))
def xxx_cyan(record):
- if 'XXX' in record.message:
- return 'cyan'
+ if "XXX" in record.message:
+ return "cyan"
+
class ColorFormatter(logging.Formatter):
"""
@@ -54,12 +55,13 @@ class ColorFormatter(logging.Formatter):
def __init__(self, fmt=None, datefmt=None, colors=None):
logging.Formatter.__init__(self, fmt, datefmt)
self.colorfilters = []
- self.colors = {'CRITICAL': 'red',
- 'ERROR': 'red',
- 'WARNING': 'magenta',
- 'INFO': 'green',
- 'DEBUG': 'yellow',
- }
+ self.colors = {
+ "CRITICAL": "red",
+ "ERROR": "red",
+ "WARNING": "magenta",
+ "INFO": "green",
+ "DEBUG": "yellow",
+ }
if colors is not None:
assert isinstance(colors, dict)
self.colors.update(colors)
@@ -76,6 +78,7 @@ class ColorFormatter(logging.Formatter):
return colorize_ansi(msg, color)
return msg
+
def set_color_formatter(logger=None, **kw):
"""
Install a color formatter on the 'logger'. If not given, it will
@@ -94,37 +97,41 @@ def set_color_formatter(logger=None, **kw):
logger.handlers[0].setFormatter(fmt)
-LOG_FORMAT = '%(asctime)s - (%(name)s) %(levelname)s: %(message)s'
-LOG_DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
+LOG_FORMAT = "%(asctime)s - (%(name)s) %(levelname)s: %(message)s"
+LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
+
def get_handler(debug=False, syslog=False, logfile=None, rotation_parameters=None):
"""get an apropriate handler according to given parameters"""
- if os.environ.get('APYCOT_ROOT'):
+ if os.environ.get("APYCOT_ROOT"):
handler = logging.StreamHandler(sys.stdout)
if debug:
handler = logging.StreamHandler()
elif logfile is None:
if syslog:
from logging import handlers
+
handler = handlers.SysLogHandler()
else:
handler = logging.StreamHandler()
else:
try:
if rotation_parameters is None:
- if os.name == 'posix' and sys.version_info >= (2, 6):
+ if os.name == "posix" and sys.version_info >= (2, 6):
from logging.handlers import WatchedFileHandler
+
handler = WatchedFileHandler(logfile)
else:
handler = logging.FileHandler(logfile)
else:
from logging.handlers import TimedRotatingFileHandler
- handler = TimedRotatingFileHandler(
- logfile, **rotation_parameters)
+
+ handler = TimedRotatingFileHandler(logfile, **rotation_parameters)
except IOError:
handler = logging.StreamHandler()
return handler
+
def get_threshold(debug=False, logthreshold=None):
if logthreshold is None:
if debug:
@@ -132,15 +139,15 @@ def get_threshold(debug=False, logthreshold=None):
else:
logthreshold = logging.ERROR
elif isinstance(logthreshold, str):
- logthreshold = getattr(logging, THRESHOLD_MAP.get(logthreshold,
- logthreshold))
+ logthreshold = getattr(logging, THRESHOLD_MAP.get(logthreshold, logthreshold))
return logthreshold
+
def _colorable_terminal():
- isatty = hasattr(sys.__stdout__, 'isatty') and sys.__stdout__.isatty()
+ isatty = hasattr(sys.__stdout__, "isatty") and sys.__stdout__.isatty()
if not isatty:
return False
- if os.name == 'nt':
+ if os.name == "nt":
try:
from colorama import init as init_win32_colors
except ImportError:
@@ -148,22 +155,34 @@ def _colorable_terminal():
init_win32_colors()
return True
+
def get_formatter(logformat=LOG_FORMAT, logdateformat=LOG_DATE_FORMAT):
if _colorable_terminal():
fmt = ColorFormatter(logformat, logdateformat)
+
def col_fact(record):
- if 'XXX' in record.message:
- return 'cyan'
- if 'kick' in record.message:
- return 'red'
+ if "XXX" in record.message:
+ return "cyan"
+ if "kick" in record.message:
+ return "red"
+
fmt.colorfilters.append(col_fact)
else:
fmt = logging.Formatter(logformat, logdateformat)
return fmt
-def init_log(debug=False, syslog=False, logthreshold=None, logfile=None,
- logformat=LOG_FORMAT, logdateformat=LOG_DATE_FORMAT, fmt=None,
- rotation_parameters=None, handler=None):
+
+def init_log(
+ debug=False,
+ syslog=False,
+ logthreshold=None,
+ logfile=None,
+ logformat=LOG_FORMAT,
+ logdateformat=LOG_DATE_FORMAT,
+ fmt=None,
+ rotation_parameters=None,
+ handler=None,
+):
"""init the log service"""
logger = logging.getLogger()
if handler is None:
@@ -181,13 +200,15 @@ def init_log(debug=False, syslog=False, logthreshold=None, logfile=None,
handler.setFormatter(fmt)
return handler
+
# map logilab.common.logger thresholds to logging thresholds
-THRESHOLD_MAP = {'LOG_DEBUG': 'DEBUG',
- 'LOG_INFO': 'INFO',
- 'LOG_NOTICE': 'INFO',
- 'LOG_WARN': 'WARNING',
- 'LOG_WARNING': 'WARNING',
- 'LOG_ERR': 'ERROR',
- 'LOG_ERROR': 'ERROR',
- 'LOG_CRIT': 'CRITICAL',
- }
+THRESHOLD_MAP = {
+ "LOG_DEBUG": "DEBUG",
+ "LOG_INFO": "INFO",
+ "LOG_NOTICE": "INFO",
+ "LOG_WARN": "WARNING",
+ "LOG_WARNING": "WARNING",
+ "LOG_ERR": "ERROR",
+ "LOG_ERROR": "ERROR",
+ "LOG_CRIT": "CRITICAL",
+}
diff --git a/logilab/common/modutils.py b/logilab/common/modutils.py
index 76c4ac4..9ca4c81 100644
--- a/logilab/common/modutils.py
+++ b/logilab/common/modutils.py
@@ -32,8 +32,18 @@ __docformat__ = "restructuredtext en"
import sys
import os
-from os.path import (splitext, join, abspath, isdir, dirname, exists,
- basename, expanduser, normcase, realpath)
+from os.path import (
+ splitext,
+ join,
+ abspath,
+ isdir,
+ dirname,
+ exists,
+ basename,
+ expanduser,
+ normcase,
+ realpath,
+)
from imp import find_module, load_module, C_BUILTIN, PY_COMPILED, PKG_DIRECTORY
from distutils.sysconfig import get_config_var, get_python_lib
from distutils.errors import DistutilsPlatformError
@@ -59,19 +69,19 @@ from logilab.common.deprecation import deprecated
#
# :see: `Problems with /usr/lib64 builds <http://bugs.python.org/issue1294959>`_
# :see: `FHS <http://www.pathname.com/fhs/pub/fhs-2.3.html#LIBLTQUALGTALTERNATEFORMATESSENTIAL>`_
-if sys.platform.startswith('win'):
- PY_SOURCE_EXTS = ('py', 'pyw')
- PY_COMPILED_EXTS = ('dll', 'pyd')
+if sys.platform.startswith("win"):
+ PY_SOURCE_EXTS = ("py", "pyw")
+ PY_COMPILED_EXTS = ("dll", "pyd")
else:
- PY_SOURCE_EXTS = ('py',)
- PY_COMPILED_EXTS = ('so',)
+ PY_SOURCE_EXTS = ("py",)
+ PY_COMPILED_EXTS = ("so",)
try:
STD_LIB_DIR = get_python_lib(standard_lib=True)
# get_python_lib(standard_lib=1) is not available on pypy, set STD_LIB_DIR to
# non-valid path, see https://bugs.pypy.org/issue1164
except DistutilsPlatformError:
- STD_LIB_DIR = '//'
+ STD_LIB_DIR = "//"
EXT_LIB_DIR = get_python_lib()
@@ -83,6 +93,7 @@ class NoSourceFile(Exception):
source file for a precompiled file
"""
+
class LazyObject(object):
def __init__(self, module, obj):
self.module = module
@@ -91,8 +102,7 @@ class LazyObject(object):
def _getobj(self):
if self._imported is None:
- self._imported = getattr(load_module_from_name(self.module),
- self.obj)
+ self._imported = getattr(load_module_from_name(self.module), self.obj)
return self._imported
def __getattribute__(self, attr):
@@ -105,7 +115,9 @@ class LazyObject(object):
return self._getobj()(*args, **kwargs)
-def load_module_from_name(dotted_name: str, path: Optional[Any] = None, use_sys: int = True) -> ModuleType:
+def load_module_from_name(
+ dotted_name: str, path: Optional[Any] = None, use_sys: int = True
+) -> ModuleType:
"""Load a Python module from its name.
:type dotted_name: str
@@ -127,13 +139,15 @@ def load_module_from_name(dotted_name: str, path: Optional[Any] = None, use_sys:
:rtype: module
:return: the loaded module
"""
- module = load_module_from_modpath(dotted_name.split('.'), path, use_sys)
+ module = load_module_from_modpath(dotted_name.split("."), path, use_sys)
if module is None:
raise ImportError("module %s doesn't exist" % dotted_name)
return module
-def load_module_from_modpath(parts: List[str], path: Optional[Any] = None, use_sys: int = True) -> Optional[ModuleType]:
+def load_module_from_modpath(
+ parts: List[str], path: Optional[Any] = None, use_sys: int = True
+) -> Optional[ModuleType]:
"""Load a python module from its splitted name.
:type parts: list(str) or tuple(str)
@@ -156,14 +170,14 @@ def load_module_from_modpath(parts: List[str], path: Optional[Any] = None, use_s
"""
if use_sys:
try:
- return sys.modules['.'.join(parts)]
+ return sys.modules[".".join(parts)]
except KeyError:
pass
modpath = []
prevmodule = None
for part in parts:
modpath.append(part)
- curname = '.'.join(modpath)
+ curname = ".".join(modpath)
module = None
if len(modpath) != len(parts):
# even with use_sys=False, should try to get outer packages from sys.modules
@@ -180,13 +194,13 @@ def load_module_from_modpath(parts: List[str], path: Optional[Any] = None, use_s
mp_file.close()
if prevmodule:
setattr(prevmodule, part, module)
- _file = getattr(module, '__file__', '')
+ _file = getattr(module, "__file__", "")
prevmodule = module
if not _file and _is_namespace(curname):
continue
if not _file and len(modpath) != len(parts):
- raise ImportError('no module in %s' % '.'.join(parts[len(modpath):]) )
- path = [dirname( _file )]
+ raise ImportError("no module in %s" % ".".join(parts[len(modpath) :]))
+ path = [dirname(_file)]
return module
@@ -222,7 +236,7 @@ def _check_init(path: str, mod_path: List[str]) -> bool:
for part in mod_path:
modpath.append(part)
path = join(path, part)
- if not _is_namespace('.'.join(modpath)) and not _has_init(path):
+ if not _is_namespace(".".join(modpath)) and not _has_init(path):
return False
return True
@@ -231,8 +245,7 @@ def _canonicalize_path(path: str) -> str:
return realpath(expanduser(path))
-
-@deprecated('you should avoid using modpath_from_file()')
+@deprecated("you should avoid using modpath_from_file()")
def modpath_from_file(filename: str, extrapath: Optional[Dict[str, str]] = None) -> List[str]:
"""DEPRECATED: doens't play well with symlinks and sys.meta_path
@@ -261,23 +274,23 @@ def modpath_from_file(filename: str, extrapath: Optional[Dict[str, str]] = None)
if extrapath is not None:
for path_ in map(_canonicalize_path, extrapath):
path = abspath(path_)
- if path and normcase(base[:len(path)]) == normcase(path):
- submodpath = [pkg for pkg in base[len(path):].split(os.sep)
- if pkg]
+ if path and normcase(base[: len(path)]) == normcase(path):
+ submodpath = [pkg for pkg in base[len(path) :].split(os.sep) if pkg]
if _check_init(path, submodpath[:-1]):
- return extrapath[path_].split('.') + submodpath
+ return extrapath[path_].split(".") + submodpath
for path in map(_canonicalize_path, sys.path):
if path and normcase(base).startswith(path):
- modpath = [pkg for pkg in base[len(path):].split(os.sep) if pkg]
+ modpath = [pkg for pkg in base[len(path) :].split(os.sep) if pkg]
if _check_init(path, modpath[:-1]):
return modpath
- raise ImportError('Unable to find module for %s in %s' % (
- filename, ', \n'.join(sys.path)))
+ raise ImportError("Unable to find module for %s in %s" % (filename, ", \n".join(sys.path)))
-def file_from_modpath(modpath: List[str], path: Optional[Any] = None, context_file: Optional[str] = None) -> Optional[str]:
+def file_from_modpath(
+ modpath: List[str], path: Optional[Any] = None, context_file: Optional[str] = None
+) -> Optional[str]:
"""given a mod path (i.e. splitted module / package name), return the
corresponding file, giving priority to source file over precompiled
file if it exists
@@ -312,19 +325,18 @@ def file_from_modpath(modpath: List[str], path: Optional[Any] = None, context_fi
context = dirname(context_file)
else:
context = context_file
- if modpath[0] == 'xml':
+ if modpath[0] == "xml":
# handle _xmlplus
try:
- return _file_from_modpath(['_xmlplus'] + modpath[1:], path, context)
+ return _file_from_modpath(["_xmlplus"] + modpath[1:], path, context)
except ImportError:
return _file_from_modpath(modpath, path, context)
- elif modpath == ['os', 'path']:
+ elif modpath == ["os", "path"]:
# FIXME: currently ignoring search_path...
return os.path.__file__
return _file_from_modpath(modpath, path, context)
-
def get_module_part(dotted_name: str, context_file: Optional[str] = None) -> str:
"""given a dotted name return the module part of the name :
@@ -352,9 +364,9 @@ def get_module_part(dotted_name: str, context_file: Optional[str] = None) -> str
(see #10066)
"""
# os.path trick
- if dotted_name.startswith('os.path'):
- return 'os.path'
- parts = dotted_name.split('.')
+ if dotted_name.startswith("os.path"):
+ return "os.path"
+ parts = dotted_name.split(".")
if context_file is not None:
# first check for builtin module which won't be considered latter
# in that case (path != None)
@@ -365,27 +377,27 @@ def get_module_part(dotted_name: str, context_file: Optional[str] = None) -> str
# don't use += or insert, we want a new list to be created !
path: Optional[List] = None
starti = 0
- if parts[0] == '':
- assert context_file is not None, \
- 'explicit relative import, but no context_file?'
- path = [] # prevent resolving the import non-relatively
+ if parts[0] == "":
+ assert context_file is not None, "explicit relative import, but no context_file?"
+ path = [] # prevent resolving the import non-relatively
starti = 1
- while parts[starti] == '': # for all further dots: change context
+ while parts[starti] == "": # for all further dots: change context
starti += 1
assert context_file is not None
context_file = dirname(context_file)
for i in range(starti, len(parts)):
try:
- file_from_modpath(parts[starti:i+1],
- path=path, context_file=context_file)
+ file_from_modpath(parts[starti : i + 1], path=path, context_file=context_file)
except ImportError:
if not i >= max(1, len(parts) - 2):
raise
- return '.'.join(parts[:i])
+ return ".".join(parts[:i])
return dotted_name
-def get_modules(package: str, src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST) -> List[str]:
+def get_modules(
+ package: str, src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST
+) -> List[str]:
"""given a package directory return a list of all available python
modules in the package and its subpackages
@@ -410,21 +422,20 @@ def get_modules(package: str, src_directory: str, blacklist: Sequence[str] = STD
for directory, dirnames, filenames in os.walk(src_directory):
_handle_blacklist(blacklist, dirnames, filenames)
# check for __init__.py
- if not '__init__.py' in filenames:
+ if not "__init__.py" in filenames:
dirnames[:] = ()
continue
if directory != src_directory:
- dir_package = directory[len(src_directory):].replace(os.sep, '.')
+ dir_package = directory[len(src_directory) :].replace(os.sep, ".")
modules.append(package + dir_package)
for filename in filenames:
- if _is_python_file(filename) and filename != '__init__.py':
+ if _is_python_file(filename) and filename != "__init__.py":
src = join(directory, filename)
- module = package + src[len(src_directory):-3]
- modules.append(module.replace(os.sep, '.'))
+ module = package + src[len(src_directory) : -3]
+ modules.append(module.replace(os.sep, "."))
return modules
-
def get_module_files(src_directory: str, blacklist: Sequence[str] = STD_BLACKLIST) -> List[str]:
"""given a package directory return a list of all available python
module's files in the package and its subpackages
@@ -447,7 +458,7 @@ def get_module_files(src_directory: str, blacklist: Sequence[str] = STD_BLACKLIS
for directory, dirnames, filenames in os.walk(src_directory):
_handle_blacklist(blacklist, dirnames, filenames)
# check for __init__.py
- if not '__init__.py' in filenames:
+ if not "__init__.py" in filenames:
dirnames[:] = ()
continue
for filename in filenames:
@@ -473,7 +484,7 @@ def get_source_file(filename: str, include_no_ext: bool = False) -> str:
"""
base, orig_ext = splitext(abspath(filename))
for ext in PY_SOURCE_EXTS:
- source_path = '%s.%s' % (base, ext)
+ source_path = "%s.%s" % (base, ext)
if exists(source_path):
return source_path
if include_no_ext and not orig_ext and exists(base):
@@ -485,7 +496,7 @@ def cleanup_sys_modules(directories):
"""remove submodules of `directories` from `sys.modules`"""
cleaned = []
for modname, module in list(sys.modules.items()):
- modfile = getattr(module, '__file__', None)
+ modfile = getattr(module, "__file__", None)
if modfile:
for directory in directories:
if modfile.startswith(directory):
@@ -515,7 +526,9 @@ def is_python_source(filename):
return splitext(filename)[1][1:] in PY_SOURCE_EXTS
-def is_standard_module(modname: str, std_path: Union[List[str], Tuple[str]] = (STD_LIB_DIR,)) -> bool:
+def is_standard_module(
+ modname: str, std_path: Union[List[str], Tuple[str]] = (STD_LIB_DIR,)
+) -> bool:
"""try to guess if a module is a standard python module (by default,
see `std_path` parameter's description)
@@ -535,7 +548,7 @@ def is_standard_module(modname: str, std_path: Union[List[str], Tuple[str]] = (S
Note: this function is known to return wrong values when inside virtualenv.
See https://www.logilab.org/ticket/294756.
"""
- modname = modname.split('.')[0]
+ modname = modname.split(".")[0]
try:
filename = file_from_modpath([modname])
except ImportError as ex:
@@ -556,7 +569,6 @@ def is_standard_module(modname: str, std_path: Union[List[str], Tuple[str]] = (S
return False
-
def is_relative(modname: str, from_file: str) -> bool:
"""return true if the given module name is relative to the given
file name
@@ -577,7 +589,7 @@ def is_relative(modname: str, from_file: str) -> bool:
if from_file in sys.path:
return False
try:
- find_module(modname.split('.')[0], [from_file])
+ find_module(modname.split(".")[0], [from_file])
return True
except ImportError:
return False
@@ -585,7 +597,10 @@ def is_relative(modname: str, from_file: str) -> bool:
# internal only functions #####################################################
-def _file_from_modpath(modpath: List[str], path: Optional[Any] = None, context: Optional[str] = None) -> Optional[str]:
+
+def _file_from_modpath(
+ modpath: List[str], path: Optional[Any] = None, context: Optional[str] = None
+) -> Optional[str]:
"""given a mod path (i.e. splitted module / package name), return the
corresponding file
@@ -614,15 +629,20 @@ def _file_from_modpath(modpath: List[str], path: Optional[Any] = None, context:
mp_filename = _has_init(mp_filename)
return mp_filename
-def _search_zip(modpath: List[str], pic: Dict[str, Optional[FileFinder]]) -> Tuple[object, str, str]:
+
+def _search_zip(
+ modpath: List[str], pic: Dict[str, Optional[FileFinder]]
+) -> Tuple[object, str, str]:
for filepath, importer in pic.items():
if importer is not None:
if importer.find_module(modpath[0]):
- if not importer.find_module('/'.join(modpath)):
- raise ImportError('No module named %s in %s/%s' % (
- '.'.join(modpath[1:]), filepath, modpath))
- return ZIPFILE, abspath(filepath) + '/' + '/'.join(modpath), filepath
- raise ImportError('No module named %s' % '.'.join(modpath))
+ if not importer.find_module("/".join(modpath)):
+ raise ImportError(
+ "No module named %s in %s/%s" % (".".join(modpath[1:]), filepath, modpath)
+ )
+ return ZIPFILE, abspath(filepath) + "/" + "/".join(modpath), filepath
+ raise ImportError("No module named %s" % ".".join(modpath))
+
try:
import pkg_resources
@@ -635,11 +655,14 @@ except ImportError:
def _is_namespace(modname: str) -> bool:
# mypy: Module has no attribute "_namespace_packages"; maybe "fixup_namespace_packages"?"
# but is still has? or is it a failure from python3 port?
- return (pkg_resources is not None
- and modname in pkg_resources._namespace_packages) # type: ignore
+ return (
+ pkg_resources is not None and modname in pkg_resources._namespace_packages
+ ) # type: ignore
-def _module_file(modpath: List[str], path: Optional[List[str]] = None) -> Tuple[Union[int, object], Optional[str]]:
+def _module_file(
+ modpath: List[str], path: Optional[List[str]] = None
+) -> Tuple[Union[int, object], Optional[str]]:
"""get a module type / file path
:type modpath: list or tuple
@@ -670,7 +693,7 @@ def _module_file(modpath: List[str], path: Optional[List[str]] = None) -> Tuple[
except AttributeError:
checkeggs = False
# pkg_resources support (aka setuptools namespace packages)
- if (_is_namespace(modpath[0]) and modpath[0] in sys.modules):
+ if _is_namespace(modpath[0]) and modpath[0] in sys.modules:
# setuptools has added into sys.modules a module object with proper
# __path__, get back information from there
module = sys.modules[modpath.pop(0)]
@@ -720,31 +743,30 @@ def _module_file(modpath: List[str], path: Optional[List[str]] = None) -> Tuple[
mtype = mp_desc[2]
if modpath:
if mtype != PKG_DIRECTORY:
- raise ImportError('No module %s in %s' % ('.'.join(modpath),
- '.'.join(imported)))
+ raise ImportError("No module %s in %s" % (".".join(modpath), ".".join(imported)))
# XXX guess if package is using pkgutil.extend_path by looking for
# those keywords in the first four Kbytes
try:
- with open(join(mp_filename, '__init__.py')) as stream:
+ with open(join(mp_filename, "__init__.py")) as stream:
data = stream.read(4096)
except IOError:
path = [mp_filename]
else:
- if 'pkgutil' in data and 'extend_path' in data:
+ if "pkgutil" in data and "extend_path" in data:
# extend_path is called, search sys.path for module/packages
# of this name see pkgutil.extend_path documentation
- path = [join(p, *imported) for p in sys.path
- if isdir(join(p, *imported))]
+ path = [join(p, *imported) for p in sys.path if isdir(join(p, *imported))]
else:
path = [mp_filename]
return mtype, mp_filename
+
def _is_python_file(filename: str) -> bool:
"""return true if the given filename should be considered as a python file
.pyc and .pyo are ignored
"""
- for ext in ('.py', '.so', '.pyd', '.pyw'):
+ for ext in (".py", ".so", ".pyd", ".pyw"):
if filename.endswith(ext):
return True
return False
@@ -754,10 +776,10 @@ def _has_init(directory: str) -> Optional[str]:
"""if the given directory has a valid __init__ file, return its path,
else return None
"""
- mod_or_pack = join(directory, '__init__')
+ mod_or_pack = join(directory, "__init__")
- for ext in PY_SOURCE_EXTS + ('pyc', 'pyo'):
- if exists(mod_or_pack + '.' + ext):
- return mod_or_pack + '.' + ext
+ for ext in PY_SOURCE_EXTS + ("pyc", "pyo"):
+ if exists(mod_or_pack + "." + ext):
+ return mod_or_pack + "." + ext
return None
diff --git a/logilab/common/optik_ext.py b/logilab/common/optik_ext.py
index 11e2155..3f321b5 100644
--- a/logilab/common/optik_ext.py
+++ b/logilab/common/optik_ext.py
@@ -62,33 +62,44 @@ from optparse import Values, IndentedHelpFormatter, OptionGroup
from _io import StringIO
# python >= 2.3
-from optparse import OptionParser as BaseParser, Option as BaseOption, \
- OptionGroup, OptionContainer, OptionValueError, OptionError, \
- Values, HelpFormatter, NO_DEFAULT, SUPPRESS_HELP
+from optparse import (
+ OptionParser as BaseParser,
+ Option as BaseOption,
+ OptionGroup,
+ OptionContainer,
+ OptionValueError,
+ OptionError,
+ Values,
+ HelpFormatter,
+ NO_DEFAULT,
+ SUPPRESS_HELP,
+)
try:
from mx import DateTime
+
HAS_MX_DATETIME = True
except ImportError:
HAS_MX_DATETIME = False
-from logilab.common.textutils import splitstrip, TIME_UNITS, BYTE_UNITS, \
- apply_units
+from logilab.common.textutils import splitstrip, TIME_UNITS, BYTE_UNITS, apply_units
def check_regexp(option, opt, value):
"""check a regexp value by trying to compile it
return the compiled regexp
"""
- if hasattr(value, 'pattern'):
+ if hasattr(value, "pattern"):
return value
try:
return re.compile(value)
except ValueError:
- raise OptionValueError(
- "option %s: invalid regexp value: %r" % (opt, value))
+ raise OptionValueError("option %s: invalid regexp value: %r" % (opt, value))
+
-def check_csv(option: Optional['Option'], opt: str, value: Union[List[str], Tuple[str, ...], str]) -> Union[List[str], Tuple[str, ...]]:
+def check_csv(
+ option: Optional["Option"], opt: str, value: Union[List[str], Tuple[str, ...], str]
+) -> Union[List[str], Tuple[str, ...]]:
"""check a csv value by trying to split it
return the list of separated values
"""
@@ -97,23 +108,26 @@ def check_csv(option: Optional['Option'], opt: str, value: Union[List[str], Tupl
try:
return splitstrip(value)
except ValueError:
- raise OptionValueError(
- "option %s: invalid csv value: %r" % (opt, value))
+ raise OptionValueError("option %s: invalid csv value: %r" % (opt, value))
+
-def check_yn(option: Optional['Option'], opt: str, value: Union[bool, str]) -> bool:
+def check_yn(option: Optional["Option"], opt: str, value: Union[bool, str]) -> bool:
"""check a yn value
return true for yes and false for no
"""
if isinstance(value, int):
return bool(value)
- if value in ('y', 'yes'):
+ if value in ("y", "yes"):
return True
- if value in ('n', 'no'):
+ if value in ("n", "no"):
return False
msg = "option %s: invalid yn value %r, should be in (y, yes, n, no)"
raise OptionValueError(msg % (opt, value))
-def check_named(option: Optional[Any], opt: str, value: Union[Dict[str, str], str]) -> Dict[str, str]:
+
+def check_named(
+ option: Optional[Any], opt: str, value: Union[Dict[str, str], str]
+) -> Dict[str, str]:
"""check a named value
return a dictionary containing (name, value) associations
"""
@@ -124,22 +138,24 @@ def check_named(option: Optional[Any], opt: str, value: Union[Dict[str, str], st
# mypy: Argument 1 to "append" of "list" has incompatible type "List[str]";
# mypy: expected "Tuple[str, str]"
# we know that the split will give a 2 items list
- if value.find('=') != -1:
- values.append(value.split('=', 1)) # type: ignore
- elif value.find(':') != -1:
- values.append(value.split(':', 1)) # type: ignore
+ if value.find("=") != -1:
+ values.append(value.split("=", 1)) # type: ignore
+ elif value.find(":") != -1:
+ values.append(value.split(":", 1)) # type: ignore
if values:
return dict(values)
msg = "option %s: invalid named value %r, should be <NAME>=<VALUE> or \
<NAME>:<VALUE>"
raise OptionValueError(msg % (opt, value))
+
def check_password(option, opt, value):
"""check a password value (can't be empty)
"""
# no actual checking, monkey patch if you want more
return value
+
def check_file(option, opt, value):
"""check a file value
return the filepath
@@ -149,6 +165,7 @@ def check_file(option, opt, value):
msg = "option %s: file %r does not exist"
raise OptionValueError(msg % (opt, value))
+
# XXX use python datetime
def check_date(option, opt, value):
"""check a file value
@@ -156,9 +173,9 @@ def check_date(option, opt, value):
"""
try:
return DateTime.strptime(value, "%Y/%m/%d")
- except DateTime.Error :
- raise OptionValueError(
- "expected format of %s is yyyy/mm/dd" % opt)
+ except DateTime.Error:
+ raise OptionValueError("expected format of %s is yyyy/mm/dd" % opt)
+
def check_color(option, opt, value):
"""check a color value and returns it
@@ -166,23 +183,25 @@ def check_color(option, opt, value):
checks hexadecimal forms
"""
# Case (1) : color label, we trust the end-user
- if re.match('[a-z0-9 ]+$', value, re.I):
+ if re.match("[a-z0-9 ]+$", value, re.I):
return value
# Case (2) : only accepts hexadecimal forms
- if re.match('#[a-f0-9]{6}', value, re.I):
+ if re.match("#[a-f0-9]{6}", value, re.I):
return value
# Else : not a color label neither a valid hexadecimal form => error
msg = "option %s: invalid color : %r, should be either hexadecimal \
value or predefined color"
raise OptionValueError(msg % (opt, value))
+
def check_time(option, opt, value):
if isinstance(value, (int, float)):
return value
return apply_units(value, TIME_UNITS)
-def check_bytes(option: Optional['Option'], opt: str, value: Any) -> int:
- if hasattr(value, '__int__'):
+
+def check_bytes(option: Optional["Option"], opt: str, value: Any) -> int:
+ if hasattr(value, "__int__"):
return value
# mypy: Incompatible return value type (got "Union[float, int]", expected "int")
# we force "int" using "final=int"
@@ -192,24 +211,34 @@ def check_bytes(option: Optional['Option'], opt: str, value: Any) -> int:
class Option(BaseOption):
"""override optik.Option to add some new option types
"""
- TYPES = BaseOption.TYPES + ('regexp', 'csv', 'yn', 'named', 'password',
- 'multiple_choice', 'file', 'color',
- 'time', 'bytes')
- ATTRS = BaseOption.ATTRS + ['hide', 'level']
+
+ TYPES = BaseOption.TYPES + (
+ "regexp",
+ "csv",
+ "yn",
+ "named",
+ "password",
+ "multiple_choice",
+ "file",
+ "color",
+ "time",
+ "bytes",
+ )
+ ATTRS = BaseOption.ATTRS + ["hide", "level"]
TYPE_CHECKER = copy(BaseOption.TYPE_CHECKER)
- TYPE_CHECKER['regexp'] = check_regexp
- TYPE_CHECKER['csv'] = check_csv
- TYPE_CHECKER['yn'] = check_yn
- TYPE_CHECKER['named'] = check_named
- TYPE_CHECKER['multiple_choice'] = check_csv
- TYPE_CHECKER['file'] = check_file
- TYPE_CHECKER['color'] = check_color
- TYPE_CHECKER['password'] = check_password
- TYPE_CHECKER['time'] = check_time
- TYPE_CHECKER['bytes'] = check_bytes
+ TYPE_CHECKER["regexp"] = check_regexp
+ TYPE_CHECKER["csv"] = check_csv
+ TYPE_CHECKER["yn"] = check_yn
+ TYPE_CHECKER["named"] = check_named
+ TYPE_CHECKER["multiple_choice"] = check_csv
+ TYPE_CHECKER["file"] = check_file
+ TYPE_CHECKER["color"] = check_color
+ TYPE_CHECKER["password"] = check_password
+ TYPE_CHECKER["time"] = check_time
+ TYPE_CHECKER["bytes"] = check_bytes
if HAS_MX_DATETIME:
- TYPES += ('date',)
- TYPE_CHECKER['date'] = check_date
+ TYPES += ("date",)
+ TYPE_CHECKER["date"] = check_date
def __init__(self, *opts: str, **attrs: Any) -> None:
BaseOption.__init__(self, *opts, **attrs)
@@ -224,15 +253,16 @@ class Option(BaseOption):
# mypy: "Option" has no attribute "choices"
# we know that option of this type has this attribute
if self.choices is None: # type: ignore
- raise OptionError(
- "must supply a list of choices for type 'choice'", self)
+ raise OptionError("must supply a list of choices for type 'choice'", self)
elif not isinstance(self.choices, (tuple, list)): # type: ignore
raise OptionError(
"choices must be a list of strings ('%s' supplied)"
- % str(type(self.choices)).split("'")[1], self) # type: ignore
+ % str(type(self.choices)).split("'")[1],
+ self,
+ ) # type: ignore
elif self.choices is not None: # type: ignore
- raise OptionError(
- "must not supply choices for type %r" % self.type, self)
+ raise OptionError("must not supply choices for type %r" % self.type, self)
+
# mypy: Unsupported target for indexed assignment
# black magic?
BaseOption.CHECK_METHODS[2] = _check_choice # type: ignore
@@ -241,7 +271,7 @@ class Option(BaseOption):
# First, convert the value(s) to the right type. Howl if any
# value(s) are bogus.
value = self.convert_value(opt, value)
- if self.type == 'named':
+ if self.type == "named":
assert self.dest is not None
existant = getattr(values, self.dest)
if existant:
@@ -253,13 +283,13 @@ class Option(BaseOption):
# mypy: Argument 2 to "take_action" of "Option" has incompatible type "Optional[str]";
# mypy: expected "str"
# is it ok?
- return self.take_action(
- self.action, self.dest, opt, value, values, parser) # type: ignore
+ return self.take_action(self.action, self.dest, opt, value, values, parser) # type: ignore
class OptionParser(BaseParser):
"""override optik.OptionParser to use our Option class
"""
+
def __init__(self, option_class: type = Option, *args: Any, **kwargs: Any) -> None:
# mypy: Argument "option_class" to "__init__" of "OptionParser" has incompatible type
# mypy: "type"; expected "Option"
@@ -269,7 +299,7 @@ class OptionParser(BaseParser):
def format_option_help(self, formatter: Optional[HelpFormatter] = None) -> str:
if formatter is None:
formatter = self.formatter
- outputlevel = getattr(formatter, 'output_level', 0)
+ outputlevel = getattr(formatter, "output_level", 0)
formatter.store_option_strings(self)
result = []
result.append(formatter.format_heading("Options"))
@@ -281,7 +311,8 @@ class OptionParser(BaseParser):
# mypy: "OptionParser" has no attribute "level"
# but it has one no?
if group.level <= outputlevel and ( # type: ignore
- group.description or level_options(group, outputlevel)):
+ group.description or level_options(group, outputlevel)
+ ):
result.append(group.format_help(formatter))
result.append("\n")
formatter.dedent()
@@ -293,19 +324,25 @@ class OptionParser(BaseParser):
# monkeypatching
OptionGroup.level = 0 # type: ignore
+
def level_options(group: BaseParser, outputlevel: int) -> List[BaseOption]:
# mypy: "Option" has no attribute "help"
# but it does
- return [option for option in group.option_list
- if (getattr(option, 'level', 0) or 0) <= outputlevel
- and not option.help is SUPPRESS_HELP] # type: ignore
+ return [
+ option
+ for option in group.option_list
+ if (getattr(option, "level", 0) or 0) <= outputlevel and not option.help is SUPPRESS_HELP
+ ] # type: ignore
+
def format_option_help(self, formatter):
result = []
- outputlevel = getattr(formatter, 'output_level', 0) or 0
+ outputlevel = getattr(formatter, "output_level", 0) or 0
for option in level_options(self, outputlevel):
result.append(formatter.format_option(option))
return "".join(result)
+
+
# mypy error: Cannot assign to a method
# but we still do it because magic
OptionContainer.format_option_help = format_option_help # type: ignore
@@ -314,16 +351,17 @@ OptionContainer.format_option_help = format_option_help # type: ignore
class ManHelpFormatter(HelpFormatter):
"""Format help using man pages ROFF format"""
- def __init__ (self,
- indent_increment: int = 0,
- max_help_position: int = 24,
- width: int = 79,
- short_first: int = 0) -> None:
- HelpFormatter.__init__ (
- self, indent_increment, max_help_position, width, short_first)
+ def __init__(
+ self,
+ indent_increment: int = 0,
+ max_help_position: int = 24,
+ width: int = 79,
+ short_first: int = 0,
+ ) -> None:
+ HelpFormatter.__init__(self, indent_increment, max_help_position, width, short_first)
def format_heading(self, heading: str) -> str:
- return '.SH %s\n' % heading.upper()
+ return ".SH %s\n" % heading.upper()
def format_description(self, description):
return description
@@ -342,12 +380,15 @@ class ManHelpFormatter(HelpFormatter):
# mypy: "OptionParser"; expected "Option"
# it still works?
help_text = self.expand_default(option) # type: ignore
- help = ' '.join([l.strip() for l in help_text.splitlines()])
+ help = " ".join([l.strip() for l in help_text.splitlines()])
else:
- help = ''
- return '''.IP "%s"
+ help = ""
+ return """.IP "%s"
%s
-''' % (optstring, help)
+""" % (
+ optstring,
+ help,
+ )
def format_head(self, optparser: OptionParser, pkginfo: attrdict, section: int = 1) -> str:
long_desc = ""
@@ -355,43 +396,54 @@ class ManHelpFormatter(HelpFormatter):
short_desc = self.format_short_description(pgm, pkginfo.description)
if hasattr(pkginfo, "long_desc"):
long_desc = self.format_long_description(pgm, pkginfo.long_desc)
- return '%s\n%s\n%s\n%s' % (self.format_title(pgm, section),
- short_desc, self.format_synopsis(pgm),
- long_desc)
+ return "%s\n%s\n%s\n%s" % (
+ self.format_title(pgm, section),
+ short_desc,
+ self.format_synopsis(pgm),
+ long_desc,
+ )
def format_title(self, pgm: str, section: int) -> str:
- date = '-'.join([str(num) for num in time.localtime()[:3]])
+ date = "-".join([str(num) for num in time.localtime()[:3]])
return '.TH %s %s "%s" %s' % (pgm, section, date, pgm)
def format_short_description(self, pgm: str, short_desc: str) -> str:
- return '''.SH NAME
+ return """.SH NAME
.B %s
\- %s
-''' % (pgm, short_desc.strip())
+""" % (
+ pgm,
+ short_desc.strip(),
+ )
def format_synopsis(self, pgm: str) -> str:
- return '''.SH SYNOPSIS
+ return (
+ """.SH SYNOPSIS
.B %s
[
.I OPTIONS
] [
.I <arguments>
]
-''' % pgm
+"""
+ % pgm
+ )
def format_long_description(self, pgm, long_desc):
- long_desc = '\n'.join([line.lstrip()
- for line in long_desc.splitlines()])
- long_desc = long_desc.replace('\n.\n', '\n\n')
+ long_desc = "\n".join([line.lstrip() for line in long_desc.splitlines()])
+ long_desc = long_desc.replace("\n.\n", "\n\n")
if long_desc.lower().startswith(pgm):
- long_desc = long_desc[len(pgm):]
- return '''.SH DESCRIPTION
+ long_desc = long_desc[len(pgm) :]
+ return """.SH DESCRIPTION
.B %s
%s
-''' % (pgm, long_desc.strip())
+""" % (
+ pgm,
+ long_desc.strip(),
+ )
def format_tail(self, pkginfo: attrdict) -> str:
- tail = '''.SH SEE ALSO
+ tail = """.SH SEE ALSO
/usr/share/doc/pythonX.Y-%s/
.SH BUGS
@@ -400,18 +452,32 @@ Please report bugs on the project\'s mailing list:
.SH AUTHOR
%s <%s>
-''' % (getattr(pkginfo, 'debian_name', pkginfo.modname),
- pkginfo.mailinglist, pkginfo.author, pkginfo.author_email)
+""" % (
+ getattr(pkginfo, "debian_name", pkginfo.modname),
+ pkginfo.mailinglist,
+ pkginfo.author,
+ pkginfo.author_email,
+ )
if hasattr(pkginfo, "copyright"):
- tail += '''
+ tail += (
+ """
.SH COPYRIGHT
%s
-''' % pkginfo.copyright
+"""
+ % pkginfo.copyright
+ )
return tail
-def generate_manpage(optparser: OptionParser, pkginfo: attrdict, section: int = 1, stream: StringIO = sys.stdout, level: int = 0) -> None:
+
+def generate_manpage(
+ optparser: OptionParser,
+ pkginfo: attrdict,
+ section: int = 1,
+ stream: StringIO = sys.stdout,
+ level: int = 0,
+) -> None:
"""generate a man page from an optik parser"""
formatter = ManHelpFormatter()
# mypy: "ManHelpFormatter" has no attribute "output_level"
@@ -423,5 +489,4 @@ def generate_manpage(optparser: OptionParser, pkginfo: attrdict, section: int =
print(formatter.format_tail(pkginfo), file=stream)
-__all__ = ('OptionParser', 'Option', 'OptionGroup', 'OptionValueError',
- 'Values')
+__all__ = ("OptionParser", "Option", "OptionGroup", "OptionValueError", "Values")
diff --git a/logilab/common/optparser.py b/logilab/common/optparser.py
index aa17750..8dd6b36 100644
--- a/logilab/common/optparser.py
+++ b/logilab/common/optparser.py
@@ -34,32 +34,37 @@ from __future__ import print_function
__docformat__ = "restructuredtext en"
from warnings import warn
-warn('lgc.optparser module is deprecated, use lgc.clcommands instead', DeprecationWarning,
- stacklevel=2)
+
+warn(
+ "lgc.optparser module is deprecated, use lgc.clcommands instead",
+ DeprecationWarning,
+ stacklevel=2,
+)
import sys
import optparse
-class OptionParser(optparse.OptionParser):
+class OptionParser(optparse.OptionParser):
def __init__(self, *args, **kwargs):
optparse.OptionParser.__init__(self, *args, **kwargs)
self._commands = {}
self.min_args, self.max_args = 0, 1
- def add_command(self, name, mod_or_funcs, help=''):
+ def add_command(self, name, mod_or_funcs, help=""):
"""name of the command, name of module or tuple of functions
(run, add_options)
"""
- assert isinstance(mod_or_funcs, str) or isinstance(mod_or_funcs, tuple), \
- "mod_or_funcs has to be a module name or a tuple of functions"
+ assert isinstance(mod_or_funcs, str) or isinstance(
+ mod_or_funcs, tuple
+ ), "mod_or_funcs has to be a module name or a tuple of functions"
self._commands[name] = (mod_or_funcs, help)
def print_main_help(self):
optparse.OptionParser.print_help(self)
- print('\ncommands:')
+ print("\ncommands:")
for cmdname, (_, help) in self._commands.items():
- print('% 10s - %s' % (cmdname, help))
+ print("% 10s - %s" % (cmdname, help))
def parse_command(self, args):
if len(args) == 0:
@@ -68,25 +73,23 @@ class OptionParser(optparse.OptionParser):
cmd = args[0]
args = args[1:]
if cmd not in self._commands:
- if cmd in ('-h', '--help'):
+ if cmd in ("-h", "--help"):
self.print_main_help()
sys.exit(0)
elif self.version is not None and cmd == "--version":
self.print_version()
sys.exit(0)
- self.error('unknown command')
- self.prog = '%s %s' % (self.prog, cmd)
+ self.error("unknown command")
+ self.prog = "%s %s" % (self.prog, cmd)
mod_or_f, help = self._commands[cmd]
# optparse inserts self.description between usage and options help
self.description = help
if isinstance(mod_or_f, str):
- exec('from %s import run, add_options' % mod_or_f)
+ exec("from %s import run, add_options" % mod_or_f)
else:
run, add_options = mod_or_f
add_options(self)
(options, args) = self.parse_args(args)
if not (self.min_args <= len(args) <= self.max_args):
- self.error('incorrect number of arguments')
+ self.error("incorrect number of arguments")
return run, options, args
-
-
diff --git a/logilab/common/proc.py b/logilab/common/proc.py
index 30e9494..2d2e78c 100644
--- a/logilab/common/proc.py
+++ b/logilab/common/proc.py
@@ -37,15 +37,19 @@ from time import time
from logilab.common.tree import Node
-class NoSuchProcess(Exception): pass
+
+class NoSuchProcess(Exception):
+ pass
+
def proc_exists(pid):
"""check the a pid is registered in /proc
raise NoSuchProcess exception if not
"""
- if not os.path.exists('/proc/%s' % pid):
+ if not os.path.exists("/proc/%s" % pid):
raise NoSuchProcess()
+
PPID = 3
UTIME = 13
STIME = 14
@@ -53,6 +57,7 @@ CUTIME = 15
CSTIME = 16
VSIZE = 22
+
class ProcInfo(Node):
"""provide access to process information found in /proc"""
@@ -60,19 +65,18 @@ class ProcInfo(Node):
self.pid = int(pid)
Node.__init__(self, self.pid)
proc_exists(self.pid)
- self.file = '/proc/%s/stat' % self.pid
+ self.file = "/proc/%s/stat" % self.pid
self.ppid = int(self.status()[PPID])
def memory_usage(self):
"""return the memory usage of the process in Ko"""
- try :
+ try:
return int(self.status()[VSIZE])
except IOError:
return 0
def lineage_memory_usage(self):
- return self.memory_usage() + sum([child.lineage_memory_usage()
- for child in self.children])
+ return self.memory_usage() + sum([child.lineage_memory_usage() for child in self.children])
def time(self, children=0):
"""return the number of jiffies that this process has been scheduled
@@ -90,13 +94,14 @@ class ProcInfo(Node):
def name(self):
"""return the process name found in /proc/<pid>/stat
"""
- return self.status()[1].strip('()')
+ return self.status()[1].strip("()")
def age(self):
"""return the age of the process
"""
return os.stat(self.file)[stat.ST_MTIME]
+
class ProcInfoLoader:
"""manage process information"""
@@ -105,7 +110,7 @@ class ProcInfoLoader:
def list_pids(self):
"""return a list of existent process ids"""
- for subdir in os.listdir('/proc'):
+ for subdir in os.listdir("/proc"):
if subdir.isdigit():
yield int(subdir)
@@ -120,7 +125,6 @@ class ProcInfoLoader:
self._loaded[pid] = procinfo
return procinfo
-
def load_all(self):
"""load all processes information"""
for pid in self.list_pids():
@@ -135,22 +139,29 @@ class ProcInfoLoader:
class ResourceError(Exception):
"""Error raise when resource limit is reached"""
+
limit = "Unknown Resource Limit"
class XCPUError(ResourceError):
"""Error raised when CPU Time limit is reached"""
+
limit = "CPU Time"
+
class LineageMemoryError(ResourceError):
"""Error raised when the total amount of memory used by a process and
it's child is reached"""
+
limit = "Lineage total Memory"
+
class TimeoutError(ResourceError):
"""Error raised when the process is running for to much time"""
+
limit = "Real Time"
+
# Can't use subclass because the StandardError MemoryError raised
RESOURCE_LIMIT_EXCEPTION = (ResourceError, MemoryError)
@@ -159,6 +170,7 @@ class MemorySentinel(Thread):
"""A class checking a process don't use too much memory in a separated
daemonic thread
"""
+
def __init__(self, interval, memory_limit, gpid=os.getpid()):
Thread.__init__(self, target=self._run, name="Test.Sentinel")
self.memory_limit = memory_limit
@@ -180,9 +192,7 @@ class MemorySentinel(Thread):
class ResourceController:
-
- def __init__(self, max_cpu_time=None, max_time=None, max_memory=None,
- max_reprieve=60):
+ def __init__(self, max_cpu_time=None, max_time=None, max_memory=None, max_reprieve=60):
if SIGXCPU == -1:
raise RuntimeError("Unsupported platform")
self.max_time = max_time
@@ -230,13 +240,12 @@ class ResourceController:
def setup_limit(self):
"""set up the process limit"""
- assert currentThread().getName() == 'MainThread'
+ assert currentThread().getName() == "MainThread"
os.setpgrp()
if self._limit_set <= 0:
if self.max_time is not None:
self._old_usr2_hdlr = signal(SIGUSR2, self._hangle_sig_timeout)
- self._timer = Timer(max(1, int(self.max_time) - self._elapse_time),
- self._time_out)
+ self._timer = Timer(max(1, int(self.max_time) - self._elapse_time), self._time_out)
self._start_time = int(time())
self._timer.start()
if self.max_cpu_time is not None:
@@ -245,7 +254,7 @@ class ResourceController:
self._old_sigxcpu_hdlr = signal(SIGXCPU, self._handle_sigxcpu)
setrlimit(RLIMIT_CPU, cpu_limit)
if self.max_memory is not None:
- self._msentinel = MemorySentinel(1, int(self.max_memory) )
+ self._msentinel = MemorySentinel(1, int(self.max_memory))
self._old_max_memory = getrlimit(RLIMIT_AS)
self._old_usr1_hdlr = signal(SIGUSR1, self._hangle_sig_memory)
as_limit = (int(self.max_memory), self._old_max_memory[1])
@@ -258,7 +267,7 @@ class ResourceController:
if self._limit_set > 0:
if self.max_time is not None:
self._timer.cancel()
- self._elapse_time += int(time())-self._start_time
+ self._elapse_time += int(time()) - self._start_time
self._timer = None
signal(SIGUSR2, self._old_usr2_hdlr)
if self.max_cpu_time is not None:
diff --git a/logilab/common/pytest.py b/logilab/common/pytest.py
index 6819c01..0f89ddf 100644
--- a/logilab/common/pytest.py
+++ b/logilab/common/pytest.py
@@ -124,6 +124,7 @@ import traceback
from inspect import isgeneratorfunction, isclass, FrameInfo
from random import shuffle
from itertools import dropwhile
+
# mypy error: Module 'unittest.runner' has no attribute '_WritelnDecorator'
# but it does
from unittest.runner import _WritelnDecorator # type: ignore
@@ -135,6 +136,7 @@ from logilab.common.deprecation import deprecated
from logilab.common.fileutils import abspath_listdir
from logilab.common import textutils
from logilab.common import testlib, STD_BLACKLIST
+
# use the same unittest module as testlib
from logilab.common.testlib import unittest, start_interactive_mode
from logilab.common.testlib import nocoverage, pause_trace, replace_trace # bwcompat
@@ -142,6 +144,7 @@ from logilab.common.debugger import Debugger, colorize_source
import doctest
import unittest as unittest_legacy
+
if not getattr(unittest_legacy, "__package__", None):
try:
import unittest2.suite as unittest_suite
@@ -154,18 +157,24 @@ else:
try:
import django
from logilab.common.modutils import modpath_from_file, load_module_from_modpath
+
DJANGO_FOUND = True
except ImportError:
DJANGO_FOUND = False
-CONF_FILE = 'pytestconf.py'
+CONF_FILE = "pytestconf.py"
TESTFILE_RE = re.compile("^((unit)?test.*|smoketest)\.py$")
+
+
def this_is_a_testfile(filename: str) -> Optional[Match]:
"""returns True if `filename` seems to be a test file"""
return TESTFILE_RE.match(osp.basename(filename))
+
TESTDIR_RE = re.compile("^(unit)?tests?$")
+
+
def this_is_a_testdir(dirpath: str) -> Optional[Match]:
"""returns True if `filename` seems to be a test directory"""
return TESTDIR_RE.match(osp.basename(dirpath))
@@ -176,10 +185,10 @@ def load_pytest_conf(path, parser):
and / or tester.
"""
namespace = {}
- exec(open(path, 'rb').read(), namespace)
- if 'update_parser' in namespace:
- namespace['update_parser'](parser)
- return namespace.get('CustomPyTester', PyTester)
+ exec(open(path, "rb").read(), namespace)
+ if "update_parser" in namespace:
+ namespace["update_parser"](parser)
+ return namespace.get("CustomPyTester", PyTester)
def project_root(parser, projdir=os.getcwd()):
@@ -189,8 +198,7 @@ def project_root(parser, projdir=os.getcwd()):
conf_file_path = osp.join(curdir, CONF_FILE)
if osp.isfile(conf_file_path):
testercls = load_pytest_conf(conf_file_path, parser)
- while this_is_a_testdir(curdir) or \
- osp.isfile(osp.join(curdir, '__init__.py')):
+ while this_is_a_testdir(curdir) or osp.isfile(osp.join(curdir, "__init__.py")):
newdir = osp.normpath(osp.join(curdir, os.pardir))
if newdir == curdir:
break
@@ -204,6 +212,7 @@ def project_root(parser, projdir=os.getcwd()):
class GlobalTestReport(object):
"""this class holds global test statistics"""
+
def __init__(self):
self.ran = 0
self.skipped = 0
@@ -218,7 +227,7 @@ class GlobalTestReport(object):
"""integrates new test information into internal statistics"""
ran = testresult.testsRun
self.ran += ran
- self.skipped += len(getattr(testresult, 'skipped', ()))
+ self.skipped += len(getattr(testresult, "skipped", ()))
self.failures += len(testresult.failures)
self.errors += len(testresult.errors)
self.ttime += ttime
@@ -243,27 +252,24 @@ class GlobalTestReport(object):
def __str__(self):
"""this is just presentation stuff"""
- line1 = ['Ran %s test cases in %.2fs (%.2fs CPU)'
- % (self.ran, self.ttime, self.ctime)]
+ line1 = ["Ran %s test cases in %.2fs (%.2fs CPU)" % (self.ran, self.ttime, self.ctime)]
if self.errors:
- line1.append('%s errors' % self.errors)
+ line1.append("%s errors" % self.errors)
if self.failures:
- line1.append('%s failures' % self.failures)
+ line1.append("%s failures" % self.failures)
if self.skipped:
- line1.append('%s skipped' % self.skipped)
+ line1.append("%s skipped" % self.skipped)
modulesok = self.modulescount - len(self.errmodules)
if self.errors or self.failures:
- line2 = '%s modules OK (%s failed)' % (modulesok,
- len(self.errmodules))
- descr = ', '.join(['%s [%s/%s]' % info for info in self.errmodules])
- line3 = '\nfailures: %s' % descr
+ line2 = "%s modules OK (%s failed)" % (modulesok, len(self.errmodules))
+ descr = ", ".join(["%s [%s/%s]" % info for info in self.errmodules])
+ line3 = "\nfailures: %s" % descr
elif modulesok:
- line2 = 'All %s modules OK' % modulesok
- line3 = ''
+ line2 = "All %s modules OK" % modulesok
+ line3 = ""
else:
- return ''
- return '%s\n%s%s' % (', '.join(line1), line2, line3)
-
+ return ""
+ return "%s\n%s%s" % (", ".join(line1), line2, line3)
def remove_local_modules_from_sys(testdir):
@@ -282,7 +288,7 @@ def remove_local_modules_from_sys(testdir):
for modname, mod in list(sys.modules.items()):
if mod is None:
continue
- if not hasattr(mod, '__file__'):
+ if not hasattr(mod, "__file__"):
# this is the case of some built-in modules like sys, imp, marshal
continue
modfile = mod.__file__
@@ -292,7 +298,6 @@ def remove_local_modules_from_sys(testdir):
del sys.modules[modname]
-
class PyTester(object):
"""encapsulates testrun logic"""
@@ -317,6 +322,7 @@ class PyTester(object):
def set_errcode(self, errcode):
self._errcode = errcode
+
errcode = property(get_errcode, set_errcode)
def testall(self, exitfirst=False):
@@ -358,9 +364,11 @@ class PyTester(object):
restartfile = open(FILE_RESTART, "w")
restartfile.close()
except Exception:
- print("Error while overwriting succeeded test file :",
- osp.join(os.getcwd(), FILE_RESTART),
- file=sys.__stderr__)
+ print(
+ "Error while overwriting succeeded test file :",
+ osp.join(os.getcwd(), FILE_RESTART),
+ file=sys.__stderr__,
+ )
raise
# run test and collect information
prog = self.testfile(filename, batchmode=True)
@@ -386,17 +394,24 @@ class PyTester(object):
restartfile = open(FILE_RESTART, "w")
restartfile.close()
except Exception:
- print("Error while overwriting succeeded test file :",
- osp.join(os.getcwd(), FILE_RESTART), file=sys.__stderr__)
+ print(
+ "Error while overwriting succeeded test file :",
+ osp.join(os.getcwd(), FILE_RESTART),
+ file=sys.__stderr__,
+ )
raise
modname = osp.basename(filename)[:-3]
- print((' %s ' % osp.basename(filename)).center(70, '='),
- file=sys.__stderr__)
+ print((" %s " % osp.basename(filename)).center(70, "="), file=sys.__stderr__)
try:
tstart, cstart = time(), process_time()
try:
- testprog = SkipAwareTestProgram(modname, batchmode=batchmode, cvg=self.cvg,
- options=self.options, outstream=sys.stderr)
+ testprog = SkipAwareTestProgram(
+ modname,
+ batchmode=batchmode,
+ cvg=self.cvg,
+ options=self.options,
+ outstream=sys.stderr,
+ )
except KeyboardInterrupt:
raise
except SystemExit as exc:
@@ -408,9 +423,9 @@ class PyTester(object):
return None
except Exception:
self.report.failed_to_test_module(filename)
- print('unhandled exception occurred while testing', modname,
- file=sys.stderr)
+ print("unhandled exception occurred while testing", modname, file=sys.stderr)
import traceback
+
traceback.print_exc(file=sys.stderr)
return None
@@ -423,23 +438,23 @@ class PyTester(object):
os.chdir(here)
-
class DjangoTester(PyTester):
-
def load_django_settings(self, dirname):
"""try to find project's setting and load it"""
curdir = osp.abspath(dirname)
previousdir = curdir
- while not osp.isfile(osp.join(curdir, 'settings.py')) and \
- osp.isfile(osp.join(curdir, '__init__.py')):
+ while not osp.isfile(osp.join(curdir, "settings.py")) and osp.isfile(
+ osp.join(curdir, "__init__.py")
+ ):
newdir = osp.normpath(osp.join(curdir, os.pardir))
if newdir == curdir:
- raise AssertionError('could not find settings.py')
+ raise AssertionError("could not find settings.py")
previousdir = curdir
curdir = newdir
# late django initialization
- settings = load_module_from_modpath(modpath_from_file(osp.join(curdir, 'settings.py')))
+ settings = load_module_from_modpath(modpath_from_file(osp.join(curdir, "settings.py")))
from django.core.management import setup_environ
+
setup_environ(settings)
settings.DEBUG = False
self.settings = settings
@@ -451,6 +466,7 @@ class DjangoTester(PyTester):
# Those imports must be done **after** setup_environ was called
from django.test.utils import setup_test_environment
from django.test.utils import create_test_db
+
setup_test_environment()
create_test_db(verbosity=0)
self.dbname = self.settings.TEST_DATABASE_NAME
@@ -459,8 +475,9 @@ class DjangoTester(PyTester):
# Those imports must be done **after** setup_environ was called
from django.test.utils import teardown_test_environment
from django.test.utils import destroy_test_db
+
teardown_test_environment()
- print('destroying', self.dbname)
+ print("destroying", self.dbname)
destroy_test_db(self.dbname, verbosity=0)
def testall(self, exitfirst=False):
@@ -468,16 +485,16 @@ class DjangoTester(PyTester):
which can be considered as a testdir and runs every test there
"""
for dirname, dirs, files in os.walk(os.getcwd()):
- for skipped in ('CVS', '.svn', '.hg'):
+ for skipped in ("CVS", ".svn", ".hg"):
if skipped in dirs:
dirs.remove(skipped)
- if 'tests.py' in files:
+ if "tests.py" in files:
if not self.testonedir(dirname, exitfirst):
break
dirs[:] = []
else:
basename = osp.basename(dirname)
- if basename in ('test', 'tests'):
+ if basename in ("test", "tests"):
print("going into", dirname)
# we found a testdir, let's explore it !
if not self.testonedir(dirname, exitfirst):
@@ -492,11 +509,10 @@ class DjangoTester(PyTester):
"""
# special django behaviour : if tests are splitted in several files,
# remove the main tests.py file and tests each test file separately
- testfiles = [fpath for fpath in abspath_listdir(testdir)
- if this_is_a_testfile(fpath)]
+ testfiles = [fpath for fpath in abspath_listdir(testdir) if this_is_a_testfile(fpath)]
if len(testfiles) > 1:
try:
- testfiles.remove(osp.join(testdir, 'tests.py'))
+ testfiles.remove(osp.join(testdir, "tests.py"))
except ValueError:
pass
for filename in testfiles:
@@ -519,8 +535,7 @@ class DjangoTester(PyTester):
os.chdir(dirname)
self.load_django_settings(dirname)
modname = osp.basename(filename)[:-3]
- print((' %s ' % osp.basename(filename)).center(70, '='),
- file=sys.stderr)
+ print((" %s " % osp.basename(filename)).center(70, "="), file=sys.stderr)
try:
try:
tstart, cstart = time(), process_time()
@@ -534,10 +549,11 @@ class DjangoTester(PyTester):
raise
except Exception as exc:
import traceback
+
traceback.print_exc()
self.report.failed_to_test_module(filename)
- print('unhandled exception occurred while testing', modname)
- print('error: %s' % exc)
+ print("unhandled exception occurred while testing", modname)
+ print("error: %s" % exc)
return None
finally:
self.after_testfile()
@@ -549,9 +565,11 @@ def make_parser():
"""creates the OptionParser instance
"""
from optparse import OptionParser
+
parser = OptionParser(usage=PYTEST_DOC)
parser.newargs = []
+
def rebuild_cmdline(option, opt, value, parser):
"""carry the option to unittest_main"""
parser.newargs.append(opt)
@@ -564,50 +582,89 @@ def make_parser():
setattr(parser.values, option.dest, True)
def capture_and_rebuild(option, opt, value, parser):
- warnings.simplefilter('ignore', DeprecationWarning)
+ warnings.simplefilter("ignore", DeprecationWarning)
rebuild_cmdline(option, opt, value, parser)
# logilab-pytest options
- parser.add_option('-t', dest='testdir', default=None,
- help="directory where the tests will be found")
- parser.add_option('-d', dest='dbc', default=False,
- action="store_true", help="enable design-by-contract")
+ parser.add_option(
+ "-t", dest="testdir", default=None, help="directory where the tests will be found"
+ )
+ parser.add_option(
+ "-d", dest="dbc", default=False, action="store_true", help="enable design-by-contract"
+ )
# unittest_main options provided and passed through logilab-pytest
- parser.add_option('-v', '--verbose', callback=rebuild_cmdline,
- action="callback", help="Verbose output")
- parser.add_option('-i', '--pdb', callback=rebuild_and_store,
- dest="pdb", action="callback",
- help="Enable test failure inspection")
- parser.add_option('-x', '--exitfirst', callback=rebuild_and_store,
- dest="exitfirst", default=False,
- action="callback", help="Exit on first failure "
- "(only make sense when logilab-pytest run one test file)")
- parser.add_option('-R', '--restart', callback=rebuild_and_store,
- dest="restart", default=False,
- action="callback",
- help="Restart tests from where it failed (implies exitfirst) "
- "(only make sense if tests previously ran with exitfirst only)")
- parser.add_option('--color', callback=rebuild_cmdline,
- action="callback",
- help="colorize tracebacks")
- parser.add_option('-s', '--skip',
- # XXX: I wish I could use the callback action but it
- # doesn't seem to be able to get the value
- # associated to the option
- action="store", dest="skipped", default=None,
- help="test names matching this name will be skipped "
- "to skip several patterns, use commas")
- parser.add_option('-q', '--quiet', callback=rebuild_cmdline,
- action="callback", help="Minimal output")
- parser.add_option('-P', '--profile', default=None, dest='profile',
- help="Profile execution and store data in the given file")
- parser.add_option('-m', '--match', default=None, dest='tags_pattern',
- help="only execute test whose tag match the current pattern")
+ parser.add_option(
+ "-v", "--verbose", callback=rebuild_cmdline, action="callback", help="Verbose output"
+ )
+ parser.add_option(
+ "-i",
+ "--pdb",
+ callback=rebuild_and_store,
+ dest="pdb",
+ action="callback",
+ help="Enable test failure inspection",
+ )
+ parser.add_option(
+ "-x",
+ "--exitfirst",
+ callback=rebuild_and_store,
+ dest="exitfirst",
+ default=False,
+ action="callback",
+ help="Exit on first failure " "(only make sense when logilab-pytest run one test file)",
+ )
+ parser.add_option(
+ "-R",
+ "--restart",
+ callback=rebuild_and_store,
+ dest="restart",
+ default=False,
+ action="callback",
+ help="Restart tests from where it failed (implies exitfirst) "
+ "(only make sense if tests previously ran with exitfirst only)",
+ )
+ parser.add_option(
+ "--color", callback=rebuild_cmdline, action="callback", help="colorize tracebacks"
+ )
+ parser.add_option(
+ "-s",
+ "--skip",
+ # XXX: I wish I could use the callback action but it
+ # doesn't seem to be able to get the value
+ # associated to the option
+ action="store",
+ dest="skipped",
+ default=None,
+ help="test names matching this name will be skipped "
+ "to skip several patterns, use commas",
+ )
+ parser.add_option(
+ "-q", "--quiet", callback=rebuild_cmdline, action="callback", help="Minimal output"
+ )
+ parser.add_option(
+ "-P",
+ "--profile",
+ default=None,
+ dest="profile",
+ help="Profile execution and store data in the given file",
+ )
+ parser.add_option(
+ "-m",
+ "--match",
+ default=None,
+ dest="tags_pattern",
+ help="only execute test whose tag match the current pattern",
+ )
if DJANGO_FOUND:
- parser.add_option('-J', '--django', dest='django', default=False,
- action="store_true",
- help='use logilab-pytest for django test cases')
+ parser.add_option(
+ "-J",
+ "--django",
+ dest="django",
+ default=False,
+ action="store_true",
+ help="use logilab-pytest for django test cases",
+ )
return parser
@@ -617,7 +674,7 @@ def parseargs(parser):
"""
# parse the command line
options, args = parser.parse_args()
- filenames = [arg for arg in args if arg.endswith('.py')]
+ filenames = [arg for arg in args if arg.endswith(".py")]
if filenames:
if len(filenames) > 1:
parser.error("only one filename is acceptable")
@@ -629,7 +686,7 @@ def parseargs(parser):
testlib.ENABLE_DBC = options.dbc
newargs = parser.newargs
if options.skipped:
- newargs.extend(['--skip', options.skipped])
+ newargs.extend(["--skip", options.skipped])
# restart implies exitfirst
if options.restart:
options.exitfirst = True
@@ -639,8 +696,7 @@ def parseargs(parser):
return options, explicitfile
-
-@deprecated('[logilab-common 1.3] logilab-pytest is deprecated, use another test runner')
+@deprecated("[logilab-common 1.3] logilab-pytest is deprecated, use another test runner")
def run():
parser = make_parser()
rootdir, testercls = project_root(parser)
@@ -648,8 +704,8 @@ def run():
# mock a new command line
sys.argv[1:] = parser.newargs
cvg = None
- if not '' in sys.path:
- sys.path.insert(0, '')
+ if not "" in sys.path:
+ sys.path.insert(0, "")
if DJANGO_FOUND and options.django:
tester = DjangoTester(cvg, options)
else:
@@ -664,21 +720,24 @@ def run():
try:
if options.profile:
import hotshot
+
prof = hotshot.Profile(options.profile)
prof.runcall(cmd, *args)
prof.close()
- print('profile data saved in', options.profile)
+ print("profile data saved in", options.profile)
else:
cmd(*args)
except SystemExit:
raise
except:
import traceback
+
traceback.print_exc()
finally:
tester.show_report()
sys.exit(tester.errcode)
+
class SkipAwareTestProgram(unittest.TestProgram):
# XXX: don't try to stay close to unittest.py, use optparse
USAGE = """\
@@ -705,15 +764,23 @@ Examples:
%(progName)s MyTestCase - run all 'test*' test methods
in MyTestCase
"""
- def __init__(self, module='__main__', defaultTest=None, batchmode=False,
- cvg=None, options=None, outstream=sys.stderr):
+
+ def __init__(
+ self,
+ module="__main__",
+ defaultTest=None,
+ batchmode=False,
+ cvg=None,
+ options=None,
+ outstream=sys.stderr,
+ ):
self.batchmode = batchmode
self.cvg = cvg
self.options = options
self.outstream = outstream
super(SkipAwareTestProgram, self).__init__(
- module=module, defaultTest=defaultTest,
- testLoader=NonStrictTestLoader())
+ module=module, defaultTest=defaultTest, testLoader=NonStrictTestLoader()
+ )
def parseArgs(self, argv):
self.pdbmode = False
@@ -724,40 +791,51 @@ Examples:
self.colorize = False
self.profile_name = None
import getopt
+
try:
- options, args = getopt.getopt(argv[1:], 'hHvixrqcp:s:m:P:',
- ['help', 'verbose', 'quiet', 'pdb',
- 'exitfirst', 'restart',
- 'skip=', 'color', 'match=', 'profile='])
+ options, args = getopt.getopt(
+ argv[1:],
+ "hHvixrqcp:s:m:P:",
+ [
+ "help",
+ "verbose",
+ "quiet",
+ "pdb",
+ "exitfirst",
+ "restart",
+ "skip=",
+ "color",
+ "match=",
+ "profile=",
+ ],
+ )
for opt, value in options:
- if opt in ('-h', '-H', '--help'):
+ if opt in ("-h", "-H", "--help"):
self.usageExit()
- if opt in ('-i', '--pdb'):
+ if opt in ("-i", "--pdb"):
self.pdbmode = True
- if opt in ('-x', '--exitfirst'):
+ if opt in ("-x", "--exitfirst"):
self.exitfirst = True
- if opt in ('-r', '--restart'):
+ if opt in ("-r", "--restart"):
self.restart = True
self.exitfirst = True
- if opt in ('-q', '--quiet'):
+ if opt in ("-q", "--quiet"):
self.verbosity = 0
- if opt in ('-v', '--verbose'):
+ if opt in ("-v", "--verbose"):
self.verbosity = 2
- if opt in ('-s', '--skip'):
- self.skipped_patterns = [pat.strip() for pat in
- value.split(', ')]
- if opt == '--color':
+ if opt in ("-s", "--skip"):
+ self.skipped_patterns = [pat.strip() for pat in value.split(", ")]
+ if opt == "--color":
self.colorize = True
- if opt in ('-m', '--match'):
- #self.tags_pattern = value
+ if opt in ("-m", "--match"):
+ # self.tags_pattern = value
self.options["tag_pattern"] = value
- if opt in ('-P', '--profile'):
+ if opt in ("-P", "--profile"):
self.profile_name = value
self.testLoader.skipped_patterns = self.skipped_patterns
if len(args) == 0 and self.defaultTest is None:
- suitefunc = getattr(self.module, 'suite', None)
- if isinstance(suitefunc, (types.FunctionType,
- types.MethodType)):
+ suitefunc = getattr(self.module, "suite", None)
+ if isinstance(suitefunc, (types.FunctionType, types.MethodType)):
self.test = self.module.suite()
else:
self.test = self.testLoader.loadTestsFromModule(self.module)
@@ -766,7 +844,7 @@ Examples:
self.test_pattern = args[0]
self.testNames = args
else:
- self.testNames = (self.defaultTest, )
+ self.testNames = (self.defaultTest,)
self.createTests()
except getopt.error as msg:
self.usageExit(msg)
@@ -774,21 +852,24 @@ Examples:
def runTests(self):
if self.profile_name:
import cProfile
- cProfile.runctx('self._runTests()', globals(), locals(), self.profile_name )
+
+ cProfile.runctx("self._runTests()", globals(), locals(), self.profile_name)
else:
return self._runTests()
def _runTests(self):
- self.testRunner = SkipAwareTextTestRunner(verbosity=self.verbosity,
- stream=self.outstream,
- exitfirst=self.exitfirst,
- pdbmode=self.pdbmode,
- cvg=self.cvg,
- test_pattern=self.test_pattern,
- skipped_patterns=self.skipped_patterns,
- colorize=self.colorize,
- batchmode=self.batchmode,
- options=self.options)
+ self.testRunner = SkipAwareTextTestRunner(
+ verbosity=self.verbosity,
+ stream=self.outstream,
+ exitfirst=self.exitfirst,
+ pdbmode=self.pdbmode,
+ cvg=self.cvg,
+ test_pattern=self.test_pattern,
+ skipped_patterns=self.skipped_patterns,
+ colorize=self.colorize,
+ batchmode=self.batchmode,
+ options=self.options,
+ )
def removeSucceededTests(obj, succTests):
""" Recursive function that removes succTests from
@@ -801,32 +882,33 @@ Examples:
if isinstance(el, unittest.TestSuite):
removeSucceededTests(el, succTests)
elif isinstance(el, unittest.TestCase):
- descr = '.'.join((el.__class__.__module__,
- el.__class__.__name__,
- el._testMethodName))
+ descr = ".".join(
+ (el.__class__.__module__, el.__class__.__name__, el._testMethodName)
+ )
if descr in succTests:
obj.remove(el)
+
# take care, self.options may be None
- if getattr(self.options, 'restart', False):
+ if getattr(self.options, "restart", False):
# retrieve succeeded tests from FILE_RESTART
try:
- restartfile = open(FILE_RESTART, 'r')
+ restartfile = open(FILE_RESTART, "r")
try:
- succeededtests = list(elem.rstrip('\n\r') for elem in
- restartfile.readlines())
+ succeededtests = list(elem.rstrip("\n\r") for elem in restartfile.readlines())
removeSucceededTests(self.test, succeededtests)
finally:
restartfile.close()
except Exception as ex:
- raise Exception("Error while reading succeeded tests into %s: %s"
- % (osp.join(os.getcwd(), FILE_RESTART), ex))
+ raise Exception(
+ "Error while reading succeeded tests into %s: %s"
+ % (osp.join(os.getcwd(), FILE_RESTART), ex)
+ )
result = self.testRunner.run(self.test)
# help garbage collection: we want TestSuite, which hold refs to every
# executed TestCase, to be gc'ed
del self.test
- if getattr(result, "debuggers", None) and \
- getattr(self, "pdbmode", None):
+ if getattr(result, "debuggers", None) and getattr(self, "pdbmode", None):
start_interactive_mode(result)
if not getattr(self, "batchmode", None):
sys.exit(not result.wasSuccessful())
@@ -834,13 +916,20 @@ Examples:
class SkipAwareTextTestRunner(unittest.TextTestRunner):
-
- def __init__(self, stream=sys.stderr, verbosity=1,
- exitfirst=False, pdbmode=False, cvg=None, test_pattern=None,
- skipped_patterns=(), colorize=False, batchmode=False,
- options=None):
- super(SkipAwareTextTestRunner, self).__init__(stream=stream,
- verbosity=verbosity)
+ def __init__(
+ self,
+ stream=sys.stderr,
+ verbosity=1,
+ exitfirst=False,
+ pdbmode=False,
+ cvg=None,
+ test_pattern=None,
+ skipped_patterns=(),
+ colorize=False,
+ batchmode=False,
+ options=None,
+ ):
+ super(SkipAwareTextTestRunner, self).__init__(stream=stream, verbosity=verbosity)
self.exitfirst = exitfirst
self.pdbmode = pdbmode
self.cvg = cvg
@@ -859,23 +948,23 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner):
else:
if isinstance(test, testlib.TestCase):
meth = test._get_test_method()
- testname = '%s.%s' % (test.__name__, meth.__name__)
+ testname = "%s.%s" % (test.__name__, meth.__name__)
elif isinstance(test, types.FunctionType):
func = test
testname = func.__name__
elif isinstance(test, types.MethodType):
cls = test.__self__.__class__
- testname = '%s.%s' % (cls.__name__, test.__name__)
+ testname = "%s.%s" % (cls.__name__, test.__name__)
else:
- return True # Not sure when this happens
+ return True # Not sure when this happens
if isgeneratorfunction(test) and skipgenerator:
- return self.does_match_tags(test) # Let inner tests decide at run time
+ return self.does_match_tags(test) # Let inner tests decide at run time
if self._this_is_skipped(testname):
- return False # this was explicitly skipped
+ return False # this was explicitly skipped
if self.test_pattern is not None:
try:
- classpattern, testpattern = self.test_pattern.split('.')
- klass, name = testname.split('.')
+ classpattern, testpattern = self.test_pattern.split(".")
+ klass, name = testname.split(".")
if classpattern not in klass or testpattern not in name:
return False
except ValueError:
@@ -886,18 +975,24 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner):
def does_match_tags(self, test: Callable) -> bool:
if self.options is not None:
- tags_pattern = getattr(self.options, 'tags_pattern', None)
+ tags_pattern = getattr(self.options, "tags_pattern", None)
if tags_pattern is not None:
- tags = getattr(test, 'tags', testlib.Tags())
+ tags = getattr(test, "tags", testlib.Tags())
if tags.inherit and isinstance(test, types.MethodType):
- tags = tags | getattr(test.__self__.__class__, 'tags', testlib.Tags())
+ tags = tags | getattr(test.__self__.__class__, "tags", testlib.Tags())
return tags.match(tags_pattern)
- return True # no pattern
-
- def _makeResult(self) -> 'SkipAwareTestResult':
- return SkipAwareTestResult(self.stream, self.descriptions,
- self.verbosity, self.exitfirst,
- self.pdbmode, self.cvg, self.colorize)
+ return True # no pattern
+
+ def _makeResult(self) -> "SkipAwareTestResult":
+ return SkipAwareTestResult(
+ self.stream,
+ self.descriptions,
+ self.verbosity,
+ self.exitfirst,
+ self.pdbmode,
+ self.cvg,
+ self.colorize,
+ )
def run(self, test):
"Run the given test case or test suite."
@@ -910,43 +1005,48 @@ class SkipAwareTextTestRunner(unittest.TextTestRunner):
if not self.batchmode:
self.stream.writeln(result.separator2)
run = result.testsRun
- self.stream.writeln("Ran %d test%s in %.3fs" %
- (run, run != 1 and "s" or "", timeTaken))
+ self.stream.writeln("Ran %d test%s in %.3fs" % (run, run != 1 and "s" or "", timeTaken))
self.stream.writeln()
if not result.wasSuccessful():
if self.colorize:
- self.stream.write(textutils.colorize_ansi("FAILED", color='red'))
+ self.stream.write(textutils.colorize_ansi("FAILED", color="red"))
else:
self.stream.write("FAILED")
else:
if self.colorize:
- self.stream.write(textutils.colorize_ansi("OK", color='green'))
+ self.stream.write(textutils.colorize_ansi("OK", color="green"))
else:
self.stream.write("OK")
- failed, errored, skipped = map(len, (result.failures,
- result.errors,
- result.skipped))
+ failed, errored, skipped = map(len, (result.failures, result.errors, result.skipped))
det_results = []
- for name, value in (("failures", result.failures),
- ("errors",result.errors),
- ("skipped", result.skipped)):
+ for name, value in (
+ ("failures", result.failures),
+ ("errors", result.errors),
+ ("skipped", result.skipped),
+ ):
if value:
det_results.append("%s=%i" % (name, len(value)))
if det_results:
self.stream.write(" (")
- self.stream.write(', '.join(det_results))
+ self.stream.write(", ".join(det_results))
self.stream.write(")")
self.stream.writeln("")
return result
class SkipAwareTestResult(unittest._TextTestResult):
-
- def __init__(self, stream: _WritelnDecorator, descriptions: bool, verbosity: int,
- exitfirst: bool = False, pdbmode: bool = False, cvg: Optional[Any] = None, colorize: bool = False) -> None:
- super(SkipAwareTestResult, self).__init__(stream,
- descriptions, verbosity)
+ def __init__(
+ self,
+ stream: _WritelnDecorator,
+ descriptions: bool,
+ verbosity: int,
+ exitfirst: bool = False,
+ pdbmode: bool = False,
+ cvg: Optional[Any] = None,
+ colorize: bool = False,
+ ) -> None:
+ super(SkipAwareTestResult, self).__init__(stream, descriptions, verbosity)
self.skipped: List[Tuple[Any, Any]] = []
self.debuggers: List = []
self.fail_descrs: List = []
@@ -959,10 +1059,10 @@ class SkipAwareTestResult(unittest._TextTestResult):
self.verbose = verbosity > 1
def descrs_for(self, flavour: str) -> List[Tuple[int, str]]:
- return getattr(self, '%s_descrs' % flavour.lower())
+ return getattr(self, "%s_descrs" % flavour.lower())
def _create_pdb(self, test_descr: str, flavour: str) -> None:
- self.descrs_for(flavour).append( (len(self.debuggers), test_descr) )
+ self.descrs_for(flavour).append((len(self.debuggers), test_descr))
if self.pdbmode:
self.debuggers.append(self.pdbclass(sys.exc_info()[2]))
@@ -982,34 +1082,34 @@ class SkipAwareTestResult(unittest._TextTestResult):
--verbose is passed
"""
exctype, exc, tb = err
- output = ['Traceback (most recent call last)']
+ output = ["Traceback (most recent call last)"]
frames = inspect.getinnerframes(tb)
colorize = self.colorize
frames = enumerate(self._iter_valid_frames(frames))
for index, (frame, filename, lineno, funcname, ctx, ctxindex) in frames:
filename = osp.abspath(filename)
- if ctx is None: # pyc files or C extensions for instance
- source = '<no source available>'
+ if ctx is None: # pyc files or C extensions for instance
+ source = "<no source available>"
else:
- source = ''.join(ctx)
+ source = "".join(ctx)
if colorize:
- filename = textutils.colorize_ansi(filename, 'magenta')
+ filename = textutils.colorize_ansi(filename, "magenta")
source = colorize_source(source)
output.append(' File "%s", line %s, in %s' % (filename, lineno, funcname))
- output.append(' %s' % source.strip())
+ output.append(" %s" % source.strip())
if self.verbose:
- output.append('%r == %r' % (dir(frame), test.__module__))
- output.append('')
- output.append(' ' + ' local variables '.center(66, '-'))
+ output.append("%r == %r" % (dir(frame), test.__module__))
+ output.append("")
+ output.append(" " + " local variables ".center(66, "-"))
for varname, value in sorted(frame.f_locals.items()):
- output.append(' %s: %r' % (varname, value))
- if varname == 'self': # special handy processing for self
+ output.append(" %s: %r" % (varname, value))
+ if varname == "self": # special handy processing for self
for varname, value in sorted(vars(value).items()):
- output.append(' self.%s: %r' % (varname, value))
- output.append(' ' + '-' * 66)
- output.append('')
- output.append(''.join(traceback.format_exception_only(exctype, exc)))
- return '\n'.join(output)
+ output.append(" self.%s: %r" % (varname, value))
+ output.append(" " + "-" * 66)
+ output.append("")
+ output.append("".join(traceback.format_exception_only(exctype, exc)))
+ return "\n".join(output)
def addError(self, test, err):
"""err -> (exc_type, exc, tcbk)"""
@@ -1022,21 +1122,21 @@ class SkipAwareTestResult(unittest._TextTestResult):
self.shouldStop = True
descr = self.getDescription(test)
super(SkipAwareTestResult, self).addError(test, err)
- self._create_pdb(descr, 'error')
+ self._create_pdb(descr, "error")
def addFailure(self, test, err):
if self.exitfirst:
self.shouldStop = True
descr = self.getDescription(test)
super(SkipAwareTestResult, self).addFailure(test, err)
- self._create_pdb(descr, 'fail')
+ self._create_pdb(descr, "fail")
def addSkip(self, test, reason):
self.skipped.append((test, reason))
if self.showAll:
self.stream.writeln("SKIPPED")
elif self.dots:
- self.stream.write('S')
+ self.stream.write("S")
def printErrors(self) -> None:
super(SkipAwareTestResult, self).printErrors()
@@ -1047,7 +1147,7 @@ class SkipAwareTestResult(unittest._TextTestResult):
for test, err in self.skipped:
descr = self.getDescription(test)
self.stream.writeln(self.separator1)
- self.stream.writeln("%s: %s" % ('SKIPPED', descr))
+ self.stream.writeln("%s: %s" % ("SKIPPED", descr))
self.stream.writeln("\t%s" % err)
def printErrorList(self, flavour, errors):
@@ -1056,32 +1156,42 @@ class SkipAwareTestResult(unittest._TextTestResult):
self.stream.writeln("%s: %s" % (flavour, descr))
self.stream.writeln(self.separator2)
self.stream.writeln(err)
- self.stream.writeln('no stdout'.center(len(self.separator2)))
- self.stream.writeln('no stderr'.center(len(self.separator2)))
+ self.stream.writeln("no stdout".center(len(self.separator2)))
+ self.stream.writeln("no stderr".center(len(self.separator2)))
from .decorators import monkeypatch
+
orig_call = testlib.TestCase.__call__
-@monkeypatch(testlib.TestCase, '__call__')
-def call(self: Any, result: SkipAwareTestResult = None, runcondition: Optional[Callable] = None, options: Optional[Any] = None) -> None:
+
+
+@monkeypatch(testlib.TestCase, "__call__")
+def call(
+ self: Any,
+ result: SkipAwareTestResult = None,
+ runcondition: Optional[Callable] = None,
+ options: Optional[Any] = None,
+) -> None:
orig_call(self, result=result, runcondition=runcondition, options=options)
# mypy: Item "None" of "Optional[Any]" has no attribute "exitfirst"
# we check it first in the if
if hasattr(options, "exitfirst") and options.exitfirst: # type: ignore
# add this test to restart file
try:
- restartfile = open(FILE_RESTART, 'a')
+ restartfile = open(FILE_RESTART, "a")
try:
- descr = '.'.join((self.__class__.__module__,
- self.__class__.__name__,
- self._testMethodName))
- restartfile.write(descr+os.linesep)
+ descr = ".".join(
+ (self.__class__.__module__, self.__class__.__name__, self._testMethodName)
+ )
+ restartfile.write(descr + os.linesep)
finally:
restartfile.close()
except Exception:
- print("Error while saving succeeded test into",
- osp.join(os.getcwd(), FILE_RESTART),
- file=sys.__stderr__)
+ print(
+ "Error while saving succeeded test into",
+ osp.join(os.getcwd(), FILE_RESTART),
+ file=sys.__stderr__,
+ )
raise
@@ -1129,7 +1239,7 @@ class NonStrictTestLoader(unittest.TestLoader):
for obj in vars(module).values():
if isclass(obj) and issubclass(obj, unittest.TestCase):
classname = obj.__name__
- if classname[0] == '_' or self._this_is_skipped(classname):
+ if classname[0] == "_" or self._this_is_skipped(classname):
continue
methodnames = []
# obj is a TestCase class
@@ -1147,14 +1257,16 @@ class NonStrictTestLoader(unittest.TestLoader):
suite = getattr(module, suitename)()
except AttributeError:
return []
- assert hasattr(suite, '_tests'), \
- "%s.%s is not a valid TestSuite" % (module.__name__, suitename)
+ assert hasattr(suite, "_tests"), "%s.%s is not a valid TestSuite" % (
+ module.__name__,
+ suitename,
+ )
# python2.3 does not implement __iter__ on suites, we need to return
# _tests explicitly
return suite._tests
def loadTestsFromName(self, name, module=None):
- parts = name.split('.')
+ parts = name.split(".")
if module is None or len(parts) > 2:
# let the base class do its job here
return [super(NonStrictTestLoader, self).loadTestsFromName(name)]
@@ -1162,34 +1274,35 @@ class NonStrictTestLoader(unittest.TestLoader):
collected = []
if len(parts) == 1:
pattern = parts[0]
- if callable(getattr(module, pattern, None)
- ) and pattern not in tests:
+ if callable(getattr(module, pattern, None)) and pattern not in tests:
# consider it as a suite
return self.loadTestsFromSuite(module, pattern)
if pattern in tests:
# case python unittest_foo.py MyTestTC
klass, methodnames = tests[pattern]
for methodname in methodnames:
- collected = [klass(methodname)
- for methodname in methodnames]
+ collected = [klass(methodname) for methodname in methodnames]
else:
# case python unittest_foo.py something
for klass, methodnames in tests.values():
# skip methodname if matched by skipped_patterns
for skip_pattern in self.skipped_patterns:
- methodnames = [methodname
- for methodname in methodnames
- if skip_pattern not in methodname]
- collected += [klass(methodname)
- for methodname in methodnames
- if pattern in methodname]
+ methodnames = [
+ methodname
+ for methodname in methodnames
+ if skip_pattern not in methodname
+ ]
+ collected += [
+ klass(methodname) for methodname in methodnames if pattern in methodname
+ ]
elif len(parts) == 2:
# case "MyClass.test_1"
classname, pattern = parts
klass, methodnames = tests.get(classname, (None, []))
for methodname in methodnames:
- collected = [klass(methodname) for methodname in methodnames
- if pattern in methodname]
+ collected = [
+ klass(methodname) for methodname in methodnames if pattern in methodname
+ ]
return collected
def _this_is_skipped(self, testedname: str) -> bool:
@@ -1202,10 +1315,9 @@ class NonStrictTestLoader(unittest.TestLoader):
"""
is_skipped = self._this_is_skipped
classname = testCaseClass.__name__
- if classname[0] == '_' or is_skipped(classname):
+ if classname[0] == "_" or is_skipped(classname):
return []
- testnames = super(NonStrictTestLoader, self).getTestCaseNames(
- testCaseClass)
+ testnames = super(NonStrictTestLoader, self).getTestCaseNames(testCaseClass)
return [testname for testname in testnames if not is_skipped(testname)]
@@ -1214,13 +1326,27 @@ class NonStrictTestLoader(unittest.TestLoader):
# It is used to monkeypatch the original implementation to support
# extra runcondition and options arguments (see in testlib.py)
-def _ts_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult:
+
+def _ts_run(
+ self: Any,
+ result: SkipAwareTestResult,
+ debug: bool = False,
+ runcondition: Callable = None,
+ options: Optional[Any] = None,
+) -> SkipAwareTestResult:
self._wrapped_run(result, runcondition=runcondition, options=options)
self._tearDownPreviousClass(None, result)
self._handleModuleTearDown(result)
return result
-def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult:
+
+def _ts_wrapped_run(
+ self: Any,
+ result: SkipAwareTestResult,
+ debug: bool = False,
+ runcondition: Callable = None,
+ options: Optional[Any] = None,
+) -> SkipAwareTestResult:
for test in self:
if result.shouldStop:
break
@@ -1229,8 +1355,9 @@ def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False,
self._handleModuleFixture(test, result)
self._handleClassSetUp(test, result)
result._previousTestClass = test.__class__
- if (getattr(test.__class__, '_classSetupFailed', False) or
- getattr(result, '_moduleSetUpFailed', False)):
+ if getattr(test.__class__, "_classSetupFailed", False) or getattr(
+ result, "_moduleSetUpFailed", False
+ ):
continue
# --- modifications to deal with _wrapped_run ---
@@ -1240,7 +1367,7 @@ def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False,
# test(result)
# else:
# test.debug()
- if hasattr(test, '_wrapped_run'):
+ if hasattr(test, "_wrapped_run"):
try:
test._wrapped_run(result, debug, runcondition=runcondition, options=options)
except TypeError:
@@ -1255,13 +1382,20 @@ def _ts_wrapped_run(self: Any, result: SkipAwareTestResult, debug: bool = False,
# --- end of modifications to deal with _wrapped_run ---
return result
+
if sys.version_info >= (2, 7):
# The function below implements a modified version of the
# TestSuite.run method that is provided with python 2.7, in
# unittest/suite.py
- def _ts_run(self: Any, result: SkipAwareTestResult, debug: bool = False, runcondition: Callable = None, options: Optional[Any] = None) -> SkipAwareTestResult:
+ def _ts_run(
+ self: Any,
+ result: SkipAwareTestResult,
+ debug: bool = False,
+ runcondition: Callable = None,
+ options: Optional[Any] = None,
+ ) -> SkipAwareTestResult:
topLevel = False
- if getattr(result, '_testRunEntered', False) is False:
+ if getattr(result, "_testRunEntered", False) is False:
result._testRunEntered = topLevel = True
self._wrapped_run(result, debug, runcondition, options)
@@ -1287,8 +1421,7 @@ def enable_dbc(*args):
from logilab.aspects.weaver import weaver
from logilab.aspects.lib.contracts import ContractAspect
except ImportError:
- sys.stderr.write(
- 'Warning: logilab.aspects is not available. Contracts disabled.')
+ sys.stderr.write("Warning: logilab.aspects is not available. Contracts disabled.")
return False
for arg in args:
weaver.weave_module(arg, ContractAspect)
@@ -1304,13 +1437,12 @@ unittest.TestProgram = SkipAwareTestProgram
if sys.version_info >= (2, 4):
doctest.DocTestCase.__bases__ = (testlib.TestCase,)
# XXX check python2.6 compatibility
- #doctest.DocTestCase._cleanups = []
- #doctest.DocTestCase._out = []
+ # doctest.DocTestCase._cleanups = []
+ # doctest.DocTestCase._out = []
else:
unittest.FunctionTestCase.__bases__ = (testlib.TestCase,)
unittest.TestSuite.run = _ts_run
unittest.TestSuite._wrapped_run = _ts_wrapped_run
-if __name__ == '__main__':
+if __name__ == "__main__":
run()
-
diff --git a/logilab/common/registry.py b/logilab/common/registry.py
index d9ae11b..83f4703 100644
--- a/logilab/common/registry.py
+++ b/logilab/common/registry.py
@@ -105,6 +105,7 @@ from logilab.common.deprecation import deprecated
# selector base classes and operations ########################################
+
def objectify_predicate(selector_func: Callable) -> Any:
"""Most of the time, a simple score function is enough to build a selector.
The :func:`objectify_predicate` decorator turn it into a proper selector
@@ -118,22 +119,29 @@ def objectify_predicate(selector_func: Callable) -> Any:
__select__ = View.__select__ & one()
"""
- return type(selector_func.__name__, (Predicate,),
- {'__doc__': selector_func.__doc__,
- '__call__': lambda self, *a, **kw: selector_func(*a, **kw)})
+ return type(
+ selector_func.__name__,
+ (Predicate,),
+ {
+ "__doc__": selector_func.__doc__,
+ "__call__": lambda self, *a, **kw: selector_func(*a, **kw),
+ },
+ )
_PREDICATES: Dict[int, Type] = {}
+
def wrap_predicates(decorator: Callable) -> None:
for predicate in _PREDICATES.values():
- if not '_decorators' in predicate.__dict__:
+ if not "_decorators" in predicate.__dict__:
predicate._decorators = set()
if decorator in predicate._decorators:
continue
predicate._decorators.add(decorator)
predicate.__call__ = decorator(predicate.__call__)
+
class PredicateMetaClass(type):
def __new__(mcs, *args, **kwargs):
# use __new__ so subclasses doesn't have to call Predicate.__init__
@@ -164,36 +172,37 @@ class Predicate(object, metaclass=PredicateMetaClass):
# backward compatibility
return self.__class__.__name__
- def search_selector(self, selector: 'Predicate') -> Optional['Predicate']:
+ def search_selector(self, selector: "Predicate") -> Optional["Predicate"]:
"""search for the given selector, selector instance or tuple of
selectors in the selectors tree. Return None if not found.
"""
if self is selector:
return self
- if (isinstance(selector, type) or isinstance(selector, tuple)) and \
- isinstance(self, selector):
+ if (isinstance(selector, type) or isinstance(selector, tuple)) and isinstance(
+ self, selector
+ ):
return self
return None
def __str__(self):
return self.__class__.__name__
- def __and__(self, other: 'Predicate') -> 'AndPredicate':
+ def __and__(self, other: "Predicate") -> "AndPredicate":
return AndPredicate(self, other)
- def __rand__(self, other: 'Predicate') -> 'AndPredicate':
+ def __rand__(self, other: "Predicate") -> "AndPredicate":
return AndPredicate(other, self)
- def __iand__(self, other: 'Predicate') -> 'AndPredicate':
+ def __iand__(self, other: "Predicate") -> "AndPredicate":
return AndPredicate(self, other)
- def __or__(self, other: 'Predicate') -> 'OrPredicate':
+ def __or__(self, other: "Predicate") -> "OrPredicate":
return OrPredicate(self, other)
- def __ror__(self, other: 'Predicate'):
+ def __ror__(self, other: "Predicate"):
return OrPredicate(other, self)
- def __ior__(self, other: 'Predicate') -> 'OrPredicate':
+ def __ior__(self, other: "Predicate") -> "OrPredicate":
return OrPredicate(self, other)
def __invert__(self):
@@ -202,11 +211,12 @@ class Predicate(object, metaclass=PredicateMetaClass):
# XXX (function | function) or (function & function) not managed yet
def __call__(self, cls, *args, **kwargs):
- return NotImplementedError("selector %s must implement its logic "
- "in its __call__ method" % self.__class__)
+ return NotImplementedError(
+ "selector %s must implement its logic " "in its __call__ method" % self.__class__
+ )
def __repr__(self):
- return u'<Predicate %s at %x>' % (self.__class__.__name__, id(self))
+ return "<Predicate %s at %x>" % (self.__class__.__name__, id(self))
class MultiPredicate(Predicate):
@@ -216,8 +226,7 @@ class MultiPredicate(Predicate):
self.selectors = self.merge_selectors(selectors)
def __str__(self):
- return '%s(%s)' % (self.__class__.__name__,
- ','.join(str(s) for s in self.selectors))
+ return "%s(%s)" % (self.__class__.__name__, ",".join(str(s) for s in self.selectors))
@classmethod
def merge_selectors(cls, selectors: Sequence[Predicate]) -> List[Predicate]:
@@ -258,6 +267,7 @@ class MultiPredicate(Predicate):
class AndPredicate(MultiPredicate):
"""and-chained selectors"""
+
def __call__(self, cls: Optional[Any], *args: Any, **kwargs: Any) -> int:
score = 0
for selector in self.selectors:
@@ -270,6 +280,7 @@ class AndPredicate(MultiPredicate):
class OrPredicate(MultiPredicate):
"""or-chained selectors"""
+
def __call__(self, cls: Optional[Any], *args: Any, **kwargs: Any) -> int:
for selector in self.selectors:
partscore = selector(cls, *args, **kwargs)
@@ -277,8 +288,10 @@ class OrPredicate(MultiPredicate):
return partscore
return 0
+
class NotPredicate(Predicate):
"""negation selector"""
+
def __init__(self, selector):
self.selector = selector
@@ -287,10 +300,10 @@ class NotPredicate(Predicate):
return int(not score)
def __str__(self):
- return 'NOT(%s)' % self.selector
+ return "NOT(%s)" % self.selector
-class yes(Predicate): # pylint: disable=C0103
+class yes(Predicate): # pylint: disable=C0103
"""Return the score given as parameter, with a default score of 0.5 so any
other selector take precedence.
@@ -299,6 +312,7 @@ class yes(Predicate): # pylint: disable=C0103
Take care, `yes(0)` could be named 'no'...
"""
+
def __init__(self, score: float = 0.5) -> None:
self.score = score
@@ -308,39 +322,50 @@ class yes(Predicate): # pylint: disable=C0103
# deprecated stuff #############################################################
-@deprecated('[lgc 0.59] use Registry.objid class method instead')
+
+@deprecated("[lgc 0.59] use Registry.objid class method instead")
def classid(cls):
- return '%s.%s' % (cls.__module__, cls.__name__)
+ return "%s.%s" % (cls.__module__, cls.__name__)
-@deprecated('[lgc 0.59] use obj_registries function instead')
+
+@deprecated("[lgc 0.59] use obj_registries function instead")
def class_registries(cls, registryname):
return obj_registries(cls, registryname)
+
class RegistryException(Exception):
"""Base class for registry exception."""
+
class RegistryNotFound(RegistryException):
"""Raised when an unknown registry is requested.
This is usually a programming/typo error.
"""
+
class ObjectNotFound(RegistryException):
"""Raised when an unregistered object is requested.
This may be a programming/typo or a misconfiguration error.
"""
+
class NoSelectableObject(RegistryException):
"""Raised when no object is selectable for a given context."""
+
def __init__(self, args, kwargs, objects):
self.args = args
self.kwargs = kwargs
self.objects = objects
def __str__(self):
- return ('args: %s, kwargs: %s\ncandidates: %s'
- % (self.args, self.kwargs.keys(), self.objects))
+ return "args: %s, kwargs: %s\ncandidates: %s" % (
+ self.args,
+ self.kwargs.keys(),
+ self.objects,
+ )
+
class SelectAmbiguity(RegistryException):
"""Raised when several objects compete at selection time with an equal
@@ -362,12 +387,14 @@ def _modname_from_path(path: str, extrapath: Optional[Any] = None) -> str:
# from package.__init__ import something
#
# which seems quite correct.
- if modpath[-1] == '__init__':
+ if modpath[-1] == "__init__":
modpath.pop()
- return '.'.join(modpath)
+ return ".".join(modpath)
-def _toload_info(path: List[str], extrapath: Optional[Any], _toload: Optional[Tuple[Dict[str, str], List]] = None) -> Tuple[Dict[str, str], List[Tuple[str, str]]]:
+def _toload_info(
+ path: List[str], extrapath: Optional[Any], _toload: Optional[Tuple[Dict[str, str], List]] = None
+) -> Tuple[Dict[str, str], List[Tuple[str, str]]]:
"""Return a dictionary of <modname>: <modpath> and an ordered list of
(file, module name) to load
"""
@@ -376,12 +403,12 @@ def _toload_info(path: List[str], extrapath: Optional[Any], _toload: Optional[Tu
_toload = {}, []
for fileordir in path:
- if isdir(fileordir) and exists(join(fileordir, '__init__.py')):
+ if isdir(fileordir) and exists(join(fileordir, "__init__.py")):
subfiles = [join(fileordir, fname) for fname in listdir(fileordir)]
_toload_info(subfiles, extrapath, _toload)
- elif fileordir[-3:] == '.py':
+ elif fileordir[-3:] == ".py":
modname = _modname_from_path(fileordir, extrapath)
_toload[0][modname] = fileordir
@@ -417,7 +444,7 @@ class RegistrableObject(object):
__registry__: Optional[str] = None
__regid__: Optional[str] = None
__select__: Union[None, str, Predicate] = None
- __abstract__ = True # see doc snipppets below (in Registry class)
+ __abstract__ = True # see doc snipppets below (in Registry class)
@classproperty
def __registries__(cls) -> Union[Tuple[str], Tuple]:
@@ -435,12 +462,13 @@ class RegistrableInstance(RegistrableObject):
"""Add a __module__ attribute telling the module where the instance was
created, for automatic registration.
"""
- module = kwargs.pop('__module__', None)
+ module = kwargs.pop("__module__", None)
obj = super(RegistrableInstance, cls).__new__(cls)
if module is None:
- warn('instantiate {0} with '
- '__module__=__name__'.format(cls.__name__),
- DeprecationWarning)
+ warn(
+ "instantiate {0} with " "__module__=__name__".format(cls.__name__),
+ DeprecationWarning,
+ )
# XXX subclass must no override __new__
filepath = tb.extract_stack(limit=2)[0][0]
obj.__module__ = _modname_from_path(filepath)
@@ -452,11 +480,19 @@ class RegistrableInstance(RegistrableObject):
super(RegistrableInstance, self).__init__()
-SelectBestReport = TypedDict("SelectBestReport", {"all_objects": List, "end_score": int,
- "winners": List,
- "winner": Optional[Any], "self": 'Registry',
- "args": List, "kwargs": Dict,
- "registry": 'Registry'})
+SelectBestReport = TypedDict(
+ "SelectBestReport",
+ {
+ "all_objects": List,
+ "end_score": int,
+ "winners": List,
+ "winner": Optional[Any],
+ "self": "Registry",
+ "args": List,
+ "kwargs": Dict,
+ "registry": "Registry",
+ },
+)
class Registry(dict):
@@ -492,6 +528,7 @@ class Registry(dict):
.. automethod:: possible_objects
.. automethod:: object_by_id
"""
+
def __init__(self, debugmode: bool) -> None:
super(Registry, self).__init__()
self.debugmode = debugmode
@@ -511,19 +548,19 @@ class Registry(dict):
@classmethod
def objid(cls, obj: Any) -> str:
"""returns a unique identifier for an object stored in the registry"""
- return '%s.%s' % (obj.__module__, cls.objname(obj))
+ return "%s.%s" % (obj.__module__, cls.objname(obj))
@classmethod
def objname(cls, obj: Any) -> str:
"""returns a readable name for an object stored in the registry"""
- return getattr(obj, '__name__', id(obj))
+ return getattr(obj, "__name__", id(obj))
def initialization_completed(self) -> None:
"""call method __registered__() on registered objects when the callback
is defined"""
for objects in self.values():
for objectcls in objects:
- registered = getattr(objectcls, '__registered__', None)
+ registered = getattr(objectcls, "__registered__", None)
if registered:
registered(self)
if self.debugmode:
@@ -531,16 +568,17 @@ class Registry(dict):
def register(self, obj: Any, oid: Optional[Any] = None, clear: bool = False) -> None:
"""base method to add an object in the registry"""
- assert not '__abstract__' in obj.__dict__, obj
+ assert not "__abstract__" in obj.__dict__, obj
assert obj.__select__, obj
oid = oid or obj.__regid__
- assert oid, ('no explicit name supplied to register object %s, '
- 'which has no __regid__ set' % obj)
+ assert oid, (
+ "no explicit name supplied to register object %s, " "which has no __regid__ set" % obj
+ )
if clear:
- objects = self[oid] = []
+ objects = self[oid] = []
else:
objects = self.setdefault(oid, [])
- assert not obj in objects, 'object %s is already registered' % obj
+ assert not obj in objects, "object %s is already registered" % obj
objects.append(obj)
def register_and_replace(self, obj, replaced):
@@ -551,15 +589,14 @@ class Registry(dict):
if not isinstance(replaced, str):
replaced = self.objid(replaced)
# prevent from misspelling
- assert obj is not replaced, 'replacing an object by itself: %s' % obj
+ assert obj is not replaced, "replacing an object by itself: %s" % obj
registered_objs = self.get(obj.__regid__, ())
for index, registered in enumerate(registered_objs):
if self.objid(registered) == replaced:
del registered_objs[index]
break
else:
- self.warning('trying to replace %s that is not registered with %s',
- replaced, obj)
+ self.warning("trying to replace %s that is not registered with %s", replaced, obj)
self.register(obj)
def unregister(self, obj):
@@ -573,8 +610,7 @@ class Registry(dict):
self[oid].remove(registered)
break
else:
- self.warning('can\'t remove %s, no id %s in the registry',
- objid, oid)
+ self.warning("can't remove %s, no id %s in the registry", objid, oid)
def all_objects(self):
"""return a list containing all objects in this registry.
@@ -608,9 +644,9 @@ class Registry(dict):
raise :exc:`NoSelectableObject` if no object can be selected
"""
- obj = self._select_best(self[__oid], *args, **kwargs)
+ obj = self._select_best(self[__oid], *args, **kwargs)
if obj is None:
- raise NoSelectableObject(args, kwargs, self[__oid] )
+ raise NoSelectableObject(args, kwargs, self[__oid])
return obj
def select_or_none(self, __oid, *args, **kwargs):
@@ -627,7 +663,7 @@ class Registry(dict):
context
"""
for objects in self.values():
- obj = self._select_best(objects, *args, **kwargs)
+ obj = self._select_best(objects, *args, **kwargs)
if obj is None:
continue
yield obj
@@ -695,7 +731,7 @@ class Registry(dict):
if len(winners) > 1:
# log in production environement / test, error while debugging
- msg = 'select ambiguity: %s\n(args: %s, kwargs: %s)'
+ msg = "select ambiguity: %s\n(args: %s, kwargs: %s)"
if self.debugmode:
# raise bare exception in debug mode
@@ -903,8 +939,9 @@ class RegistryStore(dict):
:meth:`~logilab.common.registry.RegistryStore.register_and_replace` for
instance).
"""
- assert isinstance(modname, str), \
- 'modname expected to be a module name (ie string), got %r' % modname
+ assert isinstance(modname, str), (
+ "modname expected to be a module name (ie string), got %r" % modname
+ )
for obj in objects:
if self.is_registrable(obj) and obj.__module__ == modname and not obj in butclasses:
if isinstance(obj, type):
@@ -912,7 +949,13 @@ class RegistryStore(dict):
else:
self.register(obj)
- def register(self, obj: Any, registryname: Optional[Any] = None, oid: Optional[Any] = None, clear: bool = False) -> None:
+ def register(
+ self,
+ obj: Any,
+ registryname: Optional[Any] = None,
+ oid: Optional[Any] = None,
+ clear: bool = False,
+ ) -> None:
"""register `obj` implementation into `registryname` or
`obj.__registries__` if not specified, with identifier `oid` or
`obj.__regid__` if not specified.
@@ -920,12 +963,13 @@ class RegistryStore(dict):
If `clear` is true, all objects with the same identifier will be
previously unregistered.
"""
- assert not obj.__dict__.get('__abstract__'), obj
+ assert not obj.__dict__.get("__abstract__"), obj
for registryname in obj_registries(obj, registryname):
registry = self.setdefault(registryname)
registry.register(obj, oid=oid, clear=clear)
- self.debug("register %s in %s['%s']",
- registry.objname(obj), registryname, oid or obj.__regid__)
+ self.debug(
+ "register %s in %s['%s']", registry.objname(obj), registryname, oid or obj.__regid__
+ )
self._loadedmods.setdefault(obj.__module__, {})[registry.objid(obj)] = obj
def unregister(self, obj, registryname=None):
@@ -935,8 +979,9 @@ class RegistryStore(dict):
for registryname in obj_registries(obj, registryname):
registry = self[registryname]
registry.unregister(obj)
- self.debug("unregister %s from %s['%s']",
- registry.objname(obj), registryname, obj.__regid__)
+ self.debug(
+ "unregister %s from %s['%s']", registry.objname(obj), registryname, obj.__regid__
+ )
def register_and_replace(self, obj, replaced, registryname=None):
"""register `obj` object into `registryname` or
@@ -947,13 +992,19 @@ class RegistryStore(dict):
for registryname in obj_registries(obj, registryname):
registry = self[registryname]
registry.register_and_replace(obj, replaced)
- self.debug("register %s in %s['%s'] instead of %s",
- registry.objname(obj), registryname, obj.__regid__,
- registry.objname(replaced))
+ self.debug(
+ "register %s in %s['%s'] instead of %s",
+ registry.objname(obj),
+ registryname,
+ obj.__regid__,
+ registry.objname(replaced),
+ )
# initialization methods ###################################################
- def init_registration(self, path: List[str], extrapath: Optional[Any] = None) -> List[Tuple[str, str]]:
+ def init_registration(
+ self, path: List[str], extrapath: Optional[Any] = None
+ ) -> List[Tuple[str, str]]:
"""reset registry and walk down path to return list of (path, name)
file modules to be loaded"""
# XXX make this private by renaming it to _init_registration ?
@@ -966,7 +1017,7 @@ class RegistryStore(dict):
self._loadedmods: Dict[str, Dict[str, type]] = {}
return filemods
- @deprecated('use register_modnames() instead')
+ @deprecated("use register_modnames() instead")
def register_objects(self, path: List[str], extrapath: Optional[Any] = None) -> None:
"""register all objects found walking down <path>"""
# load views from each directory in the instance's path
@@ -988,7 +1039,7 @@ class RegistryStore(dict):
# mypy: "Loader" has no attribute "get_filename"
# the selected class has one
filepath = loader.get_filename() # type: ignore
- if filepath[-4:] in ('.pyc', '.pyo'):
+ if filepath[-4:] in (".pyc", ".pyo"):
# The source file *must* exists
filepath = filepath[:-1]
self._toloadmods[modname] = filepath
@@ -1008,8 +1059,7 @@ class RegistryStore(dict):
return stat(filepath)[-2]
except OSError:
# this typically happens on emacs backup files (.#foo.py)
- self.warning('Unable to load %s. It is likely to be a backup file',
- filepath)
+ self.warning("Unable to load %s. It is likely to be a backup file", filepath)
return None
def is_reload_needed(self, path):
@@ -1018,19 +1068,18 @@ class RegistryStore(dict):
"""
lastmodifs = self._lastmodifs
for fileordir in path:
- if isdir(fileordir) and exists(join(fileordir, '__init__.py')):
- if self.is_reload_needed([join(fileordir, fname)
- for fname in listdir(fileordir)]):
+ if isdir(fileordir) and exists(join(fileordir, "__init__.py")):
+ if self.is_reload_needed([join(fileordir, fname) for fname in listdir(fileordir)]):
return True
- elif fileordir[-3:] == '.py':
+ elif fileordir[-3:] == ".py":
mdate = self._mdate(fileordir)
if mdate is None:
- continue # backup file, see _mdate implementation
+ continue # backup file, see _mdate implementation
elif "flymake" in fileordir:
# flymake + pylint in use, don't consider these they will corrupt the registry
continue
if fileordir not in lastmodifs or lastmodifs[fileordir] < mdate:
- self.info('File %s changed since last visit', fileordir)
+ self.info("File %s changed since last visit", fileordir)
return True
return False
@@ -1041,7 +1090,7 @@ class RegistryStore(dict):
self._loadedmods[modname] = {}
mdate = self._mdate(filepath)
if mdate is None:
- return # backup file, see _mdate implementation
+ return # backup file, see _mdate implementation
elif "flymake" in filepath:
# flymake + pylint in use, don't consider these they will corrupt the registry
return
@@ -1052,7 +1101,7 @@ class RegistryStore(dict):
# load the module
if sys.version_info < (3,) and not isinstance(modname, str):
modname = str(modname)
- module = __import__(modname, fromlist=modname.split('.')[:-1])
+ module = __import__(modname, fromlist=modname.split(".")[:-1])
self.load_module(module)
def load_module(self, module: ModuleType) -> None:
@@ -1074,15 +1123,17 @@ class RegistryStore(dict):
- object class needs to have registries and identifier properly set to a
non empty string to be registered.
"""
- self.info('loading %s from %s', module.__name__, module.__file__)
- if hasattr(module, 'registration_callback'):
+ self.info("loading %s from %s", module.__name__, module.__file__)
+ if hasattr(module, "registration_callback"):
# mypy: Module has no attribute "registration_callback"
# we check that before
module.registration_callback(self) # type: ignore
else:
self.register_all(vars(module).values(), module.__name__)
- def _load_ancestors_then_object(self, modname: str, objectcls: type, butclasses: Sequence[Any] = ()) -> None:
+ def _load_ancestors_then_object(
+ self, modname: str, objectcls: type, butclasses: Sequence[Any] = ()
+ ) -> None:
"""handle class registration according to rules defined in
:meth:`load_module`
"""
@@ -1103,7 +1154,7 @@ class RegistryStore(dict):
self.load_file(self._toloadmods[objmodname], objmodname)
return
# ensure object hasn't been already processed
- clsid = '%s.%s' % (modname, objectcls.__name__)
+ clsid = "%s.%s" % (modname, objectcls.__name__)
if clsid in self._loadedmods[modname]:
return
self._loadedmods[modname][clsid] = objectcls
@@ -1115,10 +1166,13 @@ class RegistryStore(dict):
return
# backward compat
reg = self.setdefault(obj_registries(objectcls)[0])
- if reg.objname(objectcls)[0] == '_':
- warn("[lgc 0.59] object whose name start with '_' won't be "
- "skipped anymore at some point, use __abstract__ = True "
- "instead (%s)" % objectcls, DeprecationWarning)
+ if reg.objname(objectcls)[0] == "_":
+ warn(
+ "[lgc 0.59] object whose name start with '_' won't be "
+ "skipped anymore at some point, use __abstract__ = True "
+ "instead (%s)" % objectcls,
+ DeprecationWarning,
+ )
return
# register, finally
self.register(objectcls)
@@ -1133,9 +1187,11 @@ class RegistryStore(dict):
if isinstance(obj, type):
if not issubclass(obj, RegistrableObject):
# ducktyping backward compat
- if not (getattr(obj, '__registries__', None)
- and getattr(obj, '__regid__', None)
- and getattr(obj, '__select__', None)):
+ if not (
+ getattr(obj, "__registries__", None)
+ and getattr(obj, "__regid__", None)
+ and getattr(obj, "__select__", None)
+ ):
return False
elif issubclass(obj, RegistrableInstance):
return False
@@ -1144,26 +1200,26 @@ class RegistryStore(dict):
return False
if not obj.__regid__:
- return False # no regid
+ return False # no regid
registries = obj.__registries__
if not registries:
- return False # no registries
+ return False # no registries
selector = obj.__select__
if not selector:
- return False # no selector
+ return False # no selector
- if obj.__dict__.get('__abstract__', False):
+ if obj.__dict__.get("__abstract__", False):
return False
# then detect potential problems that should be warned
if not isinstance(registries, (tuple, list)):
- cls.warning('%s has __registries__ which is not a list or tuple', obj)
+ cls.warning("%s has __registries__ which is not a list or tuple", obj)
return False
if not callable(selector):
- cls.warning('%s has not callable __select__', obj)
+ cls.warning("%s has not callable __select__", obj)
return False
return True
@@ -1174,32 +1230,37 @@ class RegistryStore(dict):
# init logging
-set_log_methods(RegistryStore, getLogger('registry.store'))
-set_log_methods(Registry, getLogger('registry'))
+set_log_methods(RegistryStore, getLogger("registry.store"))
+set_log_methods(Registry, getLogger("registry"))
# helpers for debugging selectors
TRACED_OIDS = None
+
def _trace_selector(cls, selector, args, ret):
vobj = args[0]
- if TRACED_OIDS == 'all' or vobj.__regid__ in TRACED_OIDS:
- print('%s -> %s for %s(%s)' % (cls, ret, vobj, vobj.__regid__))
+ if TRACED_OIDS == "all" or vobj.__regid__ in TRACED_OIDS:
+ print("%s -> %s for %s(%s)" % (cls, ret, vobj, vobj.__regid__))
+
def _lltrace(selector):
"""use this decorator on your predicates so they become traceable with
:class:`traced_selection`
"""
+
def traced(cls, *args, **kwargs):
ret = selector(cls, *args, **kwargs)
if TRACED_OIDS is not None:
_trace_selector(cls, selector, args, ret)
return ret
+
traced.__name__ = selector.__name__
traced.__doc__ = selector.__doc__
return traced
-class traced_selection(object): # pylint: disable=C0103
+
+class traced_selection(object): # pylint: disable=C0103
"""
Typical usage is :
@@ -1227,7 +1288,7 @@ class traced_selection(object): # pylint: disable=C0103
the `logilab.common.registry.Registry.select` method body.
"""
- def __init__(self, traced='all'):
+ def __init__(self, traced="all"):
self.traced = traced
def __enter__(self):
diff --git a/logilab/common/shellutils.py b/logilab/common/shellutils.py
index 2764723..557e45d 100644
--- a/logilab/common/shellutils.py
+++ b/logilab/common/shellutils.py
@@ -46,7 +46,6 @@ from logilab.common.deprecation import deprecated
class tempdir(object):
-
def __enter__(self):
self.path = tempfile.mkdtemp()
return self.path
@@ -82,7 +81,8 @@ def chown(path, login=None, group=None):
try:
uid = int(login)
except ValueError:
- import pwd # Platforms: Unix
+ import pwd # Platforms: Unix
+
uid = pwd.getpwnam(login).pw_uid
if group is None:
gid = -1
@@ -91,9 +91,11 @@ def chown(path, login=None, group=None):
gid = int(group)
except ValueError:
import grp
+
gid = grp.getgrnam(group).gr_gid
os.chown(path, uid, gid)
+
def mv(source, destination, _action=shutil.move):
"""A shell-like mv, supporting wildcards.
"""
@@ -106,14 +108,14 @@ def mv(source, destination, _action=shutil.move):
try:
source = sources[0]
except IndexError:
- raise OSError('No file matching %s' % source)
+ raise OSError("No file matching %s" % source)
if isdir(destination) and exists(destination):
destination = join(destination, basename(source))
try:
_action(source, destination)
except OSError as ex:
- raise OSError('Unable to move %r to %r (%s)' % (
- source, destination, ex))
+ raise OSError("Unable to move %r to %r (%s)" % (source, destination, ex))
+
def rm(*files):
"""A shell-like rm, supporting wildcards.
@@ -127,12 +129,19 @@ def rm(*files):
else:
os.remove(filename)
+
def cp(source, destination):
"""A shell-like cp, supporting wildcards.
"""
mv(source, destination, _action=shutil.copy)
-def find(directory: str, exts: Union[Tuple[str, ...], str], exclude: bool = False, blacklist: Tuple[str, ...] = STD_BLACKLIST) -> List[str]:
+
+def find(
+ directory: str,
+ exts: Union[Tuple[str, ...], str],
+ exclude: bool = False,
+ blacklist: Tuple[str, ...] = STD_BLACKLIST,
+) -> List[str]:
"""Recursively find files ending with the given extensions from the directory.
:type directory: str
@@ -160,17 +169,21 @@ def find(directory: str, exts: Union[Tuple[str, ...], str], exclude: bool = Fals
if isinstance(exts, str):
exts = (exts,)
if exclude:
+
def match(filename: str, exts: Tuple[str, ...]) -> bool:
for ext in exts:
if filename.endswith(ext):
return False
return True
+
else:
+
def match(filename: str, exts: Tuple[str, ...]) -> bool:
for ext in exts:
if filename.endswith(ext):
return True
return False
+
files = []
for dirpath, dirnames, filenames in os.walk(directory):
_handle_blacklist(blacklist, dirnames, filenames)
@@ -182,7 +195,11 @@ def find(directory: str, exts: Union[Tuple[str, ...], str], exclude: bool = Fals
return files
-def globfind(directory: str, pattern: str, blacklist: Tuple[str, str, str, str, str, str, str, str] = STD_BLACKLIST) -> Iterator[str]:
+def globfind(
+ directory: str,
+ pattern: str,
+ blacklist: Tuple[str, str, str, str, str, str, str, str] = STD_BLACKLIST,
+) -> Iterator[str]:
"""Recursively finds files matching glob `pattern` under `directory`.
This is an alternative to `logilab.common.shellutils.find`.
@@ -209,21 +226,23 @@ def globfind(directory: str, pattern: str, blacklist: Tuple[str, str, str, str,
for fname in fnmatch.filter(filenames, pattern):
yield join(curdir, fname)
+
def unzip(archive, destdir):
import zipfile
+
if not exists(destdir):
os.mkdir(destdir)
zfobj = zipfile.ZipFile(archive)
for name in zfobj.namelist():
- if name.endswith('/'):
+ if name.endswith("/"):
os.mkdir(join(destdir, name))
else:
- outfile = open(join(destdir, name), 'wb')
+ outfile = open(join(destdir, name), "wb")
outfile.write(zfobj.read(name))
outfile.close()
-@deprecated('Use subprocess.Popen instead')
+@deprecated("Use subprocess.Popen instead")
class Execute:
"""This is a deadlock safe version of popen2 (no stdin), that returns
an object with errorlevel, out and err.
@@ -238,11 +257,13 @@ class Execute:
class ProgressBar(object):
"""A simple text progression bar."""
- def __init__(self, nbops: int, size: int = 20, stream: StringIO = sys.stdout, title: str = '') -> None:
+ def __init__(
+ self, nbops: int, size: int = 20, stream: StringIO = sys.stdout, title: str = ""
+ ) -> None:
if title:
- self._fstr = '\r%s [%%-%ss]' % (title, int(size))
+ self._fstr = "\r%s [%%-%ss]" % (title, int(size))
else:
- self._fstr = '\r[%%-%ss]' % int(size)
+ self._fstr = "\r[%%-%ss]" % int(size)
self._stream = stream
self._total = nbops
self._size = size
@@ -280,42 +301,45 @@ class ProgressBar(object):
else:
self._current += offset
- progress = int((float(self._current)/float(self._total))*self._size)
+ progress = int((float(self._current) / float(self._total)) * self._size)
if progress > self._progress:
self._progress = progress
self.refresh()
def refresh(self) -> None:
"""Refresh the progression bar display."""
- self._stream.write(self._fstr % ('=' * min(self._progress, self._size)) )
+ self._stream.write(self._fstr % ("=" * min(self._progress, self._size)))
if self._last_text_write_size or self._current_text:
- template = ' %%-%is' % (self._last_text_write_size)
+ template = " %%-%is" % (self._last_text_write_size)
text = self._current_text
if text is None:
- text = ''
+ text = ""
self._stream.write(template % text)
self._last_text_write_size = len(text.rstrip())
self._stream.flush()
def finish(self):
- self._stream.write('\n')
+ self._stream.write("\n")
self._stream.flush()
class DummyProgressBar(object):
- __slots__ = ('text',)
+ __slots__ = ("text",)
def refresh(self):
pass
+
def update(self):
pass
+
def finish(self):
pass
_MARKER = object()
-class progress(object):
+
+class progress(object):
def __init__(self, nbops=_MARKER, size=_MARKER, stream=_MARKER, title=_MARKER, enabled=True):
self.nbops = nbops
self.size = size
@@ -326,26 +350,30 @@ class progress(object):
def __enter__(self):
if self.enabled:
kwargs = {}
- for attr in ('nbops', 'size', 'stream', 'title'):
+ for attr in ("nbops", "size", "stream", "title"):
value = getattr(self, attr)
if value is not _MARKER:
kwargs[attr] = value
self.pb = ProgressBar(**kwargs)
else:
- self.pb = DummyProgressBar()
+ self.pb = DummyProgressBar()
return self.pb
def __exit__(self, exc_type, exc_val, exc_tb):
self.pb.finish()
-class RawInput(object):
- def __init__(self, input_function: Optional[Callable] = None, printer: Optional[Callable] = None, **kwargs: Any) -> None:
- if 'input' in kwargs:
- input_function = kwargs.pop('input')
+class RawInput(object):
+ def __init__(
+ self,
+ input_function: Optional[Callable] = None,
+ printer: Optional[Callable] = None,
+ **kwargs: Any,
+ ) -> None:
+ if "input" in kwargs:
+ input_function = kwargs.pop("input")
warnings.warn(
- "'input' argument is deprecated,"
- "use 'input_function' instead",
+ "'input' argument is deprecated," "use 'input_function' instead",
DeprecationWarning,
)
self._input = input_function or input
@@ -360,35 +388,36 @@ class RawInput(object):
else:
label = option[0].lower()
if len(option) > 1:
- label += '(%s)' % option[1:].lower()
+ label += "(%s)" % option[1:].lower()
choices.append((option, label))
- prompt = "%s [%s]: " % (question,
- '/'.join([opt[1] for opt in choices]))
+ prompt = "%s [%s]: " % (question, "/".join([opt[1] for opt in choices]))
tries = 3
while tries > 0:
answer = self._input(prompt).strip().lower()
if not answer:
return default
- possible = [option for option, label in choices
- if option.lower().startswith(answer)]
+ possible = [option for option, label in choices if option.lower().startswith(answer)]
if len(possible) == 1:
return possible[0]
elif len(possible) == 0:
- msg = '%s is not an option.' % answer
+ msg = "%s is not an option." % answer
else:
- msg = ('%s is an ambiguous answer, do you mean %s ?' % (
- answer, ' or '.join(possible)))
+ msg = "%s is an ambiguous answer, do you mean %s ?" % (
+ answer,
+ " or ".join(possible),
+ )
if self._print:
self._print(msg)
else:
print(msg)
tries -= 1
- raise Exception('unable to get a sensible answer')
+ raise Exception("unable to get a sensible answer")
def confirm(self, question: str, default_is_yes: bool = True) -> bool:
- default = default_is_yes and 'y' or 'n'
- answer = self.ask(question, ('y', 'n'), default)
- return answer == 'y'
+ default = default_is_yes and "y" or "n"
+ answer = self.ask(question, ("y", "n"), default)
+ return answer == "y"
+
ASK = RawInput()
@@ -398,15 +427,17 @@ def getlogin():
(man 3 getlogin)
Another solution would be to use $LOGNAME, $USER or $USERNAME
"""
- if sys.platform != 'win32':
- import pwd # Platforms: Unix
+ if sys.platform != "win32":
+ import pwd # Platforms: Unix
+
return pwd.getpwuid(os.getuid())[0]
else:
- return os.environ['USERNAME']
+ return os.environ["USERNAME"]
+
def generate_password(length=8, vocab=string.ascii_letters + string.digits):
"""dumb password generation function"""
- pwd = ''
+ pwd = ""
for i in range(length):
pwd += random.choice(vocab)
return pwd
diff --git a/logilab/common/sphinx_ext.py b/logilab/common/sphinx_ext.py
index a24608c..4ca30f7 100644
--- a/logilab/common/sphinx_ext.py
+++ b/logilab/common/sphinx_ext.py
@@ -19,30 +19,41 @@ from logilab.common.decorators import monkeypatch
from sphinx.ext import autodoc
+
class DocstringOnlyModuleDocumenter(autodoc.ModuleDocumenter):
- objtype = 'docstring'
+ objtype = "docstring"
+
def format_signature(self):
pass
+
def add_directive_header(self, sig):
pass
+
def document_members(self, all_members=False):
pass
def resolve_name(self, modname, parents, path, base):
if modname is not None:
return modname, parents + [base]
- return (path or '') + base, []
+ return (path or "") + base, []
+
+# autodoc.add_documenter(DocstringOnlyModuleDocumenter)
-#autodoc.add_documenter(DocstringOnlyModuleDocumenter)
def setup(app):
app.add_autodocumenter(DocstringOnlyModuleDocumenter)
+from sphinx.ext.autodoc import (
+ ViewList,
+ Options,
+ AutodocReporter,
+ nodes,
+ assemble_option_dict,
+ nested_parse_with_titles,
+)
-from sphinx.ext.autodoc import (ViewList, Options, AutodocReporter, nodes,
- assemble_option_dict, nested_parse_with_titles)
@monkeypatch(autodoc.AutoDirective)
def run(self):
@@ -56,8 +67,7 @@ def run(self):
objtype = self.name[4:]
doc_class = self._registry[objtype]
# process the options with the selected documenter's option_spec
- self.genopt = Options(assemble_option_dict(
- self.options.items(), doc_class.option_spec))
+ self.genopt = Options(assemble_option_dict(self.options.items(), doc_class.option_spec))
# generate the output
documenter = doc_class(self, self.arguments[0])
documenter.generate(more_content=self.content)
@@ -72,9 +82,8 @@ def run(self):
# use a custom reporter that correctly assigns lines to source
# filename/description and lineno
old_reporter = self.state.memo.reporter
- self.state.memo.reporter = AutodocReporter(self.result,
- self.state.memo.reporter)
- if self.name in ('automodule', 'autodocstring'):
+ self.state.memo.reporter = AutodocReporter(self.result, self.state.memo.reporter)
+ if self.name in ("automodule", "autodocstring"):
node = nodes.section()
# necessary so that the child nodes get the right source/line set
node.document = self.state.document
diff --git a/logilab/common/sphinxutils.py b/logilab/common/sphinxutils.py
index ab6e8a1..350188d 100644
--- a/logilab/common/sphinxutils.py
+++ b/logilab/common/sphinxutils.py
@@ -37,18 +37,24 @@ from logilab.common import STD_BLACKLIST
from logilab.common.shellutils import globfind
from logilab.common.modutils import load_module_from_file, modpath_from_file
+
def module_members(module):
members = []
for name, value in inspect.getmembers(module):
- if getattr(value, '__module__', None) == module.__name__:
- members.append( (name, value) )
+ if getattr(value, "__module__", None) == module.__name__:
+ members.append((name, value))
return sorted(members)
def class_members(klass):
- return sorted([name for name in vars(klass)
- if name not in ('__doc__', '__module__',
- '__dict__', '__weakref__')])
+ return sorted(
+ [
+ name
+ for name in vars(klass)
+ if name not in ("__doc__", "__module__", "__dict__", "__weakref__")
+ ]
+ )
+
class ModuleGenerator:
file_header = """.. -*- coding: utf-8 -*-\n\n%s\n"""
@@ -72,7 +78,7 @@ class ModuleGenerator:
def generate(self, dest_file, exclude_dirs=STD_BLACKLIST):
"""make the module file"""
- self.fn = open(dest_file, 'w')
+ self.fn = open(dest_file, "w")
num = len(self.title) + 6
title = "=" * num + "\n %s API\n" % self.title + "=" * num
self.fn.write(self.file_header % title)
@@ -88,35 +94,34 @@ class ModuleGenerator:
for objname, obj in module_members(module):
if inspect.isclass(obj):
classmembers = class_members(obj)
- classes.append( (objname, classmembers) )
+ classes.append((objname, classmembers))
else:
modmembers.append(objname)
- self.fn.write(self.module_def % (modname, '=' * len(modname),
- modname,
- ', '.join(modmembers)))
+ self.fn.write(
+ self.module_def % (modname, "=" * len(modname), modname, ", ".join(modmembers))
+ )
for klass, members in classes:
- self.fn.write(self.class_def % (klass, ', '.join(members)))
+ self.fn.write(self.class_def % (klass, ", ".join(members)))
def find_modules(self, exclude_dirs):
basepath = osp.dirname(self.code_dir)
basedir = osp.basename(basepath) + osp.sep
if basedir not in sys.path:
sys.path.insert(1, basedir)
- for filepath in globfind(self.code_dir, '*.py', exclude_dirs):
- if osp.basename(filepath) in ('setup.py', '__pkginfo__.py'):
+ for filepath in globfind(self.code_dir, "*.py", exclude_dirs):
+ if osp.basename(filepath) in ("setup.py", "__pkginfo__.py"):
continue
try:
module = load_module_from_file(filepath)
- except: # module might be broken or magic
+ except: # module might be broken or magic
dotted_path = modpath_from_file(filepath)
- module = type('.'.join(dotted_path), (), {}) # mock it
+ module = type(".".join(dotted_path), (), {}) # mock it
yield module
-if __name__ == '__main__':
+if __name__ == "__main__":
# example :
title, code_dir, outfile = sys.argv[1:]
generator = ModuleGenerator(title, code_dir)
# XXX modnames = ['logilab']
- generator.generate(outfile, ('test', 'tests', 'examples',
- 'data', 'doc', '.hg', 'migration'))
+ generator.generate(outfile, ("test", "tests", "examples", "data", "doc", ".hg", "migration"))
diff --git a/logilab/common/table.py b/logilab/common/table.py
index e7b9195..983708b 100644
--- a/logilab/common/table.py
+++ b/logilab/common/table.py
@@ -34,7 +34,12 @@ class Table(object):
forall(self.data, lambda x: len(x) <= len(self.col_names))
"""
- def __init__(self, default_value: int = 0, col_names: Optional[List[str]] = None, row_names: Optional[Any] = None) -> None:
+ def __init__(
+ self,
+ default_value: int = 0,
+ col_names: Optional[List[str]] = None,
+ row_names: Optional[Any] = None,
+ ) -> None:
self.col_names: List = []
self.row_names: List = []
self.data: List = []
@@ -45,7 +50,7 @@ class Table(object):
self.create_rows(row_names)
def _next_row_name(self) -> str:
- return 'row%s' % (len(self.row_names)+1)
+ return "row%s" % (len(self.row_names) + 1)
def __iter__(self) -> Iterator:
return iter(self.data)
@@ -83,7 +88,7 @@ class Table(object):
"""
self.row_names.extend(row_names)
for row_name in row_names:
- self.data.append([self.default_value]*len(self.col_names))
+ self.data.append([self.default_value] * len(self.col_names))
def create_columns(self, col_names: List[str]) -> None:
"""Appends col_names to the list of existing columns
@@ -96,8 +101,7 @@ class Table(object):
"""
row_name = row_name or self._next_row_name()
self.row_names.append(row_name)
- self.data.append([self.default_value]*len(self.col_names))
-
+ self.data.append([self.default_value] * len(self.col_names))
def create_column(self, col_name: str) -> None:
"""Creates a colname to the col_names list
@@ -107,7 +111,7 @@ class Table(object):
row.append(self.default_value)
## Sort by column ##########################################################
- def sort_by_column_id(self, col_id: str, method: str = 'asc') -> None:
+ def sort_by_column_id(self, col_id: str, method: str = "asc") -> None:
"""Sorts the table (in-place) according to data stored in col_id
"""
try:
@@ -116,17 +120,17 @@ class Table(object):
except ValueError:
raise KeyError("Col (%s) not found in table" % (col_id))
-
- def sort_by_column_index(self, col_index: int, method: str = 'asc') -> None:
+ def sort_by_column_index(self, col_index: int, method: str = "asc") -> None:
"""Sorts the table 'in-place' according to data stored in col_index
method should be in ('asc', 'desc')
"""
- sort_list = sorted([(row[col_index], row, row_name)
- for row, row_name in zip(self.data, self.row_names)])
+ sort_list = sorted(
+ [(row[col_index], row, row_name) for row, row_name in zip(self.data, self.row_names)]
+ )
# Sorting sort_list will sort according to col_index
# If we want reverse sort, then reverse list
- if method.lower() == 'desc':
+ if method.lower() == "desc":
sort_list.reverse()
# Rebuild data / row names
@@ -136,8 +140,9 @@ class Table(object):
self.data.append(row)
self.row_names.append(row_name)
- def groupby(self, colname: str, *others: str) -> Union[Dict[str, Dict[str, 'Table']],
- Dict[str, 'Table']]:
+ def groupby(
+ self, colname: str, *others: str
+ ) -> Union[Dict[str, Dict[str, "Table"]], Dict[str, "Table"]]:
"""builds indexes of data
:returns: nested dictionaries pointing to actual rows
"""
@@ -148,13 +153,14 @@ class Table(object):
ptr = groups
for col_index in col_indexes[:-1]:
ptr = ptr.setdefault(row[col_index], {})
- table = ptr.setdefault(row[col_indexes[-1]],
- Table(default_value=self.default_value,
- col_names=self.col_names))
+ table = ptr.setdefault(
+ row[col_indexes[-1]],
+ Table(default_value=self.default_value, col_names=self.col_names),
+ )
table.append_row(tuple(row))
return groups
- def select(self, colname: str, value: str) -> 'Table':
+ def select(self, colname: str, value: str) -> "Table":
grouped = self.groupby(colname)
try:
# mypy: Incompatible return value type (got "Union[Dict[str, Table], Table]",
@@ -170,14 +176,12 @@ class Table(object):
if row[col_index] == value:
self.data.remove(row)
-
## The 'setter' part #######################################################
def set_cell(self, row_index: int, col_index: int, data: int) -> None:
"""sets value of cell 'row_indew', 'col_index' to data
"""
self.data[row_index][col_index] = data
-
def set_cell_by_ids(self, row_id: str, col_id: str, data: Union[int, str]) -> None:
"""sets value of cell mapped by row_id and col_id to data
Raises a KeyError if row_id or col_id are not found in the table
@@ -193,7 +197,6 @@ class Table(object):
except ValueError:
raise KeyError("Column (%s) not found in table" % (col_id))
-
def set_row(self, row_index: int, row_data: Union[List[float], List[int], List[str]]) -> None:
"""sets the 'row_index' row
pre::
@@ -203,7 +206,6 @@ class Table(object):
"""
self.data[row_index] = row_data
-
def set_row_by_id(self, row_id: str, row_data: List[str]) -> None:
"""sets the 'row_id' column
pre::
@@ -217,10 +219,11 @@ class Table(object):
row_index = self.row_names.index(row_id)
self.set_row(row_index, row_data)
except ValueError:
- raise KeyError('Row (%s) not found in table' % (row_id))
-
+ raise KeyError("Row (%s) not found in table" % (row_id))
- def append_row(self, row_data: Union[List[Union[float, str]], List[int]], row_name: Optional[str] = None) -> int:
+ def append_row(
+ self, row_data: Union[List[Union[float, str]], List[int]], row_name: Optional[str] = None
+ ) -> int:
"""Appends a row to the table
pre::
@@ -245,7 +248,6 @@ class Table(object):
self.row_names.insert(index, row_name)
self.data.insert(index, row_data)
-
def delete_row(self, index: int) -> List[str]:
"""Deletes the 'index' row in the table, and returns it.
Raises an IndexError if index is out of range
@@ -253,7 +255,6 @@ class Table(object):
self.row_names.pop(index)
return self.data.pop(index)
-
def delete_row_by_id(self, row_id: str) -> None:
"""Deletes the 'row_id' row in the table.
Raises a KeyError if row_id was not found.
@@ -262,8 +263,7 @@ class Table(object):
row_index = self.row_names.index(row_id)
self.delete_row(row_index)
except ValueError:
- raise KeyError('Row (%s) not found in table' % (row_id))
-
+ raise KeyError("Row (%s) not found in table" % (row_id))
def set_column(self, col_index: int, col_data: Union[List[int], range]) -> None:
"""sets the 'col_index' column
@@ -276,7 +276,6 @@ class Table(object):
for row_index, cell_data in enumerate(col_data):
self.data[row_index][col_index] = cell_data
-
def set_column_by_id(self, col_id: str, col_data: Union[List[int], range]) -> None:
"""sets the 'col_id' column
pre::
@@ -290,8 +289,7 @@ class Table(object):
col_index = self.col_names.index(col_id)
self.set_column(col_index, col_data)
except ValueError:
- raise KeyError('Column (%s) not found in table' % (col_id))
-
+ raise KeyError("Column (%s) not found in table" % (col_id))
def append_column(self, col_data: range, col_name: str) -> None:
"""Appends the 'col_index' column
@@ -304,7 +302,6 @@ class Table(object):
for row_index, cell_data in enumerate(col_data):
self.data[row_index].append(cell_data)
-
def insert_column(self, index: int, col_data: range, col_name: str) -> None:
"""Appends col_data before 'index' in the table. To make 'insert'
behave like 'list.insert', inserting in an out of range index will
@@ -318,7 +315,6 @@ class Table(object):
for row_index, cell_data in enumerate(col_data):
self.data[row_index].insert(index, cell_data)
-
def delete_column(self, index: int) -> List[int]:
"""Deletes the 'index' column in the table, and returns it.
Raises an IndexError if index is out of range
@@ -326,7 +322,6 @@ class Table(object):
self.col_names.pop(index)
return [row.pop(index) for row in self.data]
-
def delete_column_by_id(self, col_id: str) -> None:
"""Deletes the 'col_id' col in the table.
Raises a KeyError if col_id was not found.
@@ -335,8 +330,7 @@ class Table(object):
col_index = self.col_names.index(col_id)
self.delete_column(col_index)
except ValueError:
- raise KeyError('Column (%s) not found in table' % (col_id))
-
+ raise KeyError("Column (%s) not found in table" % (col_id))
## The 'getter' part #######################################################
@@ -344,9 +338,12 @@ class Table(object):
"""Returns a tuple which represents the table's shape
"""
return len(self.row_names), len(self.col_names)
+
shape = property(get_shape)
- def __getitem__(self, indices: Union[Tuple[Union[int, slice, str], Union[int, str]], int, slice]) -> Any:
+ def __getitem__(
+ self, indices: Union[Tuple[Union[int, slice, str], Union[int, str]], int, slice]
+ ) -> Any:
"""provided for convenience"""
multirows: bool = False
multicols: bool = False
@@ -402,7 +399,7 @@ class Table(object):
for idx, row in enumerate(self.data[rows]):
tab.set_row(idx, row[cols])
- if multirows :
+ if multirows:
if multicols:
return tab
else:
@@ -457,14 +454,13 @@ class Table(object):
col = list(set(col))
return col
- def apply_stylesheet(self, stylesheet: 'TableStyleSheet') -> None:
+ def apply_stylesheet(self, stylesheet: "TableStyleSheet") -> None:
"""Applies the stylesheet to this table
"""
for instruction in stylesheet.instructions:
eval(instruction)
-
- def transpose(self) -> 'Table':
+ def transpose(self) -> "Table":
"""Keeps the self object intact, and returns the transposed (rotated)
table.
"""
@@ -475,7 +471,6 @@ class Table(object):
transposed.set_row(col_index, column)
return transposed
-
def pprint(self) -> str:
"""returns a string representing the table in a pretty
printed 'text' format.
@@ -490,10 +485,10 @@ class Table(object):
lines = []
# Build the 'first' line <=> the col_names one
# The first cell <=> an empty one
- col_names_line = [' '*col_start]
+ col_names_line = [" " * col_start]
for col_name in self.col_names:
- col_names_line.append(col_name + ' '*5)
- lines.append('|' + '|'.join(col_names_line) + '|')
+ col_names_line.append(col_name + " " * 5)
+ lines.append("|" + "|".join(col_names_line) + "|")
max_line_length = len(lines[0])
# Build the table
@@ -501,22 +496,21 @@ class Table(object):
line = []
# First, build the row_name's cell
row_name = self.row_names[row_index]
- line.append(row_name + ' '*(col_start-len(row_name)))
+ line.append(row_name + " " * (col_start - len(row_name)))
# Then, build all the table's cell for this line.
for col_index, cell in enumerate(row):
col_name_length = len(self.col_names[col_index]) + 5
data = str(cell)
- line.append(data + ' '*(col_name_length - len(data)))
- lines.append('|' + '|'.join(line) + '|')
+ line.append(data + " " * (col_name_length - len(data)))
+ lines.append("|" + "|".join(line) + "|")
if len(lines[-1]) > max_line_length:
max_line_length = len(lines[-1])
# Wrap the table with '-' to make a frame
- lines.insert(0, '-'*max_line_length)
- lines.append('-'*max_line_length)
- return '\n'.join(lines)
-
+ lines.insert(0, "-" * max_line_length)
+ lines.append("-" * max_line_length)
+ return "\n".join(lines)
def __repr__(self) -> str:
return repr(self.data)
@@ -526,9 +520,8 @@ class Table(object):
# We must convert cells into strings before joining them
for row in self.data:
data.append([str(cell) for cell in row])
- lines = ['\t'.join(row) for row in data]
- return '\n'.join(lines)
-
+ lines = ["\t".join(row) for row in data]
+ return "\n".join(lines)
class TableStyle:
@@ -538,18 +531,17 @@ class TableStyle:
def __init__(self, table: Table) -> None:
self._table = table
- self.size = dict([(col_name, '1*') for col_name in table.col_names])
+ self.size = dict([(col_name, "1*") for col_name in table.col_names])
# __row_column__ is a special key to define the first column which
# actually has no name (<=> left most column <=> row names column)
- self.size['__row_column__'] = '1*'
- self.alignment = dict([(col_name, 'right')
- for col_name in table.col_names])
- self.alignment['__row_column__'] = 'right'
+ self.size["__row_column__"] = "1*"
+ self.alignment = dict([(col_name, "right") for col_name in table.col_names])
+ self.alignment["__row_column__"] = "right"
# We shouldn't have to create an entry for
# the 1st col (the row_column one)
- self.units = dict([(col_name, '') for col_name in table.col_names])
- self.units['__row_column__'] = ''
+ self.units = dict([(col_name, "") for col_name in table.col_names])
+ self.units["__row_column__"] = ""
# XXX FIXME : params order should be reversed for all set() methods
def set_size(self, value: str, col_id: str) -> None:
@@ -563,38 +555,34 @@ class TableStyle:
BE CAREFUL : the '0' column is the '__row_column__' one !
"""
if col_index == 0:
- col_id = '__row_column__'
+ col_id = "__row_column__"
else:
- col_id = self._table.col_names[col_index-1]
+ col_id = self._table.col_names[col_index - 1]
self.size[col_id] = value
-
def set_alignment(self, value: str, col_id: str) -> None:
"""sets the alignment of the specified col_id to value
"""
self.alignment[col_id] = value
-
def set_alignment_by_index(self, value: str, col_index: int) -> None:
"""Allows to set the alignment according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
"""
if col_index == 0:
- col_id = '__row_column__'
+ col_id = "__row_column__"
else:
- col_id = self._table.col_names[col_index-1]
+ col_id = self._table.col_names[col_index - 1]
self.alignment[col_id] = value
-
def set_unit(self, value: str, col_id: str) -> None:
"""sets the unit of the specified col_id to value
"""
self.units[col_id] = value
-
def set_unit_by_index(self, value: str, col_index: int) -> None:
"""Allows to set the unit according to the column index rather than
using the column's id.
@@ -603,73 +591,69 @@ class TableStyle:
for the 1st column (the __row__column__ one))
"""
if col_index == 0:
- col_id = '__row_column__'
+ col_id = "__row_column__"
else:
- col_id = self._table.col_names[col_index-1]
+ col_id = self._table.col_names[col_index - 1]
self.units[col_id] = value
-
def get_size(self, col_id: str) -> str:
"""Returns the size of the specified col_id
"""
return self.size[col_id]
-
def get_size_by_index(self, col_index: int) -> str:
"""Allows to get the size according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
"""
if col_index == 0:
- col_id = '__row_column__'
+ col_id = "__row_column__"
else:
- col_id = self._table.col_names[col_index-1]
+ col_id = self._table.col_names[col_index - 1]
return self.size[col_id]
-
def get_alignment(self, col_id: str) -> str:
"""Returns the alignment of the specified col_id
"""
return self.alignment[col_id]
-
def get_alignment_by_index(self, col_index: int) -> str:
"""Allors to get the alignment according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
"""
if col_index == 0:
- col_id = '__row_column__'
+ col_id = "__row_column__"
else:
- col_id = self._table.col_names[col_index-1]
+ col_id = self._table.col_names[col_index - 1]
return self.alignment[col_id]
-
def get_unit(self, col_id: str) -> str:
"""Returns the unit of the specified col_id
"""
return self.units[col_id]
-
def get_unit_by_index(self, col_index: int) -> str:
"""Allors to get the unit according to the column index rather than
using the column's id.
BE CAREFUL : the '0' column is the '__row_column__' one !
"""
if col_index == 0:
- col_id = '__row_column__'
+ col_id = "__row_column__"
else:
- col_id = self._table.col_names[col_index-1]
+ col_id = self._table.col_names[col_index - 1]
return self.units[col_id]
import re
+
CELL_PROG = re.compile("([0-9]+)_([0-9]+)")
+
class TableStyleSheet:
"""A simple Table stylesheet
Rules are expressions where cells are defined by the row_index
@@ -694,21 +678,20 @@ class TableStyleSheet:
for rule in rules:
self.add_rule(rule)
-
def add_rule(self, rule: str) -> None:
"""Adds a rule to the stylesheet rules
"""
try:
- source_code = ['from math import *']
- source_code.append(CELL_PROG.sub(r'self.data[\1][\2]', rule))
- self.instructions.append(compile('\n'.join(source_code),
- 'table.py', 'exec'))
+ source_code = ["from math import *"]
+ source_code.append(CELL_PROG.sub(r"self.data[\1][\2]", rule))
+ self.instructions.append(compile("\n".join(source_code), "table.py", "exec"))
self.rules.append(rule)
except SyntaxError:
print("Bad Stylesheet Rule : %s [skipped]" % rule)
-
- def add_rowsum_rule(self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int) -> None:
+ def add_rowsum_rule(
+ self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int
+ ) -> None:
"""Creates and adds a rule to sum over the row at row_index from
start_col to end_col.
dest_cell is a tuple of two elements (x,y) of the destination cell
@@ -718,13 +701,13 @@ class TableStyleSheet:
start_col >= 0
end_col > start_col
"""
- cell_list = ['%d_%d'%(row_index, index) for index in range(start_col,
- end_col + 1)]
- rule = '%d_%d=' % dest_cell + '+'.join(cell_list)
+ cell_list = ["%d_%d" % (row_index, index) for index in range(start_col, end_col + 1)]
+ rule = "%d_%d=" % dest_cell + "+".join(cell_list)
self.add_rule(rule)
-
- def add_rowavg_rule(self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int) -> None:
+ def add_rowavg_rule(
+ self, dest_cell: Tuple[int, int], row_index: int, start_col: int, end_col: int
+ ) -> None:
"""Creates and adds a rule to make the row average (from start_col
to end_col)
dest_cell is a tuple of two elements (x,y) of the destination cell
@@ -734,14 +717,14 @@ class TableStyleSheet:
start_col >= 0
end_col > start_col
"""
- cell_list = ['%d_%d'%(row_index, index) for index in range(start_col,
- end_col + 1)]
- num = (end_col - start_col + 1)
- rule = '%d_%d=' % dest_cell + '('+'+'.join(cell_list)+')/%f'%num
+ cell_list = ["%d_%d" % (row_index, index) for index in range(start_col, end_col + 1)]
+ num = end_col - start_col + 1
+ rule = "%d_%d=" % dest_cell + "(" + "+".join(cell_list) + ")/%f" % num
self.add_rule(rule)
-
- def add_colsum_rule(self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int) -> None:
+ def add_colsum_rule(
+ self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int
+ ) -> None:
"""Creates and adds a rule to sum over the col at col_index from
start_row to end_row.
dest_cell is a tuple of two elements (x,y) of the destination cell
@@ -751,13 +734,13 @@ class TableStyleSheet:
start_row >= 0
end_row > start_row
"""
- cell_list = ['%d_%d'%(index, col_index) for index in range(start_row,
- end_row + 1)]
- rule = '%d_%d=' % dest_cell + '+'.join(cell_list)
+ cell_list = ["%d_%d" % (index, col_index) for index in range(start_row, end_row + 1)]
+ rule = "%d_%d=" % dest_cell + "+".join(cell_list)
self.add_rule(rule)
-
- def add_colavg_rule(self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int) -> None:
+ def add_colavg_rule(
+ self, dest_cell: Tuple[int, int], col_index: int, start_row: int, end_row: int
+ ) -> None:
"""Creates and adds a rule to make the col average (from start_row
to end_row)
dest_cell is a tuple of two elements (x,y) of the destination cell
@@ -767,14 +750,12 @@ class TableStyleSheet:
start_row >= 0
end_row > start_row
"""
- cell_list = ['%d_%d'%(index, col_index) for index in range(start_row,
- end_row + 1)]
- num = (end_row - start_row + 1)
- rule = '%d_%d=' % dest_cell + '('+'+'.join(cell_list)+')/%f'%num
+ cell_list = ["%d_%d" % (index, col_index) for index in range(start_row, end_row + 1)]
+ num = end_row - start_row + 1
+ rule = "%d_%d=" % dest_cell + "(" + "+".join(cell_list) + ")/%f" % num
self.add_rule(rule)
-
class TableCellRenderer:
"""Defines a simple text renderer
"""
@@ -789,35 +770,36 @@ class TableCellRenderer:
"""
self.properties = properties
-
- def render_cell(self, cell_coord: Tuple[int, int], table: Table, table_style: TableStyle) -> Union[str, int]:
+ def render_cell(
+ self, cell_coord: Tuple[int, int], table: Table, table_style: TableStyle
+ ) -> Union[str, int]:
"""Renders the cell at 'cell_coord' in the table, using table_style
"""
row_index, col_index = cell_coord
cell_value = table.data[row_index][col_index]
- final_content = self._make_cell_content(cell_value,
- table_style, col_index +1)
- return self._render_cell_content(final_content,
- table_style, col_index + 1)
-
+ final_content = self._make_cell_content(cell_value, table_style, col_index + 1)
+ return self._render_cell_content(final_content, table_style, col_index + 1)
- def render_row_cell(self, row_name: str, table: Table, table_style: TableStyle) -> Union[str, int]:
+ def render_row_cell(
+ self, row_name: str, table: Table, table_style: TableStyle
+ ) -> Union[str, int]:
"""Renders the cell for 'row_id' row
"""
cell_value = row_name
return self._render_cell_content(cell_value, table_style, 0)
-
- def render_col_cell(self, col_name: str, table: Table, table_style: TableStyle) -> Union[str, int]:
+ def render_col_cell(
+ self, col_name: str, table: Table, table_style: TableStyle
+ ) -> Union[str, int]:
"""Renders the cell for 'col_id' row
"""
cell_value = col_name
col_index = table.col_names.index(col_name)
- return self._render_cell_content(cell_value, table_style, col_index +1)
+ return self._render_cell_content(cell_value, table_style, col_index + 1)
-
-
- def _render_cell_content(self, content: Union[str, int], table_style: TableStyle, col_index: int) -> Union[str, int]:
+ def _render_cell_content(
+ self, content: Union[str, int], table_style: TableStyle, col_index: int
+ ) -> Union[str, int]:
"""Makes the appropriate rendering for this cell content.
Rendering properties will be searched using the
*table_style.get_xxx_by_index(col_index)' methods
@@ -826,31 +808,30 @@ class TableCellRenderer:
"""
return content
-
- def _make_cell_content(self, cell_content: int, table_style: TableStyle, col_index: int) -> Union[int, str]:
+ def _make_cell_content(
+ self, cell_content: int, table_style: TableStyle, col_index: int
+ ) -> Union[int, str]:
"""Makes the cell content (adds decoration data, like units for
example)
"""
final_content: Union[int, str] = cell_content
- if 'skip_zero' in self.properties:
- replacement_char = self.properties['skip_zero']
+ if "skip_zero" in self.properties:
+ replacement_char = self.properties["skip_zero"]
else:
replacement_char = 0
if replacement_char and final_content == 0:
return replacement_char
try:
- units_on = self.properties['units']
+ units_on = self.properties["units"]
if units_on:
- final_content = self._add_unit(
- cell_content, table_style, col_index)
+ final_content = self._add_unit(cell_content, table_style, col_index)
except KeyError:
pass
return final_content
-
def _add_unit(self, cell_content: int, table_style: TableStyle, col_index: int) -> str:
"""Adds unit to the cell_content if needed
"""
@@ -858,7 +839,6 @@ class TableCellRenderer:
return str(cell_content) + " " + unit
-
class DocbookRenderer(TableCellRenderer):
"""Defines how to render a cell for a docboook table
"""
@@ -867,21 +847,20 @@ class DocbookRenderer(TableCellRenderer):
"""Computes the colspec element according to the style
"""
size = table_style.get_size_by_index(col_index)
- return '<colspec colname="c%d" colwidth="%s"/>\n' % \
- (col_index, size)
-
+ return '<colspec colname="c%d" colwidth="%s"/>\n' % (col_index, size)
- def _render_cell_content(self, cell_content: Union[int, str], table_style: TableStyle, col_index: int) -> str:
+ def _render_cell_content(
+ self, cell_content: Union[int, str], table_style: TableStyle, col_index: int
+ ) -> str:
"""Makes the appropriate rendering for this cell content.
Rendering properties will be searched using the
table_style.get_xxx_by_index(col_index)' methods.
"""
try:
- align_on = self.properties['alignment']
+ align_on = self.properties["alignment"]
alignment = table_style.get_alignment_by_index(col_index)
if align_on:
- return "<entry align='%s'>%s</entry>\n" % \
- (alignment, cell_content)
+ return "<entry align='%s'>%s</entry>\n" % (alignment, cell_content)
except KeyError:
# KeyError <=> Default alignment
return "<entry>%s</entry>\n" % cell_content
@@ -894,39 +873,36 @@ class TableWriter:
"""A class to write tables
"""
- def __init__(self, stream: StringIO, table: Table, style: Optional[Any], **properties: Any) -> None:
+ def __init__(
+ self, stream: StringIO, table: Table, style: Optional[Any], **properties: Any
+ ) -> None:
self._stream = stream
self.style = style or TableStyle(table)
self._table = table
self.properties = properties
self.renderer: Optional[DocbookRenderer] = None
-
def set_style(self, style):
"""sets the table's associated style
"""
self.style = style
-
def set_renderer(self, renderer: DocbookRenderer) -> None:
"""sets the way to render cell
"""
self.renderer = renderer
-
def update_properties(self, **properties):
"""Updates writer's properties (for cell rendering)
"""
self.properties.update(properties)
-
def write_table(self, title: str = "") -> None:
"""Writes the table
"""
raise NotImplementedError("write_table must be implemented !")
-
class DocbookTableWriter(TableWriter):
"""Defines an implementation of TableWriter to write a table in Docbook
"""
@@ -937,56 +913,48 @@ class DocbookTableWriter(TableWriter):
assert self.renderer is not None
# Define col_headers (colstpec elements)
- for col_index in range(len(self._table.col_names)+1):
- self._stream.write(self.renderer.define_col_header(col_index,
- self.style))
+ for col_index in range(len(self._table.col_names) + 1):
+ self._stream.write(self.renderer.define_col_header(col_index, self.style))
self._stream.write("<thead>\n<row>\n")
# XXX FIXME : write an empty entry <=> the first (__row_column) column
- self._stream.write('<entry></entry>\n')
+ self._stream.write("<entry></entry>\n")
for col_name in self._table.col_names:
- self._stream.write(self.renderer.render_col_cell(
- col_name, self._table,
- self.style))
+ self._stream.write(self.renderer.render_col_cell(col_name, self._table, self.style))
self._stream.write("</row>\n</thead>\n")
-
def _write_body(self) -> None:
"""Writes the table body
"""
assert self.renderer is not None
- self._stream.write('<tbody>\n')
+ self._stream.write("<tbody>\n")
for row_index, row in enumerate(self._table.data):
- self._stream.write('<row>\n')
+ self._stream.write("<row>\n")
row_name = self._table.row_names[row_index]
# Write the first entry (row_name)
- self._stream.write(self.renderer.render_row_cell(row_name,
- self._table,
- self.style))
+ self._stream.write(self.renderer.render_row_cell(row_name, self._table, self.style))
for col_index, cell in enumerate(row):
- self._stream.write(self.renderer.render_cell(
- (row_index, col_index),
- self._table, self.style))
+ self._stream.write(
+ self.renderer.render_cell((row_index, col_index), self._table, self.style)
+ )
- self._stream.write('</row>\n')
-
- self._stream.write('</tbody>\n')
+ self._stream.write("</row>\n")
+ self._stream.write("</tbody>\n")
def write_table(self, title: str = "") -> None:
"""Writes the table
"""
- self._stream.write('<table>\n<title>%s></title>\n'%(title))
+ self._stream.write("<table>\n<title>%s></title>\n" % (title))
self._stream.write(
- '<tgroup cols="%d" align="left" colsep="1" rowsep="1">\n'%
- (len(self._table.col_names)+1))
+ '<tgroup cols="%d" align="left" colsep="1" rowsep="1">\n'
+ % (len(self._table.col_names) + 1)
+ )
self._write_headers()
self._write_body()
- self._stream.write('</tgroup>\n</table>\n')
-
-
+ self._stream.write("</tgroup>\n</table>\n")
diff --git a/logilab/common/tasksqueue.py b/logilab/common/tasksqueue.py
index 4e3434e..0d4889d 100644
--- a/logilab/common/tasksqueue.py
+++ b/logilab/common/tasksqueue.py
@@ -29,22 +29,21 @@ MEDIUM = 10
HIGH = 100
PRIORITY = {
- 'LOW': LOW,
- 'MEDIUM': MEDIUM,
- 'HIGH': HIGH,
- }
+ "LOW": LOW,
+ "MEDIUM": MEDIUM,
+ "HIGH": HIGH,
+}
REVERSE_PRIORITY = dict((values, key) for key, values in PRIORITY.items())
class PrioritizedTasksQueue(queue.Queue):
-
def _init(self, maxsize: int) -> None:
"""Initialize the queue representation"""
self.maxsize = maxsize
# ordered list of task, from the lowest to the highest priority
- self.queue: List['Task'] = [] # type: ignore
+ self.queue: List["Task"] = [] # type: ignore
- def _put(self, item: 'Task') -> None:
+ def _put(self, item: "Task") -> None:
"""Put a new item in the queue"""
for i, task in enumerate(self.queue):
# equivalent task
@@ -60,11 +59,11 @@ class PrioritizedTasksQueue(queue.Queue):
return
insort_left(self.queue, item)
- def _get(self) -> 'Task':
+ def _get(self) -> "Task":
"""Get an item from the queue"""
return self.queue.pop()
- def __iter__(self) -> Iterator['Task']:
+ def __iter__(self) -> Iterator["Task"]:
return iter(self.queue)
def remove(self, tid: str) -> None:
@@ -74,7 +73,8 @@ class PrioritizedTasksQueue(queue.Queue):
if task.id == tid:
self.queue.pop(i)
return
- raise ValueError('not task of id %s in queue' % tid)
+ raise ValueError("not task of id %s in queue" % tid)
+
class Task:
def __init__(self, tid: str, priority: int = LOW) -> None:
@@ -84,9 +84,9 @@ class Task:
self.priority = priority
def __repr__(self) -> str:
- return '<Task %s @%#x>' % (self.id, id(self))
+ return "<Task %s @%#x>" % (self.id, id(self))
- def __lt__(self, other: 'Task') -> bool:
+ def __lt__(self, other: "Task") -> bool:
return self.priority < other.priority
def __eq__(self, other: object) -> bool:
@@ -94,5 +94,5 @@ class Task:
__hash__ = object.__hash__
- def merge(self, other: 'Task') -> None:
+ def merge(self, other: "Task") -> None:
pass
diff --git a/logilab/common/testlib.py b/logilab/common/testlib.py
index 8348900..f8401c4 100644
--- a/logilab/common/testlib.py
+++ b/logilab/common/testlib.py
@@ -64,6 +64,7 @@ import configparser
from logilab.common.deprecation import class_deprecated, deprecated
import unittest as unittest_legacy
+
if not getattr(unittest_legacy, "__package__", None):
try:
import unittest2 as unittest
@@ -83,22 +84,22 @@ from logilab.common.decorators import cached, classproperty
from logilab.common import textutils
-__all__ = ['unittest_main', 'find_tests', 'nocoverage', 'pause_trace']
+__all__ = ["unittest_main", "find_tests", "nocoverage", "pause_trace"]
-DEFAULT_PREFIXES = ('test', 'regrtest', 'smoketest', 'unittest',
- 'func', 'validation')
+DEFAULT_PREFIXES = ("test", "regrtest", "smoketest", "unittest", "func", "validation")
-is_generator = deprecated('[lgc 0.63] use inspect.isgeneratorfunction')(isgeneratorfunction)
+is_generator = deprecated("[lgc 0.63] use inspect.isgeneratorfunction")(isgeneratorfunction)
# used by unittest to count the number of relevant levels in the traceback
__unittest = 1
-@deprecated('with_tempdir is deprecated, use tempfile.TemporaryDirectory.')
+@deprecated("with_tempdir is deprecated, use tempfile.TemporaryDirectory.")
def with_tempdir(callable: Callable) -> Callable:
"""A decorator ensuring no temporary file left when the function return
Work only for temporary file created with the tempfile module"""
if isgeneratorfunction(callable):
+
def proxy(*args: Any, **kwargs: Any) -> Iterator[Union[Iterator, Iterator[str]]]:
old_tmpdir = tempfile.gettempdir()
new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-")
@@ -111,9 +112,11 @@ def with_tempdir(callable: Callable) -> Callable:
rmtree(new_tmpdir, ignore_errors=True)
finally:
tempfile.tempdir = old_tmpdir
+
return proxy
else:
+
@wraps(callable)
def proxy(*args: Any, **kargs: Any) -> Any:
@@ -127,11 +130,14 @@ def with_tempdir(callable: Callable) -> Callable:
rmtree(new_tmpdir, ignore_errors=True)
finally:
tempfile.tempdir = old_tmpdir
+
return proxy
+
def in_tempdir(callable):
"""A decorator moving the enclosed function inside the tempfile.tempfdir
"""
+
@wraps(callable)
def proxy(*args, **kargs):
@@ -141,8 +147,10 @@ def in_tempdir(callable):
return callable(*args, **kargs)
finally:
os.chdir(old_cwd)
+
return proxy
+
def within_tempdir(callable):
"""A decorator run the enclosed function inside a tmpdir removed after execution
"""
@@ -150,10 +158,8 @@ def within_tempdir(callable):
proxy.__name__ = callable.__name__
return proxy
-def find_tests(testdir,
- prefixes=DEFAULT_PREFIXES, suffix=".py",
- excludes=(),
- remove_suffix=True):
+
+def find_tests(testdir, prefixes=DEFAULT_PREFIXES, suffix=".py", excludes=(), remove_suffix=True):
"""
Return a list of all applicable test modules.
"""
@@ -163,7 +169,7 @@ def find_tests(testdir,
for prefix in prefixes:
if name.startswith(prefix):
if remove_suffix and name.endswith(suffix):
- name = name[:-len(suffix)]
+ name = name[: -len(suffix)]
if name not in excludes:
tests.append(name)
tests.sort()
@@ -184,13 +190,12 @@ def start_interactive_mode(result):
testindex = 0
print("Choose a test to debug:")
# order debuggers in the same way than errors were printed
- print("\n".join(['\t%s : %s' % (i, descr) for i, (_, descr)
- in enumerate(descrs)]))
+ print("\n".join(["\t%s : %s" % (i, descr) for i, (_, descr) in enumerate(descrs)]))
print("Type 'exit' (or ^D) to quit")
print()
try:
- todebug = input('Enter a test name: ')
- if todebug.strip().lower() == 'exit':
+ todebug = input("Enter a test name: ")
+ if todebug.strip().lower() == "exit":
print()
break
else:
@@ -198,7 +203,7 @@ def start_interactive_mode(result):
testindex = int(todebug)
debugger = debuggers[descrs[testindex][0]]
except (ValueError, IndexError):
- print("ERROR: invalid test number %r" % (todebug, ))
+ print("ERROR: invalid test number %r" % (todebug,))
else:
debugger.start()
except (EOFError, KeyboardInterrupt):
@@ -208,6 +213,7 @@ def start_interactive_mode(result):
# coverage pausing tools #####################################################
+
@contextmanager
def replace_trace(trace: Optional[Callable] = None) -> Iterator:
"""A context manager that temporary replaces the trace function"""
@@ -218,8 +224,7 @@ def replace_trace(trace: Optional[Callable] = None) -> Iterator:
finally:
# specific hack to work around a bug in pycoverage, see
# https://bitbucket.org/ned/coveragepy/issue/123
- if (oldtrace is not None and not callable(oldtrace) and
- hasattr(oldtrace, 'pytrace')):
+ if oldtrace is not None and not callable(oldtrace) and hasattr(oldtrace, "pytrace"):
oldtrace = oldtrace.pytrace
sys.settrace(oldtrace)
@@ -229,7 +234,7 @@ pause_trace = replace_trace
def nocoverage(func: Callable) -> Callable:
"""Function decorator that pauses tracing functions"""
- if hasattr(func, 'uncovered'):
+ if hasattr(func, "uncovered"):
return func
# mypy: "Callable[..., Any]" has no attribute "uncovered"
# dynamic attribute for magic
@@ -238,6 +243,7 @@ def nocoverage(func: Callable) -> Callable:
def not_covered(*args: Any, **kwargs: Any) -> Any:
with pause_trace():
return func(*args, **kwargs)
+
# mypy: "Callable[[VarArg(Any), KwArg(Any)], NoReturn]" has no attribute "uncovered"
# dynamic attribute for magic
not_covered.uncovered = True # type: ignore
@@ -249,49 +255,56 @@ def nocoverage(func: Callable) -> Callable:
# Add deprecation warnings about new api used by module level fixtures in unittest2
# http://www.voidspace.org.uk/python/articles/unittest2.shtml#setupmodule-and-teardownmodule
-class _DebugResult(object): # simplify import statement among unittest flavors..
+class _DebugResult(object): # simplify import statement among unittest flavors..
"Used by the TestSuite to hold previous class when running in debug."
_previousTestClass = None
_moduleSetUpFailed = False
shouldStop = False
+
# backward compatibility: TestSuite might be imported from lgc.testlib
TestSuite = unittest.TestSuite
+
class keywords(dict):
"""Keyword args (**kwargs) support for generative tests."""
+
class starargs(tuple):
"""Variable arguments (*args) for generative tests."""
+
def __new__(cls, *args):
return tuple.__new__(cls, args)
+
unittest_main = unittest.main
class InnerTestSkipped(SkipTest):
"""raised when a test is skipped"""
+
pass
+
def parse_generative_args(params: Tuple[int, ...]) -> Tuple[Union[List[bool], List[int]], Dict]:
args = []
varargs = ()
kwargs: Dict = {}
- flags = 0 # 2 <=> starargs, 4 <=> kwargs
+ flags = 0 # 2 <=> starargs, 4 <=> kwargs
for param in params:
if isinstance(param, starargs):
varargs = param
if flags:
- raise TypeError('found starargs after keywords !')
+ raise TypeError("found starargs after keywords !")
flags |= 2
args += list(varargs)
elif isinstance(param, keywords):
kwargs = param
if flags & 4:
- raise TypeError('got multiple keywords parameters')
+ raise TypeError("got multiple keywords parameters")
flags |= 4
elif flags & 2 or flags & 4:
- raise TypeError('found parameters after kwargs or args')
+ raise TypeError("found parameters after kwargs or args")
else:
args.append(param)
@@ -304,13 +317,14 @@ class InnerTest(tuple):
instance.name = name
return instance
+
class Tags(set):
"""A set of tag able validate an expression"""
def __init__(self, *tags: str, **kwargs: Any) -> None:
- self.inherit = kwargs.pop('inherit', True)
+ self.inherit = kwargs.pop("inherit", True)
if kwargs:
- raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys())
+ raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys())
if len(tags) == 1 and not isinstance(tags[0], str):
tags = tags[0]
@@ -328,25 +342,26 @@ class Tags(set):
# mypy: Argument 1 of "__or__" is incompatible with supertype "AbstractSet";
# mypy: supertype defines the argument type as "AbstractSet[_T]"
# not sure how to fix this one
- def __or__(self, other: 'Tags') -> 'Tags': # type: ignore
+ def __or__(self, other: "Tags") -> "Tags": # type: ignore
return Tags(*super(Tags, self).__or__(other))
# duplicate definition from unittest2 of the _deprecate decorator
def _deprecate(original_func):
def deprecated_func(*args, **kwargs):
- warnings.warn(
- ('Please use %s instead.' % original_func.__name__),
- DeprecationWarning, 2)
+ warnings.warn(("Please use %s instead." % original_func.__name__), DeprecationWarning, 2)
return original_func(*args, **kwargs)
+
return deprecated_func
+
class TestCase(unittest.TestCase):
"""A unittest.TestCase extension with some additional methods."""
+
maxDiff = None
tags = Tags()
- def __init__(self, methodName: str = 'runTest') -> None:
+ def __init__(self, methodName: str = "runTest") -> None:
super(TestCase, self).__init__(methodName)
self.__exc_info = sys.exc_info
self.__testMethodName = self._testMethodName
@@ -355,13 +370,14 @@ class TestCase(unittest.TestCase):
@classproperty
@cached
- def datadir(cls) -> str: # pylint: disable=E0213
+ def datadir(cls) -> str: # pylint: disable=E0213
"""helper attribute holding the standard test's data directory
NOTE: this is a logilab's standard
"""
mod = sys.modules[cls.__module__]
- return osp.join(osp.dirname(osp.abspath(mod.__file__)), 'data')
+ return osp.join(osp.dirname(osp.abspath(mod.__file__)), "data")
+
# cache it (use a class method to cache on class since TestCase is
# instantiated for each test run)
@@ -392,11 +408,12 @@ class TestCase(unittest.TestCase):
except (KeyboardInterrupt, SystemExit):
raise
except unittest.SkipTest as e:
- if hasattr(result, 'addSkip'):
+ if hasattr(result, "addSkip"):
result.addSkip(self, str(e))
else:
- warnings.warn("TestResult has no addSkip method, skips not reported",
- RuntimeWarning, 2)
+ warnings.warn(
+ "TestResult has no addSkip method, skips not reported", RuntimeWarning, 2
+ )
result.addSuccess(self)
return False
except:
@@ -423,23 +440,26 @@ class TestCase(unittest.TestCase):
# if result.cvg:
# result.cvg.start()
testMethod = self._get_test_method()
- if (getattr(self.__class__, "__unittest_skip__", False) or
- getattr(testMethod, "__unittest_skip__", False)):
+ if getattr(self.__class__, "__unittest_skip__", False) or getattr(
+ testMethod, "__unittest_skip__", False
+ ):
# If the class or method was skipped.
try:
- skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
- or getattr(testMethod, '__unittest_skip_why__', ''))
- if hasattr(result, 'addSkip'):
+ skip_why = getattr(self.__class__, "__unittest_skip_why__", "") or getattr(
+ testMethod, "__unittest_skip_why__", ""
+ )
+ if hasattr(result, "addSkip"):
result.addSkip(self, skip_why)
else:
- warnings.warn("TestResult has no addSkip method, skips not reported",
- RuntimeWarning, 2)
+ warnings.warn(
+ "TestResult has no addSkip method, skips not reported", RuntimeWarning, 2
+ )
result.addSuccess(self)
finally:
result.stopTest(self)
return
if runcondition and not runcondition(testMethod):
- return # test is skipped
+ return # test is skipped
result.startTest(self)
try:
if not self.quiet_run(result, self.setUp):
@@ -447,11 +467,10 @@ class TestCase(unittest.TestCase):
generative = isgeneratorfunction(testMethod)
# generative tests
if generative:
- self._proceed_generative(result, testMethod,
- runcondition)
+ self._proceed_generative(result, testMethod, runcondition)
else:
status = self._proceed(result, testMethod)
- success = (status == 0)
+ success = status == 0
if not self.quiet_run(result, self.tearDown):
return
if not generative and success:
@@ -461,19 +480,19 @@ class TestCase(unittest.TestCase):
# result.cvg.stop()
result.stopTest(self)
- def _proceed_generative(self, result: Any, testfunc: Callable, runcondition: Callable = None) -> bool:
+ def _proceed_generative(
+ self, result: Any, testfunc: Callable, runcondition: Callable = None
+ ) -> bool:
# cancel startTest()'s increment
result.testsRun -= 1
success = True
try:
for params in testfunc():
- if runcondition and not runcondition(testfunc,
- skipgenerator=False):
- if not (isinstance(params, InnerTest)
- and runcondition(params)):
+ if runcondition and not runcondition(testfunc, skipgenerator=False):
+ if not (isinstance(params, InnerTest) and runcondition(params)):
continue
if not isinstance(params, (tuple, list)):
- params = (params, )
+ params = (params,)
func = params[0]
args, kwargs = parse_generative_args(params[1:])
# increment test counter manually
@@ -485,9 +504,9 @@ class TestCase(unittest.TestCase):
else:
success = False
# XXX Don't stop anymore if an error occured
- #if status == 2:
+ # if status == 2:
# result.shouldStop = True
- if result.shouldStop: # either on error or on exitfirst + error
+ if result.shouldStop: # either on error or on exitfirst + error
break
except self.failureException:
result.addFailure(self, self.__exc_info())
@@ -500,7 +519,13 @@ class TestCase(unittest.TestCase):
success = False
return success
- def _proceed(self, result: Any, testfunc: Callable, args: Union[List[bool], List[int], Tuple[()]] = (), kwargs: Optional[Dict] = None) -> int:
+ def _proceed(
+ self,
+ result: Any,
+ testfunc: Callable,
+ args: Union[List[bool], List[int], Tuple[()]] = (),
+ kwargs: Optional[Dict] = None,
+ ) -> int:
"""proceed the actual test
returns 0 on success, 1 on failure, 2 on error
@@ -529,39 +554,40 @@ class TestCase(unittest.TestCase):
def innerSkip(self, msg: str = None) -> NoReturn:
"""mark a generative test as skipped for the <msg> reason"""
- msg = msg or 'test was skipped'
+ msg = msg or "test was skipped"
raise InnerTestSkipped(msg)
- if sys.version_info >= (3,2):
+ if sys.version_info >= (3, 2):
assertItemsEqual = unittest.TestCase.assertCountEqual
else:
assertCountEqual = unittest.TestCase.assertItemsEqual
-TestCase.assertItemsEqual = deprecated('assertItemsEqual is deprecated, use assertCountEqual')(
- TestCase.assertItemsEqual)
+
+TestCase.assertItemsEqual = deprecated("assertItemsEqual is deprecated, use assertCountEqual")(
+ TestCase.assertItemsEqual
+)
import doctest
+
class SkippedSuite(unittest.TestSuite):
def test(self):
"""just there to trigger test execution"""
- self.skipped_test('doctest module has no DocTestSuite class')
+ self.skipped_test("doctest module has no DocTestSuite class")
class DocTestFinder(doctest.DocTestFinder):
-
def __init__(self, *args, **kwargs):
- self.skipped = kwargs.pop('skipped', ())
+ self.skipped = kwargs.pop("skipped", ())
doctest.DocTestFinder.__init__(self, *args, **kwargs)
def _get_test(self, obj, name, module, globs, source_lines):
"""override default _get_test method to be able to skip tests
according to skipped attribute's value
"""
- if getattr(obj, '__name__', '') in self.skipped:
+ if getattr(obj, "__name__", "") in self.skipped:
return None
- return doctest.DocTestFinder._get_test(self, obj, name, module,
- globs, source_lines)
+ return doctest.DocTestFinder._get_test(self, obj, name, module, globs, source_lines)
# mypy error: Invalid metaclass 'class_deprecated'
@@ -571,10 +597,11 @@ class DocTest(TestCase, metaclass=class_deprecated): # type: ignore
I don't know how to make unittest.main consider the DocTestSuite instance
without this hack
"""
- __deprecation_warning__ = 'use stdlib doctest module with unittest API directly'
+
+ __deprecation_warning__ = "use stdlib doctest module with unittest API directly"
skipped = ()
- def __call__(self, result=None, runcondition=None, options=None):\
- # pylint: disable=W0613
+
+ def __call__(self, result=None, runcondition=None, options=None): # pylint: disable=W0613
try:
finder = DocTestFinder(skipped=self.skipped)
suite = doctest.DocTestSuite(self.module, test_finder=finder)
@@ -590,6 +617,7 @@ class DocTest(TestCase, metaclass=class_deprecated): # type: ignore
finally:
builtins.__dict__.clear()
builtins.__dict__.update(old_builtins)
+
run = __call__
def test(self):
@@ -607,21 +635,27 @@ class MockConnection:
def cursor(self):
"""Mock cursor method"""
return self
+
def execute(self, query, args=None):
"""Mock execute method"""
- self.received.append( (query, args) )
+ self.received.append((query, args))
+
def fetchone(self):
"""Mock fetchone method"""
return self.results[0]
+
def fetchall(self):
"""Mock fetchall method"""
return self.results
+
def commit(self):
"""Mock commiy method"""
- self.states.append( ('commit', len(self.received)) )
+ self.states.append(("commit", len(self.received)))
+
def rollback(self):
"""Mock rollback method"""
- self.states.append( ('rollback', len(self.received)) )
+ self.states.append(("rollback", len(self.received)))
+
def close(self):
"""Mock close method"""
pass
@@ -629,7 +663,7 @@ class MockConnection:
# mypy error: Name 'Mock' is not defined
# dynamic class created by this class
-def mock_object(**params: Any) -> 'Mock': # type: ignore
+def mock_object(**params: Any) -> "Mock": # type: ignore
"""creates an object using params to set attributes
>>> option = mock_object(verbose=False, index=range(5))
>>> option.verbose
@@ -637,7 +671,7 @@ def mock_object(**params: Any) -> 'Mock': # type: ignore
>>> option.index
[0, 1, 2, 3, 4]
"""
- return type('Mock', (), params)()
+ return type("Mock", (), params)()
def create_files(paths: List[str], chroot: str) -> None:
@@ -664,7 +698,7 @@ def create_files(paths: List[str], chroot: str) -> None:
path = osp.join(chroot, path)
filename = osp.basename(path)
# path is a directory path
- if filename == '':
+ if filename == "":
dirs.add(path)
# path is a filename path
else:
@@ -674,54 +708,69 @@ def create_files(paths: List[str], chroot: str) -> None:
if not osp.isdir(dirpath):
os.makedirs(dirpath)
for filepath in files:
- open(filepath, 'w').close()
+ open(filepath, "w").close()
-class AttrObject: # XXX cf mock_object
+class AttrObject: # XXX cf mock_object
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
+
def tag(*args: str, **kwargs: Any) -> Callable:
"""descriptor adding tag to a function"""
+
def desc(func: Callable) -> Callable:
- assert not hasattr(func, 'tags')
+ assert not hasattr(func, "tags")
# mypy: "Callable[..., Any]" has no attribute "tags"
# dynamic magic attribute
func.tags = Tags(*args, **kwargs) # type: ignore
return func
+
return desc
+
def require_version(version: str) -> Callable:
""" Compare version of python interpreter to the given one. Skip the test
if older.
"""
+
def check_require_version(f: Callable) -> Callable:
- version_elements = version.split('.')
+ version_elements = version.split(".")
try:
compare = tuple([int(v) for v in version_elements])
except ValueError:
- raise ValueError('%s is not a correct version : should be X.Y[.Z].' % version)
+ raise ValueError("%s is not a correct version : should be X.Y[.Z]." % version)
current = sys.version_info[:3]
if current < compare:
+
def new_f(self, *args, **kwargs):
- self.skipTest('Need at least %s version of python. Current version is %s.' % (version, '.'.join([str(element) for element in current])))
+ self.skipTest(
+ "Need at least %s version of python. Current version is %s."
+ % (version, ".".join([str(element) for element in current]))
+ )
+
new_f.__name__ = f.__name__
return new_f
else:
return f
+
return check_require_version
+
def require_module(module: str) -> Callable:
""" Check if the given module is loaded. Skip the test if not.
"""
+
def check_require_module(f: Callable) -> Callable:
try:
__import__(module)
return f
except ImportError:
+
def new_f(self, *args, **kwargs):
- self.skipTest('%s can not be imported.' % module)
+ self.skipTest("%s can not be imported." % module)
+
new_f.__name__ = f.__name__
return new_f
- return check_require_module
+ return check_require_module
diff --git a/logilab/common/textutils.py b/logilab/common/textutils.py
index 4b6ea98..b988c7a 100644
--- a/logilab/common/textutils.py
+++ b/logilab/common/textutils.py
@@ -50,33 +50,37 @@ from re import Pattern, Match
from warnings import warn
from unicodedata import normalize as _uninormalize
from typing import Any, Optional, Tuple, List, Callable, Dict, Union
+
try:
from os import linesep
except ImportError:
- linesep = '\n' # gae
+ linesep = "\n" # gae
from logilab.common.deprecation import deprecated
MANUAL_UNICODE_MAP = {
- u'\xa1': u'!', # INVERTED EXCLAMATION MARK
- u'\u0142': u'l', # LATIN SMALL LETTER L WITH STROKE
- u'\u2044': u'/', # FRACTION SLASH
- u'\xc6': u'AE', # LATIN CAPITAL LETTER AE
- u'\xa9': u'(c)', # COPYRIGHT SIGN
- u'\xab': u'"', # LEFT-POINTING DOUBLE ANGLE QUOTATION MARK
- u'\xe6': u'ae', # LATIN SMALL LETTER AE
- u'\xae': u'(r)', # REGISTERED SIGN
- u'\u0153': u'oe', # LATIN SMALL LIGATURE OE
- u'\u0152': u'OE', # LATIN CAPITAL LIGATURE OE
- u'\xd8': u'O', # LATIN CAPITAL LETTER O WITH STROKE
- u'\xf8': u'o', # LATIN SMALL LETTER O WITH STROKE
- u'\xbb': u'"', # RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK
- u'\xdf': u'ss', # LATIN SMALL LETTER SHARP S
- u'\u2013': u'-', # HYPHEN
- u'\u2019': u"'", # SIMPLE QUOTE
- }
-
-def unormalize(ustring: str, ignorenonascii: Optional[Any] = None, substitute: Optional[str] = None) -> str:
+ "\xa1": "!", # INVERTED EXCLAMATION MARK
+ "\u0142": "l", # LATIN SMALL LETTER L WITH STROKE
+ "\u2044": "/", # FRACTION SLASH
+ "\xc6": "AE", # LATIN CAPITAL LETTER AE
+ "\xa9": "(c)", # COPYRIGHT SIGN
+ "\xab": '"', # LEFT-POINTING DOUBLE ANGLE QUOTATION MARK
+ "\xe6": "ae", # LATIN SMALL LETTER AE
+ "\xae": "(r)", # REGISTERED SIGN
+ "\u0153": "oe", # LATIN SMALL LIGATURE OE
+ "\u0152": "OE", # LATIN CAPITAL LIGATURE OE
+ "\xd8": "O", # LATIN CAPITAL LETTER O WITH STROKE
+ "\xf8": "o", # LATIN SMALL LETTER O WITH STROKE
+ "\xbb": '"', # RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK
+ "\xdf": "ss", # LATIN SMALL LETTER SHARP S
+ "\u2013": "-", # HYPHEN
+ "\u2019": "'", # SIMPLE QUOTE
+}
+
+
+def unormalize(
+ ustring: str, ignorenonascii: Optional[Any] = None, substitute: Optional[str] = None
+) -> str:
"""replace diacritical characters with their corresponding ascii characters
Convert the unicode string to its long normalized form (unicode character
@@ -92,22 +96,26 @@ def unormalize(ustring: str, ignorenonascii: Optional[Any] = None, substitute: O
"""
# backward compatibility, ignorenonascii was a boolean
if ignorenonascii is not None:
- warn("ignorenonascii is deprecated, use substitute named parameter instead",
- DeprecationWarning, stacklevel=2)
+ warn(
+ "ignorenonascii is deprecated, use substitute named parameter instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
if ignorenonascii:
- substitute = ''
+ substitute = ""
res = []
for letter in ustring[:]:
try:
replacement = MANUAL_UNICODE_MAP[letter]
except KeyError:
- replacement = _uninormalize('NFKD', letter)[0]
+ replacement = _uninormalize("NFKD", letter)[0]
if ord(replacement) >= 2 ** 7:
if substitute is None:
raise ValueError("can't deal with non-ascii based characters")
replacement = substitute
res.append(replacement)
- return u''.join(res)
+ return "".join(res)
+
def unquote(string: str) -> str:
"""remove optional quotes (simple or double) from the string
@@ -120,17 +128,18 @@ def unquote(string: str) -> str:
"""
if not string:
return string
- if string[0] in '"\'':
+ if string[0] in "\"'":
string = string[1:]
- if string[-1] in '"\'':
+ if string[-1] in "\"'":
string = string[:-1]
return string
-_BLANKLINES_RGX = re.compile('\r?\n\r?\n')
-_NORM_SPACES_RGX = re.compile('\s+')
+_BLANKLINES_RGX = re.compile("\r?\n\r?\n")
+_NORM_SPACES_RGX = re.compile("\s+")
+
-def normalize_text(text: str, line_len: int = 80, indent: str = '', rest: bool = False) -> str:
+def normalize_text(text: str, line_len: int = 80, indent: str = "", rest: bool = False) -> str:
"""normalize a text to display it with a maximum line size and
optionally arbitrary indentation. Line jumps are normalized but blank
lines are kept. The indentation string may be used to insert a
@@ -158,10 +167,10 @@ def normalize_text(text: str, line_len: int = 80, indent: str = '', rest: bool =
result = []
for text in _BLANKLINES_RGX.split(text):
result.append(normp(text, line_len, indent))
- return ('%s%s%s' % (linesep, indent, linesep)).join(result)
+ return ("%s%s%s" % (linesep, indent, linesep)).join(result)
-def normalize_paragraph(text: str, line_len: int = 80, indent: str = '') -> str:
+def normalize_paragraph(text: str, line_len: int = 80, indent: str = "") -> str:
"""normalize a text to display it with a maximum line size and
optionally arbitrary indentation. Line jumps are normalized. The
indentation string may be used top insert a comment mark for
@@ -182,7 +191,7 @@ def normalize_paragraph(text: str, line_len: int = 80, indent: str = '') -> str:
inferior to `line_len`, and optionally prefixed by an
indentation string
"""
- text = _NORM_SPACES_RGX.sub(' ', text)
+ text = _NORM_SPACES_RGX.sub(" ", text)
line_len = line_len - len(indent)
lines = []
while text:
@@ -190,7 +199,8 @@ def normalize_paragraph(text: str, line_len: int = 80, indent: str = '') -> str:
lines.append(indent + aline)
return linesep.join(lines)
-def normalize_rest_paragraph(text: str, line_len: int = 80, indent: str = '') -> str:
+
+def normalize_rest_paragraph(text: str, line_len: int = 80, indent: str = "") -> str:
"""normalize a ReST text to display it with a maximum line size and
optionally arbitrary indentation. Line jumps are normalized. The
indentation string may be used top insert a comment mark for
@@ -211,21 +221,21 @@ def normalize_rest_paragraph(text: str, line_len: int = 80, indent: str = '') ->
inferior to `line_len`, and optionally prefixed by an
indentation string
"""
- toreport = ''
+ toreport = ""
lines = []
line_len = line_len - len(indent)
for line in text.splitlines():
- line = toreport + _NORM_SPACES_RGX.sub(' ', line.strip())
- toreport = ''
+ line = toreport + _NORM_SPACES_RGX.sub(" ", line.strip())
+ toreport = ""
while len(line) > line_len:
# too long line, need split
line, toreport = splittext(line, line_len)
lines.append(indent + line)
if toreport:
- line = toreport + ' '
- toreport = ''
+ line = toreport + " "
+ toreport = ""
else:
- line = ''
+ line = ""
if line:
lines.append(indent + line.strip())
return linesep.join(lines)
@@ -239,18 +249,18 @@ def splittext(text: str, line_len: int) -> Tuple[str, str]:
* the rest of the text which has to be reported on another line
"""
if len(text) <= line_len:
- return text, ''
- pos = min(len(text)-1, line_len)
- while pos > 0 and text[pos] != ' ':
+ return text, ""
+ pos = min(len(text) - 1, line_len)
+ while pos > 0 and text[pos] != " ":
pos -= 1
if pos == 0:
pos = min(len(text), line_len)
- while len(text) > pos and text[pos] != ' ':
+ while len(text) > pos and text[pos] != " ":
pos += 1
- return text[:pos], text[pos+1:].strip()
+ return text[:pos], text[pos + 1 :].strip()
-def splitstrip(string: str, sep: str = ',') -> List[str]:
+def splitstrip(string: str, sep: str = ",") -> List[str]:
"""return a list of stripped string by splitting the string given as
argument on `sep` (',' by default). Empty string are discarded.
@@ -271,15 +281,16 @@ def splitstrip(string: str, sep: str = ',') -> List[str]:
"""
return [word.strip() for word in string.split(sep) if word.strip()]
-get_csv = deprecated('get_csv is deprecated, use splitstrip')(splitstrip)
+
+get_csv = deprecated("get_csv is deprecated, use splitstrip")(splitstrip)
def split_url_or_path(url_or_path):
"""return the latest component of a string containing either an url of the
form <scheme>://<path> or a local file system path
"""
- if '://' in url_or_path:
- return url_or_path.rstrip('/').rsplit('/', 1)
+ if "://" in url_or_path:
+ return url_or_path.rstrip("/").rsplit("/", 1)
return osp.split(url_or_path.rstrip(osp.sep))
@@ -303,8 +314,8 @@ def text_to_dict(text):
return res
for line in text.splitlines():
line = line.strip()
- if line and not line.startswith('#'):
- key, value = [w.strip() for w in line.split('=', 1)]
+ if line and not line.startswith("#"):
+ key, value = [w.strip() for w in line.split("=", 1)]
if key in res:
try:
res[key].append(value)
@@ -315,13 +326,12 @@ def text_to_dict(text):
return res
-_BLANK_URE = r'(\s|,)+'
+_BLANK_URE = r"(\s|,)+"
_BLANK_RE = re.compile(_BLANK_URE)
-__VALUE_URE = r'-?(([0-9]+\.[0-9]*)|((0x?)?[0-9]+))'
-__UNITS_URE = r'[a-zA-Z]+'
-_VALUE_RE = re.compile(r'(?P<value>%s)(?P<unit>%s)?'%(__VALUE_URE, __UNITS_URE))
-_VALIDATION_RE = re.compile(r'^((%s)(%s))*(%s)?$' % (__VALUE_URE, __UNITS_URE,
- __VALUE_URE))
+__VALUE_URE = r"-?(([0-9]+\.[0-9]*)|((0x?)?[0-9]+))"
+__UNITS_URE = r"[a-zA-Z]+"
+_VALUE_RE = re.compile(r"(?P<value>%s)(?P<unit>%s)?" % (__VALUE_URE, __UNITS_URE))
+_VALIDATION_RE = re.compile(r"^((%s)(%s))*(%s)?$" % (__VALUE_URE, __UNITS_URE, __VALUE_URE))
BYTE_UNITS = {
"b": 1,
@@ -336,11 +346,18 @@ TIME_UNITS = {
"s": 1,
"min": 60,
"h": 60 * 60,
- "d": 60 * 60 *24,
+ "d": 60 * 60 * 24,
}
-def apply_units(string: str, units: Dict[str, int], inter: Union[Callable, None, type] = None, final: type = float, blank_reg: Pattern = _BLANK_RE,
- value_reg: Pattern = _VALUE_RE) -> Union[float, int]:
+
+def apply_units(
+ string: str,
+ units: Dict[str, int],
+ inter: Union[Callable, None, type] = None,
+ final: type = float,
+ blank_reg: Pattern = _BLANK_RE,
+ value_reg: Pattern = _VALUE_RE,
+) -> Union[float, int]:
"""Parse the string applying the units defined in units
(e.g.: "1.5m",{'m',60} -> 80).
@@ -361,7 +378,7 @@ def apply_units(string: str, units: Dict[str, int], inter: Union[Callable, None,
"""
if inter is None:
inter = final
- fstring = _BLANK_RE.sub('', string)
+ fstring = _BLANK_RE.sub("", string)
if not (fstring and _VALIDATION_RE.match(fstring)):
raise ValueError("Invalid unit string: %r." % string)
values = []
@@ -373,15 +390,15 @@ def apply_units(string: str, units: Dict[str, int], inter: Union[Callable, None,
try:
value *= units[unit.lower()]
except KeyError:
- raise ValueError('invalid unit %s. valid units are %s' %
- (unit, list(units.keys())))
+ raise ValueError("invalid unit %s. valid units are %s" % (unit, list(units.keys())))
values.append(value)
return final(sum(values))
-_LINE_RGX = re.compile('\r\n|\r+|\n')
+_LINE_RGX = re.compile("\r\n|\r+|\n")
+
-def pretty_match(match: Match, string: str, underline_char: str = '^') -> str:
+def pretty_match(match: Match, string: str, underline_char: str = "^") -> str:
"""return a string with the match location underlined:
>>> import re
@@ -419,7 +436,7 @@ def pretty_match(match: Match, string: str, underline_char: str = '^') -> str:
result = [string[:start_line_pos]]
start_line_pos += len(linesep)
offset = start - start_line_pos
- underline = ' ' * offset + underline_char * (end - start)
+ underline = " " * offset + underline_char * (end - start)
end_line_pos = string.find(linesep, end)
if end_line_pos == -1:
string = string[start_line_pos:]
@@ -429,7 +446,7 @@ def pretty_match(match: Match, string: str, underline_char: str = '^') -> str:
# mypy: Incompatible types in assignment (expression has type "str",
# mypy: variable has type "int")
# but it's a str :|
- end = string[end_line_pos + len(linesep):] # type: ignore
+ end = string[end_line_pos + len(linesep) :] # type: ignore
string = string[start_line_pos:end_line_pos]
result.append(string)
result.append(underline)
@@ -439,30 +456,31 @@ def pretty_match(match: Match, string: str, underline_char: str = '^') -> str:
# Ansi colorization ###########################################################
-ANSI_PREFIX = '\033['
-ANSI_END = 'm'
-ANSI_RESET = '\033[0m'
+ANSI_PREFIX = "\033["
+ANSI_END = "m"
+ANSI_RESET = "\033[0m"
ANSI_STYLES = {
- 'reset': "0",
- 'bold': "1",
- 'italic': "3",
- 'underline': "4",
- 'blink': "5",
- 'inverse': "7",
- 'strike': "9",
+ "reset": "0",
+ "bold": "1",
+ "italic": "3",
+ "underline": "4",
+ "blink": "5",
+ "inverse": "7",
+ "strike": "9",
}
ANSI_COLORS = {
- 'reset': "0",
- 'black': "30",
- 'red': "31",
- 'green': "32",
- 'yellow': "33",
- 'blue': "34",
- 'magenta': "35",
- 'cyan': "36",
- 'white': "37",
+ "reset": "0",
+ "black": "30",
+ "red": "31",
+ "green": "32",
+ "yellow": "33",
+ "blue": "34",
+ "magenta": "35",
+ "cyan": "36",
+ "white": "37",
}
+
def _get_ansi_code(color: Optional[str] = None, style: Optional[str] = None) -> str:
"""return ansi escape code corresponding to color and style
@@ -488,13 +506,14 @@ def _get_ansi_code(color: Optional[str] = None, style: Optional[str] = None) ->
ansi_code.append(ANSI_STYLES[effect])
if color:
if color.isdigit():
- ansi_code.extend(['38', '5'])
+ ansi_code.extend(["38", "5"])
ansi_code.append(color)
else:
ansi_code.append(ANSI_COLORS[color])
if ansi_code:
- return ANSI_PREFIX + ';'.join(ansi_code) + ANSI_END
- return ''
+ return ANSI_PREFIX + ";".join(ansi_code) + ANSI_END
+ return ""
+
def colorize_ansi(msg: str, color: Optional[str] = None, style: Optional[str] = None) -> str:
"""colorize message by wrapping it with ansi escape codes
@@ -522,23 +541,24 @@ def colorize_ansi(msg: str, color: Optional[str] = None, style: Optional[str] =
escape_code = _get_ansi_code(color, style)
# If invalid (or unknown) color, don't wrap msg with ansi codes
if escape_code:
- return '%s%s%s' % (escape_code, msg, ANSI_RESET)
+ return "%s%s%s" % (escape_code, msg, ANSI_RESET)
return msg
-DIFF_STYLE = {'separator': 'cyan', 'remove': 'red', 'add': 'green'}
+
+DIFF_STYLE = {"separator": "cyan", "remove": "red", "add": "green"}
+
def diff_colorize_ansi(lines, out=sys.stdout, style=DIFF_STYLE):
for line in lines:
- if line[:4] in ('--- ', '+++ '):
- out.write(colorize_ansi(line, style['separator']))
- elif line[0] == '-':
- out.write(colorize_ansi(line, style['remove']))
- elif line[0] == '+':
- out.write(colorize_ansi(line, style['add']))
- elif line[:4] == '--- ':
- out.write(colorize_ansi(line, style['separator']))
- elif line[:4] == '+++ ':
- out.write(colorize_ansi(line, style['separator']))
+ if line[:4] in ("--- ", "+++ "):
+ out.write(colorize_ansi(line, style["separator"]))
+ elif line[0] == "-":
+ out.write(colorize_ansi(line, style["remove"]))
+ elif line[0] == "+":
+ out.write(colorize_ansi(line, style["add"]))
+ elif line[:4] == "--- ":
+ out.write(colorize_ansi(line, style["separator"]))
+ elif line[:4] == "+++ ":
+ out.write(colorize_ansi(line, style["separator"]))
else:
out.write(line)
-
diff --git a/logilab/common/tree.py b/logilab/common/tree.py
index 1fc5a21..dbde2eb 100644
--- a/logilab/common/tree.py
+++ b/logilab/common/tree.py
@@ -32,9 +32,11 @@ from typing import Optional, Any, Union, List, Callable, TypeVar
## Exceptions #################################################################
+
class NodeNotFound(Exception):
"""raised when a node has not been found"""
+
EX_SIBLING_NOT_FOUND: str = "No such sibling as '%s'"
EX_CHILD_NOT_FOUND: str = "No such child as '%s'"
EX_NODE_NOT_FOUND: str = "No such node as '%s'"
@@ -49,7 +51,7 @@ NodeType = Any
class Node(object):
"""a basic tree node, characterized by an id"""
- def __init__(self, nid: Optional[str] = None) -> None :
+ def __init__(self, nid: Optional[str] = None) -> None:
self.id = nid
# navigation
# should be something like Optional[type(self)] for subclasses but that's not possible?
@@ -61,14 +63,14 @@ class Node(object):
return iter(self.children)
def __str__(self, indent=0):
- s = ['%s%s %s' % (' '*indent, self.__class__.__name__, self.id)]
+ s = ["%s%s %s" % (" " * indent, self.__class__.__name__, self.id)]
indent += 2
for child in self.children:
try:
s.append(child.__str__(indent))
except TypeError:
s.append(child.__str__())
- return '\n'.join(s)
+ return "\n".join(s)
def is_leaf(self):
return not self.children
@@ -103,7 +105,7 @@ class Node(object):
try:
assert self.parent is not None
return self.parent.get_child_by_id(nid)
- except NodeNotFound :
+ except NodeNotFound:
raise NodeNotFound(EX_SIBLING_NOT_FOUND % nid)
def next_sibling(self):
@@ -116,7 +118,7 @@ class Node(object):
return None
index = parent.children.index(self)
try:
- return parent.children[index+1]
+ return parent.children[index + 1]
except IndexError:
return None
@@ -130,7 +132,7 @@ class Node(object):
return None
index = parent.children.index(self)
if index > 0:
- return parent.children[index-1]
+ return parent.children[index - 1]
return None
def get_node_by_id(self, nid: str) -> NodeType:
@@ -140,7 +142,7 @@ class Node(object):
root = self.root()
try:
return root.get_child_by_id(nid, 1)
- except NodeNotFound :
+ except NodeNotFound:
raise NodeNotFound(EX_NODE_NOT_FOUND % nid)
def get_child_by_id(self, nid: str, recurse: Optional[bool] = None) -> NodeType:
@@ -149,13 +151,13 @@ class Node(object):
"""
if self.id == nid:
return self
- for c in self.children :
+ for c in self.children:
if recurse:
try:
return c.get_child_by_id(nid, 1)
- except NodeNotFound :
+ except NodeNotFound:
continue
- if c.id == nid :
+ if c.id == nid:
return c
raise NodeNotFound(EX_CHILD_NOT_FOUND % nid)
@@ -164,13 +166,13 @@ class Node(object):
return child of given path (path is a list of ids)
"""
if len(path) > 0 and path[0] == self.id:
- if len(path) == 1 :
+ if len(path) == 1:
return self
- else :
- for c in self.children :
+ else:
+ for c in self.children:
try:
return c.get_child_by_path(path[1:])
- except NodeNotFound :
+ except NodeNotFound:
pass
raise NodeNotFound(EX_CHILD_NOT_FOUND % path)
@@ -180,7 +182,7 @@ class Node(object):
"""
if self.parent is not None:
return 1 + self.parent.depth()
- else :
+ else:
return 0
def depth_down(self) -> int:
@@ -237,6 +239,7 @@ class Node(object):
lst.extend(self.parent.lineage())
return lst
+
class VNode(Node, VisitedMixIn):
# we should probably merge this VisitedMixIn here because it's only used here
"""a visitable node
@@ -247,7 +250,8 @@ class VNode(Node, VisitedMixIn):
class BinaryNode(VNode):
"""a binary node (i.e. only two children
"""
- def __init__(self, lhs=None, rhs=None) :
+
+ def __init__(self, lhs=None, rhs=None):
VNode.__init__(self)
if lhs is not None or rhs is not None:
assert lhs and rhs
@@ -267,24 +271,29 @@ class BinaryNode(VNode):
return self.children[0], self.children[1]
-
if sys.version_info[0:2] >= (2, 2):
list_class = list
else:
from UserList import UserList
+
list_class = UserList
+
class ListNode(VNode, list_class):
"""Used to manipulate Nodes as Lists
"""
+
def __init__(self):
list_class.__init__(self)
VNode.__init__(self)
self.children = self
def __str__(self, indent=0):
- return '%s%s %s' % (indent*' ', self.__class__.__name__,
- ', '.join([str(v) for v in self]))
+ return "%s%s %s" % (
+ indent * " ",
+ self.__class__.__name__,
+ ", ".join([str(v) for v in self]),
+ )
def append(self, child):
"""add a node to children"""
@@ -309,8 +318,10 @@ class ListNode(VNode, list_class):
def __iter__(self):
return list_class.__iter__(self)
+
# construct list from tree ####################################################
+
def post_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> List[Node]:
"""
create a list with tree nodes for which the <filter> function returned true
@@ -339,6 +350,7 @@ def post_order_list(node: Optional[Node], filter_func: Callable = no_filter) ->
poped = 1
return l
+
def pre_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> List[Node]:
"""
create a list with tree nodes for which the <filter> function returned true
@@ -368,15 +380,18 @@ def pre_order_list(node: Optional[Node], filter_func: Callable = no_filter) -> L
poped = 1
return l
+
class PostfixedDepthFirstIterator(FilteredIterator):
"""a postfixed depth first iterator, designed to be used with visitors
"""
+
def __init__(self, node: Node, filter_func: Optional[Any] = None) -> None:
FilteredIterator.__init__(self, node, post_order_list, filter_func)
+
class PrefixedDepthFirstIterator(FilteredIterator):
"""a prefixed depth first iterator, designed to be used with visitors
"""
+
def __init__(self, node: Node, filter_func: Optional[Any] = None) -> None:
FilteredIterator.__init__(self, node, pre_order_list, filter_func)
-
diff --git a/logilab/common/umessage.py b/logilab/common/umessage.py
index 77a6272..a759003 100644
--- a/logilab/common/umessage.py
+++ b/logilab/common/umessage.py
@@ -40,14 +40,14 @@ import logilab.common as lgc
def decode_QP(string: str) -> str:
parts: List[str] = []
for maybe_decoded, charset in decode_header(string):
- if not charset :
- charset = 'iso-8859-15'
+ if not charset:
+ charset = "iso-8859-15"
# python 3 sometimes returns str and sometimes bytes.
# the 'official' fix is to use the new 'policy' APIs
# https://bugs.python.org/issue24797
# let's just handle this bug ourselves for now
if isinstance(maybe_decoded, bytes):
- decoded = maybe_decoded.decode(charset, 'replace')
+ decoded = maybe_decoded.decode(charset, "replace")
else:
decoded = maybe_decoded
@@ -57,21 +57,24 @@ def decode_QP(string: str) -> str:
if sys.version_info < (3, 3):
# decoding was non-RFC compliant wrt to whitespace handling
# see http://bugs.python.org/issue1079
- return u' '.join(parts)
+ return " ".join(parts)
+
+ return "".join(parts)
- return u''.join(parts)
def message_from_file(fd):
try:
return UMessage(email.message_from_file(fd))
except email.errors.MessageParseError:
- return ''
+ return ""
+
-def message_from_string(string: str) -> Union['UMessage', str]:
+def message_from_string(string: str) -> Union["UMessage", str]:
try:
return UMessage(email.message_from_string(string))
except email.errors.MessageParseError:
- return ''
+ return ""
+
class UMessage:
"""Encapsulates an email.Message instance and returns only unicode objects.
@@ -92,8 +95,7 @@ class UMessage:
return self.get(header)
def get_all(self, header: str, default: Tuple[()] = ()) -> List[str]:
- return [decode_QP(val) for val in self.message.get_all(header, default)
- if val is not None]
+ return [decode_QP(val) for val in self.message.get_all(header, default) if val is not None]
def is_multipart(self):
return self.message.is_multipart()
@@ -105,7 +107,9 @@ class UMessage:
for part in self.message.walk():
yield UMessage(part)
- def get_payload(self, index: Optional[Any] = None, decode: bool = False) -> Union[str, 'UMessage', List['UMessage']]:
+ def get_payload(
+ self, index: Optional[Any] = None, decode: bool = False
+ ) -> Union[str, "UMessage", List["UMessage"]]:
message = self.message
if index is None:
@@ -118,17 +122,17 @@ class UMessage:
if isinstance(payload, list):
return [UMessage(msg) for msg in payload]
- if message.get_content_maintype() != 'text':
+ if message.get_content_maintype() != "text":
return payload
if isinstance(payload, str):
return payload
- charset = message.get_content_charset() or 'iso-8859-1'
+ charset = message.get_content_charset() or "iso-8859-1"
if search_function(charset) is None:
- charset = 'iso-8859-1'
+ charset = "iso-8859-1"
- return str(payload or b'', charset, "replace")
+ return str(payload or b"", charset, "replace")
else:
payload = UMessage(message.get_payload(index, decode))
@@ -147,7 +151,7 @@ class UMessage:
try:
return str(value)
except UnicodeDecodeError:
- return u'error decoding filename'
+ return "error decoding filename"
# other convenience methods ###############################################
@@ -155,8 +159,8 @@ class UMessage:
"""return an unicode string containing all the message's headers"""
values = []
for header in self.message.keys():
- values.append(u'%s: %s' % (header, self.get(header)))
- return '\n'.join(values)
+ values.append("%s: %s" % (header, self.get(header)))
+ return "\n".join(values)
def multi_addrs(self, header):
"""return a list of 2-uple (name, address) for the given address (which
@@ -172,7 +176,7 @@ class UMessage:
"""return a datetime object for the email's date or None if no date is
set or if it can't be parsed
"""
- value = self.get('date')
+ value = self.get("date")
if value is None and alternative_source:
unix_from = self.message.get_unixfrom()
if unix_from is not None:
diff --git a/logilab/common/ureports/__init__.py b/logilab/common/ureports/__init__.py
index 9c0f1df..a539150 100644
--- a/logilab/common/ureports/__init__.py
+++ b/logilab/common/ureports/__init__.py
@@ -46,14 +46,14 @@ def layout_title(layout):
"""
for child in layout.children:
if isinstance(child, Title):
- return u' '.join([node.data for node in get_nodes(child, Text)])
+ return " ".join([node.data for node in get_nodes(child, Text)])
def build_summary(layout, level=1):
"""make a summary for the report, including X level"""
assert level > 0
level -= 1
- summary = List(klass=u'summary')
+ summary = List(klass="summary")
for child in layout.children:
if not isinstance(child, Section):
continue
@@ -61,8 +61,8 @@ def build_summary(layout, level=1):
if not label and not child.id:
continue
if not child.id:
- child.id = label.replace(' ', '-')
- node = Link(u'#'+child.id, label=label or child.id)
+ child.id = label.replace(" ", "-")
+ node = Link("#" + child.id, label=label or child.id)
# FIXME: Three following lines produce not very compliant
# docbook: there are some useless <para><para>. They might be
# replaced by the three commented lines but this then produces
@@ -70,16 +70,21 @@ def build_summary(layout, level=1):
if level and [n for n in child.children if isinstance(n, Section)]:
node = Paragraph([node, build_summary(child, level)])
summary.append(node)
-# summary.append(node)
-# if level and [n for n in child.children if isinstance(n, Section)]:
-# summary.append(build_summary(child, level))
+ # summary.append(node)
+ # if level and [n for n in child.children if isinstance(n, Section)]:
+ # summary.append(build_summary(child, level))
return summary
class BaseWriter(object):
"""base class for ureport writers"""
- def format(self, layout: Any, stream: Optional[Union[StringIO, TextIO]] = None, encoding: Optional[Any] = None) -> None:
+ def format(
+ self,
+ layout: Any,
+ stream: Optional[Union[StringIO, TextIO]] = None,
+ encoding: Optional[Any] = None,
+ ) -> None:
"""format and write the given layout into the stream object
unicode policy: unicode strings may be found in the layout;
@@ -89,22 +94,22 @@ class BaseWriter(object):
if stream is None:
stream = sys.stdout
if not encoding:
- encoding = getattr(stream, 'encoding', 'UTF-8')
- self.encoding = encoding or 'UTF-8'
+ encoding = getattr(stream, "encoding", "UTF-8")
+ self.encoding = encoding or "UTF-8"
self.__compute_funcs: List[Tuple[Callable[[str], Any], Callable[[str], Any]]] = []
self.out = stream
self.begin_format(layout)
layout.accept(self)
self.end_format(layout)
- def format_children(self, layout: Union['Paragraph', 'Section', 'Title']) -> None:
+ def format_children(self, layout: Union["Paragraph", "Section", "Title"]) -> None:
"""recurse on the layout children and call their accept method
(see the Visitor pattern)
"""
- for child in getattr(layout, 'children', ()):
+ for child in getattr(layout, "children", ()):
child.accept(self)
- def writeln(self, string: str = u'') -> None:
+ def writeln(self, string: str = "") -> None:
"""write a line in the output buffer"""
self.write(string + linesep)
@@ -146,7 +151,7 @@ class BaseWriter(object):
# fill missing cells
while len(result[-1]) < cols:
- result[-1].append(u'')
+ result[-1].append("")
return result
@@ -166,13 +171,13 @@ class BaseWriter(object):
# error from porting to python3?
stream.write(data.encode(self.encoding)) # type: ignore
- def writeln(data: str = u'') -> None:
+ def writeln(data: str = "") -> None:
try:
- stream.write(data+linesep)
+ stream.write(data + linesep)
except UnicodeEncodeError:
# mypy: Unsupported operand types for + ("bytes" and "str")
# error from porting to python3?
- stream.write(data.encode(self.encoding)+linesep) # type: ignore
+ stream.write(data.encode(self.encoding) + linesep) # type: ignore
# mypy: Cannot assign to a method
# this really looks like black dirty magic since self.write is reused elsewhere in the code
@@ -202,6 +207,7 @@ class BaseWriter(object):
del self.write
del self.writeln
+
# mypy error: Incompatible import of "Table" (imported name has type
# mypy error: "Type[logilab.common.ureports.nodes.Table]", local name has type
# mypy error: "Type[logilab.common.table.Table]")
diff --git a/logilab/common/ureports/docbook_writer.py b/logilab/common/ureports/docbook_writer.py
index f28474e..7e7564f 100644
--- a/logilab/common/ureports/docbook_writer.py
+++ b/logilab/common/ureports/docbook_writer.py
@@ -29,15 +29,17 @@ class DocbookWriter(HTMLWriter):
super(HTMLWriter, self).begin_format(layout)
if self.snippet is None:
self.writeln('<?xml version="1.0" encoding="ISO-8859-1"?>')
- self.writeln("""
+ self.writeln(
+ """
<book xmlns:xi='http://www.w3.org/2001/XInclude'
lang='fr'>
-""")
+"""
+ )
def end_format(self, layout):
"""finished to format a layout"""
if self.snippet is None:
- self.writeln('</book>')
+ self.writeln("</book>")
def visit_section(self, layout):
"""display a section (using <chapter> (level 0) or <section>)"""
@@ -46,98 +48,95 @@ class DocbookWriter(HTMLWriter):
else:
tag = "section"
self.section += 1
- self.writeln(self._indent('<%s%s>' % (tag, self.handle_attrs(layout))))
+ self.writeln(self._indent("<%s%s>" % (tag, self.handle_attrs(layout))))
self.format_children(layout)
- self.writeln(self._indent('</%s>' % tag))
+ self.writeln(self._indent("</%s>" % tag))
self.section -= 1
def visit_title(self, layout):
"""display a title using <title>"""
- self.write(self._indent(' <title%s>' % self.handle_attrs(layout)))
+ self.write(self._indent(" <title%s>" % self.handle_attrs(layout)))
self.format_children(layout)
- self.writeln('</title>')
+ self.writeln("</title>")
def visit_table(self, layout):
"""display a table as html"""
self.writeln(
- self._indent(' <table%s><title>%s</title>' % (
- self.handle_attrs(layout), layout.title)))
+ self._indent(" <table%s><title>%s</title>" % (self.handle_attrs(layout), layout.title))
+ )
self.writeln(self._indent(' <tgroup cols="%s">' % layout.cols))
for i in range(layout.cols):
- self.writeln(
- self._indent(
- ' <colspec colname="c%s" colwidth="1*"/>' % i))
+ self.writeln(self._indent(' <colspec colname="c%s" colwidth="1*"/>' % i))
table_content = self.get_table_content(layout)
# write headers
if layout.cheaders:
- self.writeln(self._indent(' <thead>'))
+ self.writeln(self._indent(" <thead>"))
self._write_row(table_content[0])
- self.writeln(self._indent(' </thead>'))
+ self.writeln(self._indent(" </thead>"))
table_content = table_content[1:]
elif layout.rcheaders:
- self.writeln(self._indent(' <thead>'))
+ self.writeln(self._indent(" <thead>"))
self._write_row(table_content[-1])
- self.writeln(self._indent(' </thead>'))
+ self.writeln(self._indent(" </thead>"))
table_content = table_content[:-1]
# write body
- self.writeln(self._indent(' <tbody>'))
+ self.writeln(self._indent(" <tbody>"))
for i in range(len(table_content)):
row = table_content[i]
- self.writeln(self._indent(' <row>'))
+ self.writeln(self._indent(" <row>"))
for j in range(len(row)):
- cell = row[j] or '&#160;'
- self.writeln(
- self._indent(' <entry>%s</entry>' % cell))
- self.writeln(self._indent(' </row>'))
- self.writeln(self._indent(' </tbody>'))
- self.writeln(self._indent(' </tgroup>'))
- self.writeln(self._indent(' </table>'))
+ cell = row[j] or "&#160;"
+ self.writeln(self._indent(" <entry>%s</entry>" % cell))
+ self.writeln(self._indent(" </row>"))
+ self.writeln(self._indent(" </tbody>"))
+ self.writeln(self._indent(" </tgroup>"))
+ self.writeln(self._indent(" </table>"))
def _write_row(self, row):
"""write content of row (using <row> <entry>)"""
- self.writeln(' <row>')
+ self.writeln(" <row>")
for j in range(len(row)):
- cell = row[j] or '&#160;'
- self.writeln(' <entry>%s</entry>' % cell)
- self.writeln(self._indent(' </row>'))
+ cell = row[j] or "&#160;"
+ self.writeln(" <entry>%s</entry>" % cell)
+ self.writeln(self._indent(" </row>"))
def visit_list(self, layout):
"""display a list (using <itemizedlist>)"""
- self.writeln(self._indent(' <itemizedlist%s>'
- '' % self.handle_attrs(layout)))
+ self.writeln(self._indent(" <itemizedlist%s>" "" % self.handle_attrs(layout)))
for row in list(self.compute_content(layout)):
- self.writeln(' <listitem><para>%s</para></listitem>' % row)
- self.writeln(self._indent(' </itemizedlist>'))
+ self.writeln(" <listitem><para>%s</para></listitem>" % row)
+ self.writeln(self._indent(" </itemizedlist>"))
def visit_paragraph(self, layout):
"""display links (using <para>)"""
- self.write(self._indent(' <para>'))
+ self.write(self._indent(" <para>"))
self.format_children(layout)
- self.writeln('</para>')
+ self.writeln("</para>")
def visit_span(self, layout):
"""display links (using <p>)"""
# TODO: translate in docbook
- self.write('<literal %s>' % self.handle_attrs(layout))
+ self.write("<literal %s>" % self.handle_attrs(layout))
self.format_children(layout)
- self.write('</literal>')
+ self.write("</literal>")
def visit_link(self, layout):
"""display links (using <ulink>)"""
- self.write('<ulink url="%s"%s>%s</ulink>' % (
- layout.url, self.handle_attrs(layout), layout.label))
+ self.write(
+ '<ulink url="%s"%s>%s</ulink>' % (layout.url, self.handle_attrs(layout), layout.label)
+ )
def visit_verbatimtext(self, layout):
"""display verbatim text (using <programlisting>)"""
- self.writeln(self._indent(' <programlisting>'))
- self.write(layout.data.replace('&', '&amp;').replace('<', '&lt;'))
- self.writeln(self._indent(' </programlisting>'))
+ self.writeln(self._indent(" <programlisting>"))
+ self.write(layout.data.replace("&", "&amp;").replace("<", "&lt;"))
+ self.writeln(self._indent(" </programlisting>"))
def visit_text(self, layout):
"""add some text"""
- self.write(layout.data.replace('&', '&amp;').replace('<', '&lt;'))
+ self.write(layout.data.replace("&", "&amp;").replace("<", "&lt;"))
def _indent(self, string):
"""correctly indent string according to section"""
- return ' ' * 2*(self.section) + string
+ return " " * 2 * (self.section) + string
diff --git a/logilab/common/ureports/html_writer.py b/logilab/common/ureports/html_writer.py
index 0783075..23ff588 100644
--- a/logilab/common/ureports/html_writer.py
+++ b/logilab/common/ureports/html_writer.py
@@ -20,8 +20,16 @@ __docformat__ = "restructuredtext en"
from logilab.common.ureports import BaseWriter
-from logilab.common.ureports.nodes import (Section, Title, Table, List,
- Paragraph, Link, VerbatimText, Text)
+from logilab.common.ureports.nodes import (
+ Section,
+ Title,
+ Table,
+ List,
+ Paragraph,
+ Link,
+ VerbatimText,
+ Text,
+)
from typing import Any
@@ -34,100 +42,100 @@ class HTMLWriter(BaseWriter):
def handle_attrs(self, layout: Any) -> str:
"""get an attribute string from layout member attributes"""
- attrs = u''
- klass = getattr(layout, 'klass', None)
+ attrs = ""
+ klass = getattr(layout, "klass", None)
if klass:
- attrs += u' class="%s"' % klass
- nid = getattr(layout, 'id', None)
+ attrs += ' class="%s"' % klass
+ nid = getattr(layout, "id", None)
if nid:
- attrs += u' id="%s"' % nid
+ attrs += ' id="%s"' % nid
return attrs
def begin_format(self, layout: Any) -> None:
"""begin to format a layout"""
super(HTMLWriter, self).begin_format(layout)
if self.snippet is None:
- self.writeln(u'<html>')
- self.writeln(u'<body>')
+ self.writeln("<html>")
+ self.writeln("<body>")
def end_format(self, layout: Any) -> None:
"""finished to format a layout"""
if self.snippet is None:
- self.writeln(u'</body>')
- self.writeln(u'</html>')
+ self.writeln("</body>")
+ self.writeln("</html>")
def visit_section(self, layout: Section) -> None:
"""display a section as html, using div + h[section level]"""
self.section += 1
- self.writeln(u'<div%s>' % self.handle_attrs(layout))
+ self.writeln("<div%s>" % self.handle_attrs(layout))
self.format_children(layout)
- self.writeln(u'</div>')
+ self.writeln("</div>")
self.section -= 1
def visit_title(self, layout: Title) -> None:
"""display a title using <hX>"""
- self.write(u'<h%s%s>' % (self.section, self.handle_attrs(layout)))
+ self.write("<h%s%s>" % (self.section, self.handle_attrs(layout)))
self.format_children(layout)
- self.writeln(u'</h%s>' % self.section)
+ self.writeln("</h%s>" % self.section)
def visit_table(self, layout: Table) -> None:
"""display a table as html"""
- self.writeln(u'<table%s>' % self.handle_attrs(layout))
+ self.writeln("<table%s>" % self.handle_attrs(layout))
table_content = self.get_table_content(layout)
for i in range(len(table_content)):
row = table_content[i]
if i == 0 and layout.rheaders:
- self.writeln(u'<tr class="header">')
- elif i+1 == len(table_content) and layout.rrheaders:
- self.writeln(u'<tr class="header">')
+ self.writeln('<tr class="header">')
+ elif i + 1 == len(table_content) and layout.rrheaders:
+ self.writeln('<tr class="header">')
else:
- self.writeln(u'<tr class="%s">' % (i % 2 and 'even' or 'odd'))
+ self.writeln('<tr class="%s">' % (i % 2 and "even" or "odd"))
for j in range(len(row)):
- cell = row[j] or u'&#160;'
- if (layout.rheaders and i == 0) or \
- (layout.cheaders and j == 0) or \
- (layout.rrheaders and i+1 == len(table_content)) or \
- (layout.rcheaders and j+1 == len(row)):
- self.writeln(u'<th>%s</th>' % cell)
+ cell = row[j] or "&#160;"
+ if (
+ (layout.rheaders and i == 0)
+ or (layout.cheaders and j == 0)
+ or (layout.rrheaders and i + 1 == len(table_content))
+ or (layout.rcheaders and j + 1 == len(row))
+ ):
+ self.writeln("<th>%s</th>" % cell)
else:
- self.writeln(u'<td>%s</td>' % cell)
- self.writeln(u'</tr>')
- self.writeln(u'</table>')
+ self.writeln("<td>%s</td>" % cell)
+ self.writeln("</tr>")
+ self.writeln("</table>")
def visit_list(self, layout: List) -> None:
"""display a list as html"""
- self.writeln(u'<ul%s>' % self.handle_attrs(layout))
+ self.writeln("<ul%s>" % self.handle_attrs(layout))
for row in list(self.compute_content(layout)):
- self.writeln(u'<li>%s</li>' % row)
- self.writeln(u'</ul>')
+ self.writeln("<li>%s</li>" % row)
+ self.writeln("</ul>")
def visit_paragraph(self, layout: Paragraph) -> None:
"""display links (using <p>)"""
- self.write(u'<p>')
+ self.write("<p>")
self.format_children(layout)
- self.write(u'</p>')
+ self.write("</p>")
def visit_span(self, layout):
"""display links (using <p>)"""
- self.write(u'<span%s>' % self.handle_attrs(layout))
+ self.write("<span%s>" % self.handle_attrs(layout))
self.format_children(layout)
- self.write(u'</span>')
+ self.write("</span>")
def visit_link(self, layout: Link) -> None:
"""display links (using <a>)"""
- self.write(u' <a href="%s"%s>%s</a>' % (layout.url,
- self.handle_attrs(layout),
- layout.label))
+ self.write(' <a href="%s"%s>%s</a>' % (layout.url, self.handle_attrs(layout), layout.label))
def visit_verbatimtext(self, layout: VerbatimText) -> None:
"""display verbatim text (using <pre>)"""
- self.write(u'<pre>')
- self.write(layout.data.replace(u'&', u'&amp;').replace(u'<', u'&lt;'))
- self.write(u'</pre>')
+ self.write("<pre>")
+ self.write(layout.data.replace("&", "&amp;").replace("<", "&lt;"))
+ self.write("</pre>")
def visit_text(self, layout: Text) -> None:
"""add some text"""
data = layout.data
if layout.escaped:
- data = data.replace(u'&', u'&amp;').replace(u'<', u'&lt;')
+ data = data.replace("&", "&amp;").replace("<", "&lt;")
self.write(data)
diff --git a/logilab/common/ureports/nodes.py b/logilab/common/ureports/nodes.py
index d086faf..26c6715 100644
--- a/logilab/common/ureports/nodes.py
+++ b/logilab/common/ureports/nodes.py
@@ -23,6 +23,7 @@ __docformat__ = "restructuredtext en"
from logilab.common.tree import VNode
from typing import Optional
+
# from logilab.common.ureports.nodes import List
# from logilab.common.ureports.nodes import Paragraph
# from logilab.common.ureports.nodes import Text
@@ -39,6 +40,7 @@ class BaseComponent(VNode):
* id : the component's optional id
* klass : the component's optional klass
"""
+
def __init__(self, id: Optional[str] = None, klass: Optional[str] = None) -> None:
VNode.__init__(self, id)
self.klass = klass
@@ -51,11 +53,16 @@ class BaseLayout(BaseComponent):
* BaseComponent attributes
* children : components in this table (i.e. the table's cells)
"""
- def __init__(self,
- children: Union[TypingList['Text'],
- Tuple[Union['Paragraph', str],
- Union[TypingList, str]], Tuple[str, ...]] = (),
- **kwargs: Any) -> None:
+
+ def __init__(
+ self,
+ children: Union[
+ TypingList["Text"],
+ Tuple[Union["Paragraph", str], Union[TypingList, str]],
+ Tuple[str, ...],
+ ] = (),
+ **kwargs: Any,
+ ) -> None:
super(BaseLayout, self).__init__(**kwargs)
@@ -87,6 +94,7 @@ class BaseLayout(BaseComponent):
# non container nodes #########################################################
+
class Text(BaseComponent):
"""a text portion
@@ -94,6 +102,7 @@ class Text(BaseComponent):
* BaseComponent attributes
* data : the text value as an encoded or unicode string
"""
+
def __init__(self, data: str, escaped: bool = True, **kwargs: Any) -> None:
super(Text, self).__init__(**kwargs)
# if isinstance(data, unicode):
@@ -120,6 +129,7 @@ class Link(BaseComponent):
* url : the link's target (REQUIRED)
* label : the link's label as a string (use the url by default)
"""
+
def __init__(self, url: str, label: str = None, **kwargs: Any) -> None:
super(Link, self).__init__(**kwargs)
assert url
@@ -136,6 +146,7 @@ class Image(BaseComponent):
* stream : the stream object containing the image data (REQUIRED)
* title : the image's optional title
"""
+
def __init__(self, filename, stream, title=None, **kwargs):
super(Image, self).__init__(**kwargs)
assert filename
@@ -147,6 +158,7 @@ class Image(BaseComponent):
# container nodes #############################################################
+
class Section(BaseLayout):
"""a section
@@ -158,6 +170,7 @@ class Section(BaseLayout):
a description may also be given to the constructor, it'll be added
as a first paragraph
"""
+
def __init__(self, title: str = None, description: str = None, **kwargs: Any) -> None:
super(Section, self).__init__(**kwargs)
if description:
@@ -206,9 +219,17 @@ class Table(BaseLayout):
* cheaders : the first col's elements are table's header
* title : the table's optional title
"""
- def __init__(self, cols: int, title: Optional[Any] = None,
- rheaders: int = 0, cheaders: int = 0, rrheaders: int = 0, rcheaders: int = 0,
- **kwargs: Any) -> None:
+
+ def __init__(
+ self,
+ cols: int,
+ title: Optional[Any] = None,
+ rheaders: int = 0,
+ cheaders: int = 0,
+ rrheaders: int = 0,
+ rcheaders: int = 0,
+ **kwargs: Any,
+ ) -> None:
super(Table, self).__init__(**kwargs)
assert isinstance(cols, int)
self.cols = cols
diff --git a/logilab/common/ureports/text_writer.py b/logilab/common/ureports/text_writer.py
index f75d7c9..efe85b7 100644
--- a/logilab/common/ureports/text_writer.py
+++ b/logilab/common/ureports/text_writer.py
@@ -24,18 +24,27 @@ __docformat__ = "restructuredtext en"
from logilab.common.textutils import linesep
from logilab.common.ureports import BaseWriter
-from logilab.common.ureports.nodes import (Section, Title, Table, List as NodeList,
- Paragraph, Link, VerbatimText, Text)
+from logilab.common.ureports.nodes import (
+ Section,
+ Title,
+ Table,
+ List as NodeList,
+ Paragraph,
+ Link,
+ VerbatimText,
+ Text,
+)
-TITLE_UNDERLINES = [u'', u'=', u'-', u'`', u'.', u'~', u'^']
-BULLETS = [u'*', u'-']
+TITLE_UNDERLINES = ["", "=", "-", "`", ".", "~", "^"]
+BULLETS = ["*", "-"]
class TextWriter(BaseWriter):
"""format layouts as text
(ReStructured inspiration but not totally handled yet)
"""
+
def begin_format(self, layout: Any) -> None:
super(TextWriter, self).begin_format(layout)
self.list_level = 0
@@ -50,20 +59,20 @@ class TextWriter(BaseWriter):
if self.pending_urls:
self.writeln()
for label, url in self.pending_urls:
- self.writeln(u'.. _`%s`: %s' % (label, url))
+ self.writeln(".. _`%s`: %s" % (label, url))
self.pending_urls = []
self.section -= 1
self.writeln()
def visit_title(self, layout: Title) -> None:
- title = u''.join(list(self.compute_content(layout)))
+ title = "".join(list(self.compute_content(layout)))
self.writeln(title)
try:
self.writeln(TITLE_UNDERLINES[self.section] * len(title))
except IndexError:
print("FIXME TITLE TOO DEEP. TURNING TITLE INTO TEXT")
- def visit_paragraph(self, layout: 'Paragraph') -> None:
+ def visit_paragraph(self, layout: "Paragraph") -> None:
"""enter a paragraph"""
self.format_children(layout)
self.writeln()
@@ -76,64 +85,67 @@ class TextWriter(BaseWriter):
"""display a table as text"""
table_content = self.get_table_content(layout)
# get columns width
- cols_width = [0]*len(table_content[0])
+ cols_width = [0] * len(table_content[0])
for row in table_content:
for index in range(len(row)):
col = row[index]
cols_width[index] = max(cols_width[index], len(col))
- if layout.klass == 'field':
+ if layout.klass == "field":
self.field_table(layout, table_content, cols_width)
else:
self.default_table(layout, table_content, cols_width)
self.writeln()
- def default_table(self, layout: Table, table_content: List[List[str]], cols_width: List[int]) -> None:
+ def default_table(
+ self, layout: Table, table_content: List[List[str]], cols_width: List[int]
+ ) -> None:
"""format a table"""
- cols_width = [size+1 for size in cols_width]
+ cols_width = [size + 1 for size in cols_width]
- format_strings = u' '.join([u'%%-%ss'] * len(cols_width))
+ format_strings = " ".join(["%%-%ss"] * len(cols_width))
format_strings = format_strings % tuple(cols_width)
- format_strings_list = format_strings.split(' ')
+ format_strings_list = format_strings.split(" ")
- table_linesep = (
- u'\n+' + u'+'.join([u'-'*w for w in cols_width]) + u'+\n')
- headsep = u'\n+' + u'+'.join([u'='*w for w in cols_width]) + u'+\n'
+ table_linesep = "\n+" + "+".join(["-" * w for w in cols_width]) + "+\n"
+ headsep = "\n+" + "+".join(["=" * w for w in cols_width]) + "+\n"
# FIXME: layout.cheaders
self.write(table_linesep)
for i in range(len(table_content)):
- self.write(u'|')
+ self.write("|")
line = table_content[i]
for j in range(len(line)):
self.write(format_strings_list[j] % line[j])
- self.write(u'|')
+ self.write("|")
if i == 0 and layout.rheaders:
self.write(headsep)
else:
self.write(table_linesep)
- def field_table(self, layout: Table, table_content: List[List[str]], cols_width: List[int]) -> None:
+ def field_table(
+ self, layout: Table, table_content: List[List[str]], cols_width: List[int]
+ ) -> None:
"""special case for field table"""
assert layout.cols == 2
- format_string = u'%s%%-%ss: %%s' % (linesep, cols_width[0])
+ format_string = "%s%%-%ss: %%s" % (linesep, cols_width[0])
for field, value in table_content:
self.write(format_string % (field, value))
def visit_list(self, layout: NodeList) -> None:
"""display a list layout as text"""
bullet = BULLETS[self.list_level % len(BULLETS)]
- indent = ' ' * self.list_level
+ indent = " " * self.list_level
self.list_level += 1
for child in layout.children:
- self.write(u'%s%s%s ' % (linesep, indent, bullet))
+ self.write("%s%s%s " % (linesep, indent, bullet))
child.accept(self)
self.list_level -= 1
def visit_link(self, layout: Link) -> None:
"""add a hyperlink"""
if layout.label != layout.url:
- self.write(u'`%s`_' % layout.label)
+ self.write("`%s`_" % layout.label)
self.pending_urls.append((layout.label, layout.url))
else:
self.write(layout.url)
@@ -141,11 +153,11 @@ class TextWriter(BaseWriter):
def visit_verbatimtext(self, layout: VerbatimText) -> None:
"""display a verbatim layout as text (so difficult ;)
"""
- self.writeln(u'::\n')
+ self.writeln("::\n")
for line in layout.data.splitlines():
- self.writeln(u' ' + line)
+ self.writeln(" " + line)
self.writeln()
def visit_text(self, layout: Text) -> None:
"""add some text"""
- self.write(u'%s' % layout.data)
+ self.write("%s" % layout.data)
diff --git a/logilab/common/urllib2ext.py b/logilab/common/urllib2ext.py
index 339aec0..dfbafc1 100644
--- a/logilab/common/urllib2ext.py
+++ b/logilab/common/urllib2ext.py
@@ -5,22 +5,27 @@ import urllib2
import kerberos as krb
+
class GssapiAuthError(Exception):
"""raised on error during authentication process"""
+
import re
-RGX = re.compile('(?:.*,)*\s*Negotiate\s*([^,]*),?', re.I)
+
+RGX = re.compile("(?:.*,)*\s*Negotiate\s*([^,]*),?", re.I)
+
def get_negociate_value(headers):
- for authreq in headers.getheaders('www-authenticate'):
+ for authreq in headers.getheaders("www-authenticate"):
match = RGX.search(authreq)
if match:
return match.group(1)
+
class HTTPGssapiAuthHandler(urllib2.BaseHandler):
"""Negotiate HTTP authentication using context from GSSAPI"""
- handler_order = 400 # before Digest Auth
+ handler_order = 400 # before Digest Auth
def __init__(self):
self._reset()
@@ -36,15 +41,16 @@ class HTTPGssapiAuthHandler(urllib2.BaseHandler):
def http_error_401(self, req, fp, code, msg, headers):
try:
if self._retried > 5:
- raise urllib2.HTTPError(req.get_full_url(), 401,
- "negotiate auth failed", headers, None)
+ raise urllib2.HTTPError(
+ req.get_full_url(), 401, "negotiate auth failed", headers, None
+ )
self._retried += 1
- logging.debug('gssapi handler, try %s' % self._retried)
+ logging.debug("gssapi handler, try %s" % self._retried)
negotiate = get_negociate_value(headers)
if negotiate is None:
- logging.debug('no negociate found in a www-authenticate header')
+ logging.debug("no negociate found in a www-authenticate header")
return None
- logging.debug('HTTPGssapiAuthHandler: negotiate 1 is %r' % negotiate)
+ logging.debug("HTTPGssapiAuthHandler: negotiate 1 is %r" % negotiate)
result, self._context = krb.authGSSClientInit("HTTP@%s" % req.get_host())
if result < 1:
raise GssapiAuthError("HTTPGssapiAuthHandler: init failed with %d" % result)
@@ -52,14 +58,14 @@ class HTTPGssapiAuthHandler(urllib2.BaseHandler):
if result < 0:
raise GssapiAuthError("HTTPGssapiAuthHandler: step 1 failed with %d" % result)
client_response = krb.authGSSClientResponse(self._context)
- logging.debug('HTTPGssapiAuthHandler: client response is %s...' % client_response[:10])
- req.add_unredirected_header('Authorization', "Negotiate %s" % client_response)
+ logging.debug("HTTPGssapiAuthHandler: client response is %s..." % client_response[:10])
+ req.add_unredirected_header("Authorization", "Negotiate %s" % client_response)
server_response = self.parent.open(req)
negotiate = get_negociate_value(server_response.info())
if negotiate is None:
- logging.warning('HTTPGssapiAuthHandler: failed to authenticate server')
+ logging.warning("HTTPGssapiAuthHandler: failed to authenticate server")
else:
- logging.debug('HTTPGssapiAuthHandler negotiate 2: %s' % negotiate)
+ logging.debug("HTTPGssapiAuthHandler negotiate 2: %s" % negotiate)
result = krb.authGSSClientStep(self._context, negotiate)
if result < 1:
raise GssapiAuthError("HTTPGssapiAuthHandler: step 2 failed with %d" % result)
@@ -70,20 +76,25 @@ class HTTPGssapiAuthHandler(urllib2.BaseHandler):
self.clean_context()
self._reset()
-if __name__ == '__main__':
+
+if __name__ == "__main__":
import sys
+
# debug
import httplib
+
httplib.HTTPConnection.debuglevel = 1
httplib.HTTPSConnection.debuglevel = 1
# debug
import logging
+
logging.basicConfig(level=logging.DEBUG)
# handle cookies
import cookielib
+
cj = cookielib.CookieJar()
ch = urllib2.HTTPCookieProcessor(cj)
# test with url sys.argv[1]
h = HTTPGssapiAuthHandler()
response = urllib2.build_opener(h, ch).open(sys.argv[1])
- print('\nresponse: %s\n--------------\n' % response.code, response.info())
+ print("\nresponse: %s\n--------------\n" % response.code, response.info())
diff --git a/logilab/common/vcgutils.py b/logilab/common/vcgutils.py
index 9cd2acd..cd2b73a 100644
--- a/logilab/common/vcgutils.py
+++ b/logilab/common/vcgutils.py
@@ -33,101 +33,141 @@ __docformat__ = "restructuredtext en"
import string
ATTRS_VAL = {
- 'algos': ('dfs', 'tree', 'minbackward',
- 'left_to_right', 'right_to_left',
- 'top_to_bottom', 'bottom_to_top',
- 'maxdepth', 'maxdepthslow', 'mindepth', 'mindepthslow',
- 'mindegree', 'minindegree', 'minoutdegree',
- 'maxdegree', 'maxindegree', 'maxoutdegree'),
- 'booleans': ('yes', 'no'),
- 'colors': ('black', 'white', 'blue', 'red', 'green', 'yellow',
- 'magenta', 'lightgrey',
- 'cyan', 'darkgrey', 'darkblue', 'darkred', 'darkgreen',
- 'darkyellow', 'darkmagenta', 'darkcyan', 'gold',
- 'lightblue', 'lightred', 'lightgreen', 'lightyellow',
- 'lightmagenta', 'lightcyan', 'lilac', 'turquoise',
- 'aquamarine', 'khaki', 'purple', 'yellowgreen', 'pink',
- 'orange', 'orchid'),
- 'shapes': ('box', 'ellipse', 'rhomb', 'triangle'),
- 'textmodes': ('center', 'left_justify', 'right_justify'),
- 'arrowstyles': ('solid', 'line', 'none'),
- 'linestyles': ('continuous', 'dashed', 'dotted', 'invisible'),
- }
+ "algos": (
+ "dfs",
+ "tree",
+ "minbackward",
+ "left_to_right",
+ "right_to_left",
+ "top_to_bottom",
+ "bottom_to_top",
+ "maxdepth",
+ "maxdepthslow",
+ "mindepth",
+ "mindepthslow",
+ "mindegree",
+ "minindegree",
+ "minoutdegree",
+ "maxdegree",
+ "maxindegree",
+ "maxoutdegree",
+ ),
+ "booleans": ("yes", "no"),
+ "colors": (
+ "black",
+ "white",
+ "blue",
+ "red",
+ "green",
+ "yellow",
+ "magenta",
+ "lightgrey",
+ "cyan",
+ "darkgrey",
+ "darkblue",
+ "darkred",
+ "darkgreen",
+ "darkyellow",
+ "darkmagenta",
+ "darkcyan",
+ "gold",
+ "lightblue",
+ "lightred",
+ "lightgreen",
+ "lightyellow",
+ "lightmagenta",
+ "lightcyan",
+ "lilac",
+ "turquoise",
+ "aquamarine",
+ "khaki",
+ "purple",
+ "yellowgreen",
+ "pink",
+ "orange",
+ "orchid",
+ ),
+ "shapes": ("box", "ellipse", "rhomb", "triangle"),
+ "textmodes": ("center", "left_justify", "right_justify"),
+ "arrowstyles": ("solid", "line", "none"),
+ "linestyles": ("continuous", "dashed", "dotted", "invisible"),
+}
# meaning of possible values:
# O -> string
# 1 -> int
# list -> value in list
GRAPH_ATTRS = {
- 'title': 0,
- 'label': 0,
- 'color': ATTRS_VAL['colors'],
- 'textcolor': ATTRS_VAL['colors'],
- 'bordercolor': ATTRS_VAL['colors'],
- 'width': 1,
- 'height': 1,
- 'borderwidth': 1,
- 'textmode': ATTRS_VAL['textmodes'],
- 'shape': ATTRS_VAL['shapes'],
- 'shrink': 1,
- 'stretch': 1,
- 'orientation': ATTRS_VAL['algos'],
- 'vertical_order': 1,
- 'horizontal_order': 1,
- 'xspace': 1,
- 'yspace': 1,
- 'layoutalgorithm': ATTRS_VAL['algos'],
- 'late_edge_labels': ATTRS_VAL['booleans'],
- 'display_edge_labels': ATTRS_VAL['booleans'],
- 'dirty_edge_labels': ATTRS_VAL['booleans'],
- 'finetuning': ATTRS_VAL['booleans'],
- 'manhattan_edges': ATTRS_VAL['booleans'],
- 'smanhattan_edges': ATTRS_VAL['booleans'],
- 'port_sharing': ATTRS_VAL['booleans'],
- 'edges': ATTRS_VAL['booleans'],
- 'nodes': ATTRS_VAL['booleans'],
- 'splines': ATTRS_VAL['booleans'],
- }
+ "title": 0,
+ "label": 0,
+ "color": ATTRS_VAL["colors"],
+ "textcolor": ATTRS_VAL["colors"],
+ "bordercolor": ATTRS_VAL["colors"],
+ "width": 1,
+ "height": 1,
+ "borderwidth": 1,
+ "textmode": ATTRS_VAL["textmodes"],
+ "shape": ATTRS_VAL["shapes"],
+ "shrink": 1,
+ "stretch": 1,
+ "orientation": ATTRS_VAL["algos"],
+ "vertical_order": 1,
+ "horizontal_order": 1,
+ "xspace": 1,
+ "yspace": 1,
+ "layoutalgorithm": ATTRS_VAL["algos"],
+ "late_edge_labels": ATTRS_VAL["booleans"],
+ "display_edge_labels": ATTRS_VAL["booleans"],
+ "dirty_edge_labels": ATTRS_VAL["booleans"],
+ "finetuning": ATTRS_VAL["booleans"],
+ "manhattan_edges": ATTRS_VAL["booleans"],
+ "smanhattan_edges": ATTRS_VAL["booleans"],
+ "port_sharing": ATTRS_VAL["booleans"],
+ "edges": ATTRS_VAL["booleans"],
+ "nodes": ATTRS_VAL["booleans"],
+ "splines": ATTRS_VAL["booleans"],
+}
NODE_ATTRS = {
- 'title': 0,
- 'label': 0,
- 'color': ATTRS_VAL['colors'],
- 'textcolor': ATTRS_VAL['colors'],
- 'bordercolor': ATTRS_VAL['colors'],
- 'width': 1,
- 'height': 1,
- 'borderwidth': 1,
- 'textmode': ATTRS_VAL['textmodes'],
- 'shape': ATTRS_VAL['shapes'],
- 'shrink': 1,
- 'stretch': 1,
- 'vertical_order': 1,
- 'horizontal_order': 1,
- }
+ "title": 0,
+ "label": 0,
+ "color": ATTRS_VAL["colors"],
+ "textcolor": ATTRS_VAL["colors"],
+ "bordercolor": ATTRS_VAL["colors"],
+ "width": 1,
+ "height": 1,
+ "borderwidth": 1,
+ "textmode": ATTRS_VAL["textmodes"],
+ "shape": ATTRS_VAL["shapes"],
+ "shrink": 1,
+ "stretch": 1,
+ "vertical_order": 1,
+ "horizontal_order": 1,
+}
EDGE_ATTRS = {
- 'sourcename': 0,
- 'targetname': 0,
- 'label': 0,
- 'linestyle': ATTRS_VAL['linestyles'],
- 'class': 1,
- 'thickness': 0,
- 'color': ATTRS_VAL['colors'],
- 'textcolor': ATTRS_VAL['colors'],
- 'arrowcolor': ATTRS_VAL['colors'],
- 'backarrowcolor': ATTRS_VAL['colors'],
- 'arrowsize': 1,
- 'backarrowsize': 1,
- 'arrowstyle': ATTRS_VAL['arrowstyles'],
- 'backarrowstyle': ATTRS_VAL['arrowstyles'],
- 'textmode': ATTRS_VAL['textmodes'],
- 'priority': 1,
- 'anchor': 1,
- 'horizontal_order': 1,
- }
+ "sourcename": 0,
+ "targetname": 0,
+ "label": 0,
+ "linestyle": ATTRS_VAL["linestyles"],
+ "class": 1,
+ "thickness": 0,
+ "color": ATTRS_VAL["colors"],
+ "textcolor": ATTRS_VAL["colors"],
+ "arrowcolor": ATTRS_VAL["colors"],
+ "backarrowcolor": ATTRS_VAL["colors"],
+ "arrowsize": 1,
+ "backarrowsize": 1,
+ "arrowstyle": ATTRS_VAL["arrowstyles"],
+ "backarrowstyle": ATTRS_VAL["arrowstyles"],
+ "textmode": ATTRS_VAL["textmodes"],
+ "priority": 1,
+ "anchor": 1,
+ "horizontal_order": 1,
+}
# Misc utilities ###############################################################
+
def latin_to_vcg(st):
"""Convert latin characters using vcg escape sequence.
"""
@@ -136,7 +176,7 @@ def latin_to_vcg(st):
try:
num = ord(char)
if num >= 192:
- st = st.replace(char, r'\fi%d'%ord(char))
+ st = st.replace(char, r"\fi%d" % ord(char))
except:
pass
return st
@@ -148,12 +188,12 @@ class VCGPrinter:
def __init__(self, output_stream):
self._stream = output_stream
- self._indent = ''
+ self._indent = ""
def open_graph(self, **args):
"""open a vcg graph
"""
- self._stream.write('%sgraph:{\n'%self._indent)
+ self._stream.write("%sgraph:{\n" % self._indent)
self._inc_indent()
self._write_attributes(GRAPH_ATTRS, **args)
@@ -161,26 +201,24 @@ class VCGPrinter:
"""close a vcg graph
"""
self._dec_indent()
- self._stream.write('%s}\n'%self._indent)
-
+ self._stream.write("%s}\n" % self._indent)
def node(self, title, **args):
"""draw a node
"""
self._stream.write('%snode: {title:"%s"' % (self._indent, title))
self._write_attributes(NODE_ATTRS, **args)
- self._stream.write('}\n')
+ self._stream.write("}\n")
-
- def edge(self, from_node, to_node, edge_type='', **args):
+ def edge(self, from_node, to_node, edge_type="", **args):
"""draw an edge from a node to another.
"""
self._stream.write(
- '%s%sedge: {sourcename:"%s" targetname:"%s"' % (
- self._indent, edge_type, from_node, to_node))
+ '%s%sedge: {sourcename:"%s" targetname:"%s"'
+ % (self._indent, edge_type, from_node, to_node)
+ )
self._write_attributes(EDGE_ATTRS, **args)
- self._stream.write('}\n')
-
+ self._stream.write("}\n")
# private ##################################################################
@@ -189,26 +227,31 @@ class VCGPrinter:
"""
for key, value in args.items():
try:
- _type = attributes_dict[key]
+ _type = attributes_dict[key]
except KeyError:
- raise Exception('''no such attribute %s
-possible attributes are %s''' % (key, attributes_dict.keys()))
+ raise Exception(
+ """no such attribute %s
+possible attributes are %s"""
+ % (key, attributes_dict.keys())
+ )
if not _type:
self._stream.write('%s%s:"%s"\n' % (self._indent, key, value))
elif _type == 1:
- self._stream.write('%s%s:%s\n' % (self._indent, key,
- int(value)))
+ self._stream.write("%s%s:%s\n" % (self._indent, key, int(value)))
elif value in _type:
- self._stream.write('%s%s:%s\n' % (self._indent, key, value))
+ self._stream.write("%s%s:%s\n" % (self._indent, key, value))
else:
- raise Exception('''value %s isn\'t correct for attribute %s
-correct values are %s''' % (value, key, _type))
+ raise Exception(
+ """value %s isn\'t correct for attribute %s
+correct values are %s"""
+ % (value, key, _type)
+ )
def _inc_indent(self):
"""increment indentation
"""
- self._indent = ' %s' % self._indent
+ self._indent = " %s" % self._indent
def _dec_indent(self):
"""decrement indentation
diff --git a/logilab/common/visitor.py b/logilab/common/visitor.py
index 0698bae..8d80d54 100644
--- a/logilab/common/visitor.py
+++ b/logilab/common/visitor.py
@@ -23,15 +23,16 @@
"""
from typing import Any, Callable, Optional, Union
from logilab.common.types import Node, HTMLWriter, TextWriter
+
__docformat__ = "restructuredtext en"
def no_filter(_: Node) -> int:
return 1
+
# Iterators ###################################################################
class FilteredIterator(object):
-
def __init__(self, node: Node, list_func: Callable, filter_func: Optional[Any] = None) -> None:
self._next = [(node, 0)]
if filter_func is None:
@@ -41,14 +42,14 @@ class FilteredIterator(object):
def __next__(self) -> Optional[Node]:
try:
return self._list.pop(0)
- except :
+ except:
return None
next = __next__
+
# Base Visitor ################################################################
class Visitor(object):
-
def __init__(self, iterator_class, filter_func=None):
self._iter_class = iterator_class
self.filter = filter_func
@@ -87,11 +88,13 @@ class Visitor(object):
"""
return result
+
# standard visited mixin ######################################################
class VisitedMixIn(object):
"""
Visited interface allow node visitors to use the node
"""
+
def get_visit_name(self) -> str:
"""
return the visit name for the mixed class. When calling 'accept', the
@@ -101,14 +104,16 @@ class VisitedMixIn(object):
try:
# mypy: "VisitedMixIn" has no attribute "TYPE"
# dynamic attribute
- return self.TYPE.replace('-', '_') # type: ignore
+ return self.TYPE.replace("-", "_") # type: ignore
except:
return self.__class__.__name__.lower()
- def accept(self, visitor: Union[HTMLWriter, TextWriter], *args: Any, **kwargs: Any) -> Optional[Any]:
- func = getattr(visitor, 'visit_%s' % self.get_visit_name())
+ def accept(
+ self, visitor: Union[HTMLWriter, TextWriter], *args: Any, **kwargs: Any
+ ) -> Optional[Any]:
+ func = getattr(visitor, "visit_%s" % self.get_visit_name())
return func(self, *args, **kwargs)
def leave(self, visitor, *args, **kwargs):
- func = getattr(visitor, 'leave_%s' % self.get_visit_name())
+ func = getattr(visitor, "leave_%s" % self.get_visit_name())
return func(self, *args, **kwargs)
diff --git a/logilab/common/xmlutils.py b/logilab/common/xmlutils.py
index 7b12c45..14e3762 100644
--- a/logilab/common/xmlutils.py
+++ b/logilab/common/xmlutils.py
@@ -34,6 +34,7 @@ from typing import Dict, Optional, Union
RE_DOUBLE_QUOTE = re.compile('([\w\-\.]+)="([^"]+)"')
RE_SIMPLE_QUOTE = re.compile("([\w\-\.]+)='([^']+)'")
+
def parse_pi_data(pi_data: str) -> Dict[str, Optional[str]]:
"""
Utility function that parses the data contained in an XML
diff --git a/setup.py b/setup.py
index c11b59b..f42c35a 100644
--- a/setup.py
+++ b/setup.py
@@ -28,7 +28,7 @@ from os import path
here = path.abspath(path.dirname(__file__))
pkginfo = {}
-with open(path.join(here, '__pkginfo__.py')) as f:
+with open(path.join(here, "__pkginfo__.py")) as f:
exec(f.read(), pkginfo)
# Get the long description from the relevant file
@@ -36,20 +36,20 @@ with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
long_description = f.read()
setup(
- name=pkginfo['distname'],
- version=pkginfo['version'],
- description=pkginfo['description'],
+ name=pkginfo["distname"],
+ version=pkginfo["version"],
+ description=pkginfo["description"],
long_description=long_description,
- url=pkginfo['web'],
- author=pkginfo['author'],
- author_email=pkginfo['author_email'],
- license=pkginfo['license'],
+ url=pkginfo["web"],
+ author=pkginfo["author"],
+ author_email=pkginfo["author_email"],
+ license=pkginfo["license"],
# See https://pypi.python.org/pypi?%3Aaction=list_classifiers
- classifiers=pkginfo['classifiers'],
- packages=find_packages(exclude=['contrib', 'docs', 'test*']),
- namespace_packages=[pkginfo['subpackage_of']],
- python_requires='>=3.3',
- install_requires=pkginfo['install_requires'],
- tests_require=pkginfo['tests_require'],
- scripts=pkginfo['scripts'],
+ classifiers=pkginfo["classifiers"],
+ packages=find_packages(exclude=["contrib", "docs", "test*"]),
+ namespace_packages=[pkginfo["subpackage_of"]],
+ python_requires=">=3.3",
+ install_requires=pkginfo["install_requires"],
+ tests_require=pkginfo["tests_require"],
+ scripts=pkginfo["scripts"],
)
diff --git a/test/data/__pkginfo__.py b/test/data/__pkginfo__.py
index d1f5731..51f36ec 100644
--- a/test/data/__pkginfo__.py
+++ b/test/data/__pkginfo__.py
@@ -20,15 +20,15 @@ __docformat__ = "restructuredtext en"
import sys
import os
-distname = 'logilab-common'
-modname = 'common'
-subpackage_of = 'logilab'
+distname = "logilab-common"
+modname = "common"
+subpackage_of = "logilab"
subpackage_master = True
numversion = (0, 63, 2)
-version = '.'.join([str(num) for num in numversion])
+version = ".".join([str(num) for num in numversion])
-license = 'LGPL' # 2.1 or later
+license = "LGPL" # 2.1 or later
description = "collection of low-level Python packages and modules used by Logilab projects"
web = "http://www.logilab.org/project/%s" % distname
mailinglist = "mailto://python-projects@lists.logilab.org"
@@ -37,19 +37,21 @@ author_email = "contact@logilab.fr"
from os.path import join
-scripts = [join('bin', 'logilab-pytest')]
-include_dirs = [join('test', 'data')]
+
+scripts = [join("bin", "logilab-pytest")]
+include_dirs = [join("test", "data")]
install_requires = []
-tests_require = ['pytz']
+tests_require = ["pytz"]
if sys.version_info < (2, 7):
- install_requires.append('unittest2 >= 0.5.1')
-if os.name == 'nt':
- install_requires.append('colorama')
-
-classifiers = ["Topic :: Utilities",
- "Programming Language :: Python",
- "Programming Language :: Python :: 2",
- "Programming Language :: Python :: 3",
- ]
+ install_requires.append("unittest2 >= 0.5.1")
+if os.name == "nt":
+ install_requires.append("colorama")
+
+classifiers = [
+ "Topic :: Utilities",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 2",
+ "Programming Language :: Python :: 3",
+]
diff --git a/test/data/deprecation.py b/test/data/deprecation.py
index be3b103..3ef3c0f 100644
--- a/test/data/deprecation.py
+++ b/test/data/deprecation.py
@@ -1,4 +1,5 @@
# placeholder used by unittest_deprecation
+
def moving_target():
pass
diff --git a/test/data/lmfp/foo.py b/test/data/lmfp/foo.py
index 8f7de1e..841ea47 100644
--- a/test/data/lmfp/foo.py
+++ b/test/data/lmfp/foo.py
@@ -1,5 +1,6 @@
import sys
-if not getattr(sys, 'bar', None):
+
+if not getattr(sys, "bar", None):
sys.just_once = []
# there used to be two numbers here because
# of a load_module_from_path bug
diff --git a/test/data/module.py b/test/data/module.py
index 493e676..3b83811 100644
--- a/test/data/module.py
+++ b/test/data/module.py
@@ -21,11 +21,14 @@ def global_access(key, val):
else:
break
else:
- print('!!!')
+ print("!!!")
+
class YO:
"""hehe"""
- a=1
+
+ a = 1
+
def __init__(self):
try:
self.yo = 1
@@ -36,7 +39,8 @@ class YO:
except:
raise
-#print('*****>',YO.__dict__)
+
+# print('*****>',YO.__dict__)
class YOUPI(YO):
class_attr = None
@@ -51,19 +55,21 @@ class YOUPI(YO):
local = None
autre = [a for a, b in MY_DICT if b]
if b in autre:
- print('yo', end=' ')
+ print("yo", end=" ")
elif a in autre:
- print('hehe')
+ print("hehe")
global_access(local, val=autre)
finally:
return local
def static_method():
"""static method test"""
- assert MY_DICT, '???'
+ assert MY_DICT, "???"
+
static_method = staticmethod(static_method)
def class_method(cls):
"""class method test"""
exec(a, b)
+
class_method = classmethod(class_method)
diff --git a/test/data/module2.py b/test/data/module2.py
index 51509f3..7192904 100644
--- a/test/data/module2.py
+++ b/test/data/module2.py
@@ -1,51 +1,76 @@
from data.module import YO, YOUPI
import data
-class Specialization(YOUPI, YO): pass
-class Metaclass(type): pass
+class Specialization(YOUPI, YO):
+ pass
-class Interface: pass
-class MyIFace(Interface): pass
+class Metaclass(type):
+ pass
-class AnotherIFace(Interface): pass
-class MyException(Exception): pass
-class MyError(MyException): pass
+class Interface:
+ pass
-class AbstractClass(object):
+class MyIFace(Interface):
+ pass
+
+
+class AnotherIFace(Interface):
+ pass
+
+
+class MyException(Exception):
+ pass
+
+
+class MyError(MyException):
+ pass
+
+
+class AbstractClass(object):
def to_override(self, whatever):
raise NotImplementedError()
def return_something(self, param):
if param:
- return 'toto'
+ return "toto"
return
+
class Concrete0:
__implements__ = MyIFace
+
+
class Concrete1:
__implements__ = MyIFace, AnotherIFace
+
+
class Concrete2:
- __implements__ = (MyIFace,
- AnotherIFace)
-class Concrete23(Concrete1): pass
+ __implements__ = (MyIFace, AnotherIFace)
+
+
+class Concrete23(Concrete1):
+ pass
+
del YO.member
del YO
[SYN1, SYN2] = Concrete0, Concrete1
-assert '1'
+assert "1"
b = 1 | 2 & 3 ^ 8
-exec('c = 3')
-exec('c = 3', {}, {})
+exec("c = 3")
+exec("c = 3", {}, {})
+
def raise_string(a=2, *args, **kwargs):
- raise 'pas glop'
- raise Exception('yo')
- yield 'coucou'
+ raise "pas glop"
+ raise Exception("yo")
+ yield "coucou"
+
a = b + 2
c = b * 2
@@ -66,12 +91,14 @@ e = d[a:b:c]
raise_string(*args, **kwargs)
-print >> stream, 'bonjour'
-print >> stream, 'salut',
+print >> stream, "bonjour"
+print >> stream, "salut",
def make_class(any, base=data.module.YO, *args, **kwargs):
"""check base is correctly resolved to Concrete0"""
+
class Aaaa(base):
"""dynamic class"""
+
return Aaaa
diff --git a/test/data/noendingnewline.py b/test/data/noendingnewline.py
index 110f902..f309715 100644
--- a/test/data/noendingnewline.py
+++ b/test/data/noendingnewline.py
@@ -4,11 +4,9 @@ import unittest
class TestCase(unittest.TestCase):
-
def setUp(self):
unittest.TestCase.setUp(self)
-
def tearDown(self):
unittest.TestCase.tearDown(self)
@@ -16,11 +14,10 @@ class TestCase(unittest.TestCase):
self.a = 10
self.xxx()
-
def xxx(self):
if False:
pass
- print('a')
+ print("a")
if False:
pass
@@ -28,9 +25,9 @@ class TestCase(unittest.TestCase):
if False:
pass
- print('rara')
+ print("rara")
-if __name__ == '__main__':
- print('test2')
+if __name__ == "__main__":
+ print("test2")
unittest.main()
diff --git a/test/data/nonregr.py b/test/data/nonregr.py
index a4b5ef7..a8747a2 100644
--- a/test/data/nonregr.py
+++ b/test/data/nonregr.py
@@ -11,6 +11,7 @@ except NameError:
yield i, val
i += 1
+
def toto(value):
for k, v in value:
- print(v.get('yo'))
+ print(v.get("yo"))
diff --git a/test/data/regobjects.py b/test/data/regobjects.py
index 6cea558..4ad0e94 100644
--- a/test/data/regobjects.py
+++ b/test/data/regobjects.py
@@ -1,22 +1,29 @@
"""unittest_registry data file"""
from logilab.common.registry import yes, RegistrableObject, RegistrableInstance
+
class Proxy(object):
"""annoying object should that not be registered, nor cause error"""
+
def __getattr__(self, attr):
return 1
+
trap = Proxy()
+
class AppObjectClass(RegistrableObject):
- __registry__ = 'zereg'
- __regid__ = 'appobject1'
+ __registry__ = "zereg"
+ __regid__ = "appobject1"
__select__ = yes()
+
class AppObjectInstance(RegistrableInstance):
- __registry__ = 'zereg'
+ __registry__ = "zereg"
__select__ = yes()
+
def __init__(self, regid):
self.__regid__ = regid
-appobject2 = AppObjectInstance('appobject2')
+
+appobject2 = AppObjectInstance("appobject2")
diff --git a/test/data/regobjects2.py b/test/data/regobjects2.py
index 091b9f7..b6d5781 100644
--- a/test/data/regobjects2.py
+++ b/test/data/regobjects2.py
@@ -1,8 +1,10 @@
from logilab.common.registry import RegistrableObject, RegistrableInstance, yes
+
class MyRegistrableInstance(RegistrableInstance):
- __regid__ = 'appobject3'
+ __regid__ = "appobject3"
__select__ = yes()
- __registry__ = 'zereg'
+ __registry__ = "zereg"
+
instance = MyRegistrableInstance(__module__=__name__)
diff --git a/test/data/sub/momo.py b/test/data/sub/momo.py
index 746b5d0..ecf4ab5 100644
--- a/test/data/sub/momo.py
+++ b/test/data/sub/momo.py
@@ -1,3 +1,3 @@
from __future__ import print_function
-print('yo')
+print("yo")
diff --git a/test/test_cache.py b/test/test_cache.py
index 459f172..8e169c4 100644
--- a/test/test_cache.py
+++ b/test/test_cache.py
@@ -20,109 +20,112 @@
from logilab.common.testlib import TestCase, unittest_main, TestSuite
from logilab.common.cache import Cache
-class CacheTestCase(TestCase):
+class CacheTestCase(TestCase):
def setUp(self):
self.cache = Cache(5)
self.testdict = {}
def test_setitem1(self):
"""Checks that the setitem method works"""
- self.cache[1] = 'foo'
- self.assertEqual(self.cache[1], 'foo', "1:foo is not in cache")
+ self.cache[1] = "foo"
+ self.assertEqual(self.cache[1], "foo", "1:foo is not in cache")
self.assertEqual(len(self.cache._usage), 1)
- self.assertEqual(self.cache._usage[-1], 1,
- '1 is not the most recently used key')
- self.assertCountEqual(self.cache._usage,
- self.cache.keys(),
- "usage list and data keys are different")
+ self.assertEqual(self.cache._usage[-1], 1, "1 is not the most recently used key")
+ self.assertCountEqual(
+ self.cache._usage, self.cache.keys(), "usage list and data keys are different"
+ )
def test_setitem2(self):
"""Checks that the setitem method works for multiple items"""
- self.cache[1] = 'foo'
- self.cache[2] = 'bar'
- self.assertEqual(self.cache[2], 'bar',
- "2 : 'bar' is not in cache.data")
- self.assertEqual(len(self.cache._usage), 2,
- "lenght of usage list is not 2")
- self.assertEqual(self.cache._usage[-1], 2,
- '1 is not the most recently used key')
- self.assertCountEqual(self.cache._usage,
- self.cache.keys())# usage list and data keys are different
+ self.cache[1] = "foo"
+ self.cache[2] = "bar"
+ self.assertEqual(self.cache[2], "bar", "2 : 'bar' is not in cache.data")
+ self.assertEqual(len(self.cache._usage), 2, "lenght of usage list is not 2")
+ self.assertEqual(self.cache._usage[-1], 2, "1 is not the most recently used key")
+ self.assertCountEqual(
+ self.cache._usage, self.cache.keys()
+ ) # usage list and data keys are different
def test_setitem3(self):
"""Checks that the setitem method works when replacing an element in the cache"""
- self.cache[1] = 'foo'
- self.cache[1] = 'bar'
- self.assertEqual(self.cache[1], 'bar', "1 : 'bar' is not in cache.data")
+ self.cache[1] = "foo"
+ self.cache[1] = "bar"
+ self.assertEqual(self.cache[1], "bar", "1 : 'bar' is not in cache.data")
self.assertEqual(len(self.cache._usage), 1, "lenght of usage list is not 1")
- self.assertEqual(self.cache._usage[-1], 1, '1 is not the most recently used key')
- self.assertCountEqual(self.cache._usage,
- self.cache.keys())# usage list and data keys are different
+ self.assertEqual(self.cache._usage[-1], 1, "1 is not the most recently used key")
+ self.assertCountEqual(
+ self.cache._usage, self.cache.keys()
+ ) # usage list and data keys are different
def test_recycling1(self):
"""Checks the removal of old elements"""
- self.cache[1] = 'foo'
- self.cache[2] = 'bar'
- self.cache[3] = 'baz'
- self.cache[4] = 'foz'
- self.cache[5] = 'fuz'
- self.cache[6] = 'spam'
- self.assertTrue(1 not in self.cache,
- 'key 1 has not been suppressed from the cache dictionnary')
- self.assertTrue(1 not in self.cache._usage,
- 'key 1 has not been suppressed from the cache LRU list')
+ self.cache[1] = "foo"
+ self.cache[2] = "bar"
+ self.cache[3] = "baz"
+ self.cache[4] = "foz"
+ self.cache[5] = "fuz"
+ self.cache[6] = "spam"
+ self.assertTrue(
+ 1 not in self.cache, "key 1 has not been suppressed from the cache dictionnary"
+ )
+ self.assertTrue(
+ 1 not in self.cache._usage, "key 1 has not been suppressed from the cache LRU list"
+ )
self.assertEqual(len(self.cache._usage), 5, "lenght of usage list is not 5")
- self.assertEqual(self.cache._usage[-1], 6, '6 is not the most recently used key')
- self.assertCountEqual(self.cache._usage,
- self.cache.keys())# usage list and data keys are different
+ self.assertEqual(self.cache._usage[-1], 6, "6 is not the most recently used key")
+ self.assertCountEqual(
+ self.cache._usage, self.cache.keys()
+ ) # usage list and data keys are different
def test_recycling2(self):
"""Checks that accessed elements get in the front of the list"""
- self.cache[1] = 'foo'
- self.cache[2] = 'bar'
- self.cache[3] = 'baz'
- self.cache[4] = 'foz'
+ self.cache[1] = "foo"
+ self.cache[2] = "bar"
+ self.cache[3] = "baz"
+ self.cache[4] = "foz"
a = self.cache[1]
- self.assertEqual(a, 'foo')
- self.assertEqual(self.cache._usage[-1], 1, '1 is not the most recently used key')
- self.assertCountEqual(self.cache._usage,
- self.cache.keys())# usage list and data keys are different
+ self.assertEqual(a, "foo")
+ self.assertEqual(self.cache._usage[-1], 1, "1 is not the most recently used key")
+ self.assertCountEqual(
+ self.cache._usage, self.cache.keys()
+ ) # usage list and data keys are different
def test_delitem(self):
"""Checks that elements are removed from both element dict and element
list.
"""
- self.cache['foo'] = 'bar'
- del self.cache['foo']
- self.assertTrue('foo' not in self.cache.keys(), "Element 'foo' was not removed cache dictionnary")
- self.assertTrue('foo' not in self.cache._usage, "Element 'foo' was not removed usage list")
- self.assertCountEqual(self.cache._usage,
- self.cache.keys())# usage list and data keys are different
-
+ self.cache["foo"] = "bar"
+ del self.cache["foo"]
+ self.assertTrue(
+ "foo" not in self.cache.keys(), "Element 'foo' was not removed cache dictionnary"
+ )
+ self.assertTrue("foo" not in self.cache._usage, "Element 'foo' was not removed usage list")
+ self.assertCountEqual(
+ self.cache._usage, self.cache.keys()
+ ) # usage list and data keys are different
def test_nullsize(self):
"""Checks that a 'NULL' size cache doesn't store anything
"""
null_cache = Cache(0)
- null_cache['foo'] = 'bar'
- self.assertEqual(null_cache.size, 0, 'Cache size should be O, not %d' % \
- null_cache.size)
- self.assertEqual(len(null_cache), 0, 'Cache should be empty !')
+ null_cache["foo"] = "bar"
+ self.assertEqual(null_cache.size, 0, "Cache size should be O, not %d" % null_cache.size)
+ self.assertEqual(len(null_cache), 0, "Cache should be empty !")
# Assert null_cache['foo'] raises a KeyError
- self.assertRaises(KeyError, null_cache.__getitem__, 'foo')
+ self.assertRaises(KeyError, null_cache.__getitem__, "foo")
# Deleting element raises a KeyError
- self.assertRaises(KeyError, null_cache.__delitem__, 'foo')
+ self.assertRaises(KeyError, null_cache.__delitem__, "foo")
def test_getitem(self):
""" Checks that getitem doest not modify the _usage attribute
"""
try:
- self.cache['toto']
+ self.cache["toto"]
except KeyError:
- self.assertTrue('toto' not in self.cache._usage)
+ self.assertTrue("toto" not in self.cache._usage)
else:
- self.fail('excepted KeyError')
+ self.fail("excepted KeyError")
if __name__ == "__main__":
diff --git a/test/test_changelog.py b/test/test_changelog.py
index c2572d7..c251311 100644
--- a/test/test_changelog.py
+++ b/test/test_changelog.py
@@ -26,7 +26,7 @@ from logilab.common.changelog import ChangeLog
class ChangeLogTC(TestCase):
cl_class = ChangeLog
- cl_file = join(dirname(__file__), 'data', 'ChangeLog')
+ cl_file = join(dirname(__file__), "data", "ChangeLog")
def test_round_trip(self):
cl = self.cl_class(self.cl_file)
@@ -36,5 +36,5 @@ class ChangeLogTC(TestCase):
self.assertMultiLineEqual(stream.read(), out.getvalue())
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_configuration.py b/test/test_configuration.py
index 2dee7d0..d8d0a4b 100644
--- a/test/test_configuration.py
+++ b/test/test_configuration.py
@@ -26,109 +26,137 @@ from logilab.common import attrdict
from logilab.common.compat import StringIO
from logilab.common.testlib import TestCase, unittest_main
from logilab.common.optik_ext import OptionValueError
-from logilab.common.configuration import Configuration, OptionError, \
- OptionsManagerMixIn, OptionsProviderMixIn, Method, read_old_config, \
- merge_options
-
-DATA = join(dirname(abspath(__file__)), 'data')
-
-OPTIONS = [('dothis', {'type':'yn', 'action': 'store', 'default': True, 'metavar': '<y or n>'}),
- ('value', {'type': 'string', 'metavar': '<string>', 'short': 'v'}),
- ('multiple', {'type': 'csv', 'default': ['yop', 'yep'],
- 'metavar': '<comma separated values>',
- 'help': 'you can also document the option'}),
- ('number', {'type': 'int', 'default':2, 'metavar':'<int>', 'help': 'boom'}),
- ('bytes', {'type': 'bytes', 'default':'1KB', 'metavar':'<bytes>'}),
- ('choice', {'type': 'choice', 'default':'yo', 'choices': ('yo', 'ye'),
- 'metavar':'<yo|ye>'}),
- ('multiple-choice', {'type': 'multiple_choice', 'default':['yo', 'ye'],
- 'choices': ('yo', 'ye', 'yu', 'yi', 'ya'),
- 'metavar':'<yo|ye>'}),
- ('named', {'type':'named', 'default':Method('get_named'),
- 'metavar': '<key=val>'}),
-
- ('diffgroup', {'type':'string', 'default':'pouet', 'metavar': '<key=val>',
- 'group': 'agroup'}),
- ('reset-value', {'type': 'string', 'metavar': '<string>', 'short': 'r',
- 'dest':'value'}),
-
- ('opt-b-1', {'type': 'string', 'metavar': '<string>', 'group': 'bgroup'}),
- ('opt-b-2', {'type': 'string', 'metavar': '<string>', 'group': 'bgroup'}),
- ]
+from logilab.common.configuration import (
+ Configuration,
+ OptionError,
+ OptionsManagerMixIn,
+ OptionsProviderMixIn,
+ Method,
+ read_old_config,
+ merge_options,
+)
+
+DATA = join(dirname(abspath(__file__)), "data")
+
+OPTIONS = [
+ ("dothis", {"type": "yn", "action": "store", "default": True, "metavar": "<y or n>"}),
+ ("value", {"type": "string", "metavar": "<string>", "short": "v"}),
+ (
+ "multiple",
+ {
+ "type": "csv",
+ "default": ["yop", "yep"],
+ "metavar": "<comma separated values>",
+ "help": "you can also document the option",
+ },
+ ),
+ ("number", {"type": "int", "default": 2, "metavar": "<int>", "help": "boom"}),
+ ("bytes", {"type": "bytes", "default": "1KB", "metavar": "<bytes>"}),
+ ("choice", {"type": "choice", "default": "yo", "choices": ("yo", "ye"), "metavar": "<yo|ye>"}),
+ (
+ "multiple-choice",
+ {
+ "type": "multiple_choice",
+ "default": ["yo", "ye"],
+ "choices": ("yo", "ye", "yu", "yi", "ya"),
+ "metavar": "<yo|ye>",
+ },
+ ),
+ ("named", {"type": "named", "default": Method("get_named"), "metavar": "<key=val>"}),
+ (
+ "diffgroup",
+ {"type": "string", "default": "pouet", "metavar": "<key=val>", "group": "agroup"},
+ ),
+ ("reset-value", {"type": "string", "metavar": "<string>", "short": "r", "dest": "value"}),
+ ("opt-b-1", {"type": "string", "metavar": "<string>", "group": "bgroup"}),
+ ("opt-b-2", {"type": "string", "metavar": "<string>", "group": "bgroup"}),
+]
+
class MyConfiguration(Configuration):
"""test configuration"""
+
def get_named(self):
- return {'key': 'val'}
+ return {"key": "val"}
-class ConfigurationTC(TestCase):
+class ConfigurationTC(TestCase):
def setUp(self):
- self.cfg = MyConfiguration(name='test', options=OPTIONS, usage='Just do it ! (tm)')
+ self.cfg = MyConfiguration(name="test", options=OPTIONS, usage="Just do it ! (tm)")
def test_default(self):
cfg = self.cfg
- self.assertEqual(cfg['dothis'], True)
- self.assertEqual(cfg['value'], None)
- self.assertEqual(cfg['multiple'], ['yop', 'yep'])
- self.assertEqual(cfg['number'], 2)
- self.assertEqual(cfg['bytes'], 1024)
- self.assertIsInstance(cfg['bytes'], int)
- self.assertEqual(cfg['choice'], 'yo')
- self.assertEqual(cfg['multiple-choice'], ['yo', 'ye'])
- self.assertEqual(cfg['named'], {'key': 'val'})
+ self.assertEqual(cfg["dothis"], True)
+ self.assertEqual(cfg["value"], None)
+ self.assertEqual(cfg["multiple"], ["yop", "yep"])
+ self.assertEqual(cfg["number"], 2)
+ self.assertEqual(cfg["bytes"], 1024)
+ self.assertIsInstance(cfg["bytes"], int)
+ self.assertEqual(cfg["choice"], "yo")
+ self.assertEqual(cfg["multiple-choice"], ["yo", "ye"])
+ self.assertEqual(cfg["named"], {"key": "val"})
def test_base(self):
cfg = self.cfg
- cfg.set_option('number', '0')
- self.assertEqual(cfg['number'], 0)
- self.assertRaises(OptionValueError, cfg.set_option, 'number', 'youpi')
- self.assertRaises(OptionValueError, cfg.set_option, 'choice', 'youpi')
- self.assertRaises(OptionValueError, cfg.set_option, 'multiple-choice', ('yo', 'y', 'ya'))
- cfg.set_option('multiple-choice', 'yo, ya')
- self.assertEqual(cfg['multiple-choice'], ['yo', 'ya'])
- self.assertEqual(cfg.get('multiple-choice'), ['yo', 'ya'])
- self.assertEqual(cfg.get('whatever'), None)
+ cfg.set_option("number", "0")
+ self.assertEqual(cfg["number"], 0)
+ self.assertRaises(OptionValueError, cfg.set_option, "number", "youpi")
+ self.assertRaises(OptionValueError, cfg.set_option, "choice", "youpi")
+ self.assertRaises(OptionValueError, cfg.set_option, "multiple-choice", ("yo", "y", "ya"))
+ cfg.set_option("multiple-choice", "yo, ya")
+ self.assertEqual(cfg["multiple-choice"], ["yo", "ya"])
+ self.assertEqual(cfg.get("multiple-choice"), ["yo", "ya"])
+ self.assertEqual(cfg.get("whatever"), None)
def test_load_command_line_configuration(self):
cfg = self.cfg
- args = cfg.load_command_line_configuration(['--choice', 'ye', '--number', '4',
- '--multiple=1,2,3', '--dothis=n',
- '--bytes=10KB',
- 'other', 'arguments'])
- self.assertEqual(args, ['other', 'arguments'])
- self.assertEqual(cfg['dothis'], False)
- self.assertEqual(cfg['multiple'], ['1', '2', '3'])
- self.assertEqual(cfg['number'], 4)
- self.assertEqual(cfg['bytes'], 10240)
- self.assertEqual(cfg['choice'], 'ye')
- self.assertEqual(cfg['value'], None)
- args = cfg.load_command_line_configuration(['-v', 'duh'])
+ args = cfg.load_command_line_configuration(
+ [
+ "--choice",
+ "ye",
+ "--number",
+ "4",
+ "--multiple=1,2,3",
+ "--dothis=n",
+ "--bytes=10KB",
+ "other",
+ "arguments",
+ ]
+ )
+ self.assertEqual(args, ["other", "arguments"])
+ self.assertEqual(cfg["dothis"], False)
+ self.assertEqual(cfg["multiple"], ["1", "2", "3"])
+ self.assertEqual(cfg["number"], 4)
+ self.assertEqual(cfg["bytes"], 10240)
+ self.assertEqual(cfg["choice"], "ye")
+ self.assertEqual(cfg["value"], None)
+ args = cfg.load_command_line_configuration(["-v", "duh"])
self.assertEqual(args, [])
- self.assertEqual(cfg['value'], 'duh')
- self.assertEqual(cfg['dothis'], False)
- self.assertEqual(cfg['multiple'], ['1', '2', '3'])
- self.assertEqual(cfg['number'], 4)
- self.assertEqual(cfg['bytes'], 10240)
- self.assertEqual(cfg['choice'], 'ye')
+ self.assertEqual(cfg["value"], "duh")
+ self.assertEqual(cfg["dothis"], False)
+ self.assertEqual(cfg["multiple"], ["1", "2", "3"])
+ self.assertEqual(cfg["number"], 4)
+ self.assertEqual(cfg["bytes"], 10240)
+ self.assertEqual(cfg["choice"], "ye")
def test_load_configuration(self):
cfg = self.cfg
- args = cfg.load_configuration(choice='ye', number='4',
- multiple='1,2,3', dothis='n',
- multiple_choice=('yo', 'ya'))
- self.assertEqual(cfg['dothis'], False)
- self.assertEqual(cfg['multiple'], ['1', '2', '3'])
- self.assertEqual(cfg['number'], 4)
- self.assertEqual(cfg['choice'], 'ye')
- self.assertEqual(cfg['value'], None)
- self.assertEqual(cfg['multiple-choice'], ('yo', 'ya'))
+ args = cfg.load_configuration(
+ choice="ye", number="4", multiple="1,2,3", dothis="n", multiple_choice=("yo", "ya")
+ )
+ self.assertEqual(cfg["dothis"], False)
+ self.assertEqual(cfg["multiple"], ["1", "2", "3"])
+ self.assertEqual(cfg["number"], 4)
+ self.assertEqual(cfg["choice"], "ye")
+ self.assertEqual(cfg["value"], None)
+ self.assertEqual(cfg["multiple-choice"], ("yo", "ya"))
def test_load_configuration_file_case_insensitive(self):
file = tempfile.mktemp()
- stream = open(file, 'w')
+ stream = open(file, "w")
try:
- stream.write("""[Test]
+ stream.write(
+ """[Test]
dothis=no
@@ -152,13 +180,14 @@ named=key:val
[agroup]
diffgroup=zou
-""")
+"""
+ )
stream.close()
self.cfg.load_file_configuration(file)
- self.assertEqual(self.cfg['dothis'], False)
- self.assertEqual(self.cfg['value'], None)
- self.assertEqual(self.cfg['multiple'], ['yop', 'yepii'])
- self.assertEqual(self.cfg['diffgroup'], 'zou')
+ self.assertEqual(self.cfg["dothis"], False)
+ self.assertEqual(self.cfg["value"], None)
+ self.assertEqual(self.cfg["multiple"], ["yop", "yepii"])
+ self.assertEqual(self.cfg["diffgroup"], "zou")
finally:
os.remove(file)
@@ -167,37 +196,43 @@ diffgroup=zou
and not in the order they are defined in the Configuration object.
"""
file = tempfile.mktemp()
- stream = open(file, 'w')
+ stream = open(file, "w")
try:
- stream.write("""[Test]
+ stream.write(
+ """[Test]
reset-value=toto
value=tata
-""")
+"""
+ )
stream.close()
self.cfg.load_file_configuration(file)
finally:
os.remove(file)
- self.assertEqual(self.cfg['value'], 'tata')
+ self.assertEqual(self.cfg["value"], "tata")
def test_unsupported_options(self):
file = tempfile.mktemp()
- stream = open(file, 'w')
+ stream = open(file, "w")
try:
- stream.write("""[Test]
+ stream.write(
+ """[Test]
whatever=toto
value=tata
-""")
+"""
+ )
stream.close()
self.cfg.load_file_configuration(file)
finally:
os.remove(file)
- self.assertEqual(self.cfg['value'], 'tata')
- self.assertRaises(OptionError, self.cfg.__getitem__, 'whatever')
+ self.assertEqual(self.cfg["value"], "tata")
+ self.assertRaises(OptionError, self.cfg.__getitem__, "whatever")
def test_generate_config(self):
stream = StringIO()
self.cfg.generate_config(stream)
- self.assertMultiLineEqual(stream.getvalue().strip(), """[TEST]
+ self.assertMultiLineEqual(
+ stream.getvalue().strip(),
+ """[TEST]
dothis=yes
@@ -229,13 +264,16 @@ diffgroup=pouet
#opt-b-1=
-#opt-b-2=""")
+#opt-b-2=""",
+ )
def test_generate_config_with_space_string(self):
- self.cfg['value'] = ' '
+ self.cfg["value"] = " "
stream = StringIO()
self.cfg.generate_config(stream)
- self.assertMultiLineEqual(stream.getvalue().strip(), """[TEST]
+ self.assertMultiLineEqual(
+ stream.getvalue().strip(),
+ """[TEST]
dothis=yes
@@ -267,13 +305,16 @@ diffgroup=pouet
#opt-b-1=
-#opt-b-2=""")
+#opt-b-2=""",
+ )
def test_generate_config_with_multiline_string(self):
- self.cfg['value'] = 'line1\nline2\nline3'
+ self.cfg["value"] = "line1\nline2\nline3"
stream = StringIO()
self.cfg.generate_config(stream)
- self.assertMultiLineEqual(stream.getvalue().strip(), """[TEST]
+ self.assertMultiLineEqual(
+ stream.getvalue().strip(),
+ """[TEST]
dothis=yes
@@ -311,47 +352,46 @@ diffgroup=pouet
#opt-b-1=
-#opt-b-2=""")
-
+#opt-b-2=""",
+ )
def test_roundtrip(self):
cfg = self.cfg
f = tempfile.mktemp()
- stream = open(f, 'w')
+ stream = open(f, "w")
try:
- self.cfg['dothis'] = False
- self.cfg['multiple'] = ["toto", "tata"]
- self.cfg['number'] = 3
- self.cfg['bytes'] = 2048
+ self.cfg["dothis"] = False
+ self.cfg["multiple"] = ["toto", "tata"]
+ self.cfg["number"] = 3
+ self.cfg["bytes"] = 2048
cfg.generate_config(stream)
stream.close()
- new_cfg = MyConfiguration(name='test', options=OPTIONS)
+ new_cfg = MyConfiguration(name="test", options=OPTIONS)
new_cfg.load_file_configuration(f)
- self.assertEqual(cfg['dothis'], new_cfg['dothis'])
- self.assertEqual(cfg['multiple'], new_cfg['multiple'])
- self.assertEqual(cfg['number'], new_cfg['number'])
- self.assertEqual(cfg['bytes'], new_cfg['bytes'])
- self.assertEqual(cfg['choice'], new_cfg['choice'])
- self.assertEqual(cfg['value'], new_cfg['value'])
- self.assertEqual(cfg['multiple-choice'], new_cfg['multiple-choice'])
+ self.assertEqual(cfg["dothis"], new_cfg["dothis"])
+ self.assertEqual(cfg["multiple"], new_cfg["multiple"])
+ self.assertEqual(cfg["number"], new_cfg["number"])
+ self.assertEqual(cfg["bytes"], new_cfg["bytes"])
+ self.assertEqual(cfg["choice"], new_cfg["choice"])
+ self.assertEqual(cfg["value"], new_cfg["value"])
+ self.assertEqual(cfg["multiple-choice"], new_cfg["multiple-choice"])
finally:
os.remove(f)
def test_setitem(self):
- self.assertRaises(OptionValueError,
- self.cfg.__setitem__, 'multiple-choice', ('a', 'b'))
- self.cfg['multiple-choice'] = ('yi', 'ya')
- self.assertEqual(self.cfg['multiple-choice'], ('yi', 'ya'))
+ self.assertRaises(OptionValueError, self.cfg.__setitem__, "multiple-choice", ("a", "b"))
+ self.cfg["multiple-choice"] = ("yi", "ya")
+ self.assertEqual(self.cfg["multiple-choice"], ("yi", "ya"))
def test_help(self):
- self.cfg.add_help_section('bonus', 'a nice additional help')
+ self.cfg.add_help_section("bonus", "a nice additional help")
help = self.cfg.help().strip()
# at least in python 2.4.2 the output is:
# ' -v <string>, --value=<string>'
# it is not unlikely some optik/optparse versions do print -v<string>
# so accept both
- help = help.replace(' -v <string>, ', ' -v<string>, ')
- help = re.sub('[ ]*(\r?\n)', '\\1', help)
+ help = help.replace(" -v <string>, ", " -v<string>, ")
+ help = re.sub("[ ]*(\r?\n)", "\\1", help)
USAGE = """Usage: Just do it ! (tm)
Options:
@@ -378,7 +418,7 @@ Options:
a nice additional help"""
if version_info < (2, 5):
# 'usage' header is not capitalized in this version
- USAGE = USAGE.replace('Usage: ', 'usage: ')
+ USAGE = USAGE.replace("Usage: ", "usage: ")
elif version_info < (2, 4):
USAGE = """usage: Just do it ! (tm)
@@ -398,21 +438,23 @@ options:
"""
self.assertMultiLineEqual(help, USAGE)
-
def test_manpage(self):
pkginfo = {}
- with open(join(DATA, '__pkginfo__.py')) as fobj:
+ with open(join(DATA, "__pkginfo__.py")) as fobj:
exec(fobj.read(), pkginfo)
self.cfg.generate_manpage(attrdict(pkginfo), stream=StringIO())
def test_rewrite_config(self):
- changes = [('renamed', 'renamed', 'choice'),
- ('moved', 'named', 'old', 'test'),
- ]
- read_old_config(self.cfg, changes, join(DATA, 'test.ini'))
+ changes = [
+ ("renamed", "renamed", "choice"),
+ ("moved", "named", "old", "test"),
+ ]
+ read_old_config(self.cfg, changes, join(DATA, "test.ini"))
stream = StringIO()
self.cfg.generate_config(stream)
- self.assertMultiLineEqual(stream.getvalue().strip(), """[TEST]
+ self.assertMultiLineEqual(
+ stream.getvalue().strip(),
+ """[TEST]
dothis=yes
@@ -444,22 +486,26 @@ diffgroup=pouet
#opt-b-1=
-#opt-b-2=""")
+#opt-b-2=""",
+ )
+
class Linter(OptionsManagerMixIn, OptionsProviderMixIn):
options = (
- ('profile', {'type' : 'yn', 'metavar' : '<y_or_n>',
- 'default': False,
- 'help' : 'Profiled execution.'}),
- )
+ (
+ "profile",
+ {"type": "yn", "metavar": "<y_or_n>", "default": False, "help": "Profiled execution."},
+ ),
+ )
+
def __init__(self):
OptionsManagerMixIn.__init__(self, usage="")
OptionsProviderMixIn.__init__(self)
self.register_options_provider(self)
self.load_provider_defaults()
-class RegrTC(TestCase):
+class RegrTC(TestCase):
def setUp(self):
self.linter = Linter()
@@ -472,36 +518,51 @@ class RegrTC(TestCase):
config = Configuration()
self.assertEqual(config.options, ())
new_options = (
- ('option1', {'type': 'string', 'help': '',
- 'group': 'g1', 'level': 2}),
- ('option2', {'type': 'string', 'help': '',
- 'group': 'g1', 'level': 2}),
- ('option3', {'type': 'string', 'help': '',
- 'group': 'g2', 'level': 2}),
+ ("option1", {"type": "string", "help": "", "group": "g1", "level": 2}),
+ ("option2", {"type": "string", "help": "", "group": "g1", "level": 2}),
+ ("option3", {"type": "string", "help": "", "group": "g2", "level": 2}),
)
config.register_options(new_options)
self.assertEqual(config.options, new_options)
class MergeTC(TestCase):
-
def test_merge1(self):
- merged = merge_options([('dothis', {'type':'yn', 'action': 'store', 'default': True, 'metavar': '<y or n>'}),
- ('dothis', {'type':'yn', 'action': 'store', 'default': False, 'metavar': '<y or n>'}),
- ])
+ merged = merge_options(
+ [
+ (
+ "dothis",
+ {"type": "yn", "action": "store", "default": True, "metavar": "<y or n>"},
+ ),
+ (
+ "dothis",
+ {"type": "yn", "action": "store", "default": False, "metavar": "<y or n>"},
+ ),
+ ]
+ )
self.assertEqual(len(merged), 1)
- self.assertEqual(merged[0][0], 'dothis')
- self.assertEqual(merged[0][1]['default'], True)
+ self.assertEqual(merged[0][0], "dothis")
+ self.assertEqual(merged[0][1]["default"], True)
def test_merge2(self):
- merged = merge_options([('dothis', {'type':'yn', 'action': 'store', 'default': True, 'metavar': '<y or n>'}),
- ('value', {'type': 'string', 'metavar': '<string>', 'short': 'v'}),
- ('dothis', {'type':'yn', 'action': 'store', 'default': False, 'metavar': '<y or n>'}),
- ])
+ merged = merge_options(
+ [
+ (
+ "dothis",
+ {"type": "yn", "action": "store", "default": True, "metavar": "<y or n>"},
+ ),
+ ("value", {"type": "string", "metavar": "<string>", "short": "v"}),
+ (
+ "dothis",
+ {"type": "yn", "action": "store", "default": False, "metavar": "<y or n>"},
+ ),
+ ]
+ )
self.assertEqual(len(merged), 2)
- self.assertEqual(merged[0][0], 'value')
- self.assertEqual(merged[1][0], 'dothis')
- self.assertEqual(merged[1][1]['default'], True)
+ self.assertEqual(merged[0][0], "value")
+ self.assertEqual(merged[1][0], "dothis")
+ self.assertEqual(merged[1][1]["default"], True)
+
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_date.py b/test/test_date.py
index 9ae444b..cf09b11 100644
--- a/test/test_date.py
+++ b/test/test_date.py
@@ -20,20 +20,34 @@ Unittests for date helpers
"""
from logilab.common.testlib import TestCase, unittest_main, tag
-from logilab.common.date import (date_range, endOfMonth, add_days_worked,
- nb_open_days, get_national_holidays, ustrftime, ticks2datetime,
- utcdatetime, datetime2ticks)
+from logilab.common.date import (
+ date_range,
+ endOfMonth,
+ add_days_worked,
+ nb_open_days,
+ get_national_holidays,
+ ustrftime,
+ ticks2datetime,
+ utcdatetime,
+ datetime2ticks,
+)
from datetime import date, datetime, timedelta
from calendar import timegm
import pytz
try:
- from mx.DateTime import Date as mxDate, DateTime as mxDateTime, \
- now as mxNow, RelativeDateTime, RelativeDate
+ from mx.DateTime import (
+ Date as mxDate,
+ DateTime as mxDateTime,
+ now as mxNow,
+ RelativeDateTime,
+ RelativeDate,
+ )
except ImportError:
mxDate = mxDateTime = RelativeDateTime = mxNow = None
+
class DateTC(TestCase):
datecls = date
datetimecls = datetime
@@ -65,8 +79,9 @@ class DateTC(TestCase):
def test_get_national_holidays(self):
holidays = get_national_holidays
- yield self.assertEqual, holidays(self.datecls(2008, 4, 29), self.datecls(2008, 5, 2)), \
- [self.datecls(2008, 5, 1)]
+ yield self.assertEqual, holidays(self.datecls(2008, 4, 29), self.datecls(2008, 5, 2)), [
+ self.datecls(2008, 5, 1)
+ ]
yield self.assertEqual, holidays(self.datecls(2008, 5, 7), self.datecls(2008, 5, 8)), []
x = self.datetimecls(2008, 5, 7, 12, 12, 12)
yield self.assertEqual, holidays(x, x + self.timedeltacls(days=1)), []
@@ -129,22 +144,26 @@ class DateTC(TestCase):
def test_open_days_afternoon(self):
self.assertOpenDays(self.datetimecls(2008, 5, 6, 14), self.datetimecls(2008, 5, 7, 14), 1)
- @tag('posix', '1900')
+ @tag("posix", "1900")
def test_ustrftime_before_1900(self):
date = self.datetimecls(1328, 3, 12, 6, 30)
- self.assertEqual(ustrftime(date, '%Y-%m-%d %H:%M:%S'), u'1328-03-12 06:30:00')
+ self.assertEqual(ustrftime(date, "%Y-%m-%d %H:%M:%S"), "1328-03-12 06:30:00")
- @tag('posix', '1900')
+ @tag("posix", "1900")
def test_ticks2datetime_before_1900(self):
ticks = -2209075200000
date = ticks2datetime(ticks)
- self.assertEqual(ustrftime(date, '%Y-%m-%d'), u'1899-12-31')
+ self.assertEqual(ustrftime(date, "%Y-%m-%d"), "1899-12-31")
def test_month(self):
"""enumerate months"""
- r = list(date_range(self.datecls(2006, 5, 6), self.datecls(2006, 8, 27),
- incmonth=True))
- expected = [self.datecls(2006, 5, 6), self.datecls(2006, 6, 1), self.datecls(2006, 7, 1), self.datecls(2006, 8, 1)]
+ r = list(date_range(self.datecls(2006, 5, 6), self.datecls(2006, 8, 27), incmonth=True))
+ expected = [
+ self.datecls(2006, 5, 6),
+ self.datecls(2006, 6, 1),
+ self.datecls(2006, 7, 1),
+ self.datecls(2006, 8, 1),
+ ]
self.assertListEqual(expected, r)
def test_utcdatetime(self):
@@ -155,14 +174,12 @@ class DateTC(TestCase):
self.assertEqual(d, self.datetimecls(2014, 11, 26, 12, 0, 0, 57))
self.assertIsNone(d.tzinfo)
- d = pytz.timezone('Europe/Paris').localize(
- self.datetimecls(2014, 11, 26, 12, 0, 0, 57))
+ d = pytz.timezone("Europe/Paris").localize(self.datetimecls(2014, 11, 26, 12, 0, 0, 57))
d = utcdatetime(d)
self.assertEqual(d, self.datetimecls(2014, 11, 26, 11, 0, 0, 57))
self.assertIsNone(d.tzinfo)
- d = pytz.timezone('Europe/Paris').localize(
- self.datetimecls(2014, 7, 26, 12, 0, 0, 57))
+ d = pytz.timezone("Europe/Paris").localize(self.datetimecls(2014, 7, 26, 12, 0, 0, 57))
d = utcdatetime(d)
self.assertEqual(d, self.datetimecls(2014, 7, 26, 10, 0, 0, 57))
self.assertIsNone(d.tzinfo)
@@ -188,7 +205,7 @@ class MxDateTC(DateTC):
def check_mx(self):
if mxDate is None:
- self.skipTest('mx.DateTime is not installed')
+ self.skipTest("mx.DateTime is not installed")
def setUp(self):
self.check_mx()
@@ -199,8 +216,13 @@ class MxDateTC(DateTC):
expected = [self.datecls(2000, 1, 2), self.datecls(2000, 2, 29), self.datecls(2000, 3, 31)]
self.assertListEqual(r, expected)
r = list(date_range(self.datecls(2000, 11, 30), self.datecls(2001, 2, 3), endOfMonth))
- expected = [self.datecls(2000, 11, 30), self.datecls(2000, 12, 31), self.datecls(2001, 1, 31)]
+ expected = [
+ self.datecls(2000, 11, 30),
+ self.datecls(2000, 12, 31),
+ self.datecls(2001, 1, 31),
+ ]
self.assertListEqual(r, expected)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_decorators.py b/test/test_decorators.py
index e97a56f..42d8d8f 100644
--- a/test/test_decorators.py
+++ b/test/test_decorators.py
@@ -21,20 +21,23 @@ import sys
import types
from logilab.common.testlib import TestCase, unittest_main
-from logilab.common.decorators import (monkeypatch, cached, clear_cache,
- copy_cache, cachedproperty)
+from logilab.common.decorators import monkeypatch, cached, clear_cache, copy_cache, cachedproperty
-class DecoratorsTC(TestCase):
+class DecoratorsTC(TestCase):
def test_monkeypatch_instance_method(self):
- class MyClass: pass
+ class MyClass:
+ pass
+
@monkeypatch(MyClass)
def meth1(self):
return 12
+
class XXX(object):
@monkeypatch(MyClass)
def meth2(self):
return 12
+
if sys.version_info < (3, 0):
self.assertIsInstance(MyClass.meth1, types.MethodType)
self.assertIsInstance(MyClass.meth2, types.MethodType)
@@ -46,51 +49,66 @@ class DecoratorsTC(TestCase):
self.assertEqual(MyClass().meth2(), 12)
def test_monkeypatch_property(self):
- class MyClass: pass
- @monkeypatch(MyClass, methodname='prop1')
+ class MyClass:
+ pass
+
+ @monkeypatch(MyClass, methodname="prop1")
@property
def meth1(self):
return 12
+
self.assertIsInstance(MyClass.prop1, property)
self.assertEqual(MyClass().prop1, 12)
def test_monkeypatch_arbitrary_callable(self):
- class MyClass: pass
+ class MyClass:
+ pass
+
class ArbitraryCallable(object):
def __call__(self):
return 12
+
# ensure it complains about missing __name__
with self.assertRaises(AttributeError) as cm:
monkeypatch(MyClass)(ArbitraryCallable())
- self.assertTrue(str(cm.exception).endswith('has no __name__ attribute: you should provide an explicit `methodname`'))
+ self.assertTrue(
+ str(cm.exception).endswith(
+ "has no __name__ attribute: you should provide an explicit `methodname`"
+ )
+ )
# ensure no black magic under the hood
- monkeypatch(MyClass, 'foo')(ArbitraryCallable())
+ monkeypatch(MyClass, "foo")(ArbitraryCallable())
self.assertTrue(callable(MyClass.foo))
self.assertEqual(MyClass().foo(), 12)
def test_monkeypatch_with_same_name(self):
- class MyClass: pass
+ class MyClass:
+ pass
+
@monkeypatch(MyClass)
def meth1(self):
return 12
- self.assertEqual([attr for attr in dir(MyClass) if attr[:2] != '__'],
- ['meth1'])
+
+ self.assertEqual([attr for attr in dir(MyClass) if attr[:2] != "__"], ["meth1"])
inst = MyClass()
self.assertEqual(inst.meth1(), 12)
def test_monkeypatch_with_custom_name(self):
- class MyClass: pass
- @monkeypatch(MyClass, 'foo')
+ class MyClass:
+ pass
+
+ @monkeypatch(MyClass, "foo")
def meth2(self, param):
return param + 12
- self.assertEqual([attr for attr in dir(MyClass) if attr[:2] != '__'],
- ['foo'])
+
+ self.assertEqual([attr for attr in dir(MyClass) if attr[:2] != "__"], ["foo"])
inst = MyClass()
self.assertEqual(inst.foo(4), 16)
def test_cannot_cache_generator(self):
def foo():
yield 42
+
self.assertRaises(AssertionError, cached, foo)
def test_cached_preserves_docstrings_and_name(self):
@@ -98,85 +116,95 @@ class DecoratorsTC(TestCase):
@cached
def foo(self):
""" what's up doc ? """
+
def bar(self, zogzog):
""" what's up doc ? """
+
bar = cached(bar, 1)
+
@cached
def quux(self, zogzog):
""" what's up doc ? """
+
self.assertEqual(Foo.foo.__doc__, """ what's up doc ? """)
- self.assertEqual(Foo.foo.__name__, 'foo')
+ self.assertEqual(Foo.foo.__name__, "foo")
self.assertEqual(Foo.bar.__doc__, """ what's up doc ? """)
- self.assertEqual(Foo.bar.__name__, 'bar')
+ self.assertEqual(Foo.bar.__name__, "bar")
self.assertEqual(Foo.quux.__doc__, """ what's up doc ? """)
- self.assertEqual(Foo.quux.__name__, 'quux')
+ self.assertEqual(Foo.quux.__name__, "quux")
def test_cached_single_cache(self):
class Foo(object):
- @cached(cacheattr=u'_foo')
+ @cached(cacheattr="_foo")
def foo(self):
""" what's up doc ? """
+
foo = Foo()
foo.foo()
- self.assertTrue(hasattr(foo, '_foo'))
- clear_cache(foo, 'foo')
- self.assertFalse(hasattr(foo, '_foo'))
+ self.assertTrue(hasattr(foo, "_foo"))
+ clear_cache(foo, "foo")
+ self.assertFalse(hasattr(foo, "_foo"))
def test_cached_multi_cache(self):
class Foo(object):
- @cached(cacheattr=u'_foo')
+ @cached(cacheattr="_foo")
def foo(self, args):
""" what's up doc ? """
+
foo = Foo()
foo.foo(1)
self.assertEqual(foo._foo, {(1,): None})
- clear_cache(foo, 'foo')
- self.assertFalse(hasattr(foo, '_foo'))
+ clear_cache(foo, "foo")
+ self.assertFalse(hasattr(foo, "_foo"))
def test_cached_keyarg_cache(self):
class Foo(object):
- @cached(cacheattr=u'_foo', keyarg=1)
+ @cached(cacheattr="_foo", keyarg=1)
def foo(self, other, args):
""" what's up doc ? """
+
foo = Foo()
foo.foo(2, 1)
self.assertEqual(foo._foo, {2: None})
- clear_cache(foo, 'foo')
- self.assertFalse(hasattr(foo, '_foo'))
+ clear_cache(foo, "foo")
+ self.assertFalse(hasattr(foo, "_foo"))
def test_cached_property(self):
class Foo(object):
@property
- @cached(cacheattr=u'_foo')
+ @cached(cacheattr="_foo")
def foo(self):
""" what's up doc ? """
+
foo = Foo()
foo.foo
self.assertEqual(foo._foo, None)
- clear_cache(foo, 'foo')
- self.assertFalse(hasattr(foo, '_foo'))
+ clear_cache(foo, "foo")
+ self.assertFalse(hasattr(foo, "_foo"))
def test_copy_cache(self):
class Foo(object):
- @cached(cacheattr=u'_foo')
+ @cached(cacheattr="_foo")
def foo(self, args):
""" what's up doc ? """
+
foo = Foo()
foo.foo(1)
self.assertEqual(foo._foo, {(1,): None})
foo2 = Foo()
- self.assertFalse(hasattr(foo2, '_foo'))
- copy_cache(foo2, 'foo', foo)
+ self.assertFalse(hasattr(foo2, "_foo"))
+ copy_cache(foo2, "foo", foo)
self.assertEqual(foo2._foo, {(1,): None})
-
def test_cachedproperty(self):
class Foo(object):
x = 0
+
@cachedproperty
def bar(self):
self.__class__.x += 1
return self.__class__.x
+
@cachedproperty
def quux(self):
""" some prop """
@@ -184,15 +212,13 @@ class DecoratorsTC(TestCase):
foo = Foo()
self.assertEqual(Foo.x, 0)
- self.assertFalse('bar' in foo.__dict__)
+ self.assertFalse("bar" in foo.__dict__)
self.assertEqual(foo.bar, 1)
- self.assertTrue('bar' in foo.__dict__)
+ self.assertTrue("bar" in foo.__dict__)
self.assertEqual(foo.bar, 1)
self.assertEqual(foo.quux, 42)
- self.assertEqual(Foo.bar.__doc__,
- '<wrapped by the cachedproperty decorator>')
- self.assertEqual(Foo.quux.__doc__,
- '<wrapped by the cachedproperty decorator>\n some prop ')
+ self.assertEqual(Foo.bar.__doc__, "<wrapped by the cachedproperty decorator>")
+ self.assertEqual(Foo.quux.__doc__, "<wrapped by the cachedproperty decorator>\n some prop ")
foo2 = Foo()
self.assertEqual(foo2.bar, 2)
@@ -202,7 +228,9 @@ class DecoratorsTC(TestCase):
class Kallable(object):
def __call__(self):
return 42
+
self.assertRaises(TypeError, cachedproperty, Kallable())
-if __name__ == '__main__':
+
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_fileutils.py b/test/test_fileutils.py
index 555e73f..49955cd 100644
--- a/test/test_fileutils.py
+++ b/test/test_fileutils.py
@@ -27,47 +27,48 @@ from logilab.common.testlib import TestCase, unittest_main, unittest
from logilab.common.fileutils import *
-DATA_DIR = join(os.path.abspath(os.path.dirname(__file__)), 'data')
-NEWLINES_TXT = join(DATA_DIR, 'newlines.txt')
+DATA_DIR = join(os.path.abspath(os.path.dirname(__file__)), "data")
+NEWLINES_TXT = join(DATA_DIR, "newlines.txt")
class FirstleveldirectoryTC(TestCase):
-
def test_known_values_first_level_directory(self):
"""return the first level directory of a path"""
- self.assertEqual(first_level_directory('truc/bidule/chouette'), 'truc', None)
- self.assertEqual(first_level_directory('/truc/bidule/chouette'), '/', None)
+ self.assertEqual(first_level_directory("truc/bidule/chouette"), "truc", None)
+ self.assertEqual(first_level_directory("/truc/bidule/chouette"), "/", None)
+
class IsBinaryTC(TestCase):
def test(self):
- self.assertEqual(is_binary('toto.txt'), 0)
- #self.assertEqual(is_binary('toto.xml'), 0)
- self.assertEqual(is_binary('toto.bin'), 1)
- self.assertEqual(is_binary('toto.sxi'), 1)
- self.assertEqual(is_binary('toto.whatever'), 1)
+ self.assertEqual(is_binary("toto.txt"), 0)
+ # self.assertEqual(is_binary('toto.xml'), 0)
+ self.assertEqual(is_binary("toto.bin"), 1)
+ self.assertEqual(is_binary("toto.sxi"), 1)
+ self.assertEqual(is_binary("toto.whatever"), 1)
+
class GetModeTC(TestCase):
def test(self):
- self.assertEqual(write_open_mode('toto.txt'), 'w')
- #self.assertEqual(write_open_mode('toto.xml'), 'w')
- self.assertEqual(write_open_mode('toto.bin'), 'wb')
- self.assertEqual(write_open_mode('toto.sxi'), 'wb')
+ self.assertEqual(write_open_mode("toto.txt"), "w")
+ # self.assertEqual(write_open_mode('toto.xml'), 'w')
+ self.assertEqual(write_open_mode("toto.bin"), "wb")
+ self.assertEqual(write_open_mode("toto.sxi"), "wb")
+
class NormReadTC(TestCase):
def test_known_values_norm_read(self):
with io.open(NEWLINES_TXT) as f:
data = f.read()
- self.assertEqual(data.strip(), '\n'.join(['# mixed new lines', '1', '2', '3']))
+ self.assertEqual(data.strip(), "\n".join(["# mixed new lines", "1", "2", "3"]))
class LinesTC(TestCase):
def test_known_values_lines(self):
- self.assertEqual(lines(NEWLINES_TXT),
- ['# mixed new lines', '1', '2', '3'])
+ self.assertEqual(lines(NEWLINES_TXT), ["# mixed new lines", "1", "2", "3"])
def test_known_values_lines_comment(self):
- self.assertEqual(lines(NEWLINES_TXT, comments='#'),
- ['1', '2', '3'])
+ self.assertEqual(lines(NEWLINES_TXT, comments="#"), ["1", "2", "3"])
+
class ExportTC(TestCase):
def setUp(self):
@@ -76,18 +77,19 @@ class ExportTC(TestCase):
def test(self):
export(DATA_DIR, self.tempdir, verbose=0)
- self.assertTrue(exists(join(self.tempdir, '__init__.py')))
- self.assertTrue(exists(join(self.tempdir, 'sub')))
- self.assertTrue(not exists(join(self.tempdir, '__init__.pyc')))
- self.assertTrue(not exists(join(self.tempdir, 'CVS')))
+ self.assertTrue(exists(join(self.tempdir, "__init__.py")))
+ self.assertTrue(exists(join(self.tempdir, "sub")))
+ self.assertTrue(not exists(join(self.tempdir, "__init__.pyc")))
+ self.assertTrue(not exists(join(self.tempdir, "CVS")))
def tearDown(self):
shutil.rmtree(self.tempdir)
+
class ProtectedFileTC(TestCase):
def setUp(self):
- self.rpath = join(DATA_DIR, 'write_protected_file.txt')
- self.rwpath = join(DATA_DIR, 'normal_file.txt')
+ self.rpath = join(DATA_DIR, "write_protected_file.txt")
+ self.rwpath = join(DATA_DIR, "normal_file.txt")
# Make sure rpath is not writable !
os.chmod(self.rpath, 33060)
# Make sure rwpath is writable !
@@ -96,51 +98,53 @@ class ProtectedFileTC(TestCase):
def test_mode_change(self):
"""tests that mode is changed when needed"""
# test on non-writable file
- #self.assertTrue(not os.access(self.rpath, os.W_OK))
+ # self.assertTrue(not os.access(self.rpath, os.W_OK))
self.assertTrue(not os.stat(self.rpath).st_mode & S_IWRITE)
- wp_file = ProtectedFile(self.rpath, 'w')
+ wp_file = ProtectedFile(self.rpath, "w")
self.assertTrue(os.stat(self.rpath).st_mode & S_IWRITE)
self.assertTrue(os.access(self.rpath, os.W_OK))
# test on writable-file
self.assertTrue(os.stat(self.rwpath).st_mode & S_IWRITE)
self.assertTrue(os.access(self.rwpath, os.W_OK))
- wp_file = ProtectedFile(self.rwpath, 'w')
+ wp_file = ProtectedFile(self.rwpath, "w")
self.assertTrue(os.stat(self.rwpath).st_mode & S_IWRITE)
self.assertTrue(os.access(self.rwpath, os.W_OK))
def test_restore_on_close(self):
"""tests original mode is restored on close"""
# test on non-writable file
- #self.assertTrue(not os.access(self.rpath, os.W_OK))
+ # self.assertTrue(not os.access(self.rpath, os.W_OK))
self.assertTrue(not os.stat(self.rpath).st_mode & S_IWRITE)
- ProtectedFile(self.rpath, 'w').close()
- #self.assertTrue(not os.access(self.rpath, os.W_OK))
+ ProtectedFile(self.rpath, "w").close()
+ # self.assertTrue(not os.access(self.rpath, os.W_OK))
self.assertTrue(not os.stat(self.rpath).st_mode & S_IWRITE)
# test on writable-file
self.assertTrue(os.access(self.rwpath, os.W_OK))
self.assertTrue(os.stat(self.rwpath).st_mode & S_IWRITE)
- ProtectedFile(self.rwpath, 'w').close()
+ ProtectedFile(self.rwpath, "w").close()
self.assertTrue(os.access(self.rwpath, os.W_OK))
self.assertTrue(os.stat(self.rwpath).st_mode & S_IWRITE)
def test_mode_change_on_append(self):
"""tests that mode is changed when file is opened in 'a' mode"""
- #self.assertTrue(not os.access(self.rpath, os.W_OK))
+ # self.assertTrue(not os.access(self.rpath, os.W_OK))
self.assertTrue(not os.stat(self.rpath).st_mode & S_IWRITE)
- wp_file = ProtectedFile(self.rpath, 'a')
+ wp_file = ProtectedFile(self.rpath, "a")
self.assertTrue(os.access(self.rpath, os.W_OK))
self.assertTrue(os.stat(self.rpath).st_mode & S_IWRITE)
wp_file.close()
- #self.assertTrue(not os.access(self.rpath, os.W_OK))
+ # self.assertTrue(not os.access(self.rpath, os.W_OK))
self.assertTrue(not os.stat(self.rpath).st_mode & S_IWRITE)
if sys.version_info < (3, 0):
+
def load_tests(loader, tests, ignore):
from logilab.common import fileutils
+
tests.addTests(doctest.DocTestSuite(fileutils))
return tests
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_graph.py b/test/test_graph.py
index 9a2e8bc..03ed3ef 100644
--- a/test/test_graph.py
+++ b/test/test_graph.py
@@ -20,70 +20,70 @@
from logilab.common.testlib import TestCase, unittest_main
from logilab.common.graph import get_cycles, has_path, ordered_nodes, UnorderableGraph
-class getCyclesTC(TestCase):
+class getCyclesTC(TestCase):
def test_known0(self):
- self.assertEqual(get_cycles({1:[2], 2:[3], 3:[1]}), [[1, 2, 3]])
+ self.assertEqual(get_cycles({1: [2], 2: [3], 3: [1]}), [[1, 2, 3]])
def test_known1(self):
- self.assertEqual(get_cycles({1:[2], 2:[3], 3:[1, 4], 4:[3]}), [[1, 2, 3], [3, 4]])
+ self.assertEqual(get_cycles({1: [2], 2: [3], 3: [1, 4], 4: [3]}), [[1, 2, 3], [3, 4]])
def test_known2(self):
- self.assertEqual(get_cycles({1:[2], 2:[3], 3:[0], 0:[]}), [])
+ self.assertEqual(get_cycles({1: [2], 2: [3], 3: [0], 0: []}), [])
class hasPathTC(TestCase):
-
def test_direct_connection(self):
- self.assertEqual(has_path({'A': ['B'], 'B': ['A']}, 'A', 'B'), ['B'])
+ self.assertEqual(has_path({"A": ["B"], "B": ["A"]}, "A", "B"), ["B"])
def test_indirect_connection(self):
- self.assertEqual(has_path({'A': ['B'], 'B': ['A', 'C'], 'C': ['B']}, 'A', 'C'), ['B', 'C'])
+ self.assertEqual(has_path({"A": ["B"], "B": ["A", "C"], "C": ["B"]}, "A", "C"), ["B", "C"])
def test_no_connection(self):
- self.assertEqual(has_path({'A': ['B'], 'B': ['A']}, 'A', 'C'), None)
+ self.assertEqual(has_path({"A": ["B"], "B": ["A"]}, "A", "C"), None)
def test_cycle(self):
- self.assertEqual(has_path({'A': ['A']}, 'A', 'B'), None)
+ self.assertEqual(has_path({"A": ["A"]}, "A", "B"), None)
-class ordered_nodesTC(TestCase):
+class ordered_nodesTC(TestCase):
def test_one_item(self):
- graph = {'a':[]}
+ graph = {"a": []}
ordered = ordered_nodes(graph)
- self.assertEqual(ordered, ('a',))
+ self.assertEqual(ordered, ("a",))
def test_single_dependency(self):
- graph = {'a':['b'], 'b':[]}
+ graph = {"a": ["b"], "b": []}
ordered = ordered_nodes(graph)
- self.assertEqual(ordered, ('a','b'))
- graph = {'a':[], 'b':['a']}
+ self.assertEqual(ordered, ("a", "b"))
+ graph = {"a": [], "b": ["a"]}
ordered = ordered_nodes(graph)
- self.assertEqual(ordered, ('b','a'))
+ self.assertEqual(ordered, ("b", "a"))
def test_two_items_no_dependency(self):
- graph = {'a':[], 'b':[]}
+ graph = {"a": [], "b": []}
ordered = ordered_nodes(graph)
- self.assertEqual(ordered, ('a','b'))
+ self.assertEqual(ordered, ("a", "b"))
def test_three_items_no_dependency(self):
- graph = {'a':[], 'b':[], 'c':[]}
+ graph = {"a": [], "b": [], "c": []}
ordered = ordered_nodes(graph)
- self.assertEqual(ordered, ('a', 'b', 'c'))
+ self.assertEqual(ordered, ("a", "b", "c"))
def test_three_items_one_dependency(self):
- graph = {'a': ['c'], 'b': [], 'c':[]}
+ graph = {"a": ["c"], "b": [], "c": []}
ordered = ordered_nodes(graph)
- self.assertEqual(ordered, ('a', 'b', 'c'))
+ self.assertEqual(ordered, ("a", "b", "c"))
def test_three_items_two_dependencies(self):
- graph = {'a': ['b'], 'b': ['c'], 'c':[]}
+ graph = {"a": ["b"], "b": ["c"], "c": []}
ordered = ordered_nodes(graph)
- self.assertEqual(ordered, ('a', 'b', 'c'))
+ self.assertEqual(ordered, ("a", "b", "c"))
def test_bad_graph(self):
- graph = {'a':['b']}
+ graph = {"a": ["b"]}
self.assertRaises(UnorderableGraph, ordered_nodes, graph)
+
if __name__ == "__main__":
unittest_main()
diff --git a/test/test_interface.py b/test/test_interface.py
index 1dbed7a..a3c20bf 100644
--- a/test/test_interface.py
+++ b/test/test_interface.py
@@ -18,31 +18,44 @@
from logilab.common.testlib import TestCase, unittest_main
from logilab.common.interface import *
-class IFace1(Interface): pass
-class IFace2(Interface): pass
-class IFace3(Interface): pass
+
+class IFace1(Interface):
+ pass
+
+
+class IFace2(Interface):
+ pass
+
+
+class IFace3(Interface):
+ pass
class A(object):
__implements__ = (IFace1,)
-class B(A): pass
+class B(A):
+ pass
class C1(B):
__implements__ = list(B.__implements__) + [IFace3]
+
class C2(B):
__implements__ = B.__implements__ + (IFace2,)
+
class D(C1):
__implements__ = ()
-class Z(object): pass
-class ExtendTC(TestCase):
+class Z(object):
+ pass
+
+class ExtendTC(TestCase):
def setUp(self):
global aimpl, c1impl, c2impl, dimpl
aimpl = A.__implements__
@@ -73,9 +86,10 @@ class ExtendTC(TestCase):
self.assertTrue(C1.__implements__ is c1impl)
self.assertTrue(D.__implements__ is dimpl)
-
def test_nonregr_implements_baseinterface(self):
- class SubIFace(IFace1): pass
+ class SubIFace(IFace1):
+ pass
+
class X(object):
__implements__ = (SubIFace,)
@@ -83,5 +97,5 @@ class ExtendTC(TestCase):
self.assertTrue(IFace1.is_implemented_by(X))
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_pytest.py b/test/test_pytest.py
index 48e36ce..02a34d4 100644
--- a/test/test_pytest.py
+++ b/test/test_pytest.py
@@ -19,6 +19,7 @@ from os.path import join
from logilab.common.testlib import TestCase, unittest_main
from logilab.common.pytest import *
+
class ModuleFunctionTC(TestCase):
def test_this_is_testdir(self):
self.assertTrue(this_is_a_testdir("test"))
@@ -82,5 +83,5 @@ class ModuleFunctionTC(TestCase):
myfn()
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_shellutils.py b/test/test_shellutils.py
index 49b06c7..8cb06ca 100644
--- a/test/test_shellutils.py
+++ b/test/test_shellutils.py
@@ -24,54 +24,96 @@ from unittest.mock import patch
from logilab.common.testlib import TestCase, unittest_main
-from logilab.common.shellutils import (globfind, find, ProgressBar,
- RawInput)
+from logilab.common.shellutils import globfind, find, ProgressBar, RawInput
from logilab.common.compat import StringIO
-DATA_DIR = join(dirname(abspath(__file__)), 'data', 'find_test')
+DATA_DIR = join(dirname(abspath(__file__)), "data", "find_test")
class FindTC(TestCase):
def test_include(self):
- files = set(find(DATA_DIR, '.py'))
- self.assertSetEqual(files,
- set([join(DATA_DIR, f) for f in ['__init__.py', 'module.py',
- 'module2.py', 'noendingnewline.py',
- 'nonregr.py', join('sub', 'momo.py')]]))
- files = set(find(DATA_DIR, ('.py',), blacklist=('sub',)))
- self.assertSetEqual(files,
- set([join(DATA_DIR, f) for f in ['__init__.py', 'module.py',
- 'module2.py', 'noendingnewline.py',
- 'nonregr.py']]))
+ files = set(find(DATA_DIR, ".py"))
+ self.assertSetEqual(
+ files,
+ set(
+ [
+ join(DATA_DIR, f)
+ for f in [
+ "__init__.py",
+ "module.py",
+ "module2.py",
+ "noendingnewline.py",
+ "nonregr.py",
+ join("sub", "momo.py"),
+ ]
+ ]
+ ),
+ )
+ files = set(find(DATA_DIR, (".py",), blacklist=("sub",)))
+ self.assertSetEqual(
+ files,
+ set(
+ [
+ join(DATA_DIR, f)
+ for f in [
+ "__init__.py",
+ "module.py",
+ "module2.py",
+ "noendingnewline.py",
+ "nonregr.py",
+ ]
+ ]
+ ),
+ )
def test_exclude(self):
- files = set(find(DATA_DIR, ('.py', '.pyc'), exclude=True))
- self.assertSetEqual(files,
- set([join(DATA_DIR, f) for f in ['foo.txt',
- 'newlines.txt',
- 'normal_file.txt',
- 'test.ini',
- 'test1.msg',
- 'test2.msg',
- 'spam.txt',
- join('sub', 'doc.txt'),
- 'write_protected_file.txt',
- ]]))
+ files = set(find(DATA_DIR, (".py", ".pyc"), exclude=True))
+ self.assertSetEqual(
+ files,
+ set(
+ [
+ join(DATA_DIR, f)
+ for f in [
+ "foo.txt",
+ "newlines.txt",
+ "normal_file.txt",
+ "test.ini",
+ "test1.msg",
+ "test2.msg",
+ "spam.txt",
+ join("sub", "doc.txt"),
+ "write_protected_file.txt",
+ ]
+ ]
+ ),
+ )
def test_globfind(self):
- files = set(globfind(DATA_DIR, '*.py'))
- self.assertSetEqual(files,
- set([join(DATA_DIR, f) for f in ['__init__.py', 'module.py',
- 'module2.py', 'noendingnewline.py',
- 'nonregr.py', join('sub', 'momo.py')]]))
- files = set(globfind(DATA_DIR, 'mo*.py'))
- self.assertSetEqual(files,
- set([join(DATA_DIR, f) for f in ['module.py', 'module2.py',
- join('sub', 'momo.py')]]))
- files = set(globfind(DATA_DIR, 'mo*.py', blacklist=('sub',)))
- self.assertSetEqual(files,
- set([join(DATA_DIR, f) for f in ['module.py', 'module2.py']]))
+ files = set(globfind(DATA_DIR, "*.py"))
+ self.assertSetEqual(
+ files,
+ set(
+ [
+ join(DATA_DIR, f)
+ for f in [
+ "__init__.py",
+ "module.py",
+ "module2.py",
+ "noendingnewline.py",
+ "nonregr.py",
+ join("sub", "momo.py"),
+ ]
+ ]
+ ),
+ )
+ files = set(globfind(DATA_DIR, "mo*.py"))
+ self.assertSetEqual(
+ files,
+ set([join(DATA_DIR, f) for f in ["module.py", "module2.py", join("sub", "momo.py")]]),
+ )
+ files = set(globfind(DATA_DIR, "mo*.py", blacklist=("sub",)))
+ self.assertSetEqual(files, set([join(DATA_DIR, f) for f in ["module.py", "module2.py"]]))
class ProgressBarTC(TestCase):
@@ -79,9 +121,11 @@ class ProgressBarTC(TestCase):
pgb_stream = StringIO()
expected_stream = StringIO()
pgb = ProgressBar(20, stream=pgb_stream)
- self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) # nothing print before refresh
+ self.assertEqual(
+ pgb_stream.getvalue(), expected_stream.getvalue()
+ ) # nothing print before refresh
pgb.refresh()
- expected_stream.write("\r["+' '*20+"]")
+ expected_stream.write("\r[" + " " * 20 + "]")
self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
def test_refresh_g_size(self):
@@ -89,7 +133,7 @@ class ProgressBarTC(TestCase):
expected_stream = StringIO()
pgb = ProgressBar(20, 35, stream=pgb_stream)
pgb.refresh()
- expected_stream.write("\r["+' '*35+"]")
+ expected_stream.write("\r[" + " " * 35 + "]")
self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
def test_refresh_l_size(self):
@@ -97,27 +141,27 @@ class ProgressBarTC(TestCase):
expected_stream = StringIO()
pgb = ProgressBar(20, 3, stream=pgb_stream)
pgb.refresh()
- expected_stream.write("\r["+' '*3+"]")
+ expected_stream.write("\r[" + " " * 3 + "]")
self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
- def _update_test(self, nbops, expected, size = None):
+ def _update_test(self, nbops, expected, size=None):
pgb_stream = StringIO()
expected_stream = StringIO()
if size is None:
pgb = ProgressBar(nbops, stream=pgb_stream)
- size=20
+ size = 20
else:
pgb = ProgressBar(nbops, size, stream=pgb_stream)
last = 0
for round in expected:
- if not hasattr(round, '__int__'):
+ if not hasattr(round, "__int__"):
dots, update = round
else:
dots, update = round, None
pgb.update()
if update or (update is None and dots != last):
last = dots
- expected_stream.write("\r["+('='*dots)+(' '*(size-dots))+"]")
+ expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]")
self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
def test_default(self):
@@ -125,16 +169,20 @@ class ProgressBarTC(TestCase):
def test_nbops_gt_size(self):
"""Test the progress bar for nbops > size"""
+
def half(total):
- for counter in range(1, total+1):
+ for counter in range(1, total + 1):
yield counter // 2
+
self._update_test(40, half(40))
def test_nbops_lt_size(self):
"""Test the progress bar for nbops < size"""
+
def double(total):
- for counter in range(1, total+1):
+ for counter in range(1, total + 1):
yield counter * 2
+
self._update_test(10, double(10))
def test_nbops_nomul_size(self):
@@ -147,30 +195,29 @@ class ProgressBarTC(TestCase):
def test_update_exact(self):
pgb_stream = StringIO()
expected_stream = StringIO()
- size=20
+ size = 20
pgb = ProgressBar(100, size, stream=pgb_stream)
last = 0
for dots in range(10, 105, 15):
pgb.update(dots, exact=True)
dots //= 5
- expected_stream.write("\r["+('='*dots)+(' '*(size-dots))+"]")
+ expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]")
self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
def test_update_relative(self):
pgb_stream = StringIO()
expected_stream = StringIO()
- size=20
+ size = 20
pgb = ProgressBar(100, size, stream=pgb_stream)
last = 0
for dots in range(5, 105, 5):
pgb.update(5, exact=False)
dots //= 5
- expected_stream.write("\r["+('='*dots)+(' '*(size-dots))+"]")
+ expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]")
self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue())
class RawInputTC(TestCase):
-
def auto_input(self, *args):
self.input_args = args
return self.input_answer
@@ -180,61 +227,62 @@ class RawInputTC(TestCase):
self.qa = RawInput(self.auto_input, null_printer)
def test_ask_using_builtin_input(self):
- with patch('builtins.input', return_value='no'):
+ with patch("builtins.input", return_value="no"):
qa = RawInput()
- answer = qa.ask('text', ('yes', 'no'), 'yes')
- self.assertEqual(answer, 'no')
+ answer = qa.ask("text", ("yes", "no"), "yes")
+ self.assertEqual(answer, "no")
def test_ask_default(self):
- self.input_answer = ''
- answer = self.qa.ask('text', ('yes', 'no'), 'yes')
- self.assertEqual(answer, 'yes')
- self.input_answer = ' '
- answer = self.qa.ask('text', ('yes', 'no'), 'yes')
- self.assertEqual(answer, 'yes')
+ self.input_answer = ""
+ answer = self.qa.ask("text", ("yes", "no"), "yes")
+ self.assertEqual(answer, "yes")
+ self.input_answer = " "
+ answer = self.qa.ask("text", ("yes", "no"), "yes")
+ self.assertEqual(answer, "yes")
def test_ask_case(self):
- self.input_answer = 'no'
- answer = self.qa.ask('text', ('yes', 'no'), 'yes')
- self.assertEqual(answer, 'no')
- self.input_answer = 'No'
- answer = self.qa.ask('text', ('yes', 'no'), 'yes')
- self.assertEqual(answer, 'no')
- self.input_answer = 'NO'
- answer = self.qa.ask('text', ('yes', 'no'), 'yes')
- self.assertEqual(answer, 'no')
- self.input_answer = 'nO'
- answer = self.qa.ask('text', ('yes', 'no'), 'yes')
- self.assertEqual(answer, 'no')
- self.input_answer = 'YES'
- answer = self.qa.ask('text', ('yes', 'no'), 'yes')
- self.assertEqual(answer, 'yes')
+ self.input_answer = "no"
+ answer = self.qa.ask("text", ("yes", "no"), "yes")
+ self.assertEqual(answer, "no")
+ self.input_answer = "No"
+ answer = self.qa.ask("text", ("yes", "no"), "yes")
+ self.assertEqual(answer, "no")
+ self.input_answer = "NO"
+ answer = self.qa.ask("text", ("yes", "no"), "yes")
+ self.assertEqual(answer, "no")
+ self.input_answer = "nO"
+ answer = self.qa.ask("text", ("yes", "no"), "yes")
+ self.assertEqual(answer, "no")
+ self.input_answer = "YES"
+ answer = self.qa.ask("text", ("yes", "no"), "yes")
+ self.assertEqual(answer, "yes")
def test_ask_prompt(self):
- self.input_answer = ''
- answer = self.qa.ask('text', ('yes', 'no'), 'yes')
- self.assertEqual(self.input_args[0], 'text [Y(es)/n(o)]: ')
- answer = self.qa.ask('text', ('y', 'n'), 'y')
- self.assertEqual(self.input_args[0], 'text [Y/n]: ')
- answer = self.qa.ask('text', ('n', 'y'), 'y')
- self.assertEqual(self.input_args[0], 'text [n/Y]: ')
- answer = self.qa.ask('text', ('yes', 'no', 'maybe', '1'), 'yes')
- self.assertEqual(self.input_args[0], 'text [Y(es)/n(o)/m(aybe)/1]: ')
+ self.input_answer = ""
+ answer = self.qa.ask("text", ("yes", "no"), "yes")
+ self.assertEqual(self.input_args[0], "text [Y(es)/n(o)]: ")
+ answer = self.qa.ask("text", ("y", "n"), "y")
+ self.assertEqual(self.input_args[0], "text [Y/n]: ")
+ answer = self.qa.ask("text", ("n", "y"), "y")
+ self.assertEqual(self.input_args[0], "text [n/Y]: ")
+ answer = self.qa.ask("text", ("yes", "no", "maybe", "1"), "yes")
+ self.assertEqual(self.input_args[0], "text [Y(es)/n(o)/m(aybe)/1]: ")
def test_ask_ambiguous(self):
- self.input_answer = 'y'
- self.assertRaises(Exception, self.qa.ask, 'text', ('yes', 'yep'), 'yes')
+ self.input_answer = "y"
+ self.assertRaises(Exception, self.qa.ask, "text", ("yes", "yep"), "yes")
def test_confirm(self):
- self.input_answer = 'y'
- self.assertEqual(self.qa.confirm('Say yes'), True)
- self.assertEqual(self.qa.confirm('Say yes', default_is_yes=False), True)
- self.input_answer = 'n'
- self.assertEqual(self.qa.confirm('Say yes'), False)
- self.assertEqual(self.qa.confirm('Say yes', default_is_yes=False), False)
- self.input_answer = ''
- self.assertEqual(self.qa.confirm('Say default'), True)
- self.assertEqual(self.qa.confirm('Say default', default_is_yes=False), False)
-
-if __name__ == '__main__':
+ self.input_answer = "y"
+ self.assertEqual(self.qa.confirm("Say yes"), True)
+ self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), True)
+ self.input_answer = "n"
+ self.assertEqual(self.qa.confirm("Say yes"), False)
+ self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), False)
+ self.input_answer = ""
+ self.assertEqual(self.qa.confirm("Say default"), True)
+ self.assertEqual(self.qa.confirm("Say default", default_is_yes=False), False)
+
+
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_table.py b/test/test_table.py
index 0c40a7c..5c0ac19 100644
--- a/test/test_table.py
+++ b/test/test_table.py
@@ -25,8 +25,16 @@ import os
from logilab.common.compat import StringIO
from logilab.common.testlib import TestCase, unittest_main
-from logilab.common.table import Table, TableStyleSheet, DocbookTableWriter, \
- DocbookRenderer, TableStyle, TableWriter, TableCellRenderer
+from logilab.common.table import (
+ Table,
+ TableStyleSheet,
+ DocbookTableWriter,
+ DocbookRenderer,
+ TableStyle,
+ TableWriter,
+ TableCellRenderer,
+)
+
class TableTC(TestCase):
"""Table TestCase class"""
@@ -36,12 +44,12 @@ class TableTC(TestCase):
# from logilab.common import table
# reload(table)
self.table = Table()
- self.table.create_rows(['row1', 'row2', 'row3'])
- self.table.create_columns(['col1', 'col2'])
+ self.table.create_rows(["row1", "row2", "row3"])
+ self.table.create_columns(["col1", "col2"])
def test_valeur_scalaire(self):
tab = Table()
- tab.create_columns(['col1'])
+ tab.create_columns(["col1"])
tab.append_row([1])
self.assertEqual(tab, [[1]])
tab.append_row([2])
@@ -50,13 +58,13 @@ class TableTC(TestCase):
def test_valeur_ligne(self):
tab = Table()
- tab.create_columns(['col1', 'col2'])
+ tab.create_columns(["col1", "col2"])
tab.append_row([1, 2])
self.assertEqual(tab, [[1, 2]])
def test_valeur_colonne(self):
tab = Table()
- tab.create_columns(['col1'])
+ tab.create_columns(["col1"])
tab.append_row([1])
tab.append_row([2])
self.assertEqual(tab, [[1], [2]])
@@ -77,25 +85,25 @@ class TableTC(TestCase):
"""tests Table.get_rows()"""
self.assertEqual(self.table, [[0, 0], [0, 0], [0, 0]])
self.assertEqual(self.table[:], [[0, 0], [0, 0], [0, 0]])
- self.table.insert_column(1, range(3), 'supp')
+ self.table.insert_column(1, range(3), "supp")
self.assertEqual(self.table, [[0, 0, 0], [0, 1, 0], [0, 2, 0]])
self.assertEqual(self.table[:], [[0, 0, 0], [0, 1, 0], [0, 2, 0]])
def test_get_cells(self):
- self.table.insert_column(1, range(3), 'supp')
+ self.table.insert_column(1, range(3), "supp")
self.assertEqual(self.table[0, 1], 0)
self.assertEqual(self.table[1, 1], 1)
self.assertEqual(self.table[2, 1], 2)
- self.assertEqual(self.table['row1', 'supp'], 0)
- self.assertEqual(self.table['row2', 'supp'], 1)
- self.assertEqual(self.table['row3', 'supp'], 2)
- self.assertRaises(KeyError, self.table.__getitem__, ('row1', 'foo'))
- self.assertRaises(KeyError, self.table.__getitem__, ('foo', 'bar'))
+ self.assertEqual(self.table["row1", "supp"], 0)
+ self.assertEqual(self.table["row2", "supp"], 1)
+ self.assertEqual(self.table["row3", "supp"], 2)
+ self.assertRaises(KeyError, self.table.__getitem__, ("row1", "foo"))
+ self.assertRaises(KeyError, self.table.__getitem__, ("foo", "bar"))
def test_shape(self):
"""tests table shape"""
self.assertEqual(self.table.shape, (3, 2))
- self.table.insert_column(1, range(3), 'supp')
+ self.table.insert_column(1, range(3), "supp")
self.assertEqual(self.table.shape, (3, 3))
def test_set_column(self):
@@ -109,33 +117,33 @@ class TableTC(TestCase):
def test_set_column_by_id(self):
"""Tests that table.set_column_by_id() works fine.
"""
- self.table.set_column_by_id('col1', range(3))
+ self.table.set_column_by_id("col1", range(3))
self.assertEqual(self.table[0, 0], 0)
self.assertEqual(self.table[1, 0], 1)
self.assertEqual(self.table[2, 0], 2)
- self.assertRaises(KeyError, self.table.set_column_by_id, 'col123', range(3))
+ self.assertRaises(KeyError, self.table.set_column_by_id, "col123", range(3))
def test_cells_ids(self):
"""tests that we can access cells by giving row/col ids"""
- self.assertRaises(KeyError, self.table.set_cell_by_ids, 'row12', 'col1', 12)
- self.assertRaises(KeyError, self.table.set_cell_by_ids, 'row1', 'col12', 12)
+ self.assertRaises(KeyError, self.table.set_cell_by_ids, "row12", "col1", 12)
+ self.assertRaises(KeyError, self.table.set_cell_by_ids, "row1", "col12", 12)
self.assertEqual(self.table[0, 0], 0)
- self.table.set_cell_by_ids('row1', 'col1', 'DATA')
- self.assertEqual(self.table[0, 0], 'DATA')
- self.assertRaises(KeyError, self.table.set_row_by_id, 'row12', [])
- self.table.set_row_by_id('row1', ['1.0', '1.1'])
- self.assertEqual(self.table[0, 0], '1.0')
+ self.table.set_cell_by_ids("row1", "col1", "DATA")
+ self.assertEqual(self.table[0, 0], "DATA")
+ self.assertRaises(KeyError, self.table.set_row_by_id, "row12", [])
+ self.table.set_row_by_id("row1", ["1.0", "1.1"])
+ self.assertEqual(self.table[0, 0], "1.0")
def test_insert_row(self):
"""tests a row insertion"""
- tmp_data = ['tmp1', 'tmp2']
- self.table.insert_row(1, tmp_data, 'tmprow')
+ tmp_data = ["tmp1", "tmp2"]
+ self.table.insert_row(1, tmp_data, "tmprow")
self.assertEqual(self.table[1], tmp_data)
- self.assertEqual(self.table['tmprow'], tmp_data)
- self.table.delete_row_by_id('tmprow')
- self.assertRaises(KeyError, self.table.delete_row_by_id, 'tmprow')
+ self.assertEqual(self.table["tmprow"], tmp_data)
+ self.table.delete_row_by_id("tmprow")
+ self.assertRaises(KeyError, self.table.delete_row_by_id, "tmprow")
self.assertEqual(self.table[1], [0, 0])
- self.assertRaises(KeyError, self.table.__getitem__, 'tmprow')
+ self.assertRaises(KeyError, self.table.__getitem__, "tmprow")
def test_get_column(self):
"""Tests that table.get_column() works fine.
@@ -143,7 +151,7 @@ class TableTC(TestCase):
self.table.set_cell(0, 1, 12)
self.table.set_cell(2, 1, 13)
self.assertEqual(self.table[:, 1], [12, 0, 13])
- self.assertEqual(self.table[:, 'col2'], [12, 0, 13])
+ self.assertEqual(self.table[:, "col2"], [12, 0, 13])
def test_get_columns(self):
"""Tests if table.get_columns() works fine.
@@ -157,26 +165,25 @@ class TableTC(TestCase):
"""
self.table.insert_column(1, range(3), "inserted_column")
self.assertEqual(self.table[:, 1], [0, 1, 2])
- self.assertEqual(self.table.col_names,
- ['col1', 'inserted_column', 'col2'])
+ self.assertEqual(self.table.col_names, ["col1", "inserted_column", "col2"])
def test_delete_column(self):
"""Tests that table.delete_column() works fine.
"""
self.table.delete_column(1)
- self.assertEqual(self.table.col_names, ['col1'])
+ self.assertEqual(self.table.col_names, ["col1"])
self.assertEqual(self.table[:, 0], [0, 0, 0])
- self.assertRaises(KeyError, self.table.delete_column_by_id, 'col2')
- self.table.delete_column_by_id('col1')
+ self.assertRaises(KeyError, self.table.delete_column_by_id, "col2")
+ self.table.delete_column_by_id("col1")
self.assertEqual(self.table.col_names, [])
def test_transpose(self):
"""Tests that table.transpose() works fine.
"""
- self.table.append_column(range(5, 8), 'col3')
+ self.table.append_column(range(5, 8), "col3")
ttable = self.table.transpose()
- self.assertEqual(ttable.row_names, ['col1', 'col2', 'col3'])
- self.assertEqual(ttable.col_names, ['row1', 'row2', 'row3'])
+ self.assertEqual(ttable.row_names, ["col1", "col2", "col3"])
+ self.assertEqual(ttable.col_names, ["row1", "row2", "row3"])
self.assertEqual(ttable.data, [[0, 0, 0], [0, 0, 0], [5, 6, 7]])
def test_sort_table(self):
@@ -185,22 +192,22 @@ class TableTC(TestCase):
self.table.set_column(0, [3, 1, 2])
self.table.set_column(1, [1, 2, 3])
self.table.sort_by_column_index(0)
- self.assertEqual(self.table.row_names, ['row2', 'row3', 'row1'])
+ self.assertEqual(self.table.row_names, ["row2", "row3", "row1"])
self.assertEqual(self.table.data, [[1, 2], [2, 3], [3, 1]])
- self.table.sort_by_column_index(1, 'desc')
- self.assertEqual(self.table.row_names, ['row3', 'row2', 'row1'])
+ self.table.sort_by_column_index(1, "desc")
+ self.assertEqual(self.table.row_names, ["row3", "row2", "row1"])
self.assertEqual(self.table.data, [[2, 3], [1, 2], [3, 1]])
def test_sort_by_id(self):
"""tests sort_by_column_id()"""
- self.table.set_column_by_id('col1', [3, 1, 2])
- self.table.set_column_by_id('col2', [1, 2, 3])
- self.table.sort_by_column_id('col1')
- self.assertRaises(KeyError, self.table.sort_by_column_id, 'col123')
- self.assertEqual(self.table.row_names, ['row2', 'row3', 'row1'])
+ self.table.set_column_by_id("col1", [3, 1, 2])
+ self.table.set_column_by_id("col2", [1, 2, 3])
+ self.table.sort_by_column_id("col1")
+ self.assertRaises(KeyError, self.table.sort_by_column_id, "col123")
+ self.assertEqual(self.table.row_names, ["row2", "row3", "row1"])
self.assertEqual(self.table.data, [[1, 2], [2, 3], [3, 1]])
- self.table.sort_by_column_id('col2', 'desc')
- self.assertEqual(self.table.row_names, ['row3', 'row2', 'row1'])
+ self.table.sort_by_column_id("col2", "desc")
+ self.assertEqual(self.table.row_names, ["row3", "row2", "row1"])
self.assertEqual(self.table.data, [[2, 3], [1, 2], [3, 1]])
def test_pprint(self):
@@ -211,67 +218,74 @@ class TableTC(TestCase):
class GroupByTC(TestCase):
"""specific test suite for groupby()"""
+
def setUp(self):
t = Table()
- t.create_columns(['date', 'res', 'task', 'usage'])
- t.append_row(['date1', 'ing1', 'task1', 0.3])
- t.append_row(['date1', 'ing2', 'task2', 0.3])
- t.append_row(['date2', 'ing3', 'task3', 0.3])
- t.append_row(['date3', 'ing4', 'task2', 0.3])
- t.append_row(['date1', 'ing1', 'task3', 0.3])
- t.append_row(['date3', 'ing1', 'task3', 0.3])
+ t.create_columns(["date", "res", "task", "usage"])
+ t.append_row(["date1", "ing1", "task1", 0.3])
+ t.append_row(["date1", "ing2", "task2", 0.3])
+ t.append_row(["date2", "ing3", "task3", 0.3])
+ t.append_row(["date3", "ing4", "task2", 0.3])
+ t.append_row(["date1", "ing1", "task3", 0.3])
+ t.append_row(["date3", "ing1", "task3", 0.3])
self.table = t
def test_single_groupby(self):
"""tests groupby() on several columns"""
- grouped = self.table.groupby('date')
+ grouped = self.table.groupby("date")
self.assertEqual(len(grouped), 3)
- self.assertEqual(len(grouped['date1']), 3)
- self.assertEqual(len(grouped['date2']), 1)
- self.assertEqual(len(grouped['date3']), 2)
- self.assertEqual(grouped['date1'], [
- ('date1', 'ing1', 'task1', 0.3),
- ('date1', 'ing2', 'task2', 0.3),
- ('date1', 'ing1', 'task3', 0.3),
- ])
- self.assertEqual(grouped['date2'], [('date2', 'ing3', 'task3', 0.3)])
- self.assertEqual(grouped['date3'], [
- ('date3', 'ing4', 'task2', 0.3),
- ('date3', 'ing1', 'task3', 0.3),
- ])
+ self.assertEqual(len(grouped["date1"]), 3)
+ self.assertEqual(len(grouped["date2"]), 1)
+ self.assertEqual(len(grouped["date3"]), 2)
+ self.assertEqual(
+ grouped["date1"],
+ [
+ ("date1", "ing1", "task1", 0.3),
+ ("date1", "ing2", "task2", 0.3),
+ ("date1", "ing1", "task3", 0.3),
+ ],
+ )
+ self.assertEqual(grouped["date2"], [("date2", "ing3", "task3", 0.3)])
+ self.assertEqual(
+ grouped["date3"], [("date3", "ing4", "task2", 0.3), ("date3", "ing1", "task3", 0.3),]
+ )
def test_multiple_groupby(self):
"""tests groupby() on several columns"""
- grouped = self.table.groupby('date', 'task')
+ grouped = self.table.groupby("date", "task")
self.assertEqual(len(grouped), 3)
- self.assertEqual(len(grouped['date1']), 3)
- self.assertEqual(len(grouped['date2']), 1)
- self.assertEqual(len(grouped['date3']), 2)
- self.assertEqual(grouped['date1']['task1'], [('date1', 'ing1', 'task1', 0.3)])
- self.assertEqual(grouped['date2']['task3'], [('date2', 'ing3', 'task3', 0.3)])
- self.assertEqual(grouped['date3']['task2'], [('date3', 'ing4', 'task2', 0.3)])
- date3 = grouped['date3']
- self.assertRaises(KeyError, date3.__getitem__, 'task1')
-
+ self.assertEqual(len(grouped["date1"]), 3)
+ self.assertEqual(len(grouped["date2"]), 1)
+ self.assertEqual(len(grouped["date3"]), 2)
+ self.assertEqual(grouped["date1"]["task1"], [("date1", "ing1", "task1", 0.3)])
+ self.assertEqual(grouped["date2"]["task3"], [("date2", "ing3", "task3", 0.3)])
+ self.assertEqual(grouped["date3"]["task2"], [("date3", "ing4", "task2", 0.3)])
+ date3 = grouped["date3"]
+ self.assertRaises(KeyError, date3.__getitem__, "task1")
def test_select(self):
"""tests Table.select() method"""
- rows = self.table.select('date', 'date1')
- self.assertEqual(rows, [
- ('date1', 'ing1', 'task1', 0.3),
- ('date1', 'ing2', 'task2', 0.3),
- ('date1', 'ing1', 'task3', 0.3),
- ])
+ rows = self.table.select("date", "date1")
+ self.assertEqual(
+ rows,
+ [
+ ("date1", "ing1", "task1", 0.3),
+ ("date1", "ing2", "task2", 0.3),
+ ("date1", "ing1", "task3", 0.3),
+ ],
+ )
+
class TableStyleSheetTC(TestCase):
"""The Stylesheet test case
"""
+
def setUp(self):
"""Builds a simple table to test the stylesheet
"""
self.table = Table()
- self.table.create_row('row1')
- self.table.create_columns(['a', 'b', 'c'])
+ self.table.create_row("row1")
+ self.table.create_columns(["a", "b", "c"])
self.stylesheet = TableStyleSheet()
# We don't want anything to be printed
self.stdout_backup = sys.stdout
@@ -283,20 +297,20 @@ class TableStyleSheetTC(TestCase):
def test_add_rule(self):
"""Tests that the regex pattern works as expected.
"""
- rule = '0_2 = sqrt(0_0**2 + 0_1**2)'
+ rule = "0_2 = sqrt(0_0**2 + 0_1**2)"
self.stylesheet.add_rule(rule)
self.table.set_row(0, [3, 4, 0])
self.table.apply_stylesheet(self.stylesheet)
self.assertEqual(self.table[0], [3, 4, 5])
self.assertEqual(len(self.stylesheet.rules), 1)
- self.stylesheet.add_rule('some bad rule with bad syntax')
+ self.stylesheet.add_rule("some bad rule with bad syntax")
self.assertEqual(len(self.stylesheet.rules), 1, "Ill-formed rule mustn't be added")
self.assertEqual(len(self.stylesheet.instructions), 1, "Ill-formed rule mustn't be added")
def test_stylesheet_init(self):
"""tests Stylesheet.__init__"""
- rule = '0_2 = 1'
- sheet = TableStyleSheet([rule, 'bad rule'])
+ rule = "0_2 = 1"
+ sheet = TableStyleSheet([rule, "bad rule"])
self.assertEqual(len(sheet.rules), 1, "Ill-formed rule mustn't be added")
self.assertEqual(len(sheet.instructions), 1, "Ill-formed rule mustn't be added")
@@ -309,7 +323,6 @@ class TableStyleSheetTC(TestCase):
val = self.table[0, 2]
self.assertEqual(int(val), 15)
-
def test_rowsum_rule(self):
"""Tests that add_rowsum_rule works as expected
"""
@@ -319,116 +332,114 @@ class TableStyleSheetTC(TestCase):
val = self.table[0, 2]
self.assertEqual(val, 30)
-
def test_colavg_rule(self):
"""Tests that add_colavg_rule works as expected
"""
self.table.set_row(0, [10, 20, 0])
- self.table.append_row([12, 8, 3], 'row2')
- self.table.create_row('row3')
+ self.table.append_row([12, 8, 3], "row2")
+ self.table.create_row("row3")
self.stylesheet.add_colavg_rule((2, 0), 0, 0, 1)
self.table.apply_stylesheet(self.stylesheet)
val = self.table[2, 0]
self.assertEqual(int(val), 11)
-
def test_colsum_rule(self):
"""Tests that add_colsum_rule works as expected
"""
self.table.set_row(0, [10, 20, 0])
- self.table.append_row([12, 8, 3], 'row2')
- self.table.create_row('row3')
+ self.table.append_row([12, 8, 3], "row2")
+ self.table.create_row("row3")
self.stylesheet.add_colsum_rule((2, 0), 0, 0, 1)
self.table.apply_stylesheet(self.stylesheet)
val = self.table[2, 0]
self.assertEqual(val, 22)
-
class TableStyleTC(TestCase):
"""Test suite for TableSuite"""
+
def setUp(self):
self.table = Table()
- self.table.create_rows(['row1', 'row2', 'row3'])
- self.table.create_columns(['col1', 'col2'])
+ self.table.create_rows(["row1", "row2", "row3"])
+ self.table.create_columns(["col1", "col2"])
self.style = TableStyle(self.table)
- self._tested_attrs = (('size', '1*'),
- ('alignment', 'right'),
- ('unit', ''))
+ self._tested_attrs = (("size", "1*"), ("alignment", "right"), ("unit", ""))
def test_getset(self):
"""tests style's get and set methods"""
for attrname, default_value in self._tested_attrs:
- getter = getattr(self.style, 'get_%s' % attrname)
- setter = getattr(self.style, 'set_%s' % attrname)
- self.assertRaises(KeyError, getter, 'badcol')
- self.assertEqual(getter('col1'), default_value)
- setter('FOO', 'col1')
- self.assertEqual(getter('col1'), 'FOO')
+ getter = getattr(self.style, "get_%s" % attrname)
+ setter = getattr(self.style, "set_%s" % attrname)
+ self.assertRaises(KeyError, getter, "badcol")
+ self.assertEqual(getter("col1"), default_value)
+ setter("FOO", "col1")
+ self.assertEqual(getter("col1"), "FOO")
def test_getset_index(self):
"""tests style's get and set by index methods"""
for attrname, default_value in self._tested_attrs:
- getter = getattr(self.style, 'get_%s' % attrname)
- setter = getattr(self.style, 'set_%s' % attrname)
- igetter = getattr(self.style, 'get_%s_by_index' % attrname)
- isetter = getattr(self.style, 'set_%s_by_index' % attrname)
- self.assertEqual(getter('__row_column__'), default_value)
- isetter('FOO', 0)
- self.assertEqual(getter('__row_column__'), 'FOO')
- self.assertEqual(igetter(0), 'FOO')
- self.assertEqual(getter('col1'), default_value)
- isetter('FOO', 1)
- self.assertEqual(getter('col1'), 'FOO')
- self.assertEqual(igetter(1), 'FOO')
+ getter = getattr(self.style, "get_%s" % attrname)
+ setter = getattr(self.style, "set_%s" % attrname)
+ igetter = getattr(self.style, "get_%s_by_index" % attrname)
+ isetter = getattr(self.style, "set_%s_by_index" % attrname)
+ self.assertEqual(getter("__row_column__"), default_value)
+ isetter("FOO", 0)
+ self.assertEqual(getter("__row_column__"), "FOO")
+ self.assertEqual(igetter(0), "FOO")
+ self.assertEqual(getter("col1"), default_value)
+ isetter("FOO", 1)
+ self.assertEqual(getter("col1"), "FOO")
+ self.assertEqual(igetter(1), "FOO")
class RendererTC(TestCase):
"""Test suite for DocbookRenderer"""
+
def setUp(self):
- self.renderer = DocbookRenderer(alignment = True)
+ self.renderer = DocbookRenderer(alignment=True)
self.table = Table()
- self.table.create_rows(['row1', 'row2', 'row3'])
- self.table.create_columns(['col1', 'col2'])
+ self.table.create_rows(["row1", "row2", "row3"])
+ self.table.create_columns(["col1", "col2"])
self.style = TableStyle(self.table)
self.base_renderer = TableCellRenderer()
def test_cell_content(self):
"""test how alignment is rendered"""
- entry_xml = self.renderer._render_cell_content('data', self.style, 1)
+ entry_xml = self.renderer._render_cell_content("data", self.style, 1)
self.assertEqual(entry_xml, "<entry align='right'>data</entry>\n")
- self.style.set_alignment_by_index('left', 1)
- entry_xml = self.renderer._render_cell_content('data', self.style, 1)
+ self.style.set_alignment_by_index("left", 1)
+ entry_xml = self.renderer._render_cell_content("data", self.style, 1)
self.assertEqual(entry_xml, "<entry align='left'>data</entry>\n")
def test_default_content_rendering(self):
"""tests that default rendering just prints the cell's content"""
- rendered_cell = self.base_renderer._render_cell_content('data', self.style, 1)
+ rendered_cell = self.base_renderer._render_cell_content("data", self.style, 1)
self.assertEqual(rendered_cell, "data")
def test_replacement_char(self):
"""tests that 0 is replaced when asked for"""
cell_content = self.base_renderer._make_cell_content(0, self.style, 1)
self.assertEqual(cell_content, 0)
- self.base_renderer.properties['skip_zero'] = '---'
+ self.base_renderer.properties["skip_zero"] = "---"
cell_content = self.base_renderer._make_cell_content(0, self.style, 1)
- self.assertEqual(cell_content, '---')
+ self.assertEqual(cell_content, "---")
def test_unit(self):
"""tests if units are added"""
- self.base_renderer.properties['units'] = True
- self.style.set_unit_by_index('EUR', 1)
+ self.base_renderer.properties["units"] = True
+ self.style.set_unit_by_index("EUR", 1)
cell_content = self.base_renderer._make_cell_content(12, self.style, 1)
- self.assertEqual(cell_content, '12 EUR')
+ self.assertEqual(cell_content, "12 EUR")
class DocbookTableWriterTC(TestCase):
"""TestCase for table's writer"""
+
def setUp(self):
self.stream = StringIO()
self.table = Table()
- self.table.create_rows(['row1', 'row2', 'row3'])
- self.table.create_columns(['col1', 'col2'])
+ self.table.create_rows(["row1", "row2", "row3"])
+ self.table.create_columns(["col1", "col2"])
self.writer = DocbookTableWriter(self.stream, self.table, None)
self.writer.set_renderer(DocbookRenderer())
@@ -442,5 +453,5 @@ class DocbookTableWriterTC(TestCase):
self.assertRaises(NotImplementedError, writer.write_table)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_taskqueue.py b/test/test_taskqueue.py
index d8b6a9e..3e49ade 100644
--- a/test/test_taskqueue.py
+++ b/test/test_taskqueue.py
@@ -19,53 +19,53 @@ from logilab.common.testlib import TestCase, unittest_main
from logilab.common.tasksqueue import *
-class TaskTC(TestCase):
+class TaskTC(TestCase):
def test_eq(self):
- self.assertFalse(Task('t1') == Task('t2'))
- self.assertTrue(Task('t1') == Task('t1'))
+ self.assertFalse(Task("t1") == Task("t2"))
+ self.assertTrue(Task("t1") == Task("t1"))
def test_cmp(self):
- self.assertTrue(Task('t1', LOW) < Task('t2', MEDIUM))
- self.assertFalse(Task('t1', LOW) > Task('t2', MEDIUM))
- self.assertTrue(Task('t1', HIGH) > Task('t2', MEDIUM))
- self.assertFalse(Task('t1', HIGH) < Task('t2', MEDIUM))
+ self.assertTrue(Task("t1", LOW) < Task("t2", MEDIUM))
+ self.assertFalse(Task("t1", LOW) > Task("t2", MEDIUM))
+ self.assertTrue(Task("t1", HIGH) > Task("t2", MEDIUM))
+ self.assertFalse(Task("t1", HIGH) < Task("t2", MEDIUM))
class PrioritizedTasksQueueTC(TestCase):
-
def test_priority(self):
queue = PrioritizedTasksQueue()
- queue.put(Task('t1'))
- queue.put(Task('t2', MEDIUM))
- queue.put(Task('t3', HIGH))
- queue.put(Task('t4', LOW))
- self.assertEqual(queue.get().id, 't3')
- self.assertEqual(queue.get().id, 't2')
- self.assertEqual(queue.get().id, 't1')
- self.assertEqual(queue.get().id, 't4')
+ queue.put(Task("t1"))
+ queue.put(Task("t2", MEDIUM))
+ queue.put(Task("t3", HIGH))
+ queue.put(Task("t4", LOW))
+ self.assertEqual(queue.get().id, "t3")
+ self.assertEqual(queue.get().id, "t2")
+ self.assertEqual(queue.get().id, "t1")
+ self.assertEqual(queue.get().id, "t4")
def test_remove_equivalent(self):
queue = PrioritizedTasksQueue()
- queue.put(Task('t1'))
- queue.put(Task('t2', MEDIUM))
- queue.put(Task('t1', HIGH))
- queue.put(Task('t3', MEDIUM))
- queue.put(Task('t2', MEDIUM))
+ queue.put(Task("t1"))
+ queue.put(Task("t2", MEDIUM))
+ queue.put(Task("t1", HIGH))
+ queue.put(Task("t3", MEDIUM))
+ queue.put(Task("t2", MEDIUM))
self.assertEqual(queue.qsize(), 3)
- self.assertEqual(queue.get().id, 't1')
- self.assertEqual(queue.get().id, 't2')
- self.assertEqual(queue.get().id, 't3')
+ self.assertEqual(queue.get().id, "t1")
+ self.assertEqual(queue.get().id, "t2")
+ self.assertEqual(queue.get().id, "t3")
self.assertEqual(queue.qsize(), 0)
def test_remove(self):
queue = PrioritizedTasksQueue()
- queue.put(Task('t1'))
- queue.put(Task('t2'))
- queue.put(Task('t3'))
- queue.remove('t2')
- self.assertEqual([t.id for t in queue], ['t3', 't1'])
- self.assertRaises(ValueError, queue.remove, 't4')
+ queue.put(Task("t1"))
+ queue.put(Task("t2"))
+ queue.put(Task("t3"))
+ queue.remove("t2")
+ self.assertEqual([t.id for t in queue], ["t3", "t1"])
+ self.assertRaises(ValueError, queue.remove, "t4")
+
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_testlib.py b/test/test_testlib.py
index d26f2a6..8f95711 100644
--- a/test/test_testlib.py
+++ b/test/test_testlib.py
@@ -31,10 +31,21 @@ except NameError:
__file__ = sys.argv[0]
from logilab.common.compat import StringIO
-from logilab.common.testlib import (unittest, TestSuite, unittest_main, Tags,
- TestCase, mock_object, create_files, InnerTest, with_tempdir, tag,
- require_version, require_module)
-from logilab.common.pytest import SkipAwareTextTestRunner, NonStrictTestLoader
+from logilab.common.testlib import (
+ unittest,
+ TestSuite,
+ unittest_main,
+ Tags,
+ TestCase,
+ mock_object,
+ create_files,
+ InnerTest,
+ with_tempdir,
+ tag,
+ require_version,
+ require_module,
+)
+from logilab.common.pytest import SkipAwareTextTestRunner, NonStrictTestLoader
class MockTestCase(TestCase):
@@ -45,39 +56,38 @@ class MockTestCase(TestCase):
def fail(self, msg):
raise AssertionError(msg)
-class UtilTC(TestCase):
+class UtilTC(TestCase):
def test_mockobject(self):
- obj = mock_object(foo='bar', baz='bam')
- self.assertEqual(obj.foo, 'bar')
- self.assertEqual(obj.baz, 'bam')
+ obj = mock_object(foo="bar", baz="bam")
+ self.assertEqual(obj.foo, "bar")
+ self.assertEqual(obj.baz, "bam")
def test_create_files(self):
chroot = tempfile.mkdtemp()
path_to = lambda path: join(chroot, path)
dircontent = lambda path: sorted(os.listdir(join(chroot, path)))
try:
- self.assertFalse(isdir(path_to('a/')))
- create_files(['a/b/foo.py', 'a/b/c/', 'a/b/c/d/e.py'], chroot)
+ self.assertFalse(isdir(path_to("a/")))
+ create_files(["a/b/foo.py", "a/b/c/", "a/b/c/d/e.py"], chroot)
# make sure directories exist
- self.assertTrue(isdir(path_to('a')))
- self.assertTrue(isdir(path_to('a/b')))
- self.assertTrue(isdir(path_to('a/b/c')))
- self.assertTrue(isdir(path_to('a/b/c/d')))
+ self.assertTrue(isdir(path_to("a")))
+ self.assertTrue(isdir(path_to("a/b")))
+ self.assertTrue(isdir(path_to("a/b/c")))
+ self.assertTrue(isdir(path_to("a/b/c/d")))
# make sure files exist
- self.assertTrue(isfile(path_to('a/b/foo.py')))
- self.assertTrue(isfile(path_to('a/b/c/d/e.py')))
+ self.assertTrue(isfile(path_to("a/b/foo.py")))
+ self.assertTrue(isfile(path_to("a/b/c/d/e.py")))
# make sure only asked files were created
- self.assertEqual(dircontent('a'), ['b'])
- self.assertEqual(dircontent('a/b'), ['c', 'foo.py'])
- self.assertEqual(dircontent('a/b/c'), ['d'])
- self.assertEqual(dircontent('a/b/c/d'), ['e.py'])
+ self.assertEqual(dircontent("a"), ["b"])
+ self.assertEqual(dircontent("a/b"), ["c", "foo.py"])
+ self.assertEqual(dircontent("a/b/c"), ["d"])
+ self.assertEqual(dircontent("a/b/c/d"), ["e.py"])
finally:
shutil.rmtree(chroot)
class TestlibTC(TestCase):
-
def mkdir(self, path):
if not exists(path):
self._dirs.add(path)
@@ -88,13 +98,13 @@ class TestlibTC(TestCase):
self._dirs = set()
def tearDown(self):
- while(self._dirs):
+ while self._dirs:
shutil.rmtree(self._dirs.pop(), ignore_errors=True)
def test_dict_equals(self):
"""tests TestCase.assertDictEqual"""
- d1 = {'a' : 1, 'b' : 2}
- d2 = {'a' : 1, 'b' : 3}
+ d1 = {"a": 1, "b": 2}
+ d2 = {"a": 1, "b": 3}
d3 = dict(d1)
self.assertRaises(AssertionError, self.tc.assertDictEqual, d1, d2)
self.tc.assertDictEqual(d1, d3)
@@ -112,8 +122,8 @@ class TestlibTC(TestCase):
self.tc.assertListEqual(l3, l1)
def test_equality_for_sets(self):
- s1 = set('ab')
- s2 = set('a')
+ s1 = set("ab")
+ s2 = set("a")
self.assertRaises(AssertionError, self.tc.assertSetEqual, s1, s2)
self.tc.assertSetEqual(s1, s1)
self.tc.assertSetEqual(set(), set())
@@ -123,16 +133,20 @@ class TestlibTC(TestCase):
self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, "toto", 12)
self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, "toto", None)
self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, "toto", None)
- self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, 3.12, u"toto")
- self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, 3.12, u"toto")
- self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, None, u"toto")
- self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, None, u"toto")
- self.tc.assertMultiLineEqual('toto\ntiti', 'toto\ntiti')
- self.tc.assertMultiLineEqual('toto\ntiti', 'toto\ntiti')
- self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, 'toto\ntiti', 'toto\n titi\n')
- self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, 'toto\ntiti', 'toto\n titi\n')
- foo = join(dirname(__file__), 'data', 'foo.txt')
- spam = join(dirname(__file__), 'data', 'spam.txt')
+ self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, 3.12, "toto")
+ self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, 3.12, "toto")
+ self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, None, "toto")
+ self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, None, "toto")
+ self.tc.assertMultiLineEqual("toto\ntiti", "toto\ntiti")
+ self.tc.assertMultiLineEqual("toto\ntiti", "toto\ntiti")
+ self.assertRaises(
+ AssertionError, self.tc.assertMultiLineEqual, "toto\ntiti", "toto\n titi\n"
+ )
+ self.assertRaises(
+ AssertionError, self.tc.assertMultiLineEqual, "toto\ntiti", "toto\n titi\n"
+ )
+ foo = join(dirname(__file__), "data", "foo.txt")
+ spam = join(dirname(__file__), "data", "spam.txt")
with open(foo) as fobj:
text1 = fobj.read()
self.tc.assertMultiLineEqual(text1, text1)
@@ -143,37 +157,41 @@ class TestlibTC(TestCase):
self.assertRaises(AssertionError, self.tc.assertMultiLineEqual, text1, text2)
def test_default_datadir(self):
- expected_datadir = join(dirname(abspath(__file__)), 'data')
+ expected_datadir = join(dirname(abspath(__file__)), "data")
self.assertEqual(self.datadir, expected_datadir)
- self.assertEqual(self.datapath('foo'), join(expected_datadir, 'foo'))
+ self.assertEqual(self.datapath("foo"), join(expected_datadir, "foo"))
def test_multiple_args_datadir(self):
- expected_datadir = join(dirname(abspath(__file__)), 'data')
+ expected_datadir = join(dirname(abspath(__file__)), "data")
self.assertEqual(self.datadir, expected_datadir)
- self.assertEqual(self.datapath('foo', 'bar'), join(expected_datadir, 'foo', 'bar'))
+ self.assertEqual(self.datapath("foo", "bar"), join(expected_datadir, "foo", "bar"))
def test_custom_datadir(self):
class MyTC(TestCase):
- datadir = 'foo'
- def test_1(self): pass
+ datadir = "foo"
+
+ def test_1(self):
+ pass
# class' custom datadir
- tc = MyTC('test_1')
- self.assertEqual(tc.datapath('bar'), join('foo', 'bar'))
+ tc = MyTC("test_1")
+ self.assertEqual(tc.datapath("bar"), join("foo", "bar"))
def test_cached_datadir(self):
"""test datadir is cached on the class"""
+
class MyTC(TestCase):
- def test_1(self): pass
+ def test_1(self):
+ pass
- expected_datadir = join(dirname(abspath(__file__)), 'data')
- tc = MyTC('test_1')
+ expected_datadir = join(dirname(abspath(__file__)), "data")
+ tc = MyTC("test_1")
self.assertEqual(tc.datadir, expected_datadir)
# changing module should not change the datadir
- MyTC.__module__ = 'os'
+ MyTC.__module__ = "os"
self.assertEqual(tc.datadir, expected_datadir)
# even on new instances
- tc2 = MyTC('test_1')
+ tc2 = MyTC("test_1")
self.assertEqual(tc2.datadir, expected_datadir)
def test_is(self):
@@ -198,16 +216,15 @@ class TestlibTC(TestCase):
def test_in(self):
self.assertIn("a", "dsqgaqg")
- obj, seq = 'a', ('toto', "azf", "coin")
+ obj, seq = "a", ("toto", "azf", "coin")
self.assertRaises(AssertionError, self.assertIn, obj, seq)
def test_not_in(self):
- self.assertNotIn('a', ('toto', "azf", "coin"))
- self.assertRaises(AssertionError, self.assertNotIn, 'a', "dsqgaqg")
+ self.assertNotIn("a", ("toto", "azf", "coin"))
+ self.assertRaises(AssertionError, self.assertNotIn, "a", "dsqgaqg")
class GenerativeTestsTC(TestCase):
-
def setUp(self):
output = StringIO()
self.runner = SkipAwareTextTestRunner(stream=output)
@@ -217,7 +234,8 @@ class GenerativeTestsTC(TestCase):
def test_generative(self):
for i in range(10):
yield self.assertEqual, i, i
- result = self.runner.run(FooTC('test_generative'))
+
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 10)
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.errors), 0)
@@ -226,8 +244,9 @@ class GenerativeTestsTC(TestCase):
class FooTC(TestCase):
def test_generative(self):
for i in range(10):
- yield self.assertEqual, i%2, 0
- result = self.runner.run(FooTC('test_generative'))
+ yield self.assertEqual, i % 2, 0
+
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 10)
self.assertEqual(len(result.failures), 5)
self.assertEqual(len(result.errors), 0)
@@ -237,10 +256,10 @@ class GenerativeTestsTC(TestCase):
def test_generative(self):
for i in range(10):
if i == 5:
- raise ValueError('STOP !')
+ raise ValueError("STOP !")
yield self.assertEqual, i, i
- result = self.runner.run(FooTC('test_generative'))
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 5)
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.errors), 1)
@@ -252,8 +271,11 @@ class GenerativeTestsTC(TestCase):
if i == 5:
yield self.ouch
yield self.assertEqual, i, i
- def ouch(self): raise ValueError('stop !')
- result = self.runner.run(FooTC('test_generative'))
+
+ def ouch(self):
+ raise ValueError("stop !")
+
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 11)
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.errors), 1)
@@ -261,12 +283,13 @@ class GenerativeTestsTC(TestCase):
def test_generative_setup(self):
class FooTC(TestCase):
def setUp(self):
- raise ValueError('STOP !')
+ raise ValueError("STOP !")
+
def test_generative(self):
for i in range(10):
yield self.assertEqual, i, i
- result = self.runner.run(FooTC('test_generative'))
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.errors), 1)
@@ -281,9 +304,9 @@ class GenerativeTestsTC(TestCase):
def test_generative(self):
for i in range(10):
- yield InnerTest("check_%s"%i, self.check, i)
+ yield InnerTest("check_%s" % i, self.check, i)
- result = self.runner.run(FooTC('test_generative'))
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 10)
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.errors), 0)
@@ -299,9 +322,9 @@ class GenerativeTestsTC(TestCase):
def test_generative(self):
for i in range(10):
- yield InnerTest("check_%s"%i, self.check, i)
+ yield InnerTest("check_%s" % i, self.check, i)
- result = self.runner.run(FooTC('test_generative'))
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 10)
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.errors), 0)
@@ -317,9 +340,9 @@ class GenerativeTestsTC(TestCase):
def test_generative(self):
for i in range(10):
- yield InnerTest("check_%s"%i, self.check, i)
+ yield InnerTest("check_%s" % i, self.check, i)
- result = self.runner.run(FooTC('test_generative'))
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 10)
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.errors), 1)
@@ -329,28 +352,27 @@ class GenerativeTestsTC(TestCase):
class FooTC(TestCase):
def check(self, val):
if val == 5:
- self.assertEqual(val, val+1)
+ self.assertEqual(val, val + 1)
else:
self.assertEqual(val, val)
def test_generative(self):
for i in range(10):
- yield InnerTest("check_%s"%i, self.check, i)
+ yield InnerTest("check_%s" % i, self.check, i)
- result = self.runner.run(FooTC('test_generative'))
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 10)
self.assertEqual(len(result.failures), 1)
self.assertEqual(len(result.errors), 0)
self.assertEqual(len(result.skipped), 0)
-
def test_generative_outer_failure(self):
class FooTC(TestCase):
def test_generative(self):
self.fail()
yield
- result = self.runner.run(FooTC('test_generative'))
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 0)
self.assertEqual(len(result.failures), 1)
self.assertEqual(len(result.errors), 0)
@@ -359,10 +381,10 @@ class GenerativeTestsTC(TestCase):
def test_generative_outer_skip(self):
class FooTC(TestCase):
def test_generative(self):
- self.skipTest('blah')
+ self.skipTest("blah")
yield
- result = self.runner.run(FooTC('test_generative'))
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 0)
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.errors), 0)
@@ -376,10 +398,16 @@ class ExitFirstTC(TestCase):
def test_failure_exit_first(self):
class FooTC(TestCase):
- def test_1(self): pass
- def test_2(self): assert False
- def test_3(self): pass
- tests = [FooTC('test_1'), FooTC('test_2')]
+ def test_1(self):
+ pass
+
+ def test_2(self):
+ assert False
+
+ def test_3(self):
+ pass
+
+ tests = [FooTC("test_1"), FooTC("test_2")]
result = self.runner.run(TestSuite(tests))
self.assertEqual(result.testsRun, 2)
self.assertEqual(len(result.failures), 1)
@@ -387,10 +415,16 @@ class ExitFirstTC(TestCase):
def test_error_exit_first(self):
class FooTC(TestCase):
- def test_1(self): pass
- def test_2(self): raise ValueError()
- def test_3(self): pass
- tests = [FooTC('test_1'), FooTC('test_2'), FooTC('test_3')]
+ def test_1(self):
+ pass
+
+ def test_2(self):
+ raise ValueError()
+
+ def test_3(self):
+ pass
+
+ tests = [FooTC("test_1"), FooTC("test_2"), FooTC("test_3")]
result = self.runner.run(TestSuite(tests))
self.assertEqual(result.testsRun, 2)
self.assertEqual(len(result.failures), 0)
@@ -401,7 +435,8 @@ class ExitFirstTC(TestCase):
def test_generative(self):
for i in range(10):
yield self.assertTrue, False
- result = self.runner.run(FooTC('test_generative'))
+
+ result = self.runner.run(FooTC("test_generative"))
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.failures), 1)
self.assertEqual(len(result.errors), 0)
@@ -410,17 +445,26 @@ class ExitFirstTC(TestCase):
class TestLoaderTC(TestCase):
## internal classes for test purposes ########
class FooTC(TestCase):
- def test_foo1(self): pass
- def test_foo2(self): pass
- def test_bar1(self): pass
+ def test_foo1(self):
+ pass
+
+ def test_foo2(self):
+ pass
+
+ def test_bar1(self):
+ pass
class BarTC(TestCase):
- def test_bar2(self): pass
+ def test_bar2(self):
+ pass
+
##############################################
def setUp(self):
self.loader = NonStrictTestLoader()
- self.module = TestLoaderTC # mock_object(FooTC=TestLoaderTC.FooTC, BarTC=TestLoaderTC.BarTC)
+ self.module = (
+ TestLoaderTC # mock_object(FooTC=TestLoaderTC.FooTC, BarTC=TestLoaderTC.BarTC)
+ )
self.output = StringIO()
self.runner = SkipAwareTextTestRunner(stream=self.output)
@@ -446,114 +490,172 @@ class TestLoaderTC(TestCase):
self.assertEqual(len(suite1._tests) + len(suite2._tests), 4)
def test_collect_with_classname(self):
- self.assertRunCount('FooTC', self.module, 3)
- self.assertRunCount('BarTC', self.module, 1)
+ self.assertRunCount("FooTC", self.module, 3)
+ self.assertRunCount("BarTC", self.module, 1)
def test_collect_with_classname_and_pattern(self):
- data = [('FooTC.test_foo1', 1), ('FooTC.test_foo', 2), ('FooTC.test_fo', 2),
- ('FooTC.foo1', 1), ('FooTC.foo', 2), ('FooTC.whatever', 0)
- ]
+ data = [
+ ("FooTC.test_foo1", 1),
+ ("FooTC.test_foo", 2),
+ ("FooTC.test_fo", 2),
+ ("FooTC.foo1", 1),
+ ("FooTC.foo", 2),
+ ("FooTC.whatever", 0),
+ ]
for pattern, expected_count in data:
yield self.assertRunCount, pattern, self.module, expected_count
def test_collect_with_pattern(self):
- data = [('test_foo1', 1), ('test_foo', 2), ('test_bar', 2),
- ('foo1', 1), ('foo', 2), ('bar', 2), ('ba', 2),
- ('test', 4), ('ab', 0),
- ]
+ data = [
+ ("test_foo1", 1),
+ ("test_foo", 2),
+ ("test_bar", 2),
+ ("foo1", 1),
+ ("foo", 2),
+ ("bar", 2),
+ ("ba", 2),
+ ("test", 4),
+ ("ab", 0),
+ ]
for pattern, expected_count in data:
yield self.assertRunCount, pattern, self.module, expected_count
def test_testcase_with_custom_metaclass(self):
- class mymetaclass(type): pass
+ class mymetaclass(type):
+ pass
+
class MyMod:
class MyTestCase(TestCase):
__metaclass__ = mymetaclass
- def test_foo1(self): pass
- def test_foo2(self): pass
- def test_bar(self): pass
- data = [('test_foo1', 1), ('test_foo', 2), ('test_bar', 1),
- ('foo1', 1), ('foo', 2), ('bar', 1), ('ba', 1),
- ('test', 3), ('ab', 0),
- ('MyTestCase.test_foo1', 1), ('MyTestCase.test_foo', 2),
- ('MyTestCase.test_fo', 2), ('MyTestCase.foo1', 1),
- ('MyTestCase.foo', 2), ('MyTestCase.whatever', 0)
- ]
+
+ def test_foo1(self):
+ pass
+
+ def test_foo2(self):
+ pass
+
+ def test_bar(self):
+ pass
+
+ data = [
+ ("test_foo1", 1),
+ ("test_foo", 2),
+ ("test_bar", 1),
+ ("foo1", 1),
+ ("foo", 2),
+ ("bar", 1),
+ ("ba", 1),
+ ("test", 3),
+ ("ab", 0),
+ ("MyTestCase.test_foo1", 1),
+ ("MyTestCase.test_foo", 2),
+ ("MyTestCase.test_fo", 2),
+ ("MyTestCase.foo1", 1),
+ ("MyTestCase.foo", 2),
+ ("MyTestCase.whatever", 0),
+ ]
for pattern, expected_count in data:
yield self.assertRunCount, pattern, MyMod, expected_count
def test_collect_everything_and_skipped_patterns(self):
- testdata = [ (['foo1'], 3), (['foo'], 2),
- (['foo', 'bar'], 0), ]
+ testdata = [
+ (["foo1"], 3),
+ (["foo"], 2),
+ (["foo", "bar"], 0),
+ ]
for skipped, expected_count in testdata:
yield self.assertRunCount, None, self.module, expected_count, skipped
def test_collect_specific_pattern_and_skip_some(self):
- testdata = [ ('bar', ['foo1'], 2), ('bar', [], 2),
- ('bar', ['bar'], 0), ]
+ testdata = [
+ ("bar", ["foo1"], 2),
+ ("bar", [], 2),
+ ("bar", ["bar"], 0),
+ ]
for runpattern, skipped, expected_count in testdata:
yield self.assertRunCount, runpattern, self.module, expected_count, skipped
def test_skip_classname(self):
- testdata = [ (['BarTC'], 3), (['FooTC'], 1), ]
+ testdata = [
+ (["BarTC"], 3),
+ (["FooTC"], 1),
+ ]
for skipped, expected_count in testdata:
yield self.assertRunCount, None, self.module, expected_count, skipped
def test_skip_classname_and_specific_collect(self):
- testdata = [ ('bar', ['BarTC'], 1), ('foo', ['FooTC'], 0), ]
+ testdata = [
+ ("bar", ["BarTC"], 1),
+ ("foo", ["FooTC"], 0),
+ ]
for runpattern, skipped, expected_count in testdata:
yield self.assertRunCount, runpattern, self.module, expected_count, skipped
def test_nonregr_dotted_path(self):
- self.assertRunCount('FooTC.test_foo', self.module, 2)
+ self.assertRunCount("FooTC.test_foo", self.module, 2)
def test_inner_tests_selection(self):
class MyMod:
class MyTestCase(TestCase):
- def test_foo(self): pass
+ def test_foo(self):
+ pass
+
def test_foobar(self):
for i in range(5):
- if i%2 == 0:
- yield InnerTest('even', lambda: None)
+ if i % 2 == 0:
+ yield InnerTest("even", lambda: None)
else:
- yield InnerTest('odd', lambda: None)
+ yield InnerTest("odd", lambda: None)
yield lambda: None
# FIXME InnerTest masked by pattern usage
# data = [('foo', 7), ('test_foobar', 6), ('even', 3), ('odd', 2), ]
- data = [('foo', 7), ('test_foobar', 6), ('even', 0), ('odd', 0), ]
+ data = [
+ ("foo", 7),
+ ("test_foobar", 6),
+ ("even", 0),
+ ("odd", 0),
+ ]
for pattern, expected_count in data:
yield self.assertRunCount, pattern, MyMod, expected_count
def test_nonregr_class_skipped_option(self):
class MyMod:
class MyTestCase(TestCase):
- def test_foo(self): pass
- def test_bar(self): pass
+ def test_foo(self):
+ pass
+
+ def test_bar(self):
+ pass
+
class FooTC(TestCase):
- def test_foo(self): pass
- self.assertRunCount('foo', MyMod, 2)
+ def test_foo(self):
+ pass
+
+ self.assertRunCount("foo", MyMod, 2)
self.assertRunCount(None, MyMod, 3)
- self.assertRunCount('foo', MyMod, 1, ['FooTC'])
- self.assertRunCount(None, MyMod, 2, ['FooTC'])
+ self.assertRunCount("foo", MyMod, 1, ["FooTC"])
+ self.assertRunCount(None, MyMod, 2, ["FooTC"])
def test__classes_are_ignored(self):
class MyMod:
class _Base(TestCase):
- def test_1(self): pass
+ def test_1(self):
+ pass
+
class MyTestCase(_Base):
- def test_2(self): pass
+ def test_2(self):
+ pass
+
self.assertRunCount(None, MyMod, 2)
class DecoratorTC(TestCase):
-
@with_tempdir
def test_tmp_dir_normal_1(self):
tempdir = tempfile.gettempdir()
# assert temp directory is empty
- self.assertListEqual(list(os.walk(tempdir)),
- [(tempdir, [], [])])
+ self.assertListEqual(list(os.walk(tempdir)), [(tempdir, [], [])])
witness = []
@@ -575,16 +677,13 @@ class DecoratorTC(TestCase):
self.assertEqual(tempfile.gettempdir(), tempdir)
# assert temp directory is empty
- self.assertListEqual(list(os.walk(tempdir)),
- [(tempdir, [], [])])
+ self.assertListEqual(list(os.walk(tempdir)), [(tempdir, [], [])])
@with_tempdir
def test_tmp_dir_normal_2(self):
tempdir = tempfile.gettempdir()
# assert temp directory is empty
- self.assertListEqual(list(os.walk(tempfile.tempdir)),
- [(tempfile.tempdir, [], [])])
-
+ self.assertListEqual(list(os.walk(tempfile.tempdir)), [(tempfile.tempdir, [], [])])
class WitnessException(Exception):
pass
@@ -606,8 +705,7 @@ class DecoratorTC(TestCase):
self.assertEqual(tempfile.gettempdir(), tempdir)
# assert temp directory is empty
- self.assertListEqual(list(os.walk(tempdir)),
- [(tempdir, [], [])])
+ self.assertListEqual(list(os.walk(tempdir)), [(tempdir, [], [])])
def test_tmpdir_generator(self):
orig_tempdir = tempfile.gettempdir()
@@ -629,37 +727,51 @@ class DecoratorTC(TestCase):
def test_require_version_good(self):
""" should return the same function
"""
- def func() :
+
+ def func():
pass
- sys.version_info = (2, 5, 5, 'final', 4)
+
+ sys.version_info = (2, 5, 5, "final", 4)
current = sys.version_info[:3]
- compare = ('2.4', '2.5', '2.5.4', '2.5.5')
+ compare = ("2.4", "2.5", "2.5.4", "2.5.5")
for version in compare:
decorator = require_version(version)
- self.assertEqual(func, decorator(func), '%s =< %s : function \
- return by the decorator should be the same.' % (version,
- '.'.join([str(element) for element in current])))
+ self.assertEqual(
+ func,
+ decorator(func),
+ "%s =< %s : function \
+ return by the decorator should be the same."
+ % (version, ".".join([str(element) for element in current])),
+ )
def test_require_version_bad(self):
""" should return a different function : skipping test
"""
- def func() :
+
+ def func():
pass
- sys.version_info = (2, 5, 5, 'final', 4)
+
+ sys.version_info = (2, 5, 5, "final", 4)
current = sys.version_info[:3]
- compare = ('2.5.6', '2.6', '2.6.5')
+ compare = ("2.5.6", "2.6", "2.6.5")
for version in compare:
decorator = require_version(version)
- self.assertNotEqual(func, decorator(func), '%s >= %s : function \
- return by the decorator should NOT be the same.'
- % ('.'.join([str(element) for element in current]), version))
+ self.assertNotEqual(
+ func,
+ decorator(func),
+ "%s >= %s : function \
+ return by the decorator should NOT be the same."
+ % (".".join([str(element) for element in current]), version),
+ )
def test_require_version_exception(self):
""" should throw a ValueError exception
"""
- def func() :
+
+ def func():
pass
- compare = ('2.5.a', '2.a', 'azerty')
+
+ compare = ("2.5.a", "2.a", "azerty")
for version in compare:
decorator = require_version(version)
self.assertRaises(ValueError, decorator, func)
@@ -667,122 +779,139 @@ class DecoratorTC(TestCase):
def test_require_module_good(self):
""" should return the same function
"""
- def func() :
+
+ def func():
pass
- module = 'sys'
+
+ module = "sys"
decorator = require_module(module)
- self.assertEqual(func, decorator(func), 'module %s exists : function \
- return by the decorator should be the same.' % module)
+ self.assertEqual(
+ func,
+ decorator(func),
+ "module %s exists : function \
+ return by the decorator should be the same."
+ % module,
+ )
def test_require_module_bad(self):
""" should return a different function : skipping test
"""
- def func() :
+
+ def func():
pass
- modules = ('bla', 'blo', 'bli')
+
+ modules = ("bla", "blo", "bli")
for module in modules:
try:
__import__(module)
pass
except ImportError:
decorator = require_module(module)
- self.assertNotEqual(func, decorator(func), 'module %s does \
+ self.assertNotEqual(
+ func,
+ decorator(func),
+ "module %s does \
not exist : function return by the decorator should \
- NOT be the same.' % module)
+ NOT be the same."
+ % module,
+ )
return
- print('all modules in %s exist. Could not test %s' % (', '.join(modules),
- sys._getframe().f_code.co_name))
+ print(
+ "all modules in %s exist. Could not test %s"
+ % (", ".join(modules), sys._getframe().f_code.co_name)
+ )
-class TagTC(TestCase):
+class TagTC(TestCase):
def setUp(self):
- @tag('testing', 'bob')
+ @tag("testing", "bob")
def bob(a, b, c):
return (a + b) * c
self.func = bob
class TagTestTC(TestCase):
- tags = Tags('one', 'two')
+ tags = Tags("one", "two")
def test_one(self):
self.assertTrue(True)
- @tag('two', 'three')
+ @tag("two", "three")
def test_two(self):
self.assertTrue(True)
- @tag('three', inherit=False)
+ @tag("three", inherit=False)
def test_three(self):
self.assertTrue(True)
+
self.cls = TagTestTC
def test_tag_decorator(self):
bob = self.func
self.assertEqual(bob(2, 3, 7), 35)
- self.assertTrue(hasattr(bob, 'tags'))
- self.assertSetEqual(bob.tags, set(['testing', 'bob']))
+ self.assertTrue(hasattr(bob, "tags"))
+ self.assertSetEqual(bob.tags, set(["testing", "bob"]))
def test_tags_class(self):
tags = self.func.tags
- self.assertTrue(tags['testing'])
- self.assertFalse(tags['Not inside'])
+ self.assertTrue(tags["testing"])
+ self.assertFalse(tags["Not inside"])
def test_tags_match(self):
tags = self.func.tags
- self.assertTrue(tags.match('testing'))
- self.assertFalse(tags.match('other'))
+ self.assertTrue(tags.match("testing"))
+ self.assertFalse(tags.match("other"))
- self.assertFalse(tags.match('testing and coin'))
- self.assertTrue(tags.match('testing or other'))
+ self.assertFalse(tags.match("testing and coin"))
+ self.assertTrue(tags.match("testing or other"))
- self.assertTrue(tags.match('not other'))
+ self.assertTrue(tags.match("not other"))
- self.assertTrue(tags.match('not other or (testing and bibi)'))
- self.assertTrue(tags.match('other or (testing and bob)'))
+ self.assertTrue(tags.match("not other or (testing and bibi)"))
+ self.assertTrue(tags.match("other or (testing and bob)"))
def test_tagged_class(self):
def options(tags):
class Options(object):
tags_pattern = tags
+
return Options()
- tc = self.cls('test_one')
+ tc = self.cls("test_one")
runner = SkipAwareTextTestRunner()
self.assertTrue(runner.does_match_tags(tc.test_one))
self.assertTrue(runner.does_match_tags(tc.test_two))
self.assertTrue(runner.does_match_tags(tc.test_three))
- runner = SkipAwareTextTestRunner(options=options('one'))
+ runner = SkipAwareTextTestRunner(options=options("one"))
self.assertTrue(runner.does_match_tags(tc.test_one))
self.assertTrue(runner.does_match_tags(tc.test_two))
self.assertFalse(runner.does_match_tags(tc.test_three))
- runner = SkipAwareTextTestRunner(options=options('two'))
+ runner = SkipAwareTextTestRunner(options=options("two"))
self.assertTrue(runner.does_match_tags(tc.test_one))
self.assertTrue(runner.does_match_tags(tc.test_two))
self.assertFalse(runner.does_match_tags(tc.test_three))
- runner = SkipAwareTextTestRunner(options=options('three'))
+ runner = SkipAwareTextTestRunner(options=options("three"))
self.assertFalse(runner.does_match_tags(tc.test_one))
self.assertTrue(runner.does_match_tags(tc.test_two))
self.assertTrue(runner.does_match_tags(tc.test_three))
- runner = SkipAwareTextTestRunner(options=options('two or three'))
+ runner = SkipAwareTextTestRunner(options=options("two or three"))
self.assertTrue(runner.does_match_tags(tc.test_one))
self.assertTrue(runner.does_match_tags(tc.test_two))
self.assertTrue(runner.does_match_tags(tc.test_three))
- runner = SkipAwareTextTestRunner(options=options('two and three'))
+ runner = SkipAwareTextTestRunner(options=options("two and three"))
self.assertFalse(runner.does_match_tags(tc.test_one))
self.assertTrue(runner.does_match_tags(tc.test_two))
self.assertFalse(runner.does_match_tags(tc.test_three))
-
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_textutils.py b/test/test_textutils.py
index 3e9a343..8aa2ded 100644
--- a/test/test_textutils.py
+++ b/test/test_textutils.py
@@ -29,245 +29,304 @@ from logilab.common import textutils as tu
from logilab.common.testlib import TestCase, unittest_main
-if linesep != '\n':
+if linesep != "\n":
import re
+
LINE_RGX = re.compile(linesep)
+
def ulines(string):
- return LINE_RGX.sub('\n', string)
+ return LINE_RGX.sub("\n", string)
+
+
else:
+
def ulines(string):
return string
-class NormalizeTextTC(TestCase):
+class NormalizeTextTC(TestCase):
def test_known_values(self):
- self.assertEqual(ulines(tu.normalize_text('''some really malformated
+ self.assertEqual(
+ ulines(
+ tu.normalize_text(
+ """some really malformated
text.
With some times some veeeeeeeeeeeeeeerrrrryyyyyyyyyyyyyyyyyyy loooooooooooooooooooooong linnnnnnnnnnnes
and empty lines!
- ''')),
- '''some really malformated text. With some times some
+ """
+ )
+ ),
+ """some really malformated text. With some times some
veeeeeeeeeeeeeeerrrrryyyyyyyyyyyyyyyyyyy loooooooooooooooooooooong
linnnnnnnnnnnes
-and empty lines!''')
- self.assertMultiLineEqual(ulines(tu.normalize_text('''\
+and empty lines!""",
+ )
+ self.assertMultiLineEqual(
+ ulines(
+ tu.normalize_text(
+ """\
some ReST formated text
=======================
With some times some veeeeeeeeeeeeeeerrrrryyyyyyyyyyyyyyyyyyy loooooooooooooooooooooong linnnnnnnnnnnes
and normal lines!
another paragraph
- ''', rest=True)),
- '''\
+ """,
+ rest=True,
+ )
+ ),
+ """\
some ReST formated text
=======================
With some times some veeeeeeeeeeeeeeerrrrryyyyyyyyyyyyyyyyyyy
loooooooooooooooooooooong linnnnnnnnnnnes
and normal lines!
-another paragraph''')
+another paragraph""",
+ )
def test_nonregr_unsplitable_word(self):
- self.assertEqual(ulines(tu.normalize_text('''petit complement :
+ self.assertEqual(
+ ulines(
+ tu.normalize_text(
+ """petit complement :
http://www.plonefr.net/blog/archive/2005/10/30/tester-la-future-infrastructure-i18n
-''', 80)),
- '''petit complement :
-
-http://www.plonefr.net/blog/archive/2005/10/30/tester-la-future-infrastructure-i18n''')
+""",
+ 80,
+ )
+ ),
+ """petit complement :
+http://www.plonefr.net/blog/archive/2005/10/30/tester-la-future-infrastructure-i18n""",
+ )
def test_nonregr_rest_normalize(self):
- self.assertEqual(ulines(tu.normalize_text("""... Il est donc evident que tout le monde doit lire le compte-rendu de RSH et aller discuter avec les autres si c'est utile ou necessaire.
- """, rest=True)), """... Il est donc evident que tout le monde doit lire le compte-rendu de RSH et
-aller discuter avec les autres si c'est utile ou necessaire.""")
+ self.assertEqual(
+ ulines(
+ tu.normalize_text(
+ """... Il est donc evident que tout le monde doit lire le compte-rendu de RSH et aller discuter avec les autres si c'est utile ou necessaire.
+ """,
+ rest=True,
+ )
+ ),
+ """... Il est donc evident que tout le monde doit lire le compte-rendu de RSH et
+aller discuter avec les autres si c'est utile ou necessaire.""",
+ )
def test_normalize_rest_paragraph(self):
- self.assertEqual(ulines(tu.normalize_rest_paragraph("""**nico**: toto""")),
- """**nico**: toto""")
+ self.assertEqual(
+ ulines(tu.normalize_rest_paragraph("""**nico**: toto""")), """**nico**: toto"""
+ )
def test_normalize_rest_paragraph2(self):
- self.assertEqual(ulines(tu.normalize_rest_paragraph(""".. _tdm: http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Table-des-matieres/.20_adaa41fb-c125-4919-aece-049601e81c8e_0_0.pdf
-.. _extrait: http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Extrait-du-livre/.20_d6eed0be-0d36-4384-be59-2dd09e081012_0_0.pdf""", indent='> ')),
- """> .. _tdm:
+ self.assertEqual(
+ ulines(
+ tu.normalize_rest_paragraph(
+ """.. _tdm: http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Table-des-matieres/.20_adaa41fb-c125-4919-aece-049601e81c8e_0_0.pdf
+.. _extrait: http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Extrait-du-livre/.20_d6eed0be-0d36-4384-be59-2dd09e081012_0_0.pdf""",
+ indent="> ",
+ )
+ ),
+ """> .. _tdm:
> http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Table-des-matieres/.20_adaa41fb-c125-4919-aece-049601e81c8e_0_0.pdf
> .. _extrait:
-> http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Extrait-du-livre/.20_d6eed0be-0d36-4384-be59-2dd09e081012_0_0.pdf""")
+> http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Extrait-du-livre/.20_d6eed0be-0d36-4384-be59-2dd09e081012_0_0.pdf""",
+ )
def test_normalize_paragraph2(self):
- self.assertEqual(ulines(tu.normalize_paragraph(""".. _tdm: http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Table-des-matieres/.20_adaa41fb-c125-4919-aece-049601e81c8e_0_0.pdf
-.. _extrait: http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Extrait-du-livre/.20_d6eed0be-0d36-4384-be59-2dd09e081012_0_0.pdf""", indent='> ')),
- """> .. _tdm:
+ self.assertEqual(
+ ulines(
+ tu.normalize_paragraph(
+ """.. _tdm: http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Table-des-matieres/.20_adaa41fb-c125-4919-aece-049601e81c8e_0_0.pdf
+.. _extrait: http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Extrait-du-livre/.20_d6eed0be-0d36-4384-be59-2dd09e081012_0_0.pdf""",
+ indent="> ",
+ )
+ ),
+ """> .. _tdm:
> http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Table-des-matieres/.20_adaa41fb-c125-4919-aece-049601e81c8e_0_0.pdf
> .. _extrait:
-> http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Extrait-du-livre/.20_d6eed0be-0d36-4384-be59-2dd09e081012_0_0.pdf""")
+> http://www.editions-eni.fr/Livres/Python-Les-fondamentaux-du-langage---La-programmation-pour-les-scientifiques-Extrait-du-livre/.20_d6eed0be-0d36-4384-be59-2dd09e081012_0_0.pdf""",
+ )
class NormalizeParagraphTC(TestCase):
-
def test_known_values(self):
- self.assertEqual(ulines(tu.normalize_text("""This package contains test files shared by the logilab-common package. It isn't
+ self.assertEqual(
+ ulines(
+ tu.normalize_text(
+ """This package contains test files shared by the logilab-common package. It isn't
necessary to install this package unless you want to execute or look at
-the tests.""", indent=' ', line_len=70)),
- """\
+the tests.""",
+ indent=" ",
+ line_len=70,
+ )
+ ),
+ """\
This package contains test files shared by the logilab-common
package. It isn't necessary to install this package unless you want
- to execute or look at the tests.""")
+ to execute or look at the tests.""",
+ )
class GetCsvTC(TestCase):
-
def test_known(self):
- self.assertEqual(tu.splitstrip('a, b,c '), ['a', 'b', 'c'])
+ self.assertEqual(tu.splitstrip("a, b,c "), ["a", "b", "c"])
-class UnitsTC(TestCase):
+class UnitsTC(TestCase):
def setUp(self):
self.units = {
- 'm': 60,
- 'kb': 1024,
- 'mb': 1024*1024,
- }
+ "m": 60,
+ "kb": 1024,
+ "mb": 1024 * 1024,
+ }
def test_empty_base(self):
- self.assertEqual(tu.apply_units('17', {}), 17)
+ self.assertEqual(tu.apply_units("17", {}), 17)
def test_empty_inter(self):
def inter(value):
return int(float(value)) * 2
- result = tu.apply_units('12.4', {}, inter=inter)
+
+ result = tu.apply_units("12.4", {}, inter=inter)
self.assertEqual(result, 12 * 2)
self.assertIsInstance(result, float)
def test_empty_final(self):
# int('12.4') raise value error
- self.assertRaises(ValueError, tu.apply_units, '12.4', {}, final=int)
+ self.assertRaises(ValueError, tu.apply_units, "12.4", {}, final=int)
def test_empty_inter_final(self):
- result = tu.apply_units('12.4', {}, inter=float, final=int)
+ result = tu.apply_units("12.4", {}, inter=float, final=int)
self.assertEqual(result, 12)
self.assertIsInstance(result, int)
def test_blank_base(self):
- result = tu.apply_units(' 42 ', {}, final=int)
+ result = tu.apply_units(" 42 ", {}, final=int)
self.assertEqual(result, 42)
def test_blank_space(self):
- result = tu.apply_units(' 1 337 ', {}, final=int)
+ result = tu.apply_units(" 1 337 ", {}, final=int)
self.assertEqual(result, 1337)
def test_blank_coma(self):
- result = tu.apply_units(' 4,298.42 ', {})
+ result = tu.apply_units(" 4,298.42 ", {})
self.assertEqual(result, 4298.42)
def test_blank_mixed(self):
- result = tu.apply_units('45, 317, 337', {}, final=int)
+ result = tu.apply_units("45, 317, 337", {}, final=int)
self.assertEqual(result, 45317337)
def test_unit_singleunit_singleletter(self):
- result = tu.apply_units('15m', self.units)
- self.assertEqual(result, 15 * self.units['m'])
+ result = tu.apply_units("15m", self.units)
+ self.assertEqual(result, 15 * self.units["m"])
def test_unit_singleunit_multipleletter(self):
- result = tu.apply_units('47KB', self.units)
- self.assertEqual(result, 47 * self.units['kb'])
+ result = tu.apply_units("47KB", self.units)
+ self.assertEqual(result, 47 * self.units["kb"])
def test_unit_singleunit_caseinsensitive(self):
- result = tu.apply_units('47kb', self.units)
- self.assertEqual(result, 47 * self.units['kb'])
+ result = tu.apply_units("47kb", self.units)
+ self.assertEqual(result, 47 * self.units["kb"])
def test_unit_multipleunit(self):
- result = tu.apply_units('47KB 1.5MB', self.units)
- self.assertEqual(result, 47 * self.units['kb'] + 1.5 * self.units['mb'])
+ result = tu.apply_units("47KB 1.5MB", self.units)
+ self.assertEqual(result, 47 * self.units["kb"] + 1.5 * self.units["mb"])
def test_unit_with_blank(self):
- result = tu.apply_units('1 000 KB', self.units)
- self.assertEqual(result, 1000 * self.units['kb'])
+ result = tu.apply_units("1 000 KB", self.units)
+ self.assertEqual(result, 1000 * self.units["kb"])
def test_unit_wrong_input(self):
- self.assertRaises(
- ValueError, tu.apply_units, '', self.units)
- self.assertRaises(
- ValueError, tu.apply_units, 'wrong input', self.units)
- self.assertRaises(
- ValueError, tu.apply_units, 'wrong13 input', self.units)
- self.assertRaises(
- ValueError, tu.apply_units, 'wrong input42', self.units)
+ self.assertRaises(ValueError, tu.apply_units, "", self.units)
+ self.assertRaises(ValueError, tu.apply_units, "wrong input", self.units)
+ self.assertRaises(ValueError, tu.apply_units, "wrong13 input", self.units)
+ self.assertRaises(ValueError, tu.apply_units, "wrong input42", self.units)
with self.assertRaises(ValueError) as cm:
- tu.apply_units('42 cakes', self.units)
- self.assertIn('invalid unit cakes.', str(cm.exception))
+ tu.apply_units("42 cakes", self.units)
+ self.assertIn("invalid unit cakes.", str(cm.exception))
-RGX = re.compile('abcd')
+RGX = re.compile("abcd")
class PrettyMatchTC(TestCase):
-
def test_known(self):
- string = 'hiuherabcdef'
- self.assertEqual(ulines(tu.pretty_match(RGX.search(string), string)),
- 'hiuherabcdef\n ^^^^')
+ string = "hiuherabcdef"
+ self.assertEqual(
+ ulines(tu.pretty_match(RGX.search(string), string)), "hiuherabcdef\n ^^^^"
+ )
+
def test_known_values_1(self):
- rgx = re.compile('(to*)')
- string = 'toto'
+ rgx = re.compile("(to*)")
+ string = "toto"
match = rgx.search(string)
- self.assertEqual(ulines(tu.pretty_match(match, string)), '''toto
-^^''')
+ self.assertEqual(
+ ulines(tu.pretty_match(match, string)),
+ """toto
+^^""",
+ )
def test_known_values_2(self):
- rgx = re.compile('(to*)')
- string = ''' ... ... to to
- ... ... '''
+ rgx = re.compile("(to*)")
+ string = """ ... ... to to
+ ... ... """
match = rgx.search(string)
- self.assertEqual(ulines(tu.pretty_match(match, string)), ''' ... ... to to
+ self.assertEqual(
+ ulines(tu.pretty_match(match, string)),
+ """ ... ... to to
^^
- ... ...''')
-
+ ... ...""",
+ )
class UnquoteTC(TestCase):
def test(self):
- self.assertEqual(tu.unquote('"toto"'), 'toto')
+ self.assertEqual(tu.unquote('"toto"'), "toto")
self.assertEqual(tu.unquote("'l'inenarrable toto'"), "l'inenarrable toto")
self.assertEqual(tu.unquote("no quote"), "no quote")
class ColorizeAnsiTC(TestCase):
def test_known(self):
- self.assertEqual(tu.colorize_ansi('hello', 'blue', 'strike'), '\x1b[9;34mhello\x1b[0m')
- self.assertEqual(tu.colorize_ansi('hello', style='strike, inverse'), '\x1b[9;7mhello\x1b[0m')
- self.assertEqual(tu.colorize_ansi('hello', None, None), 'hello')
- self.assertEqual(tu.colorize_ansi('hello', '', ''), 'hello')
+ self.assertEqual(tu.colorize_ansi("hello", "blue", "strike"), "\x1b[9;34mhello\x1b[0m")
+ self.assertEqual(
+ tu.colorize_ansi("hello", style="strike, inverse"), "\x1b[9;7mhello\x1b[0m"
+ )
+ self.assertEqual(tu.colorize_ansi("hello", None, None), "hello")
+ self.assertEqual(tu.colorize_ansi("hello", "", ""), "hello")
+
def test_raise(self):
- self.assertRaises(KeyError, tu.colorize_ansi, 'hello', 'bleu', None)
- self.assertRaises(KeyError, tu.colorize_ansi, 'hello', None, 'italique')
+ self.assertRaises(KeyError, tu.colorize_ansi, "hello", "bleu", None)
+ self.assertRaises(KeyError, tu.colorize_ansi, "hello", None, "italique")
class UnormalizeTC(TestCase):
def test_unormalize_no_substitute(self):
- data = [(u'\u0153nologie', u'oenologie'),
- (u'\u0152nologie', u'OEnologie'),
- (u'l\xf8to', u'loto'),
- (u'été', u'ete'),
- (u'àèùéïîôêç', u'aeueiioec'),
- (u'ÀÈÙÉÏÎÔÊÇ', u'AEUEIIOEC'),
- (u'\xa0', u' '), # NO-BREAK SPACE managed by NFKD decomposition
- (u'\u0154', u'R'),
- (u'Pointe d\u2019Yves', u"Pointe d'Yves"),
- (u'Bordeaux\u2013Mérignac', u'Bordeaux-Merignac'),
- ]
+ data = [
+ ("\u0153nologie", "oenologie"),
+ ("\u0152nologie", "OEnologie"),
+ ("l\xf8to", "loto"),
+ ("été", "ete"),
+ ("àèùéïîôêç", "aeueiioec"),
+ ("ÀÈÙÉÏÎÔÊÇ", "AEUEIIOEC"),
+ ("\xa0", " "), # NO-BREAK SPACE managed by NFKD decomposition
+ ("\u0154", "R"),
+ ("Pointe d\u2019Yves", "Pointe d'Yves"),
+ ("Bordeaux\u2013Mérignac", "Bordeaux-Merignac"),
+ ]
for input, output in data:
yield self.assertEqual, tu.unormalize(input), output
def test_unormalize_substitute(self):
- self.assertEqual(tu.unormalize(u'ab \u8000 cd', substitute='_'),
- 'ab _ cd')
+ self.assertEqual(tu.unormalize("ab \u8000 cd", substitute="_"), "ab _ cd")
def test_unormalize_backward_compat(self):
- self.assertRaises(ValueError, tu.unormalize, u"\u8000")
- self.assertEqual(tu.unormalize(u"\u8000", substitute=''), u'')
+ self.assertRaises(ValueError, tu.unormalize, "\u8000")
+ self.assertEqual(tu.unormalize("\u8000", substitute=""), "")
def load_tests(loader, tests, ignore):
@@ -275,5 +334,5 @@ def load_tests(loader, tests, ignore):
return tests
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_tree.py b/test/test_tree.py
index ea5af81..58a8ba9 100644
--- a/test/test_tree.py
+++ b/test/test_tree.py
@@ -23,12 +23,14 @@ squeleton generated by /home/syt/bin/py2tests on Jan 20 at 10:43:25
from logilab.common.testlib import TestCase, unittest_main
from logilab.common.tree import *
-tree = ('root', (
- ('child_1_1', (
- ('child_2_1', ()), ('child_2_2', (
- ('child_3_1', ()),
- )))),
- ('child_1_2', (('child_2_3', ()),))))
+tree = (
+ "root",
+ (
+ ("child_1_1", (("child_2_1", ()), ("child_2_2", (("child_3_1", ()),)))),
+ ("child_1_2", (("child_2_3", ()),)),
+ ),
+)
+
def make_tree(tuple):
n = Node(tuple[0])
@@ -36,123 +38,155 @@ def make_tree(tuple):
n.append(make_tree(child))
return n
+
class Node_ClassTest(TestCase):
""" a basic tree node, caracterised by an id"""
+
def setUp(self):
""" called before each test from this class """
self.o = make_tree(tree)
-
def test_flatten(self):
result = [r.id for r in self.o.flatten()]
- expected = ['root', 'child_1_1', 'child_2_1', 'child_2_2', 'child_3_1', 'child_1_2', 'child_2_3']
+ expected = [
+ "root",
+ "child_1_1",
+ "child_2_1",
+ "child_2_2",
+ "child_3_1",
+ "child_1_2",
+ "child_2_3",
+ ]
self.assertListEqual(result, expected)
def test_flatten_with_outlist(self):
resultnodes = []
self.o.flatten(resultnodes)
result = [r.id for r in resultnodes]
- expected = ['root', 'child_1_1', 'child_2_1', 'child_2_2', 'child_3_1', 'child_1_2', 'child_2_3']
+ expected = [
+ "root",
+ "child_1_1",
+ "child_2_1",
+ "child_2_2",
+ "child_3_1",
+ "child_1_2",
+ "child_2_3",
+ ]
self.assertListEqual(result, expected)
-
def test_known_values_remove(self):
"""
remove a child node
"""
- self.o.remove(self.o.get_node_by_id('child_1_1'))
- self.assertRaises(NodeNotFound, self.o.get_node_by_id, 'child_1_1')
+ self.o.remove(self.o.get_node_by_id("child_1_1"))
+ self.assertRaises(NodeNotFound, self.o.get_node_by_id, "child_1_1")
def test_known_values_replace(self):
"""
replace a child node with another
"""
- self.o.replace(self.o.get_node_by_id('child_1_1'), Node('hoho'))
- self.assertRaises(NodeNotFound, self.o.get_node_by_id, 'child_1_1')
- self.assertEqual(self.o.get_node_by_id('hoho'), self.o.children[0])
+ self.o.replace(self.o.get_node_by_id("child_1_1"), Node("hoho"))
+ self.assertRaises(NodeNotFound, self.o.get_node_by_id, "child_1_1")
+ self.assertEqual(self.o.get_node_by_id("hoho"), self.o.children[0])
def test_known_values_get_sibling(self):
"""
return the sibling node that has given id
"""
- self.assertEqual(self.o.children[0].get_sibling('child_1_2'), self.o.children[1], None)
+ self.assertEqual(self.o.children[0].get_sibling("child_1_2"), self.o.children[1], None)
def test_raise_get_sibling_NodeNotFound(self):
- self.assertRaises(NodeNotFound, self.o.children[0].get_sibling, 'houhou')
+ self.assertRaises(NodeNotFound, self.o.children[0].get_sibling, "houhou")
def test_known_values_get_node_by_id(self):
"""
return node in whole hierarchy that has given id
"""
- self.assertEqual(self.o.get_node_by_id('child_1_1'), self.o.children[0])
+ self.assertEqual(self.o.get_node_by_id("child_1_1"), self.o.children[0])
def test_raise_get_node_by_id_NodeNotFound(self):
- self.assertRaises(NodeNotFound, self.o.get_node_by_id, 'houhou')
+ self.assertRaises(NodeNotFound, self.o.get_node_by_id, "houhou")
def test_known_values_get_child_by_id(self):
"""
return child of given id
"""
- self.assertEqual(self.o.get_child_by_id('child_2_1', recurse=1), self.o.children[0].children[0])
+ self.assertEqual(
+ self.o.get_child_by_id("child_2_1", recurse=1), self.o.children[0].children[0]
+ )
def test_raise_get_child_by_id_NodeNotFound(self):
- self.assertRaises(NodeNotFound, self.o.get_child_by_id, nid='child_2_1')
- self.assertRaises(NodeNotFound, self.o.get_child_by_id, 'houhou')
+ self.assertRaises(NodeNotFound, self.o.get_child_by_id, nid="child_2_1")
+ self.assertRaises(NodeNotFound, self.o.get_child_by_id, "houhou")
def test_known_values_get_child_by_path(self):
"""
return child of given path (path is a list of ids)
"""
- self.assertEqual(self.o.get_child_by_path(['root', 'child_1_1', 'child_2_1']), self.o.children[0].children[0])
+ self.assertEqual(
+ self.o.get_child_by_path(["root", "child_1_1", "child_2_1"]),
+ self.o.children[0].children[0],
+ )
def test_raise_get_child_by_path_NodeNotFound(self):
- self.assertRaises(NodeNotFound, self.o.get_child_by_path, ['child_1_1', 'child_2_11'])
+ self.assertRaises(NodeNotFound, self.o.get_child_by_path, ["child_1_1", "child_2_11"])
def test_known_values_depth_down(self):
"""
return depth of this node in the tree
"""
self.assertEqual(self.o.depth_down(), 4)
- self.assertEqual(self.o.get_child_by_id('child_2_1', True).depth_down(), 1)
+ self.assertEqual(self.o.get_child_by_id("child_2_1", True).depth_down(), 1)
def test_known_values_depth(self):
"""
return depth of this node in the tree
"""
self.assertEqual(self.o.depth(), 0)
- self.assertEqual(self.o.get_child_by_id('child_2_1', True).depth(), 2)
+ self.assertEqual(self.o.get_child_by_id("child_2_1", True).depth(), 2)
def test_known_values_width(self):
"""
return depth of this node in the tree
"""
self.assertEqual(self.o.width(), 3)
- self.assertEqual(self.o.get_child_by_id('child_2_1', True).width(), 1)
+ self.assertEqual(self.o.get_child_by_id("child_2_1", True).width(), 1)
def test_known_values_root(self):
"""
return the root node of the tree
"""
- self.assertEqual(self.o.get_child_by_id('child_2_1', True).root(), self.o)
+ self.assertEqual(self.o.get_child_by_id("child_2_1", True).root(), self.o)
def test_known_values_leaves(self):
"""
return a list with all the leaf nodes descendant from this task
"""
- self.assertEqual(self.o.leaves(), [self.o.get_child_by_id('child_2_1', True),
- self.o.get_child_by_id('child_3_1', True),
- self.o.get_child_by_id('child_2_3', True)])
+ self.assertEqual(
+ self.o.leaves(),
+ [
+ self.o.get_child_by_id("child_2_1", True),
+ self.o.get_child_by_id("child_3_1", True),
+ self.o.get_child_by_id("child_2_3", True),
+ ],
+ )
def test_known_values_lineage(self):
- c31 = self.o.get_child_by_id('child_3_1', True)
- self.assertEqual(c31.lineage(), [self.o.get_child_by_id('child_3_1', True),
- self.o.get_child_by_id('child_2_2', True),
- self.o.get_child_by_id('child_1_1', True),
- self.o])
+ c31 = self.o.get_child_by_id("child_3_1", True)
+ self.assertEqual(
+ c31.lineage(),
+ [
+ self.o.get_child_by_id("child_3_1", True),
+ self.o.get_child_by_id("child_2_2", True),
+ self.o.get_child_by_id("child_1_1", True),
+ self.o,
+ ],
+ )
class post_order_list_FunctionTest(TestCase):
""""""
+
def setUp(self):
""" called before each test from this class """
self.o = make_tree(tree)
@@ -162,7 +196,7 @@ class post_order_list_FunctionTest(TestCase):
create a list with tree nodes for which the <filter> function returned true
in a post order foashion
"""
- L = ['child_2_1', 'child_3_1', 'child_2_2', 'child_1_1', 'child_2_3', 'child_1_2', 'root']
+ L = ["child_2_1", "child_3_1", "child_2_2", "child_1_1", "child_2_3", "child_1_2", "root"]
l = [n.id for n in post_order_list(self.o)]
self.assertEqual(l, L, l)
@@ -171,23 +205,26 @@ class post_order_list_FunctionTest(TestCase):
create a list with tree nodes for which the <filter> function returned true
in a post order foashion
"""
+
def filter(node):
- if node.id == 'child_2_2':
+ if node.id == "child_2_2":
return 0
return 1
- L = ['child_2_1', 'child_1_1', 'child_2_3', 'child_1_2', 'root']
+
+ L = ["child_2_1", "child_1_1", "child_2_3", "child_1_2", "root"]
l = [n.id for n in post_order_list(self.o, filter)]
self.assertEqual(l, L, l)
class PostfixedDepthFirstIterator_ClassTest(TestCase):
""""""
+
def setUp(self):
""" called before each test from this class """
self.o = make_tree(tree)
def test_known_values_next(self):
- L = ['child_2_1', 'child_3_1', 'child_2_2', 'child_1_1', 'child_2_3', 'child_1_2', 'root']
+ L = ["child_2_1", "child_3_1", "child_2_2", "child_1_1", "child_2_3", "child_1_2", "root"]
iter = PostfixedDepthFirstIterator(self.o)
o = next(iter)
i = 0
@@ -199,6 +236,7 @@ class PostfixedDepthFirstIterator_ClassTest(TestCase):
class pre_order_list_FunctionTest(TestCase):
""""""
+
def setUp(self):
""" called before each test from this class """
self.o = make_tree(tree)
@@ -208,7 +246,7 @@ class pre_order_list_FunctionTest(TestCase):
create a list with tree nodes for which the <filter> function returned true
in a pre order fashion
"""
- L = ['root', 'child_1_1', 'child_2_1', 'child_2_2', 'child_3_1', 'child_1_2', 'child_2_3']
+ L = ["root", "child_1_1", "child_2_1", "child_2_2", "child_3_1", "child_1_2", "child_2_3"]
l = [n.id for n in pre_order_list(self.o)]
self.assertEqual(l, L, l)
@@ -217,23 +255,26 @@ class pre_order_list_FunctionTest(TestCase):
create a list with tree nodes for which the <filter> function returned true
in a pre order fashion
"""
+
def filter(node):
- if node.id == 'child_2_2':
+ if node.id == "child_2_2":
return 0
return 1
- L = ['root', 'child_1_1', 'child_2_1', 'child_1_2', 'child_2_3']
+
+ L = ["root", "child_1_1", "child_2_1", "child_1_2", "child_2_3"]
l = [n.id for n in pre_order_list(self.o, filter)]
self.assertEqual(l, L, l)
class PrefixedDepthFirstIterator_ClassTest(TestCase):
""""""
+
def setUp(self):
""" called before each test from this class """
self.o = make_tree(tree)
def test_known_values_next(self):
- L = ['root', 'child_1_1', 'child_2_1', 'child_2_2', 'child_3_1', 'child_1_2', 'child_2_3']
+ L = ["root", "child_1_1", "child_2_1", "child_2_2", "child_3_1", "child_1_2", "child_2_3"]
iter = PrefixedDepthFirstIterator(self.o)
o = next(iter)
i = 0
@@ -243,5 +284,5 @@ class PrefixedDepthFirstIterator_ClassTest(TestCase):
i += 1
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_umessage.py b/test/test_umessage.py
index e70e386..7ae13e6 100644
--- a/test/test_umessage.py
+++ b/test/test_umessage.py
@@ -23,35 +23,36 @@ from os.path import join, dirname, abspath
from logilab.common.testlib import TestCase, unittest_main
from logilab.common.umessage import UMessage, decode_QP, message_from_string
-DATA = join(dirname(abspath(__file__)), 'data')
+DATA = join(dirname(abspath(__file__)), "data")
-class UMessageTC(TestCase):
+class UMessageTC(TestCase):
def setUp(self):
if sys.version_info >= (3, 2):
import io
- msg1 = email.message_from_file(io.open(join(DATA, 'test1.msg'), encoding='utf8'))
- msg2 = email.message_from_file(io.open(join(DATA, 'test2.msg'), encoding='utf8'))
+
+ msg1 = email.message_from_file(io.open(join(DATA, "test1.msg"), encoding="utf8"))
+ msg2 = email.message_from_file(io.open(join(DATA, "test2.msg"), encoding="utf8"))
else:
- msg1 = email.message_from_file(open(join(DATA, 'test1.msg')))
- msg2 = email.message_from_file(open(join(DATA, 'test2.msg')))
+ msg1 = email.message_from_file(open(join(DATA, "test1.msg")))
+ msg2 = email.message_from_file(open(join(DATA, "test2.msg")))
self.umessage1 = UMessage(msg1)
self.umessage2 = UMessage(msg2)
def test_get_subject(self):
- subj = self.umessage2.get('Subject')
+ subj = self.umessage2.get("Subject")
self.assertEqual(type(subj), str)
- self.assertEqual(subj, u' LA MER')
+ self.assertEqual(subj, " LA MER")
def test_get_all(self):
- to = self.umessage2.get_all('To')
+ to = self.umessage2.get_all("To")
self.assertEqual(type(to[0]), str)
- self.assertEqual(to, [u'lment accents <alf@logilab.fr>'])
+ self.assertEqual(to, ["lment accents <alf@logilab.fr>"])
def test_get_payload_no_multi(self):
payload = self.umessage1.get_payload()
self.assertEqual(type(payload), str)
-
+
def test_get_payload_decode(self):
msg = """\
MIME-Version: 1.0
@@ -67,26 +68,26 @@ Date: now
dW4gcGV0aXQgY8O2dWNvdQ==
"""
msg = message_from_string(msg)
- self.assertEqual(msg.get_payload(decode=True), u'un petit cucou')
+ self.assertEqual(msg.get_payload(decode=True), "un petit cucou")
def test_decode_QP(self):
- test_line = '=??b?UmFwaGHrbA==?= DUPONT<raphael.dupont@societe.fr>'
+ test_line = "=??b?UmFwaGHrbA==?= DUPONT<raphael.dupont@societe.fr>"
test = decode_QP(test_line)
self.assertEqual(type(test), str)
- self.assertEqual(test, u'Raphal DUPONT<raphael.dupont@societe.fr>')
+ self.assertEqual(test, "Raphal DUPONT<raphael.dupont@societe.fr>")
def test_decode_QP_utf8(self):
- test_line = '=?utf-8?q?o=C3=AEm?= <oim@logilab.fr>'
+ test_line = "=?utf-8?q?o=C3=AEm?= <oim@logilab.fr>"
test = decode_QP(test_line)
self.assertEqual(type(test), str)
- self.assertEqual(test, u'om <oim@logilab.fr>')
+ self.assertEqual(test, "om <oim@logilab.fr>")
def test_decode_QP_ascii(self):
- test_line = 'test <test@logilab.fr>'
+ test_line = "test <test@logilab.fr>"
test = decode_QP(test_line)
self.assertEqual(type(test), str)
- self.assertEqual(test, u'test <test@logilab.fr>')
+ self.assertEqual(test, "test <test@logilab.fr>")
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_ureports_html.py b/test/test_ureports_html.py
index 2298eec..16300c9 100644
--- a/test/test_ureports_html.py
+++ b/test/test_ureports_html.py
@@ -15,31 +15,31 @@
#
# You should have received a copy of the GNU Lesser General Public License along
# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
-'''unit tests for ureports.html_writer
-'''
+"""unit tests for ureports.html_writer
+"""
from utils import WriterTC
from logilab.common.testlib import TestCase, unittest_main
from logilab.common.ureports.html_writer import *
-class HTMLWriterTC(TestCase, WriterTC):
+class HTMLWriterTC(TestCase, WriterTC):
def setUp(self):
self.writer = HTMLWriter(1)
# Section tests ###########################################################
- section_base = '''<div>
+ section_base = """<div>
<h1>Section title</h1>
<p>Section\'s description.
Blabla bla</p></div>
-'''
- section_nested = '''<div>\n<h1>Section title</h1>\n<p>Section\'s description.\nBlabla bla</p><div>\n<h2>Subsection</h2>\n<p>Sub section description</p></div>\n</div>\n'''
+"""
+ section_nested = """<div>\n<h1>Section title</h1>\n<p>Section\'s description.\nBlabla bla</p><div>\n<h2>Subsection</h2>\n<p>Sub section description</p></div>\n</div>\n"""
# List tests ##############################################################
- list_base = '''<ul>\n<li>item1</li>\n<li>item2</li>\n<li>item3</li>\n<li>item4</li>\n</ul>\n'''
+ list_base = """<ul>\n<li>item1</li>\n<li>item2</li>\n<li>item3</li>\n<li>item4</li>\n</ul>\n"""
- nested_list = '''<ul>
+ nested_list = """<ul>
<li><p>blabla<ul>
<li>1</li>
<li>2</li>
@@ -48,16 +48,16 @@ Blabla bla</p></div>
</p></li>
<li>an other point</li>
</ul>
-'''
+"""
# Table tests #############################################################
- table_base = '''<table>\n<tr class="odd">\n<td>head1</td>\n<td>head2</td>\n</tr>\n<tr class="even">\n<td>cell1</td>\n<td>cell2</td>\n</tr>\n</table>\n'''
- field_table = '''<table class="field" id="mytable">\n<tr class="odd">\n<td>f1</td>\n<td>v1</td>\n</tr>\n<tr class="even">\n<td>f22</td>\n<td>v22</td>\n</tr>\n<tr class="odd">\n<td>f333</td>\n<td>v333</td>\n</tr>\n</table>\n'''
- advanced_table = '''<table class="whatever" id="mytable">\n<tr class="header">\n<th>field</th>\n<th>value</th>\n</tr>\n<tr class="even">\n<td>f1</td>\n<td>v1</td>\n</tr>\n<tr class="odd">\n<td>f22</td>\n<td>v22</td>\n</tr>\n<tr class="even">\n<td>f333</td>\n<td>v333</td>\n</tr>\n<tr class="odd">\n<td> <a href="http://www.perdu.com">toi perdu ?</a></td>\n<td>&#160;</td>\n</tr>\n</table>\n'''
-
+ table_base = """<table>\n<tr class="odd">\n<td>head1</td>\n<td>head2</td>\n</tr>\n<tr class="even">\n<td>cell1</td>\n<td>cell2</td>\n</tr>\n</table>\n"""
+ field_table = """<table class="field" id="mytable">\n<tr class="odd">\n<td>f1</td>\n<td>v1</td>\n</tr>\n<tr class="even">\n<td>f22</td>\n<td>v22</td>\n</tr>\n<tr class="odd">\n<td>f333</td>\n<td>v333</td>\n</tr>\n</table>\n"""
+ advanced_table = """<table class="whatever" id="mytable">\n<tr class="header">\n<th>field</th>\n<th>value</th>\n</tr>\n<tr class="even">\n<td>f1</td>\n<td>v1</td>\n</tr>\n<tr class="odd">\n<td>f22</td>\n<td>v22</td>\n</tr>\n<tr class="even">\n<td>f333</td>\n<td>v333</td>\n</tr>\n<tr class="odd">\n<td> <a href="http://www.perdu.com">toi perdu ?</a></td>\n<td>&#160;</td>\n</tr>\n</table>\n"""
# VerbatimText tests ######################################################
- verbatim_base = '''<pre>blablabla</pre>'''
+ verbatim_base = """<pre>blablabla</pre>"""
+
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_ureports_text.py b/test/test_ureports_text.py
index dd39dd8..cc602dc 100644
--- a/test/test_ureports_text.py
+++ b/test/test_ureports_text.py
@@ -15,27 +15,28 @@
#
# You should have received a copy of the GNU Lesser General Public License along
# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
-'''unit tests for ureports.text_writer
-'''
+"""unit tests for ureports.text_writer
+"""
from utils import WriterTC
from logilab.common.testlib import TestCase, unittest_main
from logilab.common.ureports.text_writer import TextWriter
+
class TextWriterTC(TestCase, WriterTC):
def setUp(self):
self.writer = TextWriter()
# Section tests ###########################################################
- section_base = '''
+ section_base = """
Section title
=============
Section\'s description.
Blabla bla
-'''
- section_nested = '''
+"""
+ section_nested = """
Section title
=============
Section\'s description.
@@ -46,38 +47,38 @@ Subsection
Sub section description
-'''
+"""
# List tests ##############################################################
- list_base = '''
+ list_base = """
* item1
* item2
* item3
-* item4'''
+* item4"""
- nested_list = '''
+ nested_list = """
* blabla
- 1
- 2
- 3
-* an other point'''
+* an other point"""
# Table tests #############################################################
- table_base = '''
+ table_base = """
+------+------+
|head1 |head2 |
+------+------+
|cell1 |cell2 |
+------+------+
-'''
- field_table = '''
+"""
+ field_table = """
f1 : v1
f22 : v22
f333: v333
-'''
- advanced_table = '''
+"""
+ advanced_table = """
+---------------+------+
|field |value |
+===============+======+
@@ -90,15 +91,15 @@ f333: v333
|`toi perdu ?`_ | |
+---------------+------+
-'''
-
+"""
# VerbatimText tests ######################################################
- verbatim_base = '''::
+ verbatim_base = """::
blablabla
-'''
+"""
+
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
diff --git a/test/test_xmlutils.py b/test/test_xmlutils.py
index 3d82da9..92a058d 100644
--- a/test/test_xmlutils.py
+++ b/test/test_xmlutils.py
@@ -26,7 +26,7 @@ class ProcessingInstructionDataParsingTest(TestCase):
"""
Tests the parsing of the data of an empty processing instruction.
"""
- pi_data = u" \t \n "
+ pi_data = " \t \n "
data = parse_pi_data(pi_data)
self.assertEqual(data, {})
@@ -35,41 +35,39 @@ class ProcessingInstructionDataParsingTest(TestCase):
Tests the parsing of the data of a simple processing instruction using
double quotes for embedding the value.
"""
- pi_data = u""" \t att="value"\n """
+ pi_data = """ \t att="value"\n """
data = parse_pi_data(pi_data)
- self.assertEqual(data, {u"att": u"value"})
+ self.assertEqual(data, {"att": "value"})
def test_simple_pi_with_simple_quotes(self):
"""
Tests the parsing of the data of a simple processing instruction using
simple quotes for embedding the value.
"""
- pi_data = u""" \t att='value'\n """
+ pi_data = """ \t att='value'\n """
data = parse_pi_data(pi_data)
- self.assertEqual(data, {u"att": u"value"})
+ self.assertEqual(data, {"att": "value"})
def test_complex_pi_with_different_quotes(self):
"""
Tests the parsing of the data of a complex processing instruction using
simple quotes or double quotes for embedding the values.
"""
- pi_data = u""" \t att='value'\n att2="value2" att3='value3'"""
+ pi_data = """ \t att='value'\n att2="value2" att3='value3'"""
data = parse_pi_data(pi_data)
- self.assertEqual(data, {u"att": u"value", u"att2": u"value2",
- u"att3": u"value3"})
+ self.assertEqual(data, {"att": "value", "att2": "value2", "att3": "value3"})
def test_pi_with_non_attribute_data(self):
"""
Tests the parsing of the data of a complex processing instruction
containing non-attribute data.
"""
- pi_data = u""" \t keyword att1="value1" """
+ pi_data = """ \t keyword att1="value1" """
data = parse_pi_data(pi_data)
- self.assertEqual(data, {u"keyword": None, u"att1": u"value1"})
+ self.assertEqual(data, {"keyword": None, "att1": "value1"})
# definitions for automatic unit testing
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest_main()
-
diff --git a/test/utils.py b/test/utils.py
index ca1730e..de9bc23 100644
--- a/test/utils.py
+++ b/test/utils.py
@@ -15,21 +15,24 @@
#
# You should have received a copy of the GNU Lesser General Public License along
# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
-'''unit tests utilities for ureports
-'''
+"""unit tests utilities for ureports
+"""
from __future__ import print_function
import sys
from io import StringIO
+
buffers = [StringIO]
if sys.version_info < (3, 0):
from cStringIO import StringIO as cStringIO
from StringIO import StringIO as pStringIO
+
buffers += [cStringIO, pStringIO]
from logilab.common.ureports.nodes import *
+
class WriterTC:
def _test_output(self, test_id, layout, msg=None):
for buffercls in buffers:
@@ -40,55 +43,53 @@ class WriterTC:
try:
self.assertMultiLineEqual(got, expected)
except:
- print('**** using a %s' % buffer.__class__)
- print('**** got for %s' % test_id)
+ print("**** using a %s" % buffer.__class__)
+ print("**** got for %s" % test_id)
print(got)
- print('**** while expected')
+ print("**** while expected")
print(expected)
- print('****')
+ print("****")
raise
def test_section(self):
- layout = Section('Section title',
- 'Section\'s description.\nBlabla bla')
- self._test_output('section_base', layout)
- layout.append(Section('Subsection', 'Sub section description'))
- self._test_output('section_nested', layout)
+ layout = Section("Section title", "Section's description.\nBlabla bla")
+ self._test_output("section_base", layout)
+ layout.append(Section("Subsection", "Sub section description"))
+ self._test_output("section_nested", layout)
def test_verbatim(self):
- layout = VerbatimText('blablabla')
- self._test_output('verbatim_base', layout)
-
+ layout = VerbatimText("blablabla")
+ self._test_output("verbatim_base", layout)
def test_list(self):
- layout = List(children=('item1', 'item2', 'item3', 'item4'))
- self._test_output('list_base', layout)
+ layout = List(children=("item1", "item2", "item3", "item4"))
+ self._test_output("list_base", layout)
def test_nested_list(self):
- layout = List(children=(Paragraph(("blabla", List(children=('1', "2", "3")))),
- "an other point"))
- self._test_output('nested_list', layout)
-
+ layout = List(
+ children=(Paragraph(("blabla", List(children=("1", "2", "3")))), "an other point")
+ )
+ self._test_output("nested_list", layout)
def test_table(self):
- layout = Table(cols=2, children=('head1', 'head2', 'cell1', 'cell2'))
- self._test_output('table_base', layout)
+ layout = Table(cols=2, children=("head1", "head2", "cell1", "cell2"))
+ self._test_output("table_base", layout)
def test_field_table(self):
- table = Table(cols=2, klass='field', id='mytable')
- for field, value in (('f1', 'v1'), ('f22', 'v22'), ('f333', 'v333')):
+ table = Table(cols=2, klass="field", id="mytable")
+ for field, value in (("f1", "v1"), ("f22", "v22"), ("f333", "v333")):
table.append(Text(field))
table.append(Text(value))
- self._test_output('field_table', table)
+ self._test_output("field_table", table)
def test_advanced_table(self):
- table = Table(cols=2, klass='whatever', id='mytable', rheaders=1)
- for field, value in (('field', 'value'), ('f1', 'v1'), ('f22', 'v22'), ('f333', 'v333')):
+ table = Table(cols=2, klass="whatever", id="mytable", rheaders=1)
+ for field, value in (("field", "value"), ("f1", "v1"), ("f22", "v22"), ("f333", "v333")):
table.append(Text(field))
table.append(Text(value))
- table.append(Link('http://www.perdu.com', 'toi perdu ?'))
- table.append(Text(''))
- self._test_output('advanced_table', table)
+ table.append(Link("http://www.perdu.com", "toi perdu ?"))
+ table.append(Text(""))
+ self._test_output("advanced_table", table)
## def test_image(self):