diff options
author | David Douard <david.douard@logilab.fr> | 2015-03-13 15:18:12 +0100 |
---|---|---|
committer | David Douard <david.douard@logilab.fr> | 2015-03-13 15:18:12 +0100 |
commit | 84ba0c13c480f1e0fb3853caa6bc8ee48dd13178 (patch) | |
tree | 61ef71cc521fdba98a5b496029caa009e346ec88 /logilab | |
parent | b95ae183478e43f8a2229d6cbdfe79e389c0f6e3 (diff) | |
download | logilab-common-84ba0c13c480f1e0fb3853caa6bc8ee48dd13178.tar.gz |
[layout] change the source directory layout
The logilab.common package now lives in a logilab/common directory to make it pip compliant.
Related to #294479.
Diffstat (limited to 'logilab')
47 files changed, 14291 insertions, 0 deletions
diff --git a/logilab/__init__.py b/logilab/__init__.py new file mode 100644 index 0000000..de40ea7 --- /dev/null +++ b/logilab/__init__.py @@ -0,0 +1 @@ +__import__('pkg_resources').declare_namespace(__name__) diff --git a/logilab/common/__init__.py b/logilab/common/__init__.py new file mode 100644 index 0000000..6a7c8b4 --- /dev/null +++ b/logilab/common/__init__.py @@ -0,0 +1,175 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Logilab common library (aka Logilab's extension to the standard library). + +:type STD_BLACKLIST: tuple +:var STD_BLACKLIST: directories ignored by default by the functions in + this package which have to recurse into directories + +:type IGNORED_EXTENSIONS: tuple +:var IGNORED_EXTENSIONS: file extensions that may usually be ignored +""" +__docformat__ = "restructuredtext en" + + +from logilab.common.__pkginfo__ import version as __version__ + +STD_BLACKLIST = ('CVS', '.svn', '.hg', 'debian', 'dist', 'build') + +IGNORED_EXTENSIONS = ('.pyc', '.pyo', '.elc', '~', '.swp', '.orig') + +# set this to False if you've mx DateTime installed but you don't want your db +# adapter to use it (should be set before you got a connection) +USE_MX_DATETIME = True + + +class attrdict(dict): + """A dictionary for which keys are also accessible as attributes.""" + def __getattr__(self, attr): + try: + return self[attr] + except KeyError: + raise AttributeError(attr) + +class dictattr(dict): + def __init__(self, proxy): + self.__proxy = proxy + + def __getitem__(self, attr): + try: + return getattr(self.__proxy, attr) + except AttributeError: + raise KeyError(attr) + +class nullobject(object): + def __repr__(self): + return '<nullobject>' + def __bool__(self): + return False + __nonzero__ = __bool__ + +class tempattr(object): + def __init__(self, obj, attr, value): + self.obj = obj + self.attr = attr + self.value = value + + def __enter__(self): + self.oldvalue = getattr(self.obj, self.attr) + setattr(self.obj, self.attr, self.value) + return self.obj + + def __exit__(self, exctype, value, traceback): + setattr(self.obj, self.attr, self.oldvalue) + + + +# flatten ----- +# XXX move in a specific module and use yield instead +# do not mix flatten and translate +# +# def iterable(obj): +# try: iter(obj) +# except: return False +# return True +# +# def is_string_like(obj): +# try: obj +'' +# except (TypeError, ValueError): return False +# return True +# +#def is_scalar(obj): +# return is_string_like(obj) or not iterable(obj) +# +#def flatten(seq): +# for item in seq: +# if is_scalar(item): +# yield item +# else: +# for subitem in flatten(item): +# yield subitem + +def flatten(iterable, tr_func=None, results=None): + """Flatten a list of list with any level. + + If tr_func is not None, it should be a one argument function that'll be called + on each final element. + + :rtype: list + + >>> flatten([1, [2, 3]]) + [1, 2, 3] + """ + if results is None: + results = [] + for val in iterable: + if isinstance(val, (list, tuple)): + flatten(val, tr_func, results) + elif tr_func is None: + results.append(val) + else: + results.append(tr_func(val)) + return results + + +# XXX is function below still used ? + +def make_domains(lists): + """ + Given a list of lists, return a list of domain for each list to produce all + combinations of possibles values. + + :rtype: list + + Example: + + >>> make_domains(['a', 'b'], ['c','d', 'e']) + [['a', 'b', 'a', 'b', 'a', 'b'], ['c', 'c', 'd', 'd', 'e', 'e']] + """ + from six.moves import range + domains = [] + for iterable in lists: + new_domain = iterable[:] + for i in range(len(domains)): + domains[i] = domains[i]*len(iterable) + if domains: + missing = (len(domains[0]) - len(iterable)) / len(iterable) + i = 0 + for j in range(len(iterable)): + value = iterable[j] + for dummy in range(missing): + new_domain.insert(i, value) + i += 1 + i += 1 + domains.append(new_domain) + return domains + + +# private stuff ################################################################ + +def _handle_blacklist(blacklist, dirnames, filenames): + """remove files/directories in the black list + + dirnames/filenames are usually from os.walk + """ + for norecurs in blacklist: + if norecurs in dirnames: + dirnames.remove(norecurs) + elif norecurs in filenames: + filenames.remove(norecurs) + diff --git a/logilab/common/__pkginfo__.py b/logilab/common/__pkginfo__.py new file mode 100644 index 0000000..55a2cc3 --- /dev/null +++ b/logilab/common/__pkginfo__.py @@ -0,0 +1,57 @@ +# copyright 2003-2014 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""logilab.common packaging information""" +__docformat__ = "restructuredtext en" +import sys +import os + +distname = 'logilab-common' +modname = 'common' +subpackage_of = 'logilab' +subpackage_master = True + +numversion = (0, 63, 2) +version = '.'.join([str(num) for num in numversion]) + +license = 'LGPL' # 2.1 or later +description = "collection of low-level Python packages and modules used by Logilab projects" +web = "http://www.logilab.org/project/%s" % distname +mailinglist = "mailto://python-projects@lists.logilab.org" +author = "Logilab" +author_email = "contact@logilab.fr" + + +from os.path import join +scripts = [join('bin', 'pytest')] +include_dirs = [join('test', 'data')] + +install_requires = [ + 'six >= 1.4.0', + ] +test_require = ['pytz'] + +if sys.version_info < (2, 7): + install_requires.append('unittest2 >= 0.5.1') +if os.name == 'nt': + install_requires.append('colorama') + +classifiers = ["Topic :: Utilities", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 3", + ] diff --git a/logilab/common/cache.py b/logilab/common/cache.py new file mode 100644 index 0000000..11ed137 --- /dev/null +++ b/logilab/common/cache.py @@ -0,0 +1,114 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Cache module, with a least recently used algorithm for the management of the +deletion of entries. + + + + +""" +__docformat__ = "restructuredtext en" + +from threading import Lock + +from logilab.common.decorators import locked + +_marker = object() + +class Cache(dict): + """A dictionary like cache. + + inv: + len(self._usage) <= self.size + len(self.data) <= self.size + """ + + def __init__(self, size=100): + """ Warning : Cache.__init__() != dict.__init__(). + Constructor does not take any arguments beside size. + """ + assert size >= 0, 'cache size must be >= 0 (0 meaning no caching)' + self.size = size + self._usage = [] + self._lock = Lock() + super(Cache, self).__init__() + + def _acquire(self): + self._lock.acquire() + + def _release(self): + self._lock.release() + + def _update_usage(self, key): + if not self._usage: + self._usage.append(key) + elif self._usage[-1] != key: + try: + self._usage.remove(key) + except ValueError: + # we are inserting a new key + # check the size of the dictionary + # and remove the oldest item in the cache + if self.size and len(self._usage) >= self.size: + super(Cache, self).__delitem__(self._usage[0]) + del self._usage[0] + self._usage.append(key) + else: + pass # key is already the most recently used key + + def __getitem__(self, key): + value = super(Cache, self).__getitem__(key) + self._update_usage(key) + return value + __getitem__ = locked(_acquire, _release)(__getitem__) + + def __setitem__(self, key, item): + # Just make sure that size > 0 before inserting a new item in the cache + if self.size > 0: + super(Cache, self).__setitem__(key, item) + self._update_usage(key) + __setitem__ = locked(_acquire, _release)(__setitem__) + + def __delitem__(self, key): + super(Cache, self).__delitem__(key) + self._usage.remove(key) + __delitem__ = locked(_acquire, _release)(__delitem__) + + def clear(self): + super(Cache, self).clear() + self._usage = [] + clear = locked(_acquire, _release)(clear) + + def pop(self, key, default=_marker): + if key in self: + self._usage.remove(key) + #if default is _marker: + # return super(Cache, self).pop(key) + return super(Cache, self).pop(key, default) + pop = locked(_acquire, _release)(pop) + + def popitem(self): + raise NotImplementedError() + + def setdefault(self, key, default=None): + raise NotImplementedError() + + def update(self, other): + raise NotImplementedError() + + diff --git a/logilab/common/changelog.py b/logilab/common/changelog.py new file mode 100644 index 0000000..2fff2ed --- /dev/null +++ b/logilab/common/changelog.py @@ -0,0 +1,238 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Manipulation of upstream change log files. + +The upstream change log files format handled is simpler than the one +often used such as those generated by the default Emacs changelog mode. + +Sample ChangeLog format:: + + Change log for project Yoo + ========================== + + -- + * add a new functionality + + 2002-02-01 -- 0.1.1 + * fix bug #435454 + * fix bug #434356 + + 2002-01-01 -- 0.1 + * initial release + + +There is 3 entries in this change log, one for each released version and one +for the next version (i.e. the current entry). +Each entry contains a set of messages corresponding to changes done in this +release. +All the non empty lines before the first entry are considered as the change +log title. +""" + +__docformat__ = "restructuredtext en" + +import sys +from stat import S_IWRITE + +from six import string_types + +BULLET = '*' +SUBBULLET = '-' +INDENT = ' ' * 4 + +class NoEntry(Exception): + """raised when we are unable to find an entry""" + +class EntryNotFound(Exception): + """raised when we are unable to find a given entry""" + +class Version(tuple): + """simple class to handle soft version number has a tuple while + correctly printing it as X.Y.Z + """ + def __new__(cls, versionstr): + if isinstance(versionstr, string_types): + versionstr = versionstr.strip(' :') # XXX (syt) duh? + parsed = cls.parse(versionstr) + else: + parsed = versionstr + return tuple.__new__(cls, parsed) + + @classmethod + def parse(cls, versionstr): + versionstr = versionstr.strip(' :') + try: + return [int(i) for i in versionstr.split('.')] + except ValueError as ex: + raise ValueError("invalid literal for version '%s' (%s)"%(versionstr, ex)) + + def __str__(self): + return '.'.join([str(i) for i in self]) + +# upstream change log ######################################################### + +class ChangeLogEntry(object): + """a change log entry, i.e. a set of messages associated to a version and + its release date + """ + version_class = Version + + def __init__(self, date=None, version=None, **kwargs): + self.__dict__.update(kwargs) + if version: + self.version = self.version_class(version) + else: + self.version = None + self.date = date + self.messages = [] + + def add_message(self, msg): + """add a new message""" + self.messages.append(([msg], [])) + + def complete_latest_message(self, msg_suite): + """complete the latest added message + """ + if not self.messages: + raise ValueError('unable to complete last message as there is no previous message)') + if self.messages[-1][1]: # sub messages + self.messages[-1][1][-1].append(msg_suite) + else: # message + self.messages[-1][0].append(msg_suite) + + def add_sub_message(self, sub_msg, key=None): + if not self.messages: + raise ValueError('unable to complete last message as there is no previous message)') + if key is None: + self.messages[-1][1].append([sub_msg]) + else: + raise NotImplementedError("sub message to specific key are not implemented yet") + + def write(self, stream=sys.stdout): + """write the entry to file """ + stream.write('%s -- %s\n' % (self.date or '', self.version or '')) + for msg, sub_msgs in self.messages: + stream.write('%s%s %s\n' % (INDENT, BULLET, msg[0])) + stream.write(''.join(msg[1:])) + if sub_msgs: + stream.write('\n') + for sub_msg in sub_msgs: + stream.write('%s%s %s\n' % (INDENT * 2, SUBBULLET, sub_msg[0])) + stream.write(''.join(sub_msg[1:])) + stream.write('\n') + + stream.write('\n\n') + +class ChangeLog(object): + """object representation of a whole ChangeLog file""" + + entry_class = ChangeLogEntry + + def __init__(self, changelog_file, title=''): + self.file = changelog_file + self.title = title + self.additional_content = '' + self.entries = [] + self.load() + + def __repr__(self): + return '<ChangeLog %s at %s (%s entries)>' % (self.file, id(self), + len(self.entries)) + + def add_entry(self, entry): + """add a new entry to the change log""" + self.entries.append(entry) + + def get_entry(self, version='', create=None): + """ return a given changelog entry + if version is omitted, return the current entry + """ + if not self.entries: + if version or not create: + raise NoEntry() + self.entries.append(self.entry_class()) + if not version: + if self.entries[0].version and create is not None: + self.entries.insert(0, self.entry_class()) + return self.entries[0] + version = self.version_class(version) + for entry in self.entries: + if entry.version == version: + return entry + raise EntryNotFound() + + def add(self, msg, create=None): + """add a new message to the latest opened entry""" + entry = self.get_entry(create=create) + entry.add_message(msg) + + def load(self): + """ read a logilab's ChangeLog from file """ + try: + stream = open(self.file) + except IOError: + return + last = None + expect_sub = False + for line in stream.readlines(): + sline = line.strip() + words = sline.split() + # if new entry + if len(words) == 1 and words[0] == '--': + expect_sub = False + last = self.entry_class() + self.add_entry(last) + # if old entry + elif len(words) == 3 and words[1] == '--': + expect_sub = False + last = self.entry_class(words[0], words[2]) + self.add_entry(last) + # if title + elif sline and last is None: + self.title = '%s%s' % (self.title, line) + # if new entry + elif sline and sline[0] == BULLET: + expect_sub = False + last.add_message(sline[1:].strip()) + # if new sub_entry + elif expect_sub and sline and sline[0] == SUBBULLET: + last.add_sub_message(sline[1:].strip()) + # if new line for current entry + elif sline and last.messages: + last.complete_latest_message(line) + else: + expect_sub = True + self.additional_content += line + stream.close() + + def format_title(self): + return '%s\n\n' % self.title.strip() + + def save(self): + """write back change log""" + # filetutils isn't importable in appengine, so import locally + from logilab.common.fileutils import ensure_fs_mode + ensure_fs_mode(self.file, S_IWRITE) + self.write(open(self.file, 'w')) + + def write(self, stream=sys.stdout): + """write changelog to stream""" + stream.write(self.format_title()) + for entry in self.entries: + entry.write(stream) + diff --git a/logilab/common/clcommands.py b/logilab/common/clcommands.py new file mode 100644 index 0000000..4778b99 --- /dev/null +++ b/logilab/common/clcommands.py @@ -0,0 +1,334 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Helper functions to support command line tools providing more than +one command. + +e.g called as "tool command [options] args..." where <options> and <args> are +command'specific +""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +import sys +import logging +from os.path import basename + +from logilab.common.configuration import Configuration +from logilab.common.logging_ext import init_log, get_threshold +from logilab.common.deprecation import deprecated + + +class BadCommandUsage(Exception): + """Raised when an unknown command is used or when a command is not + correctly used (bad options, too much / missing arguments...). + + Trigger display of command usage. + """ + +class CommandError(Exception): + """Raised when a command can't be processed and we want to display it and + exit, without traceback nor usage displayed. + """ + + +# command line access point #################################################### + +class CommandLine(dict): + """Usage: + + >>> LDI = cli.CommandLine('ldi', doc='Logilab debian installer', + version=version, rcfile=RCFILE) + >>> LDI.register(MyCommandClass) + >>> LDI.register(MyOtherCommandClass) + >>> LDI.run(sys.argv[1:]) + + Arguments: + + * `pgm`, the program name, default to `basename(sys.argv[0])` + + * `doc`, a short description of the command line tool + + * `copyright`, additional doc string that will be appended to the generated + doc + + * `version`, version number of string of the tool. If specified, global + --version option will be available. + + * `rcfile`, path to a configuration file. If specified, global --C/--rc-file + option will be available? self.rcfile = rcfile + + * `logger`, logger to propagate to commands, default to + `logging.getLogger(self.pgm))` + """ + def __init__(self, pgm=None, doc=None, copyright=None, version=None, + rcfile=None, logthreshold=logging.ERROR, + check_duplicated_command=True): + if pgm is None: + pgm = basename(sys.argv[0]) + self.pgm = pgm + self.doc = doc + self.copyright = copyright + self.version = version + self.rcfile = rcfile + self.logger = None + self.logthreshold = logthreshold + self.check_duplicated_command = check_duplicated_command + + def register(self, cls, force=False): + """register the given :class:`Command` subclass""" + assert not self.check_duplicated_command or force or not cls.name in self, \ + 'a command %s is already defined' % cls.name + self[cls.name] = cls + return cls + + def run(self, args): + """main command line access point: + * init logging + * handle global options (-h/--help, --version, -C/--rc-file) + * check command + * run command + + Terminate by :exc:`SystemExit` + """ + init_log(debug=True, # so that we use StreamHandler + logthreshold=self.logthreshold, + logformat='%(levelname)s: %(message)s') + try: + arg = args.pop(0) + except IndexError: + self.usage_and_exit(1) + if arg in ('-h', '--help'): + self.usage_and_exit(0) + if self.version is not None and arg in ('--version'): + print(self.version) + sys.exit(0) + rcfile = self.rcfile + if rcfile is not None and arg in ('-C', '--rc-file'): + try: + rcfile = args.pop(0) + arg = args.pop(0) + except IndexError: + self.usage_and_exit(1) + try: + command = self.get_command(arg) + except KeyError: + print('ERROR: no %s command' % arg) + print() + self.usage_and_exit(1) + try: + sys.exit(command.main_run(args, rcfile)) + except KeyboardInterrupt as exc: + print('Interrupted', end=' ') + if str(exc): + print(': %s' % exc, end=' ') + print() + sys.exit(4) + except BadCommandUsage as err: + print('ERROR:', err) + print() + print(command.help()) + sys.exit(1) + + def create_logger(self, handler, logthreshold=None): + logger = logging.Logger(self.pgm) + logger.handlers = [handler] + if logthreshold is None: + logthreshold = get_threshold(self.logthreshold) + logger.setLevel(logthreshold) + return logger + + def get_command(self, cmd, logger=None): + if logger is None: + logger = self.logger + if logger is None: + logger = self.logger = logging.getLogger(self.pgm) + logger.setLevel(get_threshold(self.logthreshold)) + return self[cmd](logger) + + def usage(self): + """display usage for the main program (i.e. when no command supplied) + and exit + """ + print('usage:', self.pgm, end=' ') + if self.rcfile: + print('[--rc-file=<configuration file>]', end=' ') + print('<command> [options] <command argument>...') + if self.doc: + print('\n%s' % self.doc) + print(''' +Type "%(pgm)s <command> --help" for more information about a specific +command. Available commands are :\n''' % self.__dict__) + max_len = max([len(cmd) for cmd in self]) + padding = ' ' * max_len + for cmdname, cmd in sorted(self.items()): + if not cmd.hidden: + print(' ', (cmdname + padding)[:max_len], cmd.short_description()) + if self.rcfile: + print(''' +Use --rc-file=<configuration file> / -C <configuration file> before the command +to specify a configuration file. Default to %s. +''' % self.rcfile) + print('''%(pgm)s -h/--help + display this usage information and exit''' % self.__dict__) + if self.version: + print('''%(pgm)s -v/--version + display version configuration and exit''' % self.__dict__) + if self.copyright: + print('\n', self.copyright) + + def usage_and_exit(self, status): + self.usage() + sys.exit(status) + + +# base command classes ######################################################### + +class Command(Configuration): + """Base class for command line commands. + + Class attributes: + + * `name`, the name of the command + + * `min_args`, minimum number of arguments, None if unspecified + + * `max_args`, maximum number of arguments, None if unspecified + + * `arguments`, string describing arguments, used in command usage + + * `hidden`, boolean flag telling if the command should be hidden, e.g. does + not appear in help's commands list + + * `options`, options list, as allowed by :mod:configuration + """ + + arguments = '' + name = '' + # hidden from help ? + hidden = False + # max/min args, None meaning unspecified + min_args = None + max_args = None + + @classmethod + def description(cls): + return cls.__doc__.replace(' ', '') + + @classmethod + def short_description(cls): + return cls.description().split('.')[0] + + def __init__(self, logger): + usage = '%%prog %s %s\n\n%s' % (self.name, self.arguments, + self.description()) + Configuration.__init__(self, usage=usage) + self.logger = logger + + def check_args(self, args): + """check command's arguments are provided""" + if self.min_args is not None and len(args) < self.min_args: + raise BadCommandUsage('missing argument') + if self.max_args is not None and len(args) > self.max_args: + raise BadCommandUsage('too many arguments') + + def main_run(self, args, rcfile=None): + """Run the command and return status 0 if everything went fine. + + If :exc:`CommandError` is raised by the underlying command, simply log + the error and return status 2. + + Any other exceptions, including :exc:`BadCommandUsage` will be + propagated. + """ + if rcfile: + self.load_file_configuration(rcfile) + args = self.load_command_line_configuration(args) + try: + self.check_args(args) + self.run(args) + except CommandError as err: + self.logger.error(err) + return 2 + return 0 + + def run(self, args): + """run the command with its specific arguments""" + raise NotImplementedError() + + +class ListCommandsCommand(Command): + """list available commands, useful for bash completion.""" + name = 'listcommands' + arguments = '[command]' + hidden = True + + def run(self, args): + """run the command with its specific arguments""" + if args: + command = args.pop() + cmd = _COMMANDS[command] + for optname, optdict in cmd.options: + print('--help') + print('--' + optname) + else: + commands = sorted(_COMMANDS.keys()) + for command in commands: + cmd = _COMMANDS[command] + if not cmd.hidden: + print(command) + + +# deprecated stuff ############################################################# + +_COMMANDS = CommandLine() + +DEFAULT_COPYRIGHT = '''\ +Copyright (c) 2004-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +http://www.logilab.fr/ -- mailto:contact@logilab.fr''' + +@deprecated('use cls.register(cli)') +def register_commands(commands): + """register existing commands""" + for command_klass in commands: + _COMMANDS.register(command_klass) + +@deprecated('use args.pop(0)') +def main_run(args, doc=None, copyright=None, version=None): + """command line tool: run command specified by argument list (without the + program name). Raise SystemExit with status 0 if everything went fine. + + >>> main_run(sys.argv[1:]) + """ + _COMMANDS.doc = doc + _COMMANDS.copyright = copyright + _COMMANDS.version = version + _COMMANDS.run(args) + +@deprecated('use args.pop(0)') +def pop_arg(args_list, expected_size_after=None, msg="Missing argument"): + """helper function to get and check command line arguments""" + try: + value = args_list.pop(0) + except IndexError: + raise BadCommandUsage(msg) + if expected_size_after is not None and len(args_list) > expected_size_after: + raise BadCommandUsage('too many arguments') + return value + diff --git a/logilab/common/cli.py b/logilab/common/cli.py new file mode 100644 index 0000000..cdeef97 --- /dev/null +++ b/logilab/common/cli.py @@ -0,0 +1,211 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Command line interface helper classes. + +It provides some default commands, a help system, a default readline +configuration with completion and persistent history. + +Example:: + + class BookShell(CLIHelper): + + def __init__(self): + # quit and help are builtins + # CMD_MAP keys are commands, values are topics + self.CMD_MAP['pionce'] = _("Sommeil") + self.CMD_MAP['ronfle'] = _("Sommeil") + CLIHelper.__init__(self) + + help_do_pionce = ("pionce", "pionce duree", _("met ton corps en veille")) + def do_pionce(self): + print('nap is good') + + help_do_ronfle = ("ronfle", "ronfle volume", _("met les autres en veille")) + def do_ronfle(self): + print('fuuuuuuuuuuuu rhhhhhrhrhrrh') + + cl = BookShell() +""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +from six.moves import builtins, input + +if not hasattr(builtins, '_'): + builtins._ = str + + +def init_readline(complete_method, histfile=None): + """Init the readline library if available.""" + try: + import readline + readline.parse_and_bind("tab: complete") + readline.set_completer(complete_method) + string = readline.get_completer_delims().replace(':', '') + readline.set_completer_delims(string) + if histfile is not None: + try: + readline.read_history_file(histfile) + except IOError: + pass + import atexit + atexit.register(readline.write_history_file, histfile) + except: + print('readline is not available :-(') + + +class Completer : + """Readline completer.""" + + def __init__(self, commands): + self.list = commands + + def complete(self, text, state): + """Hook called by readline when <tab> is pressed.""" + n = len(text) + matches = [] + for cmd in self.list : + if cmd[:n] == text : + matches.append(cmd) + try: + return matches[state] + except IndexError: + return None + + +class CLIHelper: + """An abstract command line interface client which recognize commands + and provide an help system. + """ + + CMD_MAP = {'help': _("Others"), + 'quit': _("Others"), + } + CMD_PREFIX = '' + + def __init__(self, histfile=None) : + self._topics = {} + self.commands = None + self._completer = Completer(self._register_commands()) + init_readline(self._completer.complete, histfile) + + def run(self): + """loop on user input, exit on EOF""" + while True: + try: + line = input('>>> ') + except EOFError: + print + break + s_line = line.strip() + if not s_line: + continue + args = s_line.split() + if args[0] in self.commands: + try: + cmd = 'do_%s' % self.commands[args[0]] + getattr(self, cmd)(*args[1:]) + except EOFError: + break + except: + import traceback + traceback.print_exc() + else: + try: + self.handle_line(s_line) + except: + import traceback + traceback.print_exc() + + def handle_line(self, stripped_line): + """Method to overload in the concrete class (should handle + lines which are not commands). + """ + raise NotImplementedError() + + + # private methods ######################################################### + + def _register_commands(self): + """ register available commands method and return the list of + commands name + """ + self.commands = {} + self._command_help = {} + commands = [attr[3:] for attr in dir(self) if attr[:3] == 'do_'] + for command in commands: + topic = self.CMD_MAP[command] + help_method = getattr(self, 'help_do_%s' % command) + self._topics.setdefault(topic, []).append(help_method) + self.commands[self.CMD_PREFIX + command] = command + self._command_help[command] = help_method + return self.commands.keys() + + def _print_help(self, cmd, syntax, explanation): + print(_('Command %s') % cmd) + print(_('Syntax: %s') % syntax) + print('\t', explanation) + print() + + + # predefined commands ##################################################### + + def do_help(self, command=None) : + """base input of the help system""" + if command in self._command_help: + self._print_help(*self._command_help[command]) + elif command is None or command not in self._topics: + print(_("Use help <topic> or help <command>.")) + print(_("Available topics are:")) + topics = sorted(self._topics.keys()) + for topic in topics: + print('\t', topic) + print() + print(_("Available commands are:")) + commands = self.commands.keys() + commands.sort() + for command in commands: + print('\t', command[len(self.CMD_PREFIX):]) + + else: + print(_('Available commands about %s:') % command) + print + for command_help_method in self._topics[command]: + try: + if callable(command_help_method): + self._print_help(*command_help_method()) + else: + self._print_help(*command_help_method) + except: + import traceback + traceback.print_exc() + print('ERROR in help method %s'% ( + command_help_method.__name__)) + + help_do_help = ("help", "help [topic|command]", + _("print help message for the given topic/command or \ +available topics when no argument")) + + def do_quit(self): + """quit the CLI""" + raise EOFError() + + def help_do_quit(self): + return ("quit", "quit", _("quit the application")) diff --git a/logilab/common/compat.py b/logilab/common/compat.py new file mode 100644 index 0000000..f2eb590 --- /dev/null +++ b/logilab/common/compat.py @@ -0,0 +1,78 @@ +# pylint: disable=E0601,W0622,W0611 +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Wrappers around some builtins introduced in python 2.3, 2.4 and +2.5, making them available in for earlier versions of python. + +See another compatibility snippets from other projects: + + :mod:`lib2to3.fixes` + :mod:`coverage.backward` + :mod:`unittest2.compatibility` +""" + + +__docformat__ = "restructuredtext en" + +import os +import sys +import types +from warnings import warn + +# not used here, but imported to preserve API +from six.moves import builtins + +if sys.version_info < (3, 0): + str_to_bytes = str + def str_encode(string, encoding): + if isinstance(string, unicode): + return string.encode(encoding) + return str(string) +else: + def str_to_bytes(string): + return str.encode(string) + # we have to ignore the encoding in py3k to be able to write a string into a + # TextIOWrapper or like object (which expect an unicode string) + def str_encode(string, encoding): + return str(string) + +# See also http://bugs.python.org/issue11776 +if sys.version_info[0] == 3: + def method_type(callable, instance, klass): + # api change. klass is no more considered + return types.MethodType(callable, instance) +else: + # alias types otherwise + method_type = types.MethodType + +# Pythons 2 and 3 differ on where to get StringIO +if sys.version_info < (3, 0): + from cStringIO import StringIO + FileIO = file + BytesIO = StringIO + reload = reload +else: + from io import FileIO, BytesIO, StringIO + from imp import reload + +from logilab.common.deprecation import deprecated + +# Other projects import these from here, keep providing them for +# backwards compat +any = deprecated('use builtin "any"')(any) +all = deprecated('use builtin "all"')(all) diff --git a/logilab/common/configuration.py b/logilab/common/configuration.py new file mode 100644 index 0000000..b292427 --- /dev/null +++ b/logilab/common/configuration.py @@ -0,0 +1,1105 @@ +# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Classes to handle advanced configuration in simple to complex applications. + +Allows to load the configuration from a file or from command line +options, to generate a sample configuration file or to display +program's usage. Fills the gap between optik/optparse and ConfigParser +by adding data types (which are also available as a standalone optik +extension in the `optik_ext` module). + + +Quick start: simplest usage +--------------------------- + +.. python :: + + >>> import sys + >>> from logilab.common.configuration import Configuration + >>> options = [('dothis', {'type':'yn', 'default': True, 'metavar': '<y or n>'}), + ... ('value', {'type': 'string', 'metavar': '<string>'}), + ... ('multiple', {'type': 'csv', 'default': ('yop',), + ... 'metavar': '<comma separated values>', + ... 'help': 'you can also document the option'}), + ... ('number', {'type': 'int', 'default':2, 'metavar':'<int>'}), + ... ] + >>> config = Configuration(options=options, name='My config') + >>> print config['dothis'] + True + >>> print config['value'] + None + >>> print config['multiple'] + ('yop',) + >>> print config['number'] + 2 + >>> print config.help() + Usage: [options] + + Options: + -h, --help show this help message and exit + --dothis=<y or n> + --value=<string> + --multiple=<comma separated values> + you can also document the option [current: none] + --number=<int> + + >>> f = open('myconfig.ini', 'w') + >>> f.write('''[MY CONFIG] + ... number = 3 + ... dothis = no + ... multiple = 1,2,3 + ... ''') + >>> f.close() + >>> config.load_file_configuration('myconfig.ini') + >>> print config['dothis'] + False + >>> print config['value'] + None + >>> print config['multiple'] + ['1', '2', '3'] + >>> print config['number'] + 3 + >>> sys.argv = ['mon prog', '--value', 'bacon', '--multiple', '4,5,6', + ... 'nonoptionargument'] + >>> print config.load_command_line_configuration() + ['nonoptionargument'] + >>> print config['value'] + bacon + >>> config.generate_config() + # class for simple configurations which don't need the + # manager / providers model and prefer delegation to inheritance + # + # configuration values are accessible through a dict like interface + # + [MY CONFIG] + + dothis=no + + value=bacon + + # you can also document the option + multiple=4,5,6 + + number=3 + + Note : starting with Python 2.7 ConfigParser is able to take into + account the order of occurrences of the options into a file (by + using an OrderedDict). If you have two options changing some common + state, like a 'disable-all-stuff' and a 'enable-some-stuff-a', their + order of appearance will be significant : the last specified in the + file wins. For earlier version of python and logilab.common newer + than 0.61 the behaviour is unspecified. + +""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +__all__ = ('OptionsManagerMixIn', 'OptionsProviderMixIn', + 'ConfigurationMixIn', 'Configuration', + 'OptionsManager2ConfigurationAdapter') + +import os +import sys +import re +from os.path import exists, expanduser +from copy import copy +from warnings import warn + +from six import string_types +from six.moves import range, configparser as cp, input + +from logilab.common.compat import str_encode as _encode +from logilab.common.deprecation import deprecated +from logilab.common.textutils import normalize_text, unquote +from logilab.common import optik_ext + +OptionError = optik_ext.OptionError + +REQUIRED = [] + +class UnsupportedAction(Exception): + """raised by set_option when it doesn't know what to do for an action""" + + +def _get_encoding(encoding, stream): + encoding = encoding or getattr(stream, 'encoding', None) + if not encoding: + import locale + encoding = locale.getpreferredencoding() + return encoding + + +# validation functions ######################################################## + +# validators will return the validated value or raise optparse.OptionValueError +# XXX add to documentation + +def choice_validator(optdict, name, value): + """validate and return a converted value for option of type 'choice' + """ + if not value in optdict['choices']: + msg = "option %s: invalid value: %r, should be in %s" + raise optik_ext.OptionValueError(msg % (name, value, optdict['choices'])) + return value + +def multiple_choice_validator(optdict, name, value): + """validate and return a converted value for option of type 'choice' + """ + choices = optdict['choices'] + values = optik_ext.check_csv(None, name, value) + for value in values: + if not value in choices: + msg = "option %s: invalid value: %r, should be in %s" + raise optik_ext.OptionValueError(msg % (name, value, choices)) + return values + +def csv_validator(optdict, name, value): + """validate and return a converted value for option of type 'csv' + """ + return optik_ext.check_csv(None, name, value) + +def yn_validator(optdict, name, value): + """validate and return a converted value for option of type 'yn' + """ + return optik_ext.check_yn(None, name, value) + +def named_validator(optdict, name, value): + """validate and return a converted value for option of type 'named' + """ + return optik_ext.check_named(None, name, value) + +def file_validator(optdict, name, value): + """validate and return a filepath for option of type 'file'""" + return optik_ext.check_file(None, name, value) + +def color_validator(optdict, name, value): + """validate and return a valid color for option of type 'color'""" + return optik_ext.check_color(None, name, value) + +def password_validator(optdict, name, value): + """validate and return a string for option of type 'password'""" + return optik_ext.check_password(None, name, value) + +def date_validator(optdict, name, value): + """validate and return a mx DateTime object for option of type 'date'""" + return optik_ext.check_date(None, name, value) + +def time_validator(optdict, name, value): + """validate and return a time object for option of type 'time'""" + return optik_ext.check_time(None, name, value) + +def bytes_validator(optdict, name, value): + """validate and return an integer for option of type 'bytes'""" + return optik_ext.check_bytes(None, name, value) + + +VALIDATORS = {'string': unquote, + 'int': int, + 'float': float, + 'file': file_validator, + 'font': unquote, + 'color': color_validator, + 'regexp': re.compile, + 'csv': csv_validator, + 'yn': yn_validator, + 'bool': yn_validator, + 'named': named_validator, + 'password': password_validator, + 'date': date_validator, + 'time': time_validator, + 'bytes': bytes_validator, + 'choice': choice_validator, + 'multiple_choice': multiple_choice_validator, + } + +def _call_validator(opttype, optdict, option, value): + if opttype not in VALIDATORS: + raise Exception('Unsupported type "%s"' % opttype) + try: + return VALIDATORS[opttype](optdict, option, value) + except TypeError: + try: + return VALIDATORS[opttype](value) + except optik_ext.OptionValueError: + raise + except: + raise optik_ext.OptionValueError('%s value (%r) should be of type %s' % + (option, value, opttype)) + +# user input functions ######################################################## + +# user input functions will ask the user for input on stdin then validate +# the result and return the validated value or raise optparse.OptionValueError +# XXX add to documentation + +def input_password(optdict, question='password:'): + from getpass import getpass + while True: + value = getpass(question) + value2 = getpass('confirm: ') + if value == value2: + return value + print('password mismatch, try again') + +def input_string(optdict, question): + value = input(question).strip() + return value or None + +def _make_input_function(opttype): + def input_validator(optdict, question): + while True: + value = input(question) + if not value.strip(): + return None + try: + return _call_validator(opttype, optdict, None, value) + except optik_ext.OptionValueError as ex: + msg = str(ex).split(':', 1)[-1].strip() + print('bad value: %s' % msg) + return input_validator + +INPUT_FUNCTIONS = { + 'string': input_string, + 'password': input_password, + } + +for opttype in VALIDATORS.keys(): + INPUT_FUNCTIONS.setdefault(opttype, _make_input_function(opttype)) + +# utility functions ############################################################ + +def expand_default(self, option): + """monkey patch OptionParser.expand_default since we have a particular + way to handle defaults to avoid overriding values in the configuration + file + """ + if self.parser is None or not self.default_tag: + return option.help + optname = option._long_opts[0][2:] + try: + provider = self.parser.options_manager._all_options[optname] + except KeyError: + value = None + else: + optdict = provider.get_option_def(optname) + optname = provider.option_attrname(optname, optdict) + value = getattr(provider.config, optname, optdict) + value = format_option_value(optdict, value) + if value is optik_ext.NO_DEFAULT or not value: + value = self.NO_DEFAULT_VALUE + return option.help.replace(self.default_tag, str(value)) + + +def _validate(value, optdict, name=''): + """return a validated value for an option according to its type + + optional argument name is only used for error message formatting + """ + try: + _type = optdict['type'] + except KeyError: + # FIXME + return value + return _call_validator(_type, optdict, name, value) +convert = deprecated('[0.60] convert() was renamed _validate()')(_validate) + +# format and output functions ################################################## + +def comment(string): + """return string as a comment""" + lines = [line.strip() for line in string.splitlines()] + return '# ' + ('%s# ' % os.linesep).join(lines) + +def format_time(value): + if not value: + return '0' + if value != int(value): + return '%.2fs' % value + value = int(value) + nbmin, nbsec = divmod(value, 60) + if nbsec: + return '%ss' % value + nbhour, nbmin_ = divmod(nbmin, 60) + if nbmin_: + return '%smin' % nbmin + nbday, nbhour_ = divmod(nbhour, 24) + if nbhour_: + return '%sh' % nbhour + return '%sd' % nbday + +def format_bytes(value): + if not value: + return '0' + if value != int(value): + return '%.2fB' % value + value = int(value) + prevunit = 'B' + for unit in ('KB', 'MB', 'GB', 'TB'): + next, remain = divmod(value, 1024) + if remain: + return '%s%s' % (value, prevunit) + prevunit = unit + value = next + return '%s%s' % (value, unit) + +def format_option_value(optdict, value): + """return the user input's value from a 'compiled' value""" + if isinstance(value, (list, tuple)): + value = ','.join(value) + elif isinstance(value, dict): + value = ','.join(['%s:%s' % (k, v) for k, v in value.items()]) + elif hasattr(value, 'match'): # optdict.get('type') == 'regexp' + # compiled regexp + value = value.pattern + elif optdict.get('type') == 'yn': + value = value and 'yes' or 'no' + elif isinstance(value, string_types) and value.isspace(): + value = "'%s'" % value + elif optdict.get('type') == 'time' and isinstance(value, (float, int, long)): + value = format_time(value) + elif optdict.get('type') == 'bytes' and hasattr(value, '__int__'): + value = format_bytes(value) + return value + +def ini_format_section(stream, section, options, encoding=None, doc=None): + """format an options section using the INI format""" + encoding = _get_encoding(encoding, stream) + if doc: + print(_encode(comment(doc), encoding), file=stream) + print('[%s]' % section, file=stream) + ini_format(stream, options, encoding) + +def ini_format(stream, options, encoding): + """format options using the INI format""" + for optname, optdict, value in options: + value = format_option_value(optdict, value) + help = optdict.get('help') + if help: + help = normalize_text(help, line_len=79, indent='# ') + print(file=stream) + print(_encode(help, encoding), file=stream) + else: + print(file=stream) + if value is None: + print('#%s=' % optname, file=stream) + else: + value = _encode(value, encoding).strip() + print('%s=%s' % (optname, value), file=stream) + +format_section = ini_format_section + +def rest_format_section(stream, section, options, encoding=None, doc=None): + """format an options section using as ReST formatted output""" + encoding = _get_encoding(encoding, stream) + if section: + print('%s\n%s' % (section, "'"*len(section)), file=stream) + if doc: + print(_encode(normalize_text(doc, line_len=79, indent=''), encoding), file=stream) + print(file=stream) + for optname, optdict, value in options: + help = optdict.get('help') + print(':%s:' % optname, file=stream) + if help: + help = normalize_text(help, line_len=79, indent=' ') + print(_encode(help, encoding), file=stream) + if value: + value = _encode(format_option_value(optdict, value), encoding) + print(file=stream) + print(' Default: ``%s``' % value.replace("`` ", "```` ``"), file=stream) + +# Options Manager ############################################################## + +class OptionsManagerMixIn(object): + """MixIn to handle a configuration from both a configuration file and + command line options + """ + + def __init__(self, usage, config_file=None, version=None, quiet=0): + self.config_file = config_file + self.reset_parsers(usage, version=version) + # list of registered options providers + self.options_providers = [] + # dictionary associating option name to checker + self._all_options = {} + self._short_options = {} + self._nocallback_options = {} + self._mygroups = dict() + # verbosity + self.quiet = quiet + self._maxlevel = 0 + + def reset_parsers(self, usage='', version=None): + # configuration file parser + self.cfgfile_parser = cp.ConfigParser() + # command line parser + self.cmdline_parser = optik_ext.OptionParser(usage=usage, version=version) + self.cmdline_parser.options_manager = self + self._optik_option_attrs = set(self.cmdline_parser.option_class.ATTRS) + + def register_options_provider(self, provider, own_group=True): + """register an options provider""" + assert provider.priority <= 0, "provider's priority can't be >= 0" + for i in range(len(self.options_providers)): + if provider.priority > self.options_providers[i].priority: + self.options_providers.insert(i, provider) + break + else: + self.options_providers.append(provider) + non_group_spec_options = [option for option in provider.options + if 'group' not in option[1]] + groups = getattr(provider, 'option_groups', ()) + if own_group and non_group_spec_options: + self.add_option_group(provider.name.upper(), provider.__doc__, + non_group_spec_options, provider) + else: + for opt, optdict in non_group_spec_options: + self.add_optik_option(provider, self.cmdline_parser, opt, optdict) + for gname, gdoc in groups: + gname = gname.upper() + goptions = [option for option in provider.options + if option[1].get('group', '').upper() == gname] + self.add_option_group(gname, gdoc, goptions, provider) + + def add_option_group(self, group_name, doc, options, provider): + """add an option group including the listed options + """ + assert options + # add option group to the command line parser + if group_name in self._mygroups: + group = self._mygroups[group_name] + else: + group = optik_ext.OptionGroup(self.cmdline_parser, + title=group_name.capitalize()) + self.cmdline_parser.add_option_group(group) + group.level = provider.level + self._mygroups[group_name] = group + # add section to the config file + if group_name != "DEFAULT": + self.cfgfile_parser.add_section(group_name) + # add provider's specific options + for opt, optdict in options: + self.add_optik_option(provider, group, opt, optdict) + + def add_optik_option(self, provider, optikcontainer, opt, optdict): + if 'inputlevel' in optdict: + warn('[0.50] "inputlevel" in option dictionary for %s is deprecated,' + ' use "level"' % opt, DeprecationWarning) + optdict['level'] = optdict.pop('inputlevel') + args, optdict = self.optik_option(provider, opt, optdict) + option = optikcontainer.add_option(*args, **optdict) + self._all_options[opt] = provider + self._maxlevel = max(self._maxlevel, option.level or 0) + + def optik_option(self, provider, opt, optdict): + """get our personal option definition and return a suitable form for + use with optik/optparse + """ + optdict = copy(optdict) + others = {} + if 'action' in optdict: + self._nocallback_options[provider] = opt + else: + optdict['action'] = 'callback' + optdict['callback'] = self.cb_set_provider_option + # default is handled here and *must not* be given to optik if you + # want the whole machinery to work + if 'default' in optdict: + if ('help' in optdict + and optdict.get('default') is not None + and not optdict['action'] in ('store_true', 'store_false')): + optdict['help'] += ' [current: %default]' + del optdict['default'] + args = ['--' + str(opt)] + if 'short' in optdict: + self._short_options[optdict['short']] = opt + args.append('-' + optdict['short']) + del optdict['short'] + # cleanup option definition dict before giving it to optik + for key in list(optdict.keys()): + if not key in self._optik_option_attrs: + optdict.pop(key) + return args, optdict + + def cb_set_provider_option(self, option, opt, value, parser): + """optik callback for option setting""" + if opt.startswith('--'): + # remove -- on long option + opt = opt[2:] + else: + # short option, get its long equivalent + opt = self._short_options[opt[1:]] + # trick since we can't set action='store_true' on options + if value is None: + value = 1 + self.global_set_option(opt, value) + + def global_set_option(self, opt, value): + """set option on the correct option provider""" + self._all_options[opt].set_option(opt, value) + + def generate_config(self, stream=None, skipsections=(), encoding=None): + """write a configuration file according to the current configuration + into the given stream or stdout + """ + options_by_section = {} + sections = [] + for provider in self.options_providers: + for section, options in provider.options_by_section(): + if section is None: + section = provider.name + if section in skipsections: + continue + options = [(n, d, v) for (n, d, v) in options + if d.get('type') is not None] + if not options: + continue + if not section in sections: + sections.append(section) + alloptions = options_by_section.setdefault(section, []) + alloptions += options + stream = stream or sys.stdout + encoding = _get_encoding(encoding, stream) + printed = False + for section in sections: + if printed: + print('\n', file=stream) + format_section(stream, section.upper(), options_by_section[section], + encoding) + printed = True + + def generate_manpage(self, pkginfo, section=1, stream=None): + """write a man page for the current configuration into the given + stream or stdout + """ + self._monkeypatch_expand_default() + try: + optik_ext.generate_manpage(self.cmdline_parser, pkginfo, + section, stream=stream or sys.stdout, + level=self._maxlevel) + finally: + self._unmonkeypatch_expand_default() + + # initialization methods ################################################## + + def load_provider_defaults(self): + """initialize configuration using default values""" + for provider in self.options_providers: + provider.load_defaults() + + def load_file_configuration(self, config_file=None): + """load the configuration from file""" + self.read_config_file(config_file) + self.load_config_file() + + def read_config_file(self, config_file=None): + """read the configuration file but do not load it (i.e. dispatching + values to each options provider) + """ + helplevel = 1 + while helplevel <= self._maxlevel: + opt = '-'.join(['long'] * helplevel) + '-help' + if opt in self._all_options: + break # already processed + def helpfunc(option, opt, val, p, level=helplevel): + print(self.help(level)) + sys.exit(0) + helpmsg = '%s verbose help.' % ' '.join(['more'] * helplevel) + optdict = {'action' : 'callback', 'callback' : helpfunc, + 'help' : helpmsg} + provider = self.options_providers[0] + self.add_optik_option(provider, self.cmdline_parser, opt, optdict) + provider.options += ( (opt, optdict), ) + helplevel += 1 + if config_file is None: + config_file = self.config_file + if config_file is not None: + config_file = expanduser(config_file) + if config_file and exists(config_file): + parser = self.cfgfile_parser + parser.read([config_file]) + # normalize sections'title + for sect, values in parser._sections.items(): + if not sect.isupper() and values: + parser._sections[sect.upper()] = values + elif not self.quiet: + msg = 'No config file found, using default configuration' + print(msg, file=sys.stderr) + return + + def input_config(self, onlysection=None, inputlevel=0, stream=None): + """interactively get configuration values by asking to the user and generate + a configuration file + """ + if onlysection is not None: + onlysection = onlysection.upper() + for provider in self.options_providers: + for section, option, optdict in provider.all_options(): + if onlysection is not None and section != onlysection: + continue + if not 'type' in optdict: + # ignore action without type (callback, store_true...) + continue + provider.input_option(option, optdict, inputlevel) + # now we can generate the configuration file + if stream is not None: + self.generate_config(stream) + + def load_config_file(self): + """dispatch values previously read from a configuration file to each + options provider) + """ + parser = self.cfgfile_parser + for section in parser.sections(): + for option, value in parser.items(section): + try: + self.global_set_option(option, value) + except (KeyError, OptionError): + # TODO handle here undeclared options appearing in the config file + continue + + def load_configuration(self, **kwargs): + """override configuration according to given parameters + """ + for opt, opt_value in kwargs.items(): + opt = opt.replace('_', '-') + provider = self._all_options[opt] + provider.set_option(opt, opt_value) + + def load_command_line_configuration(self, args=None): + """override configuration according to command line parameters + + return additional arguments + """ + self._monkeypatch_expand_default() + try: + if args is None: + args = sys.argv[1:] + else: + args = list(args) + (options, args) = self.cmdline_parser.parse_args(args=args) + for provider in self._nocallback_options.keys(): + config = provider.config + for attr in config.__dict__.keys(): + value = getattr(options, attr, None) + if value is None: + continue + setattr(config, attr, value) + return args + finally: + self._unmonkeypatch_expand_default() + + + # help methods ############################################################ + + def add_help_section(self, title, description, level=0): + """add a dummy option section for help purpose """ + group = optik_ext.OptionGroup(self.cmdline_parser, + title=title.capitalize(), + description=description) + group.level = level + self._maxlevel = max(self._maxlevel, level) + self.cmdline_parser.add_option_group(group) + + def _monkeypatch_expand_default(self): + # monkey patch optik_ext to deal with our default values + try: + self.__expand_default_backup = optik_ext.HelpFormatter.expand_default + optik_ext.HelpFormatter.expand_default = expand_default + except AttributeError: + # python < 2.4: nothing to be done + pass + def _unmonkeypatch_expand_default(self): + # remove monkey patch + if hasattr(optik_ext.HelpFormatter, 'expand_default'): + # unpatch optik_ext to avoid side effects + optik_ext.HelpFormatter.expand_default = self.__expand_default_backup + + def help(self, level=0): + """return the usage string for available options """ + self.cmdline_parser.formatter.output_level = level + self._monkeypatch_expand_default() + try: + return self.cmdline_parser.format_help() + finally: + self._unmonkeypatch_expand_default() + + +class Method(object): + """used to ease late binding of default method (so you can define options + on the class using default methods on the configuration instance) + """ + def __init__(self, methname): + self.method = methname + self._inst = None + + def bind(self, instance): + """bind the method to its instance""" + if self._inst is None: + self._inst = instance + + def __call__(self, *args, **kwargs): + assert self._inst, 'unbound method' + return getattr(self._inst, self.method)(*args, **kwargs) + +# Options Provider ############################################################# + +class OptionsProviderMixIn(object): + """Mixin to provide options to an OptionsManager""" + + # those attributes should be overridden + priority = -1 + name = 'default' + options = () + level = 0 + + def __init__(self): + self.config = optik_ext.Values() + for option in self.options: + try: + option, optdict = option + except ValueError: + raise Exception('Bad option: %r' % option) + if isinstance(optdict.get('default'), Method): + optdict['default'].bind(self) + elif isinstance(optdict.get('callback'), Method): + optdict['callback'].bind(self) + self.load_defaults() + + def load_defaults(self): + """initialize the provider using default values""" + for opt, optdict in self.options: + action = optdict.get('action') + if action != 'callback': + # callback action have no default + default = self.option_default(opt, optdict) + if default is REQUIRED: + continue + self.set_option(opt, default, action, optdict) + + def option_default(self, opt, optdict=None): + """return the default value for an option""" + if optdict is None: + optdict = self.get_option_def(opt) + default = optdict.get('default') + if callable(default): + default = default() + return default + + def option_attrname(self, opt, optdict=None): + """get the config attribute corresponding to opt + """ + if optdict is None: + optdict = self.get_option_def(opt) + return optdict.get('dest', opt.replace('-', '_')) + option_name = deprecated('[0.60] OptionsProviderMixIn.option_name() was renamed to option_attrname()')(option_attrname) + + def option_value(self, opt): + """get the current value for the given option""" + return getattr(self.config, self.option_attrname(opt), None) + + def set_option(self, opt, value, action=None, optdict=None): + """method called to set an option (registered in the options list) + """ + if optdict is None: + optdict = self.get_option_def(opt) + if value is not None: + value = _validate(value, optdict, opt) + if action is None: + action = optdict.get('action', 'store') + if optdict.get('type') == 'named': # XXX need specific handling + optname = self.option_attrname(opt, optdict) + currentvalue = getattr(self.config, optname, None) + if currentvalue: + currentvalue.update(value) + value = currentvalue + if action == 'store': + setattr(self.config, self.option_attrname(opt, optdict), value) + elif action in ('store_true', 'count'): + setattr(self.config, self.option_attrname(opt, optdict), 0) + elif action == 'store_false': + setattr(self.config, self.option_attrname(opt, optdict), 1) + elif action == 'append': + opt = self.option_attrname(opt, optdict) + _list = getattr(self.config, opt, None) + if _list is None: + if isinstance(value, (list, tuple)): + _list = value + elif value is not None: + _list = [] + _list.append(value) + setattr(self.config, opt, _list) + elif isinstance(_list, tuple): + setattr(self.config, opt, _list + (value,)) + else: + _list.append(value) + elif action == 'callback': + optdict['callback'](None, opt, value, None) + else: + raise UnsupportedAction(action) + + def input_option(self, option, optdict, inputlevel=99): + default = self.option_default(option, optdict) + if default is REQUIRED: + defaultstr = '(required): ' + elif optdict.get('level', 0) > inputlevel: + return + elif optdict['type'] == 'password' or default is None: + defaultstr = ': ' + else: + defaultstr = '(default: %s): ' % format_option_value(optdict, default) + print(':%s:' % option) + print(optdict.get('help') or option) + inputfunc = INPUT_FUNCTIONS[optdict['type']] + value = inputfunc(optdict, defaultstr) + while default is REQUIRED and not value: + print('please specify a value') + value = inputfunc(optdict, '%s: ' % option) + if value is None and default is not None: + value = default + self.set_option(option, value, optdict=optdict) + + def get_option_def(self, opt): + """return the dictionary defining an option given it's name""" + assert self.options + for option in self.options: + if option[0] == opt: + return option[1] + raise OptionError('no such option %s in section %r' + % (opt, self.name), opt) + + + def all_options(self): + """return an iterator on available options for this provider + option are actually described by a 3-uple: + (section, option name, option dictionary) + """ + for section, options in self.options_by_section(): + if section is None: + if self.name is None: + continue + section = self.name.upper() + for option, optiondict, value in options: + yield section, option, optiondict + + def options_by_section(self): + """return an iterator on options grouped by section + + (section, [list of (optname, optdict, optvalue)]) + """ + sections = {} + for optname, optdict in self.options: + sections.setdefault(optdict.get('group'), []).append( + (optname, optdict, self.option_value(optname))) + if None in sections: + yield None, sections.pop(None) + for section, options in sections.items(): + yield section.upper(), options + + def options_and_values(self, options=None): + if options is None: + options = self.options + for optname, optdict in options: + yield (optname, optdict, self.option_value(optname)) + +# configuration ################################################################ + +class ConfigurationMixIn(OptionsManagerMixIn, OptionsProviderMixIn): + """basic mixin for simple configurations which don't need the + manager / providers model + """ + def __init__(self, *args, **kwargs): + if not args: + kwargs.setdefault('usage', '') + kwargs.setdefault('quiet', 1) + OptionsManagerMixIn.__init__(self, *args, **kwargs) + OptionsProviderMixIn.__init__(self) + if not getattr(self, 'option_groups', None): + self.option_groups = [] + for option, optdict in self.options: + try: + gdef = (optdict['group'].upper(), '') + except KeyError: + continue + if not gdef in self.option_groups: + self.option_groups.append(gdef) + self.register_options_provider(self, own_group=False) + + def register_options(self, options): + """add some options to the configuration""" + options_by_group = {} + for optname, optdict in options: + options_by_group.setdefault(optdict.get('group', self.name.upper()), []).append((optname, optdict)) + for group, options in options_by_group.items(): + self.add_option_group(group, None, options, self) + self.options += tuple(options) + + def load_defaults(self): + OptionsProviderMixIn.load_defaults(self) + + def __iter__(self): + return iter(self.config.__dict__.iteritems()) + + def __getitem__(self, key): + try: + return getattr(self.config, self.option_attrname(key)) + except (optik_ext.OptionValueError, AttributeError): + raise KeyError(key) + + def __setitem__(self, key, value): + self.set_option(key, value) + + def get(self, key, default=None): + try: + return getattr(self.config, self.option_attrname(key)) + except (OptionError, AttributeError): + return default + + +class Configuration(ConfigurationMixIn): + """class for simple configurations which don't need the + manager / providers model and prefer delegation to inheritance + + configuration values are accessible through a dict like interface + """ + + def __init__(self, config_file=None, options=None, name=None, + usage=None, doc=None, version=None): + if options is not None: + self.options = options + if name is not None: + self.name = name + if doc is not None: + self.__doc__ = doc + super(Configuration, self).__init__(config_file=config_file, usage=usage, version=version) + + +class OptionsManager2ConfigurationAdapter(object): + """Adapt an option manager to behave like a + `logilab.common.configuration.Configuration` instance + """ + def __init__(self, provider): + self.config = provider + + def __getattr__(self, key): + return getattr(self.config, key) + + def __getitem__(self, key): + provider = self.config._all_options[key] + try: + return getattr(provider.config, provider.option_attrname(key)) + except AttributeError: + raise KeyError(key) + + def __setitem__(self, key, value): + self.config.global_set_option(self.config.option_attrname(key), value) + + def get(self, key, default=None): + provider = self.config._all_options[key] + try: + return getattr(provider.config, provider.option_attrname(key)) + except AttributeError: + return default + +# other functions ############################################################## + +def read_old_config(newconfig, changes, configfile): + """initialize newconfig from a deprecated configuration file + + possible changes: + * ('renamed', oldname, newname) + * ('moved', option, oldgroup, newgroup) + * ('typechanged', option, oldtype, newvalue) + """ + # build an index of changes + changesindex = {} + for action in changes: + if action[0] == 'moved': + option, oldgroup, newgroup = action[1:] + changesindex.setdefault(option, []).append((action[0], oldgroup, newgroup)) + continue + if action[0] == 'renamed': + oldname, newname = action[1:] + changesindex.setdefault(newname, []).append((action[0], oldname)) + continue + if action[0] == 'typechanged': + option, oldtype, newvalue = action[1:] + changesindex.setdefault(option, []).append((action[0], oldtype, newvalue)) + continue + if action[1] in ('added', 'removed'): + continue # nothing to do here + raise Exception('unknown change %s' % action[0]) + # build a config object able to read the old config + options = [] + for optname, optdef in newconfig.options: + for action in changesindex.pop(optname, ()): + if action[0] == 'moved': + oldgroup, newgroup = action[1:] + optdef = optdef.copy() + optdef['group'] = oldgroup + elif action[0] == 'renamed': + optname = action[1] + elif action[0] == 'typechanged': + oldtype = action[1] + optdef = optdef.copy() + optdef['type'] = oldtype + options.append((optname, optdef)) + if changesindex: + raise Exception('unapplied changes: %s' % changesindex) + oldconfig = Configuration(options=options, name=newconfig.name) + # read the old config + oldconfig.load_file_configuration(configfile) + # apply values reverting changes + changes.reverse() + done = set() + for action in changes: + if action[0] == 'renamed': + oldname, newname = action[1:] + newconfig[newname] = oldconfig[oldname] + done.add(newname) + elif action[0] == 'typechanged': + optname, oldtype, newvalue = action[1:] + newconfig[optname] = newvalue + done.add(optname) + for optname, optdef in newconfig.options: + if optdef.get('type') and not optname in done: + newconfig.set_option(optname, oldconfig[optname], optdict=optdef) + + +def merge_options(options, optgroup=None): + """preprocess a list of options and remove duplicates, returning a new list + (tuple actually) of options. + + Options dictionaries are copied to avoid later side-effect. Also, if + `otpgroup` argument is specified, ensure all options are in the given group. + """ + alloptions = {} + options = list(options) + for i in range(len(options)-1, -1, -1): + optname, optdict = options[i] + if optname in alloptions: + options.pop(i) + alloptions[optname].update(optdict) + else: + optdict = optdict.copy() + options[i] = (optname, optdict) + alloptions[optname] = optdict + if optgroup is not None: + alloptions[optname]['group'] = optgroup + return tuple(options) diff --git a/logilab/common/contexts.py b/logilab/common/contexts.py new file mode 100644 index 0000000..d78c327 --- /dev/null +++ b/logilab/common/contexts.py @@ -0,0 +1,5 @@ +from warnings import warn +warn('logilab.common.contexts module is deprecated, use logilab.common.shellutils instead', + DeprecationWarning, stacklevel=1) + +from logilab.common.shellutils import tempfile, pushd diff --git a/logilab/common/corbautils.py b/logilab/common/corbautils.py new file mode 100644 index 0000000..65c301d --- /dev/null +++ b/logilab/common/corbautils.py @@ -0,0 +1,117 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""A set of utility function to ease the use of OmniORBpy. + + + + +""" +__docformat__ = "restructuredtext en" + +from omniORB import CORBA, PortableServer +import CosNaming + +orb = None + +def get_orb(): + """ + returns a reference to the ORB. + The first call to the method initialized the ORB + This method is mainly used internally in the module. + """ + + global orb + if orb is None: + import sys + orb = CORBA.ORB_init(sys.argv, CORBA.ORB_ID) + return orb + +def get_root_context(): + """ + returns a reference to the NameService object. + This method is mainly used internally in the module. + """ + + orb = get_orb() + nss = orb.resolve_initial_references("NameService") + rootContext = nss._narrow(CosNaming.NamingContext) + assert rootContext is not None, "Failed to narrow root naming context" + return rootContext + +def register_object_name(object, namepath): + """ + Registers a object in the NamingService. + The name path is a list of 2-uples (id,kind) giving the path. + + For instance if the path of an object is [('foo',''),('bar','')], + it is possible to get a reference to the object using the URL + 'corbaname::hostname#foo/bar'. + [('logilab','rootmodule'),('chatbot','application'),('chatter','server')] + is mapped to + 'corbaname::hostname#logilab.rootmodule/chatbot.application/chatter.server' + + The get_object_reference() function can be used to resolve such a URL. + """ + context = get_root_context() + for id, kind in namepath[:-1]: + name = [CosNaming.NameComponent(id, kind)] + try: + context = context.bind_new_context(name) + except CosNaming.NamingContext.AlreadyBound as ex: + context = context.resolve(name)._narrow(CosNaming.NamingContext) + assert context is not None, \ + 'test context exists but is not a NamingContext' + + id, kind = namepath[-1] + name = [CosNaming.NameComponent(id, kind)] + try: + context.bind(name, object._this()) + except CosNaming.NamingContext.AlreadyBound as ex: + context.rebind(name, object._this()) + +def activate_POA(): + """ + This methods activates the Portable Object Adapter. + You need to call it to enable the reception of messages in your code, + on both the client and the server. + """ + orb = get_orb() + poa = orb.resolve_initial_references('RootPOA') + poaManager = poa._get_the_POAManager() + poaManager.activate() + +def run_orb(): + """ + Enters the ORB mainloop on the server. + You should not call this method on the client. + """ + get_orb().run() + +def get_object_reference(url): + """ + Resolves a corbaname URL to an object proxy. + See register_object_name() for examples URLs + """ + return get_orb().string_to_object(url) + +def get_object_string(host, namepath): + """given an host name and a name path as described in register_object_name, + return a corba string identifier + """ + strname = '/'.join(['.'.join(path_elt) for path_elt in namepath]) + return 'corbaname::%s#%s' % (host, strname) diff --git a/logilab/common/daemon.py b/logilab/common/daemon.py new file mode 100644 index 0000000..40319a4 --- /dev/null +++ b/logilab/common/daemon.py @@ -0,0 +1,101 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""A daemonize function (for Unices)""" + +__docformat__ = "restructuredtext en" + +import os +import errno +import signal +import sys +import time +import warnings + +from six.moves import range + +def setugid(user): + """Change process user and group ID + + Argument is a numeric user id or a user name""" + try: + from pwd import getpwuid + passwd = getpwuid(int(user)) + except ValueError: + from pwd import getpwnam + passwd = getpwnam(user) + + if hasattr(os, 'initgroups'): # python >= 2.7 + os.initgroups(passwd.pw_name, passwd.pw_gid) + else: + import ctypes + if ctypes.CDLL(None).initgroups(passwd.pw_name, passwd.pw_gid) < 0: + err = ctypes.c_int.in_dll(ctypes.pythonapi,"errno").value + raise OSError(err, os.strerror(err), 'initgroups') + os.setgid(passwd.pw_gid) + os.setuid(passwd.pw_uid) + os.environ['HOME'] = passwd.pw_dir + + +def daemonize(pidfile=None, uid=None, umask=0o77): + """daemonize a Unix process. Set paranoid umask by default. + + Return 1 in the original process, 2 in the first fork, and None for the + second fork (eg daemon process). + """ + # http://www.faqs.org/faqs/unix-faq/programmer/faq/ + # + # fork so the parent can exit + if os.fork(): # launch child and... + return 1 + # disconnect from tty and create a new session + os.setsid() + # fork again so the parent, (the session group leader), can exit. + # as a non-session group leader, we can never regain a controlling + # terminal. + if os.fork(): # launch child again. + return 2 + # move to the root to avoit mount pb + os.chdir('/') + # redirect standard descriptors + null = os.open('/dev/null', os.O_RDWR) + for i in range(3): + try: + os.dup2(null, i) + except OSError as e: + if e.errno != errno.EBADF: + raise + os.close(null) + # filter warnings + warnings.filterwarnings('ignore') + # write pid in a file + if pidfile: + # ensure the directory where the pid-file should be set exists (for + # instance /var/run/cubicweb may be deleted on computer restart) + piddir = os.path.dirname(pidfile) + if not os.path.exists(piddir): + os.makedirs(piddir) + f = file(pidfile, 'w') + f.write(str(os.getpid())) + f.close() + # set umask if specified + if umask is not None: + os.umask(umask) + # change process uid + if uid: + setugid(uid) + return None diff --git a/logilab/common/date.py b/logilab/common/date.py new file mode 100644 index 0000000..a093a8a --- /dev/null +++ b/logilab/common/date.py @@ -0,0 +1,335 @@ +# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Date manipulation helper functions.""" +from __future__ import division + +__docformat__ = "restructuredtext en" + +import math +import re +import sys +from locale import getlocale, LC_TIME +from datetime import date, time, datetime, timedelta +from time import strptime as time_strptime +from calendar import monthrange, timegm + +from six.moves import range + +try: + from mx.DateTime import RelativeDateTime, Date, DateTimeType +except ImportError: + endOfMonth = None + DateTimeType = datetime +else: + endOfMonth = RelativeDateTime(months=1, day=-1) + +# NOTE: should we implement a compatibility layer between date representations +# as we have in lgc.db ? + +FRENCH_FIXED_HOLIDAYS = { + 'jour_an': '%s-01-01', + 'fete_travail': '%s-05-01', + 'armistice1945': '%s-05-08', + 'fete_nat': '%s-07-14', + 'assomption': '%s-08-15', + 'toussaint': '%s-11-01', + 'armistice1918': '%s-11-11', + 'noel': '%s-12-25', + } + +FRENCH_MOBILE_HOLIDAYS = { + 'paques2004': '2004-04-12', + 'ascension2004': '2004-05-20', + 'pentecote2004': '2004-05-31', + + 'paques2005': '2005-03-28', + 'ascension2005': '2005-05-05', + 'pentecote2005': '2005-05-16', + + 'paques2006': '2006-04-17', + 'ascension2006': '2006-05-25', + 'pentecote2006': '2006-06-05', + + 'paques2007': '2007-04-09', + 'ascension2007': '2007-05-17', + 'pentecote2007': '2007-05-28', + + 'paques2008': '2008-03-24', + 'ascension2008': '2008-05-01', + 'pentecote2008': '2008-05-12', + + 'paques2009': '2009-04-13', + 'ascension2009': '2009-05-21', + 'pentecote2009': '2009-06-01', + + 'paques2010': '2010-04-05', + 'ascension2010': '2010-05-13', + 'pentecote2010': '2010-05-24', + + 'paques2011': '2011-04-25', + 'ascension2011': '2011-06-02', + 'pentecote2011': '2011-06-13', + + 'paques2012': '2012-04-09', + 'ascension2012': '2012-05-17', + 'pentecote2012': '2012-05-28', + } + +# XXX this implementation cries for multimethod dispatching + +def get_step(dateobj, nbdays=1): + # assume date is either a python datetime or a mx.DateTime object + if isinstance(dateobj, date): + return ONEDAY * nbdays + return nbdays # mx.DateTime is ok with integers + +def datefactory(year, month, day, sampledate): + # assume date is either a python datetime or a mx.DateTime object + if isinstance(sampledate, datetime): + return datetime(year, month, day) + if isinstance(sampledate, date): + return date(year, month, day) + return Date(year, month, day) + +def weekday(dateobj): + # assume date is either a python datetime or a mx.DateTime object + if isinstance(dateobj, date): + return dateobj.weekday() + return dateobj.day_of_week + +def str2date(datestr, sampledate): + # NOTE: datetime.strptime is not an option until we drop py2.4 compat + year, month, day = [int(chunk) for chunk in datestr.split('-')] + return datefactory(year, month, day, sampledate) + +def days_between(start, end): + if isinstance(start, date): + delta = end - start + # datetime.timedelta.days is always an integer (floored) + if delta.seconds: + return delta.days + 1 + return delta.days + else: + return int(math.ceil((end - start).days)) + +def get_national_holidays(begin, end): + """return french national days off between begin and end""" + begin = datefactory(begin.year, begin.month, begin.day, begin) + end = datefactory(end.year, end.month, end.day, end) + holidays = [str2date(datestr, begin) + for datestr in FRENCH_MOBILE_HOLIDAYS.values()] + for year in range(begin.year, end.year+1): + for datestr in FRENCH_FIXED_HOLIDAYS.values(): + date = str2date(datestr % year, begin) + if date not in holidays: + holidays.append(date) + return [day for day in holidays if begin <= day < end] + +def add_days_worked(start, days): + """adds date but try to only take days worked into account""" + step = get_step(start) + weeks, plus = divmod(days, 5) + end = start + ((weeks * 7) + plus) * step + if weekday(end) >= 5: # saturday or sunday + end += (2 * step) + end += len([x for x in get_national_holidays(start, end + step) + if weekday(x) < 5]) * step + if weekday(end) >= 5: # saturday or sunday + end += (2 * step) + return end + +def nb_open_days(start, end): + assert start <= end + step = get_step(start) + days = days_between(start, end) + weeks, plus = divmod(days, 7) + if weekday(start) > weekday(end): + plus -= 2 + elif weekday(end) == 6: + plus -= 1 + open_days = weeks * 5 + plus + nb_week_holidays = len([x for x in get_national_holidays(start, end+step) + if weekday(x) < 5 and x < end]) + open_days -= nb_week_holidays + if open_days < 0: + return 0 + return open_days + +def date_range(begin, end, incday=None, incmonth=None): + """yields each date between begin and end + + :param begin: the start date + :param end: the end date + :param incr: the step to use to iterate over dates. Default is + one day. + :param include: None (means no exclusion) or a function taking a + date as parameter, and returning True if the date + should be included. + + When using mx datetime, you should *NOT* use incmonth argument, use instead + oneDay, oneHour, oneMinute, oneSecond, oneWeek or endOfMonth (to enumerate + months) as `incday` argument + """ + assert not (incday and incmonth) + begin = todate(begin) + end = todate(end) + if incmonth: + while begin < end: + yield begin + begin = next_month(begin, incmonth) + else: + incr = get_step(begin, incday or 1) + while begin < end: + yield begin + begin += incr + +# makes py datetime usable ##################################################### + +ONEDAY = timedelta(days=1) +ONEWEEK = timedelta(days=7) + +try: + strptime = datetime.strptime +except AttributeError: # py < 2.5 + from time import strptime as time_strptime + def strptime(value, format): + return datetime(*time_strptime(value, format)[:6]) + +def strptime_time(value, format='%H:%M'): + return time(*time_strptime(value, format)[3:6]) + +def todate(somedate): + """return a date from a date (leaving unchanged) or a datetime""" + if isinstance(somedate, datetime): + return date(somedate.year, somedate.month, somedate.day) + assert isinstance(somedate, (date, DateTimeType)), repr(somedate) + return somedate + +def totime(somedate): + """return a time from a time (leaving unchanged), date or datetime""" + # XXX mx compat + if not isinstance(somedate, time): + return time(somedate.hour, somedate.minute, somedate.second) + assert isinstance(somedate, (time)), repr(somedate) + return somedate + +def todatetime(somedate): + """return a date from a date (leaving unchanged) or a datetime""" + # take care, datetime is a subclass of date + if isinstance(somedate, datetime): + return somedate + assert isinstance(somedate, (date, DateTimeType)), repr(somedate) + return datetime(somedate.year, somedate.month, somedate.day) + +def datetime2ticks(somedate): + return timegm(somedate.timetuple()) * 1000 + +def ticks2datetime(ticks): + miliseconds, microseconds = divmod(ticks, 1000) + try: + return datetime.fromtimestamp(miliseconds) + except (ValueError, OverflowError): + epoch = datetime.fromtimestamp(0) + nb_days, seconds = divmod(int(miliseconds), 86400) + delta = timedelta(nb_days, seconds=seconds, microseconds=microseconds) + try: + return epoch + delta + except (ValueError, OverflowError): + raise + +def days_in_month(somedate): + return monthrange(somedate.year, somedate.month)[1] + +def days_in_year(somedate): + feb = date(somedate.year, 2, 1) + if days_in_month(feb) == 29: + return 366 + else: + return 365 + +def previous_month(somedate, nbmonth=1): + while nbmonth: + somedate = first_day(somedate) - ONEDAY + nbmonth -= 1 + return somedate + +def next_month(somedate, nbmonth=1): + while nbmonth: + somedate = last_day(somedate) + ONEDAY + nbmonth -= 1 + return somedate + +def first_day(somedate): + return date(somedate.year, somedate.month, 1) + +def last_day(somedate): + return date(somedate.year, somedate.month, days_in_month(somedate)) + +def ustrftime(somedate, fmt='%Y-%m-%d'): + """like strftime, but returns a unicode string instead of an encoded + string which may be problematic with localized date. + """ + if sys.version_info >= (3, 3): + # datetime.date.strftime() supports dates since year 1 in Python >=3.3. + return somedate.strftime(fmt) + else: + try: + if sys.version_info < (3, 0): + encoding = getlocale(LC_TIME)[1] or 'ascii' + return unicode(somedate.strftime(str(fmt)), encoding) + else: + return somedate.strftime(fmt) + except ValueError: + if somedate.year >= 1900: + raise + # datetime is not happy with dates before 1900 + # we try to work around this, assuming a simple + # format string + fields = {'Y': somedate.year, + 'm': somedate.month, + 'd': somedate.day, + } + if isinstance(somedate, datetime): + fields.update({'H': somedate.hour, + 'M': somedate.minute, + 'S': somedate.second}) + fmt = re.sub('%([YmdHMS])', r'%(\1)02d', fmt) + return unicode(fmt) % fields + +def utcdatetime(dt): + if dt.tzinfo is None: + return dt + return (dt.replace(tzinfo=None) - dt.utcoffset()) + +def utctime(dt): + if dt.tzinfo is None: + return dt + return (dt + dt.utcoffset() + dt.dst()).replace(tzinfo=None) + +def datetime_to_seconds(date): + """return the number of seconds since the begining of the day for that date + """ + return date.second+60*date.minute + 3600*date.hour + +def timedelta_to_days(delta): + """return the time delta as a number of seconds""" + return delta.days + delta.seconds / (3600*24) + +def timedelta_to_seconds(delta): + """return the time delta as a fraction of days""" + return delta.days*(3600*24) + delta.seconds diff --git a/logilab/common/dbf.py b/logilab/common/dbf.py new file mode 100644 index 0000000..ab142b2 --- /dev/null +++ b/logilab/common/dbf.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""This is a DBF reader which reads Visual Fox Pro DBF format with Memo field + +Usage: + +>>> rec = readDbf('test.dbf') +>>> for line in rec: +>>> print line['name'] + + +:date: 13/07/2007 + +http://www.physics.ox.ac.uk/users/santoso/Software.Repository.html +page says code is "available as is without any warranty or support". +""" +from __future__ import print_function + +import struct +import os, os.path +import sys +import csv +import tempfile + +from six.moves import range + +class Dbase: + def __init__(self): + self.fdb = None + self.fmemo = None + self.db_data = None + self.memo_data = None + self.fields = None + self.num_records = 0 + self.header = None + self.memo_file = '' + self.memo_header = None + self.memo_block_size = 0 + self.memo_header_len = 0 + + def _drop_after_NULL(self, txt): + for i in range(0, len(txt)): + if ord(struct.unpack('c', txt[i])[0])==0: + return txt[:i] + return txt + + def _reverse_endian(self, num): + if not len(num): + return 0 + val = struct.unpack('<L', num) + val = struct.pack('>L', val[0]) + val = struct.unpack('>L', val) + return val[0] + + def _assign_ids(self, lst, ids): + result = {} + idx = 0 + for item in lst: + id = ids[idx] + result[id] = item + idx += 1 + return result + + def open(self, db_name): + filesize = os.path.getsize(db_name) + if filesize <= 68: + raise IOError('The file is not large enough to be a dbf file') + + self.fdb = open(db_name, 'rb') + + self.memo_file = '' + if os.path.isfile(db_name[0:-1] + 't'): + self.memo_file = db_name[0:-1] + 't' + elif os.path.isfile(db_name[0:-3] + 'fpt'): + self.memo_file = db_name[0:-3] + 'fpt' + + if self.memo_file: + #Read memo file + self.fmemo = open(self.memo_file, 'rb') + self.memo_data = self.fmemo.read() + self.memo_header = self._assign_ids(struct.unpack('>6x1H', self.memo_data[:8]), ['Block size']) + block_size = self.memo_header['Block size'] + if not block_size: + block_size = 512 + self.memo_block_size = block_size + self.memo_header_len = block_size + memo_size = os.path.getsize(self.memo_file) + + #Start reading data file + data = self.fdb.read(32) + self.header = self._assign_ids(struct.unpack('<B 3B L 2H 20x', data), ['id', 'Year', 'Month', 'Day', '# of Records', 'Header Size', 'Record Size']) + self.header['id'] = hex(self.header['id']) + + self.num_records = self.header['# of Records'] + data = self.fdb.read(self.header['Header Size']-34) + self.fields = {} + x = 0 + header_pattern = '<11s c 4x B B 14x' + ids = ['Field Name', 'Field Type', 'Field Length', 'Field Precision'] + pattern_len = 32 + for offset in range(0, len(data), 32): + if ord(data[offset])==0x0d: + break + x += 1 + data_subset = data[offset: offset+pattern_len] + if len(data_subset) < pattern_len: + data_subset += ' '*(pattern_len-len(data_subset)) + self.fields[x] = self._assign_ids(struct.unpack(header_pattern, data_subset), ids) + self.fields[x]['Field Name'] = self._drop_after_NULL(self.fields[x]['Field Name']) + + self.fdb.read(3) + if self.header['# of Records']: + data_size = (self.header['# of Records'] * self.header['Record Size']) - 1 + self.db_data = self.fdb.read(data_size) + else: + self.db_data = '' + self.row_format = '<' + self.row_ids = [] + self.row_len = 0 + for key in self.fields: + field = self.fields[key] + self.row_format += '%ds ' % (field['Field Length']) + self.row_ids.append(field['Field Name']) + self.row_len += field['Field Length'] + + def close(self): + if self.fdb: + self.fdb.close() + if self.fmemo: + self.fmemo.close() + + def get_numrecords(self): + return self.num_records + + def get_record_with_names(self, rec_no): + """ + This function accept record number from 0 to N-1 + """ + if rec_no < 0 or rec_no > self.num_records: + raise Exception('Unable to extract data outside the range') + + offset = self.header['Record Size'] * rec_no + data = self.db_data[offset:offset+self.row_len] + record = self._assign_ids(struct.unpack(self.row_format, data), self.row_ids) + + if self.memo_file: + for key in self.fields: + field = self.fields[key] + f_type = field['Field Type'] + f_name = field['Field Name'] + c_data = record[f_name] + + if f_type=='M' or f_type=='G' or f_type=='B' or f_type=='P': + c_data = self._reverse_endian(c_data) + if c_data: + record[f_name] = self.read_memo(c_data-1).strip() + else: + record[f_name] = c_data.strip() + return record + + def read_memo_record(self, num, in_length): + """ + Read the record of given number. The second parameter is the length of + the record to read. It can be undefined, meaning read the whole record, + and it can be negative, meaning at most the length + """ + if in_length < 0: + in_length = -self.memo_block_size + + offset = self.memo_header_len + num * self.memo_block_size + self.fmemo.seek(offset) + if in_length<0: + in_length = -in_length + if in_length==0: + return '' + return self.fmemo.read(in_length) + + def read_memo(self, num): + result = '' + buffer = self.read_memo_record(num, -1) + if len(buffer)<=0: + return '' + length = struct.unpack('>L', buffer[4:4+4])[0] + 8 + + block_size = self.memo_block_size + if length < block_size: + return buffer[8:length] + rest_length = length - block_size + rest_data = self.read_memo_record(num+1, rest_length) + if len(rest_data)<=0: + return '' + return buffer[8:] + rest_data + +def readDbf(filename): + """ + Read the DBF file specified by the filename and + return the records as a list of dictionary. + + :param: filename File name of the DBF + :return: List of rows + """ + db = Dbase() + db.open(filename) + num = db.get_numrecords() + rec = [] + for i in range(0, num): + record = db.get_record_with_names(i) + rec.append(record) + db.close() + return rec + +if __name__=='__main__': + rec = readDbf('dbf/sptable.dbf') + for line in rec: + print('%s %s' % (line['GENUS'].strip(), line['SPECIES'].strip())) diff --git a/logilab/common/debugger.py b/logilab/common/debugger.py new file mode 100644 index 0000000..1f540a1 --- /dev/null +++ b/logilab/common/debugger.py @@ -0,0 +1,214 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Customized version of pdb's default debugger. + +- sets up a history file +- uses ipython if available to colorize lines of code +- overrides list command to search for current block instead + of using 5 lines of context + + + + +""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +try: + import readline +except ImportError: + readline = None +import os +import os.path as osp +import sys +from pdb import Pdb +import inspect + +from logilab.common.compat import StringIO + +try: + from IPython import PyColorize +except ImportError: + def colorize(source, *args): + """fallback colorize function""" + return source + def colorize_source(source, *args): + return source +else: + def colorize(source, start_lineno, curlineno): + """colorize and annotate source with linenos + (as in pdb's list command) + """ + parser = PyColorize.Parser() + output = StringIO() + parser.format(source, output) + annotated = [] + for index, line in enumerate(output.getvalue().splitlines()): + lineno = index + start_lineno + if lineno == curlineno: + annotated.append('%4s\t->\t%s' % (lineno, line)) + else: + annotated.append('%4s\t\t%s' % (lineno, line)) + return '\n'.join(annotated) + + def colorize_source(source): + """colorize given source""" + parser = PyColorize.Parser() + output = StringIO() + parser.format(source, output) + return output.getvalue() + + +def getsource(obj): + """Return the text of the source code for an object. + + The argument may be a module, class, method, function, traceback, frame, + or code object. The source code is returned as a single string. An + IOError is raised if the source code cannot be retrieved.""" + lines, lnum = inspect.getsourcelines(obj) + return ''.join(lines), lnum + + +################################################################ +class Debugger(Pdb): + """custom debugger + + - sets up a history file + - uses ipython if available to colorize lines of code + - overrides list command to search for current block instead + of using 5 lines of context + """ + def __init__(self, tcbk=None): + Pdb.__init__(self) + self.reset() + if tcbk: + while tcbk.tb_next is not None: + tcbk = tcbk.tb_next + self._tcbk = tcbk + self._histfile = os.path.expanduser("~/.pdbhist") + + def setup_history_file(self): + """if readline is available, read pdb history file + """ + if readline is not None: + try: + # XXX try..except shouldn't be necessary + # read_history_file() can accept None + readline.read_history_file(self._histfile) + except IOError: + pass + + def start(self): + """starts the interactive mode""" + self.interaction(self._tcbk.tb_frame, self._tcbk) + + def setup(self, frame, tcbk): + """setup hook: set up history file""" + self.setup_history_file() + Pdb.setup(self, frame, tcbk) + + def set_quit(self): + """quit hook: save commands in the history file""" + if readline is not None: + readline.write_history_file(self._histfile) + Pdb.set_quit(self) + + def complete_p(self, text, line, begin_idx, end_idx): + """provide variable names completion for the ``p`` command""" + namespace = dict(self.curframe.f_globals) + namespace.update(self.curframe.f_locals) + if '.' in text: + return self.attr_matches(text, namespace) + return [varname for varname in namespace if varname.startswith(text)] + + + def attr_matches(self, text, namespace): + """implementation coming from rlcompleter.Completer.attr_matches + Compute matches when text contains a dot. + + Assuming the text is of the form NAME.NAME....[NAME], and is + evaluatable in self.namespace, it will be evaluated and its attributes + (as revealed by dir()) are used as possible completions. (For class + instances, class members are also considered.) + + WARNING: this can still invoke arbitrary C code, if an object + with a __getattr__ hook is evaluated. + + """ + import re + m = re.match(r"(\w+(\.\w+)*)\.(\w*)", text) + if not m: + return + expr, attr = m.group(1, 3) + object = eval(expr, namespace) + words = dir(object) + if hasattr(object, '__class__'): + words.append('__class__') + words = words + self.get_class_members(object.__class__) + matches = [] + n = len(attr) + for word in words: + if word[:n] == attr and word != "__builtins__": + matches.append("%s.%s" % (expr, word)) + return matches + + def get_class_members(self, klass): + """implementation coming from rlcompleter.get_class_members""" + ret = dir(klass) + if hasattr(klass, '__bases__'): + for base in klass.__bases__: + ret = ret + self.get_class_members(base) + return ret + + ## specific / overridden commands + def do_list(self, arg): + """overrides default list command to display the surrounding block + instead of 5 lines of context + """ + self.lastcmd = 'list' + if not arg: + try: + source, start_lineno = getsource(self.curframe) + print(colorize(''.join(source), start_lineno, + self.curframe.f_lineno)) + except KeyboardInterrupt: + pass + except IOError: + Pdb.do_list(self, arg) + else: + Pdb.do_list(self, arg) + do_l = do_list + + def do_open(self, arg): + """opens source file corresponding to the current stack level""" + filename = self.curframe.f_code.co_filename + lineno = self.curframe.f_lineno + cmd = 'emacsclient --no-wait +%s %s' % (lineno, filename) + os.system(cmd) + + do_o = do_open + +def pm(): + """use our custom debugger""" + dbg = Debugger(sys.last_traceback) + dbg.start() + +def set_trace(): + Debugger().set_trace(sys._getframe().f_back) diff --git a/logilab/common/decorators.py b/logilab/common/decorators.py new file mode 100644 index 0000000..beafa20 --- /dev/null +++ b/logilab/common/decorators.py @@ -0,0 +1,281 @@ +# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +""" A few useful function/method decorators. """ + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +import sys +import types +from time import clock, time +from inspect import isgeneratorfunction, getargspec + +from logilab.common.compat import method_type + +# XXX rewrite so we can use the decorator syntax when keyarg has to be specified + +class cached_decorator(object): + def __init__(self, cacheattr=None, keyarg=None): + self.cacheattr = cacheattr + self.keyarg = keyarg + def __call__(self, callableobj=None): + assert not isgeneratorfunction(callableobj), \ + 'cannot cache generator function: %s' % callableobj + if len(getargspec(callableobj).args) == 1 or self.keyarg == 0: + cache = _SingleValueCache(callableobj, self.cacheattr) + elif self.keyarg: + cache = _MultiValuesKeyArgCache(callableobj, self.keyarg, self.cacheattr) + else: + cache = _MultiValuesCache(callableobj, self.cacheattr) + return cache.closure() + +class _SingleValueCache(object): + def __init__(self, callableobj, cacheattr=None): + self.callable = callableobj + if cacheattr is None: + self.cacheattr = '_%s_cache_' % callableobj.__name__ + else: + assert cacheattr != callableobj.__name__ + self.cacheattr = cacheattr + + def __call__(__me, self, *args): + try: + return self.__dict__[__me.cacheattr] + except KeyError: + value = __me.callable(self, *args) + setattr(self, __me.cacheattr, value) + return value + + def closure(self): + def wrapped(*args, **kwargs): + return self.__call__(*args, **kwargs) + wrapped.cache_obj = self + try: + wrapped.__doc__ = self.callable.__doc__ + wrapped.__name__ = self.callable.__name__ + except: + pass + return wrapped + + def clear(self, holder): + holder.__dict__.pop(self.cacheattr, None) + + +class _MultiValuesCache(_SingleValueCache): + def _get_cache(self, holder): + try: + _cache = holder.__dict__[self.cacheattr] + except KeyError: + _cache = {} + setattr(holder, self.cacheattr, _cache) + return _cache + + def __call__(__me, self, *args, **kwargs): + _cache = __me._get_cache(self) + try: + return _cache[args] + except KeyError: + _cache[args] = __me.callable(self, *args) + return _cache[args] + +class _MultiValuesKeyArgCache(_MultiValuesCache): + def __init__(self, callableobj, keyarg, cacheattr=None): + super(_MultiValuesKeyArgCache, self).__init__(callableobj, cacheattr) + self.keyarg = keyarg + + def __call__(__me, self, *args, **kwargs): + _cache = __me._get_cache(self) + key = args[__me.keyarg-1] + try: + return _cache[key] + except KeyError: + _cache[key] = __me.callable(self, *args, **kwargs) + return _cache[key] + + +def cached(callableobj=None, keyarg=None, **kwargs): + """Simple decorator to cache result of method call.""" + kwargs['keyarg'] = keyarg + decorator = cached_decorator(**kwargs) + if callableobj is None: + return decorator + else: + return decorator(callableobj) + + +class cachedproperty(object): + """ Provides a cached property equivalent to the stacking of + @cached and @property, but more efficient. + + After first usage, the <property_name> becomes part of the object's + __dict__. Doing: + + del obj.<property_name> empties the cache. + + Idea taken from the pyramid_ framework and the mercurial_ project. + + .. _pyramid: http://pypi.python.org/pypi/pyramid + .. _mercurial: http://pypi.python.org/pypi/Mercurial + """ + __slots__ = ('wrapped',) + + def __init__(self, wrapped): + try: + wrapped.__name__ + except AttributeError: + raise TypeError('%s must have a __name__ attribute' % + wrapped) + self.wrapped = wrapped + + @property + def __doc__(self): + doc = getattr(self.wrapped, '__doc__', None) + return ('<wrapped by the cachedproperty decorator>%s' + % ('\n%s' % doc if doc else '')) + + def __get__(self, inst, objtype=None): + if inst is None: + return self + val = self.wrapped(inst) + setattr(inst, self.wrapped.__name__, val) + return val + + +def get_cache_impl(obj, funcname): + cls = obj.__class__ + member = getattr(cls, funcname) + if isinstance(member, property): + member = member.fget + return member.cache_obj + +def clear_cache(obj, funcname): + """Clear a cache handled by the :func:`cached` decorator. If 'x' class has + @cached on its method `foo`, type + + >>> clear_cache(x, 'foo') + + to purge this method's cache on the instance. + """ + get_cache_impl(obj, funcname).clear(obj) + +def copy_cache(obj, funcname, cacheobj): + """Copy cache for <funcname> from cacheobj to obj.""" + cacheattr = get_cache_impl(obj, funcname).cacheattr + try: + setattr(obj, cacheattr, cacheobj.__dict__[cacheattr]) + except KeyError: + pass + + +class wproperty(object): + """Simple descriptor expecting to take a modifier function as first argument + and looking for a _<function name> to retrieve the attribute. + """ + def __init__(self, setfunc): + self.setfunc = setfunc + self.attrname = '_%s' % setfunc.__name__ + + def __set__(self, obj, value): + self.setfunc(obj, value) + + def __get__(self, obj, cls): + assert obj is not None + return getattr(obj, self.attrname) + + +class classproperty(object): + """this is a simple property-like class but for class attributes. + """ + def __init__(self, get): + self.get = get + def __get__(self, inst, cls): + return self.get(cls) + + +class iclassmethod(object): + '''Descriptor for method which should be available as class method if called + on the class or instance method if called on an instance. + ''' + def __init__(self, func): + self.func = func + def __get__(self, instance, objtype): + if instance is None: + return method_type(self.func, objtype, objtype.__class__) + return method_type(self.func, instance, objtype) + def __set__(self, instance, value): + raise AttributeError("can't set attribute") + + +def timed(f): + def wrap(*args, **kwargs): + t = time() + c = clock() + res = f(*args, **kwargs) + print('%s clock: %.9f / time: %.9f' % (f.__name__, + clock() - c, time() - t)) + return res + return wrap + + +def locked(acquire, release): + """Decorator taking two methods to acquire/release a lock as argument, + returning a decorator function which will call the inner method after + having called acquire(self) et will call release(self) afterwards. + """ + def decorator(f): + def wrapper(self, *args, **kwargs): + acquire(self) + try: + return f(self, *args, **kwargs) + finally: + release(self) + return wrapper + return decorator + + +def monkeypatch(klass, methodname=None): + """Decorator extending class with the decorated callable. This is basically + a syntactic sugar vs class assignment. + + >>> class A: + ... pass + >>> @monkeypatch(A) + ... def meth(self): + ... return 12 + ... + >>> a = A() + >>> a.meth() + 12 + >>> @monkeypatch(A, 'foo') + ... def meth(self): + ... return 12 + ... + >>> a.foo() + 12 + """ + def decorator(func): + try: + name = methodname or func.__name__ + except AttributeError: + raise AttributeError('%s has no __name__ attribute: ' + 'you should provide an explicit `methodname`' + % func) + setattr(klass, name, func) + return func + return decorator diff --git a/logilab/common/deprecation.py b/logilab/common/deprecation.py new file mode 100644 index 0000000..1c81b63 --- /dev/null +++ b/logilab/common/deprecation.py @@ -0,0 +1,189 @@ +# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Deprecation utilities.""" + +__docformat__ = "restructuredtext en" + +import sys +from warnings import warn + +from logilab.common.changelog import Version + + +class DeprecationWrapper(object): + """proxy to print a warning on access to any attribute of the wrapped object + """ + def __init__(self, proxied, msg=None): + self._proxied = proxied + self._msg = msg + + def __getattr__(self, attr): + warn(self._msg, DeprecationWarning, stacklevel=2) + return getattr(self._proxied, attr) + + def __setattr__(self, attr, value): + if attr in ('_proxied', '_msg'): + self.__dict__[attr] = value + else: + warn(self._msg, DeprecationWarning, stacklevel=2) + setattr(self._proxied, attr, value) + + +class DeprecationManager(object): + """Manage the deprecation message handling. Messages are dropped for + versions more recent than the 'compatible' version. Example:: + + deprecator = deprecation.DeprecationManager("module_name") + deprecator.compatibility('1.3') + + deprecator.warn('1.2', "message.") + + @deprecator.deprecated('1.2', 'Message') + def any_func(): + pass + + class AnyClass(object): + __metaclass__ = deprecator.class_deprecated('1.2') + """ + def __init__(self, module_name=None): + """ + """ + self.module_name = module_name + self.compatible_version = None + + def compatibility(self, compatible_version): + """Set the compatible version. + """ + self.compatible_version = Version(compatible_version) + + def deprecated(self, version=None, reason=None, stacklevel=2, name=None, doc=None): + """Display a deprecation message only if the version is older than the + compatible version. + """ + def decorator(func): + message = reason or 'The function "%s" is deprecated' + if '%s' in message: + message %= func.__name__ + def wrapped(*args, **kwargs): + self.warn(version, message, stacklevel+1) + return func(*args, **kwargs) + return wrapped + return decorator + + def class_deprecated(self, version=None): + class metaclass(type): + """metaclass to print a warning on instantiation of a deprecated class""" + + def __call__(cls, *args, **kwargs): + msg = getattr(cls, "__deprecation_warning__", + "%(cls)s is deprecated") % {'cls': cls.__name__} + self.warn(version, msg, stacklevel=3) + return type.__call__(cls, *args, **kwargs) + return metaclass + + def moved(self, version, modpath, objname): + """use to tell that a callable has been moved to a new module. + + It returns a callable wrapper, so that when its called a warning is printed + telling where the object can be found, import is done (and not before) and + the actual object is called. + + NOTE: the usage is somewhat limited on classes since it will fail if the + wrapper is use in a class ancestors list, use the `class_moved` function + instead (which has no lazy import feature though). + """ + def callnew(*args, **kwargs): + from logilab.common.modutils import load_module_from_name + message = "object %s has been moved to module %s" % (objname, modpath) + self.warn(version, message) + m = load_module_from_name(modpath) + return getattr(m, objname)(*args, **kwargs) + return callnew + + def class_renamed(self, version, old_name, new_class, message=None): + clsdict = {} + if message is None: + message = '%s is deprecated, use %s' % (old_name, new_class.__name__) + clsdict['__deprecation_warning__'] = message + try: + # new-style class + return self.class_deprecated(version)(old_name, (new_class,), clsdict) + except (NameError, TypeError): + # old-style class + warn = self.warn + class DeprecatedClass(new_class): + """FIXME: There might be a better way to handle old/new-style class + """ + def __init__(self, *args, **kwargs): + warn(version, message, stacklevel=3) + new_class.__init__(self, *args, **kwargs) + return DeprecatedClass + + def class_moved(self, version, new_class, old_name=None, message=None): + """nice wrapper around class_renamed when a class has been moved into + another module + """ + if old_name is None: + old_name = new_class.__name__ + if message is None: + message = 'class %s is now available as %s.%s' % ( + old_name, new_class.__module__, new_class.__name__) + return self.class_renamed(version, old_name, new_class, message) + + def warn(self, version=None, reason="", stacklevel=2): + """Display a deprecation message only if the version is older than the + compatible version. + """ + if (self.compatible_version is None + or version is None + or Version(version) < self.compatible_version): + if self.module_name and version: + reason = '[%s %s] %s' % (self.module_name, version, reason) + elif self.module_name: + reason = '[%s] %s' % (self.module_name, reason) + elif version: + reason = '[%s] %s' % (version, reason) + warn(reason, DeprecationWarning, stacklevel=stacklevel) + +_defaultdeprecator = DeprecationManager() + +def deprecated(reason=None, stacklevel=2, name=None, doc=None): + return _defaultdeprecator.deprecated(None, reason, stacklevel, name, doc) + +class_deprecated = _defaultdeprecator.class_deprecated() + +def moved(modpath, objname): + return _defaultdeprecator.moved(None, modpath, objname) +moved.__doc__ = _defaultdeprecator.moved.__doc__ + +def class_renamed(old_name, new_class, message=None): + """automatically creates a class which fires a DeprecationWarning + when instantiated. + + >>> Set = class_renamed('Set', set, 'Set is now replaced by set') + >>> s = Set() + sample.py:57: DeprecationWarning: Set is now replaced by set + s = Set() + >>> + """ + return _defaultdeprecator.class_renamed(None, old_name, new_class, message) + +def class_moved(new_class, old_name=None, message=None): + return _defaultdeprecator.class_moved(None, new_class, old_name, message) +class_moved.__doc__ = _defaultdeprecator.class_moved.__doc__ + diff --git a/logilab/common/fileutils.py b/logilab/common/fileutils.py new file mode 100644 index 0000000..b30cf5f --- /dev/null +++ b/logilab/common/fileutils.py @@ -0,0 +1,404 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""File and file-path manipulation utilities. + +:group path manipulation: first_level_directory, relative_path, is_binary,\ +get_by_ext, remove_dead_links +:group file manipulation: norm_read, norm_open, lines, stream_lines, lines,\ +write_open_mode, ensure_fs_mode, export +:sort: path manipulation, file manipulation +""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +import sys +import shutil +import mimetypes +from os.path import isabs, isdir, islink, split, exists, normpath, join +from os.path import abspath +from os import sep, mkdir, remove, listdir, stat, chmod, walk +from stat import ST_MODE, S_IWRITE + +from logilab.common import STD_BLACKLIST as BASE_BLACKLIST, IGNORED_EXTENSIONS +from logilab.common.shellutils import find +from logilab.common.deprecation import deprecated +from logilab.common.compat import FileIO + +def first_level_directory(path): + """Return the first level directory of a path. + + >>> first_level_directory('home/syt/work') + 'home' + >>> first_level_directory('/home/syt/work') + '/' + >>> first_level_directory('work') + 'work' + >>> + + :type path: str + :param path: the path for which we want the first level directory + + :rtype: str + :return: the first level directory appearing in `path` + """ + head, tail = split(path) + while head and tail: + head, tail = split(head) + if tail: + return tail + # path was absolute, head is the fs root + return head + +def abspath_listdir(path): + """Lists path's content using absolute paths. + + >>> os.listdir('/home') + ['adim', 'alf', 'arthur', 'auc'] + >>> abspath_listdir('/home') + ['/home/adim', '/home/alf', '/home/arthur', '/home/auc'] + """ + path = abspath(path) + return [join(path, filename) for filename in listdir(path)] + + +def is_binary(filename): + """Return true if filename may be a binary file, according to it's + extension. + + :type filename: str + :param filename: the name of the file + + :rtype: bool + :return: + true if the file is a binary file (actually if it's mime type + isn't beginning by text/) + """ + try: + return not mimetypes.guess_type(filename)[0].startswith('text') + except AttributeError: + return 1 + + +def write_open_mode(filename): + """Return the write mode that should used to open file. + + :type filename: str + :param filename: the name of the file + + :rtype: str + :return: the mode that should be use to open the file ('w' or 'wb') + """ + if is_binary(filename): + return 'wb' + return 'w' + + +def ensure_fs_mode(filepath, desired_mode=S_IWRITE): + """Check that the given file has the given mode(s) set, else try to + set it. + + :type filepath: str + :param filepath: path of the file + + :type desired_mode: int + :param desired_mode: + ORed flags describing the desired mode. Use constants from the + `stat` module for file permission's modes + """ + mode = stat(filepath)[ST_MODE] + if not mode & desired_mode: + chmod(filepath, mode | desired_mode) + + +# XXX (syt) unused? kill? +class ProtectedFile(FileIO): + """A special file-object class that automatically does a 'chmod +w' when + needed. + + XXX: for now, the way it is done allows 'normal file-objects' to be + created during the ProtectedFile object lifetime. + One way to circumvent this would be to chmod / unchmod on each + write operation. + + One other way would be to : + + - catch the IOError in the __init__ + + - if IOError, then create a StringIO object + + - each write operation writes in this StringIO object + + - on close()/del(), write/append the StringIO content to the file and + do the chmod only once + """ + def __init__(self, filepath, mode): + self.original_mode = stat(filepath)[ST_MODE] + self.mode_changed = False + if mode in ('w', 'a', 'wb', 'ab'): + if not self.original_mode & S_IWRITE: + chmod(filepath, self.original_mode | S_IWRITE) + self.mode_changed = True + FileIO.__init__(self, filepath, mode) + + def _restore_mode(self): + """restores the original mode if needed""" + if self.mode_changed: + chmod(self.name, self.original_mode) + # Don't re-chmod in case of several restore + self.mode_changed = False + + def close(self): + """restore mode before closing""" + self._restore_mode() + FileIO.close(self) + + def __del__(self): + if not self.closed: + self.close() + + +class UnresolvableError(Exception): + """Exception raised by relative path when it's unable to compute relative + path between two paths. + """ + +def relative_path(from_file, to_file): + """Try to get a relative path from `from_file` to `to_file` + (path will be absolute if to_file is an absolute file). This function + is useful to create link in `from_file` to `to_file`. This typical use + case is used in this function description. + + If both files are relative, they're expected to be relative to the same + directory. + + >>> relative_path( from_file='toto/index.html', to_file='index.html') + '../index.html' + >>> relative_path( from_file='index.html', to_file='toto/index.html') + 'toto/index.html' + >>> relative_path( from_file='tutu/index.html', to_file='toto/index.html') + '../toto/index.html' + >>> relative_path( from_file='toto/index.html', to_file='/index.html') + '/index.html' + >>> relative_path( from_file='/toto/index.html', to_file='/index.html') + '../index.html' + >>> relative_path( from_file='/toto/index.html', to_file='/toto/summary.html') + 'summary.html' + >>> relative_path( from_file='index.html', to_file='index.html') + '' + >>> relative_path( from_file='/index.html', to_file='toto/index.html') + Traceback (most recent call last): + File "<string>", line 1, in ? + File "<stdin>", line 37, in relative_path + UnresolvableError + >>> relative_path( from_file='/index.html', to_file='/index.html') + '' + >>> + + :type from_file: str + :param from_file: source file (where links will be inserted) + + :type to_file: str + :param to_file: target file (on which links point) + + :raise UnresolvableError: if it has been unable to guess a correct path + + :rtype: str + :return: the relative path of `to_file` from `from_file` + """ + from_file = normpath(from_file) + to_file = normpath(to_file) + if from_file == to_file: + return '' + if isabs(to_file): + if not isabs(from_file): + return to_file + elif isabs(from_file): + raise UnresolvableError() + from_parts = from_file.split(sep) + to_parts = to_file.split(sep) + idem = 1 + result = [] + while len(from_parts) > 1: + dirname = from_parts.pop(0) + if idem and len(to_parts) > 1 and dirname == to_parts[0]: + to_parts.pop(0) + else: + idem = 0 + result.append('..') + result += to_parts + return sep.join(result) + + +def norm_read(path): + """Return the content of the file with normalized line feeds. + + :type path: str + :param path: path to the file to read + + :rtype: str + :return: the content of the file with normalized line feeds + """ + return open(path, 'U').read() +norm_read = deprecated("use \"open(path, 'U').read()\"")(norm_read) + +def norm_open(path): + """Return a stream for a file with content with normalized line feeds. + + :type path: str + :param path: path to the file to open + + :rtype: file or StringIO + :return: the opened file with normalized line feeds + """ + return open(path, 'U') +norm_open = deprecated("use \"open(path, 'U')\"")(norm_open) + +def lines(path, comments=None): + """Return a list of non empty lines in the file located at `path`. + + :type path: str + :param path: path to the file + + :type comments: str or None + :param comments: + optional string which can be used to comment a line in the file + (i.e. lines starting with this string won't be returned) + + :rtype: list + :return: + a list of stripped line in the file, without empty and commented + lines + + :warning: at some point this function will probably return an iterator + """ + stream = open(path, 'U') + result = stream_lines(stream, comments) + stream.close() + return result + + +def stream_lines(stream, comments=None): + """Return a list of non empty lines in the given `stream`. + + :type stream: object implementing 'xreadlines' or 'readlines' + :param stream: file like object + + :type comments: str or None + :param comments: + optional string which can be used to comment a line in the file + (i.e. lines starting with this string won't be returned) + + :rtype: list + :return: + a list of stripped line in the file, without empty and commented + lines + + :warning: at some point this function will probably return an iterator + """ + try: + readlines = stream.xreadlines + except AttributeError: + readlines = stream.readlines + result = [] + for line in readlines(): + line = line.strip() + if line and (comments is None or not line.startswith(comments)): + result.append(line) + return result + + +def export(from_dir, to_dir, + blacklist=BASE_BLACKLIST, ignore_ext=IGNORED_EXTENSIONS, + verbose=0): + """Make a mirror of `from_dir` in `to_dir`, omitting directories and + files listed in the black list or ending with one of the given + extensions. + + :type from_dir: str + :param from_dir: directory to export + + :type to_dir: str + :param to_dir: destination directory + + :type blacklist: list or tuple + :param blacklist: + list of files or directories to ignore, default to the content of + `BASE_BLACKLIST` + + :type ignore_ext: list or tuple + :param ignore_ext: + list of extensions to ignore, default to the content of + `IGNORED_EXTENSIONS` + + :type verbose: bool + :param verbose: + flag indicating whether information about exported files should be + printed to stderr, default to False + """ + try: + mkdir(to_dir) + except OSError: + pass # FIXME we should use "exists" if the point is about existing dir + # else (permission problems?) shouldn't return / raise ? + for directory, dirnames, filenames in walk(from_dir): + for norecurs in blacklist: + try: + dirnames.remove(norecurs) + except ValueError: + continue + for dirname in dirnames: + src = join(directory, dirname) + dest = to_dir + src[len(from_dir):] + if isdir(src): + if not exists(dest): + mkdir(dest) + for filename in filenames: + # don't include binary files + # endswith does not accept tuple in 2.4 + if any([filename.endswith(ext) for ext in ignore_ext]): + continue + src = join(directory, filename) + dest = to_dir + src[len(from_dir):] + if verbose: + print(src, '->', dest, file=sys.stderr) + if exists(dest): + remove(dest) + shutil.copy2(src, dest) + + +def remove_dead_links(directory, verbose=0): + """Recursively traverse directory and remove all dead links. + + :type directory: str + :param directory: directory to cleanup + + :type verbose: bool + :param verbose: + flag indicating whether information about deleted links should be + printed to stderr, default to False + """ + for dirpath, dirname, filenames in walk(directory): + for filename in dirnames + filenames: + src = join(dirpath, filename) + if islink(src) and not exists(src): + if verbose: + print('remove dead link', src) + remove(src) + diff --git a/logilab/common/graph.py b/logilab/common/graph.py new file mode 100644 index 0000000..cef1c98 --- /dev/null +++ b/logilab/common/graph.py @@ -0,0 +1,282 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Graph manipulation utilities. + +(dot generation adapted from pypy/translator/tool/make_dot.py) +""" + +__docformat__ = "restructuredtext en" + +__metaclass__ = type + +import os.path as osp +import os +import sys +import tempfile +import codecs +import errno + +def escape(value): + """Make <value> usable in a dot file.""" + lines = [line.replace('"', '\\"') for line in value.split('\n')] + data = '\\l'.join(lines) + return '\\n' + data + +def target_info_from_filename(filename): + """Transforms /some/path/foo.png into ('/some/path', 'foo.png', 'png').""" + basename = osp.basename(filename) + storedir = osp.dirname(osp.abspath(filename)) + target = filename.split('.')[-1] + return storedir, basename, target + + +class DotBackend: + """Dot File backend.""" + def __init__(self, graphname, rankdir=None, size=None, ratio=None, + charset='utf-8', renderer='dot', additionnal_param={}): + self.graphname = graphname + self.renderer = renderer + self.lines = [] + self._source = None + self.emit("digraph %s {" % normalize_node_id(graphname)) + if rankdir: + self.emit('rankdir=%s' % rankdir) + if ratio: + self.emit('ratio=%s' % ratio) + if size: + self.emit('size="%s"' % size) + if charset: + assert charset.lower() in ('utf-8', 'iso-8859-1', 'latin1'), \ + 'unsupported charset %s' % charset + self.emit('charset="%s"' % charset) + for param in sorted(additionnal_param.items()): + self.emit('='.join(param)) + + def get_source(self): + """returns self._source""" + if self._source is None: + self.emit("}\n") + self._source = '\n'.join(self.lines) + del self.lines + return self._source + + source = property(get_source) + + def generate(self, outputfile=None, dotfile=None, mapfile=None): + """Generates a graph file. + + :param outputfile: filename and path [defaults to graphname.png] + :param dotfile: filename and path [defaults to graphname.dot] + + :rtype: str + :return: a path to the generated file + """ + import subprocess # introduced in py 2.4 + name = self.graphname + if not dotfile: + # if 'outputfile' is a dot file use it as 'dotfile' + if outputfile and outputfile.endswith(".dot"): + dotfile = outputfile + else: + dotfile = '%s.dot' % name + if outputfile is not None: + storedir, basename, target = target_info_from_filename(outputfile) + if target != "dot": + pdot, dot_sourcepath = tempfile.mkstemp(".dot", name) + os.close(pdot) + else: + dot_sourcepath = osp.join(storedir, dotfile) + else: + target = 'png' + pdot, dot_sourcepath = tempfile.mkstemp(".dot", name) + ppng, outputfile = tempfile.mkstemp(".png", name) + os.close(pdot) + os.close(ppng) + pdot = codecs.open(dot_sourcepath, 'w', encoding='utf8') + pdot.write(self.source) + pdot.close() + if target != 'dot': + if sys.platform == 'win32': + use_shell = True + else: + use_shell = False + try: + if mapfile: + subprocess.call([self.renderer, '-Tcmapx', '-o', mapfile, '-T', target, dot_sourcepath, '-o', outputfile], + shell=use_shell) + else: + subprocess.call([self.renderer, '-T', target, + dot_sourcepath, '-o', outputfile], + shell=use_shell) + except OSError as e: + if e.errno == errno.ENOENT: + e.strerror = 'File not found: {0}'.format(self.renderer) + raise + os.unlink(dot_sourcepath) + return outputfile + + def emit(self, line): + """Adds <line> to final output.""" + self.lines.append(line) + + def emit_edge(self, name1, name2, **props): + """emit an edge from <name1> to <name2>. + edge properties: see http://www.graphviz.org/doc/info/attrs.html + """ + attrs = ['%s="%s"' % (prop, value) for prop, value in props.items()] + n_from, n_to = normalize_node_id(name1), normalize_node_id(name2) + self.emit('%s -> %s [%s];' % (n_from, n_to, ', '.join(sorted(attrs))) ) + + def emit_node(self, name, **props): + """emit a node with given properties. + node properties: see http://www.graphviz.org/doc/info/attrs.html + """ + attrs = ['%s="%s"' % (prop, value) for prop, value in props.items()] + self.emit('%s [%s];' % (normalize_node_id(name), ', '.join(sorted(attrs)))) + +def normalize_node_id(nid): + """Returns a suitable DOT node id for `nid`.""" + return '"%s"' % nid + +class GraphGenerator: + def __init__(self, backend): + # the backend is responsible to output the graph in a particular format + self.backend = backend + + # XXX doesn't like space in outpufile / mapfile + def generate(self, visitor, propshdlr, outputfile=None, mapfile=None): + # the visitor + # the property handler is used to get node and edge properties + # according to the graph and to the backend + self.propshdlr = propshdlr + for nodeid, node in visitor.nodes(): + props = propshdlr.node_properties(node) + self.backend.emit_node(nodeid, **props) + for subjnode, objnode, edge in visitor.edges(): + props = propshdlr.edge_properties(edge, subjnode, objnode) + self.backend.emit_edge(subjnode, objnode, **props) + return self.backend.generate(outputfile=outputfile, mapfile=mapfile) + + +class UnorderableGraph(Exception): + pass + +def ordered_nodes(graph): + """takes a dependency graph dict as arguments and return an ordered tuple of + nodes starting with nodes without dependencies and up to the outermost node. + + If there is some cycle in the graph, :exc:`UnorderableGraph` will be raised. + + Also the given graph dict will be emptied. + """ + # check graph consistency + cycles = get_cycles(graph) + if cycles: + cycles = '\n'.join([' -> '.join(cycle) for cycle in cycles]) + raise UnorderableGraph('cycles in graph: %s' % cycles) + vertices = set(graph) + to_vertices = set() + for edges in graph.values(): + to_vertices |= set(edges) + missing_vertices = to_vertices - vertices + if missing_vertices: + raise UnorderableGraph('missing vertices: %s' % ', '.join(missing_vertices)) + # order vertices + order = [] + order_set = set() + old_len = None + while graph: + if old_len == len(graph): + raise UnorderableGraph('unknown problem with %s' % graph) + old_len = len(graph) + deps_ok = [] + for node, node_deps in graph.items(): + for dep in node_deps: + if dep not in order_set: + break + else: + deps_ok.append(node) + order.append(deps_ok) + order_set |= set(deps_ok) + for node in deps_ok: + del graph[node] + result = [] + for grp in reversed(order): + result.extend(sorted(grp)) + return tuple(result) + + +def get_cycles(graph_dict, vertices=None): + '''given a dictionary representing an ordered graph (i.e. key are vertices + and values is a list of destination vertices representing edges), return a + list of detected cycles + ''' + if not graph_dict: + return () + result = [] + if vertices is None: + vertices = graph_dict.keys() + for vertice in vertices: + _get_cycles(graph_dict, [], set(), result, vertice) + return result + +def _get_cycles(graph_dict, path, visited, result, vertice): + """recursive function doing the real work for get_cycles""" + if vertice in path: + cycle = [vertice] + for node in path[::-1]: + if node == vertice: + break + cycle.insert(0, node) + # make a canonical representation + start_from = min(cycle) + index = cycle.index(start_from) + cycle = cycle[index:] + cycle[0:index] + # append it to result if not already in + if not cycle in result: + result.append(cycle) + return + path.append(vertice) + try: + for node in graph_dict[vertice]: + # don't check already visited nodes again + if node not in visited: + _get_cycles(graph_dict, path, visited, result, node) + visited.add(node) + except KeyError: + pass + path.pop() + +def has_path(graph_dict, fromnode, tonode, path=None): + """generic function taking a simple graph definition as a dictionary, with + node has key associated to a list of nodes directly reachable from it. + + Return None if no path exists to go from `fromnode` to `tonode`, else the + first path found (as a list including the destination node at last) + """ + if path is None: + path = [] + elif fromnode in path: + return None + path.append(fromnode) + for destnode in graph_dict[fromnode]: + if destnode == tonode or has_path(graph_dict, destnode, tonode, path): + return path[1:] + [tonode] + path.pop() + return None + diff --git a/logilab/common/interface.py b/logilab/common/interface.py new file mode 100644 index 0000000..3ea4ab7 --- /dev/null +++ b/logilab/common/interface.py @@ -0,0 +1,71 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Bases class for interfaces to provide 'light' interface handling. + + TODO: + _ implements a check method which check that an object implements the + interface + _ Attribute objects + + This module requires at least python 2.2 +""" +__docformat__ = "restructuredtext en" + + +class Interface(object): + """Base class for interfaces.""" + def is_implemented_by(cls, instance): + return implements(instance, cls) + is_implemented_by = classmethod(is_implemented_by) + + +def implements(obj, interface): + """Return true if the give object (maybe an instance or class) implements + the interface. + """ + kimplements = getattr(obj, '__implements__', ()) + if not isinstance(kimplements, (list, tuple)): + kimplements = (kimplements,) + for implementedinterface in kimplements: + if issubclass(implementedinterface, interface): + return True + return False + + +def extend(klass, interface, _recurs=False): + """Add interface to klass'__implements__ if not already implemented in. + + If klass is subclassed, ensure subclasses __implements__ it as well. + + NOTE: klass should be e new class. + """ + if not implements(klass, interface): + try: + kimplements = klass.__implements__ + kimplementsklass = type(kimplements) + kimplements = list(kimplements) + except AttributeError: + kimplementsklass = tuple + kimplements = [] + kimplements.append(interface) + klass.__implements__ = kimplementsklass(kimplements) + for subklass in klass.__subclasses__(): + extend(subklass, interface, _recurs=True) + elif _recurs: + for subklass in klass.__subclasses__(): + extend(subklass, interface, _recurs=True) diff --git a/logilab/common/logging_ext.py b/logilab/common/logging_ext.py new file mode 100644 index 0000000..3b6a580 --- /dev/null +++ b/logilab/common/logging_ext.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Extends the logging module from the standard library.""" + +__docformat__ = "restructuredtext en" + +import os +import sys +import logging + +from six import string_types + +from logilab.common.textutils import colorize_ansi + + +def set_log_methods(cls, logger): + """bind standard logger's methods as methods on the class""" + cls.__logger = logger + for attr in ('debug', 'info', 'warning', 'error', 'critical', 'exception'): + setattr(cls, attr, getattr(logger, attr)) + + +def xxx_cyan(record): + if 'XXX' in record.message: + return 'cyan' + +class ColorFormatter(logging.Formatter): + """ + A color Formatter for the logging standard module. + + By default, colorize CRITICAL and ERROR in red, WARNING in orange, INFO in + green and DEBUG in yellow. + + self.colors is customizable via the 'color' constructor argument (dictionary). + + self.colorfilters is a list of functions that get the LogRecord + and return a color name or None. + """ + + def __init__(self, fmt=None, datefmt=None, colors=None): + logging.Formatter.__init__(self, fmt, datefmt) + self.colorfilters = [] + self.colors = {'CRITICAL': 'red', + 'ERROR': 'red', + 'WARNING': 'magenta', + 'INFO': 'green', + 'DEBUG': 'yellow', + } + if colors is not None: + assert isinstance(colors, dict) + self.colors.update(colors) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.levelname in self.colors: + color = self.colors[record.levelname] + return colorize_ansi(msg, color) + else: + for cf in self.colorfilters: + color = cf(record) + if color: + return colorize_ansi(msg, color) + return msg + +def set_color_formatter(logger=None, **kw): + """ + Install a color formatter on the 'logger'. If not given, it will + defaults to the default logger. + + Any additional keyword will be passed as-is to the ColorFormatter + constructor. + """ + if logger is None: + logger = logging.getLogger() + if not logger.handlers: + logging.basicConfig() + format_msg = logger.handlers[0].formatter._fmt + fmt = ColorFormatter(format_msg, **kw) + fmt.colorfilters.append(xxx_cyan) + logger.handlers[0].setFormatter(fmt) + + +LOG_FORMAT = '%(asctime)s - (%(name)s) %(levelname)s: %(message)s' +LOG_DATE_FORMAT = '%Y-%m-%d %H:%M:%S' + +def get_handler(debug=False, syslog=False, logfile=None, rotation_parameters=None): + """get an apropriate handler according to given parameters""" + if os.environ.get('APYCOT_ROOT'): + handler = logging.StreamHandler(sys.stdout) + if debug: + handler = logging.StreamHandler() + elif logfile is None: + if syslog: + from logging import handlers + handler = handlers.SysLogHandler() + else: + handler = logging.StreamHandler() + else: + try: + if rotation_parameters is None: + if os.name == 'posix' and sys.version_info >= (2, 6): + from logging.handlers import WatchedFileHandler + handler = WatchedFileHandler(logfile) + else: + handler = logging.FileHandler(logfile) + else: + from logging.handlers import TimedRotatingFileHandler + handler = TimedRotatingFileHandler( + logfile, **rotation_parameters) + except IOError: + handler = logging.StreamHandler() + return handler + +def get_threshold(debug=False, logthreshold=None): + if logthreshold is None: + if debug: + logthreshold = logging.DEBUG + else: + logthreshold = logging.ERROR + elif isinstance(logthreshold, string_types): + logthreshold = getattr(logging, THRESHOLD_MAP.get(logthreshold, + logthreshold)) + return logthreshold + +def _colorable_terminal(): + isatty = hasattr(sys.__stdout__, 'isatty') and sys.__stdout__.isatty() + if not isatty: + return False + if os.name == 'nt': + try: + from colorama import init as init_win32_colors + except ImportError: + return False + init_win32_colors() + return True + +def get_formatter(logformat=LOG_FORMAT, logdateformat=LOG_DATE_FORMAT): + if _colorable_terminal(): + fmt = ColorFormatter(logformat, logdateformat) + def col_fact(record): + if 'XXX' in record.message: + return 'cyan' + if 'kick' in record.message: + return 'red' + fmt.colorfilters.append(col_fact) + else: + fmt = logging.Formatter(logformat, logdateformat) + return fmt + +def init_log(debug=False, syslog=False, logthreshold=None, logfile=None, + logformat=LOG_FORMAT, logdateformat=LOG_DATE_FORMAT, fmt=None, + rotation_parameters=None, handler=None): + """init the log service""" + logger = logging.getLogger() + if handler is None: + handler = get_handler(debug, syslog, logfile, rotation_parameters) + # only addHandler and removeHandler method while I would like a setHandler + # method, so do it this way :$ + logger.handlers = [handler] + logthreshold = get_threshold(debug, logthreshold) + logger.setLevel(logthreshold) + if fmt is None: + if debug: + fmt = get_formatter(logformat=logformat, logdateformat=logdateformat) + else: + fmt = logging.Formatter(logformat, logdateformat) + handler.setFormatter(fmt) + return handler + +# map logilab.common.logger thresholds to logging thresholds +THRESHOLD_MAP = {'LOG_DEBUG': 'DEBUG', + 'LOG_INFO': 'INFO', + 'LOG_NOTICE': 'INFO', + 'LOG_WARN': 'WARNING', + 'LOG_WARNING': 'WARNING', + 'LOG_ERR': 'ERROR', + 'LOG_ERROR': 'ERROR', + 'LOG_CRIT': 'CRITICAL', + } diff --git a/logilab/common/modutils.py b/logilab/common/modutils.py new file mode 100644 index 0000000..a426a3a --- /dev/null +++ b/logilab/common/modutils.py @@ -0,0 +1,702 @@ +# -*- coding: utf-8 -*- +# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Python modules manipulation utility functions. + +:type PY_SOURCE_EXTS: tuple(str) +:var PY_SOURCE_EXTS: list of possible python source file extension + +:type STD_LIB_DIR: str +:var STD_LIB_DIR: directory where standard modules are located + +:type BUILTIN_MODULES: dict +:var BUILTIN_MODULES: dictionary with builtin module names has key +""" + +__docformat__ = "restructuredtext en" + +import sys +import os +from os.path import splitext, join, abspath, isdir, dirname, exists, basename +from imp import find_module, load_module, C_BUILTIN, PY_COMPILED, PKG_DIRECTORY +from distutils.sysconfig import get_config_var, get_python_lib, get_python_version +from distutils.errors import DistutilsPlatformError + +from six.moves import range + +try: + import zipimport +except ImportError: + zipimport = None + +ZIPFILE = object() + +from logilab.common import STD_BLACKLIST, _handle_blacklist + +# Notes about STD_LIB_DIR +# Consider arch-specific installation for STD_LIB_DIR definition +# :mod:`distutils.sysconfig` contains to much hardcoded values to rely on +# +# :see: `Problems with /usr/lib64 builds <http://bugs.python.org/issue1294959>`_ +# :see: `FHS <http://www.pathname.com/fhs/pub/fhs-2.3.html#LIBLTQUALGTALTERNATEFORMATESSENTIAL>`_ +if sys.platform.startswith('win'): + PY_SOURCE_EXTS = ('py', 'pyw') + PY_COMPILED_EXTS = ('dll', 'pyd') +else: + PY_SOURCE_EXTS = ('py',) + PY_COMPILED_EXTS = ('so',) + +try: + STD_LIB_DIR = get_python_lib(standard_lib=1) +# get_python_lib(standard_lib=1) is not available on pypy, set STD_LIB_DIR to +# non-valid path, see https://bugs.pypy.org/issue1164 +except DistutilsPlatformError: + STD_LIB_DIR = '//' + +EXT_LIB_DIR = get_python_lib() + +BUILTIN_MODULES = dict(zip(sys.builtin_module_names, + [1]*len(sys.builtin_module_names))) + + +class NoSourceFile(Exception): + """exception raised when we are not able to get a python + source file for a precompiled file + """ + +class LazyObject(object): + def __init__(self, module, obj): + self.module = module + self.obj = obj + self._imported = None + + def _getobj(self): + if self._imported is None: + self._imported = getattr(load_module_from_name(self.module), + self.obj) + return self._imported + + def __getattribute__(self, attr): + try: + return super(LazyObject, self).__getattribute__(attr) + except AttributeError as ex: + return getattr(self._getobj(), attr) + + def __call__(self, *args, **kwargs): + return self._getobj()(*args, **kwargs) + + +def load_module_from_name(dotted_name, path=None, use_sys=1): + """Load a Python module from its name. + + :type dotted_name: str + :param dotted_name: python name of a module or package + + :type path: list or None + :param path: + optional list of path where the module or package should be + searched (use sys.path if nothing or None is given) + + :type use_sys: bool + :param use_sys: + boolean indicating whether the sys.modules dictionary should be + used or not + + + :raise ImportError: if the module or package is not found + + :rtype: module + :return: the loaded module + """ + return load_module_from_modpath(dotted_name.split('.'), path, use_sys) + + +def load_module_from_modpath(parts, path=None, use_sys=1): + """Load a python module from its splitted name. + + :type parts: list(str) or tuple(str) + :param parts: + python name of a module or package splitted on '.' + + :type path: list or None + :param path: + optional list of path where the module or package should be + searched (use sys.path if nothing or None is given) + + :type use_sys: bool + :param use_sys: + boolean indicating whether the sys.modules dictionary should be used or not + + :raise ImportError: if the module or package is not found + + :rtype: module + :return: the loaded module + """ + if use_sys: + try: + return sys.modules['.'.join(parts)] + except KeyError: + pass + modpath = [] + prevmodule = None + for part in parts: + modpath.append(part) + curname = '.'.join(modpath) + module = None + if len(modpath) != len(parts): + # even with use_sys=False, should try to get outer packages from sys.modules + module = sys.modules.get(curname) + elif use_sys: + # because it may have been indirectly loaded through a parent + module = sys.modules.get(curname) + if module is None: + mp_file, mp_filename, mp_desc = find_module(part, path) + module = load_module(curname, mp_file, mp_filename, mp_desc) + if prevmodule: + setattr(prevmodule, part, module) + _file = getattr(module, '__file__', '') + if not _file and len(modpath) != len(parts): + raise ImportError('no module in %s' % '.'.join(parts[len(modpath):]) ) + path = [dirname( _file )] + prevmodule = module + return module + + +def load_module_from_file(filepath, path=None, use_sys=1, extrapath=None): + """Load a Python module from it's path. + + :type filepath: str + :param filepath: path to the python module or package + + :type path: list or None + :param path: + optional list of path where the module or package should be + searched (use sys.path if nothing or None is given) + + :type use_sys: bool + :param use_sys: + boolean indicating whether the sys.modules dictionary should be + used or not + + + :raise ImportError: if the module or package is not found + + :rtype: module + :return: the loaded module + """ + modpath = modpath_from_file(filepath, extrapath) + return load_module_from_modpath(modpath, path, use_sys) + + +def _check_init(path, mod_path): + """check there are some __init__.py all along the way""" + for part in mod_path: + path = join(path, part) + if not _has_init(path): + return False + return True + + +def modpath_from_file(filename, extrapath=None): + """given a file path return the corresponding splitted module's name + (i.e name of a module or package splitted on '.') + + :type filename: str + :param filename: file's path for which we want the module's name + + :type extrapath: dict + :param extrapath: + optional extra search path, with path as key and package name for the path + as value. This is usually useful to handle package splitted in multiple + directories using __path__ trick. + + + :raise ImportError: + if the corresponding module's name has not been found + + :rtype: list(str) + :return: the corresponding splitted module's name + """ + base = splitext(abspath(filename))[0] + if extrapath is not None: + for path_ in extrapath: + path = abspath(path_) + if path and base[:len(path)] == path: + submodpath = [pkg for pkg in base[len(path):].split(os.sep) + if pkg] + if _check_init(path, submodpath[:-1]): + return extrapath[path_].split('.') + submodpath + for path in sys.path: + path = abspath(path) + if path and base.startswith(path): + modpath = [pkg for pkg in base[len(path):].split(os.sep) if pkg] + if _check_init(path, modpath[:-1]): + return modpath + raise ImportError('Unable to find module for %s in %s' % ( + filename, ', \n'.join(sys.path))) + + + +def file_from_modpath(modpath, path=None, context_file=None): + """given a mod path (i.e. splitted module / package name), return the + corresponding file, giving priority to source file over precompiled + file if it exists + + :type modpath: list or tuple + :param modpath: + splitted module's name (i.e name of a module or package splitted + on '.') + (this means explicit relative imports that start with dots have + empty strings in this list!) + + :type path: list or None + :param path: + optional list of path where the module or package should be + searched (use sys.path if nothing or None is given) + + :type context_file: str or None + :param context_file: + context file to consider, necessary if the identifier has been + introduced using a relative import unresolvable in the actual + context (i.e. modutils) + + :raise ImportError: if there is no such module in the directory + + :rtype: str or None + :return: + the path to the module's file or None if it's an integrated + builtin module such as 'sys' + """ + if context_file is not None: + context = dirname(context_file) + else: + context = context_file + if modpath[0] == 'xml': + # handle _xmlplus + try: + return _file_from_modpath(['_xmlplus'] + modpath[1:], path, context) + except ImportError: + return _file_from_modpath(modpath, path, context) + elif modpath == ['os', 'path']: + # FIXME: currently ignoring search_path... + return os.path.__file__ + return _file_from_modpath(modpath, path, context) + + + +def get_module_part(dotted_name, context_file=None): + """given a dotted name return the module part of the name : + + >>> get_module_part('logilab.common.modutils.get_module_part') + 'logilab.common.modutils' + + :type dotted_name: str + :param dotted_name: full name of the identifier we are interested in + + :type context_file: str or None + :param context_file: + context file to consider, necessary if the identifier has been + introduced using a relative import unresolvable in the actual + context (i.e. modutils) + + + :raise ImportError: if there is no such module in the directory + + :rtype: str or None + :return: + the module part of the name or None if we have not been able at + all to import the given name + + XXX: deprecated, since it doesn't handle package precedence over module + (see #10066) + """ + # os.path trick + if dotted_name.startswith('os.path'): + return 'os.path' + parts = dotted_name.split('.') + if context_file is not None: + # first check for builtin module which won't be considered latter + # in that case (path != None) + if parts[0] in BUILTIN_MODULES: + if len(parts) > 2: + raise ImportError(dotted_name) + return parts[0] + # don't use += or insert, we want a new list to be created ! + path = None + starti = 0 + if parts[0] == '': + assert context_file is not None, \ + 'explicit relative import, but no context_file?' + path = [] # prevent resolving the import non-relatively + starti = 1 + while parts[starti] == '': # for all further dots: change context + starti += 1 + context_file = dirname(context_file) + for i in range(starti, len(parts)): + try: + file_from_modpath(parts[starti:i+1], + path=path, context_file=context_file) + except ImportError: + if not i >= max(1, len(parts) - 2): + raise + return '.'.join(parts[:i]) + return dotted_name + + +def get_modules(package, src_directory, blacklist=STD_BLACKLIST): + """given a package directory return a list of all available python + modules in the package and its subpackages + + :type package: str + :param package: the python name for the package + + :type src_directory: str + :param src_directory: + path of the directory corresponding to the package + + :type blacklist: list or tuple + :param blacklist: + optional list of files or directory to ignore, default to + the value of `logilab.common.STD_BLACKLIST` + + :rtype: list + :return: + the list of all available python modules in the package and its + subpackages + """ + modules = [] + for directory, dirnames, filenames in os.walk(src_directory): + _handle_blacklist(blacklist, dirnames, filenames) + # check for __init__.py + if not '__init__.py' in filenames: + dirnames[:] = () + continue + if directory != src_directory: + dir_package = directory[len(src_directory):].replace(os.sep, '.') + modules.append(package + dir_package) + for filename in filenames: + if _is_python_file(filename) and filename != '__init__.py': + src = join(directory, filename) + module = package + src[len(src_directory):-3] + modules.append(module.replace(os.sep, '.')) + return modules + + + +def get_module_files(src_directory, blacklist=STD_BLACKLIST): + """given a package directory return a list of all available python + module's files in the package and its subpackages + + :type src_directory: str + :param src_directory: + path of the directory corresponding to the package + + :type blacklist: list or tuple + :param blacklist: + optional list of files or directory to ignore, default to the value of + `logilab.common.STD_BLACKLIST` + + :rtype: list + :return: + the list of all available python module's files in the package and + its subpackages + """ + files = [] + for directory, dirnames, filenames in os.walk(src_directory): + _handle_blacklist(blacklist, dirnames, filenames) + # check for __init__.py + if not '__init__.py' in filenames: + dirnames[:] = () + continue + for filename in filenames: + if _is_python_file(filename): + src = join(directory, filename) + files.append(src) + return files + + +def get_source_file(filename, include_no_ext=False): + """given a python module's file name return the matching source file + name (the filename will be returned identically if it's a already an + absolute path to a python source file...) + + :type filename: str + :param filename: python module's file name + + + :raise NoSourceFile: if no source file exists on the file system + + :rtype: str + :return: the absolute path of the source file if it exists + """ + base, orig_ext = splitext(abspath(filename)) + for ext in PY_SOURCE_EXTS: + source_path = '%s.%s' % (base, ext) + if exists(source_path): + return source_path + if include_no_ext and not orig_ext and exists(base): + return base + raise NoSourceFile(filename) + + +def cleanup_sys_modules(directories): + """remove submodules of `directories` from `sys.modules`""" + cleaned = [] + for modname, module in list(sys.modules.items()): + modfile = getattr(module, '__file__', None) + if modfile: + for directory in directories: + if modfile.startswith(directory): + cleaned.append(modname) + del sys.modules[modname] + break + return cleaned + + +def is_python_source(filename): + """ + rtype: bool + return: True if the filename is a python source file + """ + return splitext(filename)[1][1:] in PY_SOURCE_EXTS + + + +def is_standard_module(modname, std_path=(STD_LIB_DIR,)): + """try to guess if a module is a standard python module (by default, + see `std_path` parameter's description) + + :type modname: str + :param modname: name of the module we are interested in + + :type std_path: list(str) or tuple(str) + :param std_path: list of path considered has standard + + + :rtype: bool + :return: + true if the module: + - is located on the path listed in one of the directory in `std_path` + - is a built-in module + """ + modname = modname.split('.')[0] + try: + filename = file_from_modpath([modname]) + except ImportError as ex: + # import failed, i'm probably not so wrong by supposing it's + # not standard... + return 0 + # modules which are not living in a file are considered standard + # (sys and __builtin__ for instance) + if filename is None: + return 1 + filename = abspath(filename) + if filename.startswith(EXT_LIB_DIR): + return 0 + for path in std_path: + if filename.startswith(abspath(path)): + return 1 + return False + + + +def is_relative(modname, from_file): + """return true if the given module name is relative to the given + file name + + :type modname: str + :param modname: name of the module we are interested in + + :type from_file: str + :param from_file: + path of the module from which modname has been imported + + :rtype: bool + :return: + true if the module has been imported relatively to `from_file` + """ + if not isdir(from_file): + from_file = dirname(from_file) + if from_file in sys.path: + return False + try: + find_module(modname.split('.')[0], [from_file]) + return True + except ImportError: + return False + + +# internal only functions ##################################################### + +def _file_from_modpath(modpath, path=None, context=None): + """given a mod path (i.e. splitted module / package name), return the + corresponding file + + this function is used internally, see `file_from_modpath`'s + documentation for more information + """ + assert len(modpath) > 0 + if context is not None: + try: + mtype, mp_filename = _module_file(modpath, [context]) + except ImportError: + mtype, mp_filename = _module_file(modpath, path) + else: + mtype, mp_filename = _module_file(modpath, path) + if mtype == PY_COMPILED: + try: + return get_source_file(mp_filename) + except NoSourceFile: + return mp_filename + elif mtype == C_BUILTIN: + # integrated builtin module + return None + elif mtype == PKG_DIRECTORY: + mp_filename = _has_init(mp_filename) + return mp_filename + +def _search_zip(modpath, pic): + for filepath, importer in pic.items(): + if importer is not None: + if importer.find_module(modpath[0]): + if not importer.find_module('/'.join(modpath)): + raise ImportError('No module named %s in %s/%s' % ( + '.'.join(modpath[1:]), filepath, modpath)) + return ZIPFILE, abspath(filepath) + '/' + '/'.join(modpath), filepath + raise ImportError('No module named %s' % '.'.join(modpath)) + +try: + import pkg_resources +except ImportError: + pkg_resources = None + +def _module_file(modpath, path=None): + """get a module type / file path + + :type modpath: list or tuple + :param modpath: + splitted module's name (i.e name of a module or package splitted + on '.'), with leading empty strings for explicit relative import + + :type path: list or None + :param path: + optional list of path where the module or package should be + searched (use sys.path if nothing or None is given) + + + :rtype: tuple(int, str) + :return: the module type flag and the file path for a module + """ + # egg support compat + try: + pic = sys.path_importer_cache + _path = (path is None and sys.path or path) + for __path in _path: + if not __path in pic: + try: + pic[__path] = zipimport.zipimporter(__path) + except zipimport.ZipImportError: + pic[__path] = None + checkeggs = True + except AttributeError: + checkeggs = False + # pkg_resources support (aka setuptools namespace packages) + if (pkg_resources is not None + and modpath[0] in pkg_resources._namespace_packages + and modpath[0] in sys.modules + and len(modpath) > 1): + # setuptools has added into sys.modules a module object with proper + # __path__, get back information from there + module = sys.modules[modpath.pop(0)] + path = module.__path__ + imported = [] + while modpath: + modname = modpath[0] + # take care to changes in find_module implementation wrt builtin modules + # + # Python 2.6.6 (r266:84292, Sep 11 2012, 08:34:23) + # >>> imp.find_module('posix') + # (None, 'posix', ('', '', 6)) + # + # Python 3.3.1 (default, Apr 26 2013, 12:08:46) + # >>> imp.find_module('posix') + # (None, None, ('', '', 6)) + try: + _, mp_filename, mp_desc = find_module(modname, path) + except ImportError: + if checkeggs: + return _search_zip(modpath, pic)[:2] + raise + else: + if checkeggs and mp_filename: + fullabspath = [abspath(x) for x in _path] + try: + pathindex = fullabspath.index(dirname(abspath(mp_filename))) + emtype, emp_filename, zippath = _search_zip(modpath, pic) + if pathindex > _path.index(zippath): + # an egg takes priority + return emtype, emp_filename + except ValueError: + # XXX not in _path + pass + except ImportError: + pass + checkeggs = False + imported.append(modpath.pop(0)) + mtype = mp_desc[2] + if modpath: + if mtype != PKG_DIRECTORY: + raise ImportError('No module %s in %s' % ('.'.join(modpath), + '.'.join(imported))) + # XXX guess if package is using pkgutil.extend_path by looking for + # those keywords in the first four Kbytes + try: + with open(join(mp_filename, '__init__.py')) as stream: + data = stream.read(4096) + except IOError: + path = [mp_filename] + else: + if 'pkgutil' in data and 'extend_path' in data: + # extend_path is called, search sys.path for module/packages + # of this name see pkgutil.extend_path documentation + path = [join(p, *imported) for p in sys.path + if isdir(join(p, *imported))] + else: + path = [mp_filename] + return mtype, mp_filename + +def _is_python_file(filename): + """return true if the given filename should be considered as a python file + + .pyc and .pyo are ignored + """ + for ext in ('.py', '.so', '.pyd', '.pyw'): + if filename.endswith(ext): + return True + return False + + +def _has_init(directory): + """if the given directory has a valid __init__ file, return its path, + else return None + """ + mod_or_pack = join(directory, '__init__') + for ext in PY_SOURCE_EXTS + ('pyc', 'pyo'): + if exists(mod_or_pack + '.' + ext): + return mod_or_pack + '.' + ext + return None diff --git a/logilab/common/optik_ext.py b/logilab/common/optik_ext.py new file mode 100644 index 0000000..1fd2a7f --- /dev/null +++ b/logilab/common/optik_ext.py @@ -0,0 +1,392 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Add an abstraction level to transparently import optik classes from optparse +(python >= 2.3) or the optik package. + +It also defines three new types for optik/optparse command line parser : + + * regexp + argument of this type will be converted using re.compile + * csv + argument of this type will be converted using split(',') + * yn + argument of this type will be true if 'y' or 'yes', false if 'n' or 'no' + * named + argument of this type are in the form <NAME>=<VALUE> or <NAME>:<VALUE> + * password + argument of this type wont be converted but this is used by other tools + such as interactive prompt for configuration to double check value and + use an invisible field + * multiple_choice + same as default "choice" type but multiple choices allowed + * file + argument of this type wont be converted but checked that the given file exists + * color + argument of this type wont be converted but checked its either a + named color or a color specified using hexadecimal notation (preceded by a #) + * time + argument of this type will be converted to a float value in seconds + according to time units (ms, s, min, h, d) + * bytes + argument of this type will be converted to a float value in bytes + according to byte units (b, kb, mb, gb, tb) +""" +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +import re +import sys +import time +from copy import copy +from os.path import exists + +# python >= 2.3 +from optparse import OptionParser as BaseParser, Option as BaseOption, \ + OptionGroup, OptionContainer, OptionValueError, OptionError, \ + Values, HelpFormatter, NO_DEFAULT, SUPPRESS_HELP + +try: + from mx import DateTime + HAS_MX_DATETIME = True +except ImportError: + HAS_MX_DATETIME = False + +from logilab.common.textutils import splitstrip, TIME_UNITS, BYTE_UNITS, \ + apply_units + + +def check_regexp(option, opt, value): + """check a regexp value by trying to compile it + return the compiled regexp + """ + if hasattr(value, 'pattern'): + return value + try: + return re.compile(value) + except ValueError: + raise OptionValueError( + "option %s: invalid regexp value: %r" % (opt, value)) + +def check_csv(option, opt, value): + """check a csv value by trying to split it + return the list of separated values + """ + if isinstance(value, (list, tuple)): + return value + try: + return splitstrip(value) + except ValueError: + raise OptionValueError( + "option %s: invalid csv value: %r" % (opt, value)) + +def check_yn(option, opt, value): + """check a yn value + return true for yes and false for no + """ + if isinstance(value, int): + return bool(value) + if value in ('y', 'yes'): + return True + if value in ('n', 'no'): + return False + msg = "option %s: invalid yn value %r, should be in (y, yes, n, no)" + raise OptionValueError(msg % (opt, value)) + +def check_named(option, opt, value): + """check a named value + return a dictionary containing (name, value) associations + """ + if isinstance(value, dict): + return value + values = [] + for value in check_csv(option, opt, value): + if value.find('=') != -1: + values.append(value.split('=', 1)) + elif value.find(':') != -1: + values.append(value.split(':', 1)) + if values: + return dict(values) + msg = "option %s: invalid named value %r, should be <NAME>=<VALUE> or \ +<NAME>:<VALUE>" + raise OptionValueError(msg % (opt, value)) + +def check_password(option, opt, value): + """check a password value (can't be empty) + """ + # no actual checking, monkey patch if you want more + return value + +def check_file(option, opt, value): + """check a file value + return the filepath + """ + if exists(value): + return value + msg = "option %s: file %r does not exist" + raise OptionValueError(msg % (opt, value)) + +# XXX use python datetime +def check_date(option, opt, value): + """check a file value + return the filepath + """ + try: + return DateTime.strptime(value, "%Y/%m/%d") + except DateTime.Error : + raise OptionValueError( + "expected format of %s is yyyy/mm/dd" % opt) + +def check_color(option, opt, value): + """check a color value and returns it + /!\ does *not* check color labels (like 'red', 'green'), only + checks hexadecimal forms + """ + # Case (1) : color label, we trust the end-user + if re.match('[a-z0-9 ]+$', value, re.I): + return value + # Case (2) : only accepts hexadecimal forms + if re.match('#[a-f0-9]{6}', value, re.I): + return value + # Else : not a color label neither a valid hexadecimal form => error + msg = "option %s: invalid color : %r, should be either hexadecimal \ + value or predefined color" + raise OptionValueError(msg % (opt, value)) + +def check_time(option, opt, value): + if isinstance(value, (int, long, float)): + return value + return apply_units(value, TIME_UNITS) + +def check_bytes(option, opt, value): + if hasattr(value, '__int__'): + return value + return apply_units(value, BYTE_UNITS) + + +class Option(BaseOption): + """override optik.Option to add some new option types + """ + TYPES = BaseOption.TYPES + ('regexp', 'csv', 'yn', 'named', 'password', + 'multiple_choice', 'file', 'color', + 'time', 'bytes') + ATTRS = BaseOption.ATTRS + ['hide', 'level'] + TYPE_CHECKER = copy(BaseOption.TYPE_CHECKER) + TYPE_CHECKER['regexp'] = check_regexp + TYPE_CHECKER['csv'] = check_csv + TYPE_CHECKER['yn'] = check_yn + TYPE_CHECKER['named'] = check_named + TYPE_CHECKER['multiple_choice'] = check_csv + TYPE_CHECKER['file'] = check_file + TYPE_CHECKER['color'] = check_color + TYPE_CHECKER['password'] = check_password + TYPE_CHECKER['time'] = check_time + TYPE_CHECKER['bytes'] = check_bytes + if HAS_MX_DATETIME: + TYPES += ('date',) + TYPE_CHECKER['date'] = check_date + + def __init__(self, *opts, **attrs): + BaseOption.__init__(self, *opts, **attrs) + if hasattr(self, "hide") and self.hide: + self.help = SUPPRESS_HELP + + def _check_choice(self): + """FIXME: need to override this due to optik misdesign""" + if self.type in ("choice", "multiple_choice"): + if self.choices is None: + raise OptionError( + "must supply a list of choices for type 'choice'", self) + elif not isinstance(self.choices, (tuple, list)): + raise OptionError( + "choices must be a list of strings ('%s' supplied)" + % str(type(self.choices)).split("'")[1], self) + elif self.choices is not None: + raise OptionError( + "must not supply choices for type %r" % self.type, self) + BaseOption.CHECK_METHODS[2] = _check_choice + + + def process(self, opt, value, values, parser): + # First, convert the value(s) to the right type. Howl if any + # value(s) are bogus. + value = self.convert_value(opt, value) + if self.type == 'named': + existant = getattr(values, self.dest) + if existant: + existant.update(value) + value = existant + # And then take whatever action is expected of us. + # This is a separate method to make life easier for + # subclasses to add new actions. + return self.take_action( + self.action, self.dest, opt, value, values, parser) + + +class OptionParser(BaseParser): + """override optik.OptionParser to use our Option class + """ + def __init__(self, option_class=Option, *args, **kwargs): + BaseParser.__init__(self, option_class=Option, *args, **kwargs) + + def format_option_help(self, formatter=None): + if formatter is None: + formatter = self.formatter + outputlevel = getattr(formatter, 'output_level', 0) + formatter.store_option_strings(self) + result = [] + result.append(formatter.format_heading("Options")) + formatter.indent() + if self.option_list: + result.append(OptionContainer.format_option_help(self, formatter)) + result.append("\n") + for group in self.option_groups: + if group.level <= outputlevel and ( + group.description or level_options(group, outputlevel)): + result.append(group.format_help(formatter)) + result.append("\n") + formatter.dedent() + # Drop the last "\n", or the header if no options or option groups: + return "".join(result[:-1]) + + +OptionGroup.level = 0 + +def level_options(group, outputlevel): + return [option for option in group.option_list + if (getattr(option, 'level', 0) or 0) <= outputlevel + and not option.help is SUPPRESS_HELP] + +def format_option_help(self, formatter): + result = [] + outputlevel = getattr(formatter, 'output_level', 0) or 0 + for option in level_options(self, outputlevel): + result.append(formatter.format_option(option)) + return "".join(result) +OptionContainer.format_option_help = format_option_help + + +class ManHelpFormatter(HelpFormatter): + """Format help using man pages ROFF format""" + + def __init__ (self, + indent_increment=0, + max_help_position=24, + width=79, + short_first=0): + HelpFormatter.__init__ ( + self, indent_increment, max_help_position, width, short_first) + + def format_heading(self, heading): + return '.SH %s\n' % heading.upper() + + def format_description(self, description): + return description + + def format_option(self, option): + try: + optstring = option.option_strings + except AttributeError: + optstring = self.format_option_strings(option) + if option.help: + help_text = self.expand_default(option) + help = ' '.join([l.strip() for l in help_text.splitlines()]) + else: + help = '' + return '''.IP "%s" +%s +''' % (optstring, help) + + def format_head(self, optparser, pkginfo, section=1): + long_desc = "" + try: + pgm = optparser._get_prog_name() + except AttributeError: + # py >= 2.4.X (dunno which X exactly, at least 2) + pgm = optparser.get_prog_name() + short_desc = self.format_short_description(pgm, pkginfo.description) + if hasattr(pkginfo, "long_desc"): + long_desc = self.format_long_description(pgm, pkginfo.long_desc) + return '%s\n%s\n%s\n%s' % (self.format_title(pgm, section), + short_desc, self.format_synopsis(pgm), + long_desc) + + def format_title(self, pgm, section): + date = '-'.join([str(num) for num in time.localtime()[:3]]) + return '.TH %s %s "%s" %s' % (pgm, section, date, pgm) + + def format_short_description(self, pgm, short_desc): + return '''.SH NAME +.B %s +\- %s +''' % (pgm, short_desc.strip()) + + def format_synopsis(self, pgm): + return '''.SH SYNOPSIS +.B %s +[ +.I OPTIONS +] [ +.I <arguments> +] +''' % pgm + + def format_long_description(self, pgm, long_desc): + long_desc = '\n'.join([line.lstrip() + for line in long_desc.splitlines()]) + long_desc = long_desc.replace('\n.\n', '\n\n') + if long_desc.lower().startswith(pgm): + long_desc = long_desc[len(pgm):] + return '''.SH DESCRIPTION +.B %s +%s +''' % (pgm, long_desc.strip()) + + def format_tail(self, pkginfo): + tail = '''.SH SEE ALSO +/usr/share/doc/pythonX.Y-%s/ + +.SH BUGS +Please report bugs on the project\'s mailing list: +%s + +.SH AUTHOR +%s <%s> +''' % (getattr(pkginfo, 'debian_name', pkginfo.modname), + pkginfo.mailinglist, pkginfo.author, pkginfo.author_email) + + if hasattr(pkginfo, "copyright"): + tail += ''' +.SH COPYRIGHT +%s +''' % pkginfo.copyright + + return tail + +def generate_manpage(optparser, pkginfo, section=1, stream=sys.stdout, level=0): + """generate a man page from an optik parser""" + formatter = ManHelpFormatter() + formatter.output_level = level + formatter.parser = optparser + print(formatter.format_head(optparser, pkginfo, section), file=stream) + print(optparser.format_option_help(formatter), file=stream) + print(formatter.format_tail(pkginfo), file=stream) + + +__all__ = ('OptionParser', 'Option', 'OptionGroup', 'OptionValueError', + 'Values') diff --git a/logilab/common/optparser.py b/logilab/common/optparser.py new file mode 100644 index 0000000..aa17750 --- /dev/null +++ b/logilab/common/optparser.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Extend OptionParser with commands. + +Example: + +>>> parser = OptionParser() +>>> parser.usage = '%prog COMMAND [options] <arg> ...' +>>> parser.add_command('build', 'mymod.build') +>>> parser.add_command('clean', run_clean, add_opt_clean) +>>> run, options, args = parser.parse_command(sys.argv[1:]) +>>> return run(options, args[1:]) + +With mymod.build that defines two functions run and add_options +""" +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +from warnings import warn +warn('lgc.optparser module is deprecated, use lgc.clcommands instead', DeprecationWarning, + stacklevel=2) + +import sys +import optparse + +class OptionParser(optparse.OptionParser): + + def __init__(self, *args, **kwargs): + optparse.OptionParser.__init__(self, *args, **kwargs) + self._commands = {} + self.min_args, self.max_args = 0, 1 + + def add_command(self, name, mod_or_funcs, help=''): + """name of the command, name of module or tuple of functions + (run, add_options) + """ + assert isinstance(mod_or_funcs, str) or isinstance(mod_or_funcs, tuple), \ + "mod_or_funcs has to be a module name or a tuple of functions" + self._commands[name] = (mod_or_funcs, help) + + def print_main_help(self): + optparse.OptionParser.print_help(self) + print('\ncommands:') + for cmdname, (_, help) in self._commands.items(): + print('% 10s - %s' % (cmdname, help)) + + def parse_command(self, args): + if len(args) == 0: + self.print_main_help() + sys.exit(1) + cmd = args[0] + args = args[1:] + if cmd not in self._commands: + if cmd in ('-h', '--help'): + self.print_main_help() + sys.exit(0) + elif self.version is not None and cmd == "--version": + self.print_version() + sys.exit(0) + self.error('unknown command') + self.prog = '%s %s' % (self.prog, cmd) + mod_or_f, help = self._commands[cmd] + # optparse inserts self.description between usage and options help + self.description = help + if isinstance(mod_or_f, str): + exec('from %s import run, add_options' % mod_or_f) + else: + run, add_options = mod_or_f + add_options(self) + (options, args) = self.parse_args(args) + if not (self.min_args <= len(args) <= self.max_args): + self.error('incorrect number of arguments') + return run, options, args + + diff --git a/logilab/common/proc.py b/logilab/common/proc.py new file mode 100644 index 0000000..c27356c --- /dev/null +++ b/logilab/common/proc.py @@ -0,0 +1,277 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""module providing: +* process information (linux specific: rely on /proc) +* a class for resource control (memory / time / cpu time) + +This module doesn't work on windows platforms (only tested on linux) + +:organization: Logilab + + + +""" +__docformat__ = "restructuredtext en" + +import os +import stat +from resource import getrlimit, setrlimit, RLIMIT_CPU, RLIMIT_AS +from signal import signal, SIGXCPU, SIGKILL, SIGUSR2, SIGUSR1 +from threading import Timer, currentThread, Thread, Event +from time import time + +from logilab.common.tree import Node + +class NoSuchProcess(Exception): pass + +def proc_exists(pid): + """check the a pid is registered in /proc + raise NoSuchProcess exception if not + """ + if not os.path.exists('/proc/%s' % pid): + raise NoSuchProcess() + +PPID = 3 +UTIME = 13 +STIME = 14 +CUTIME = 15 +CSTIME = 16 +VSIZE = 22 + +class ProcInfo(Node): + """provide access to process information found in /proc""" + + def __init__(self, pid): + self.pid = int(pid) + Node.__init__(self, self.pid) + proc_exists(self.pid) + self.file = '/proc/%s/stat' % self.pid + self.ppid = int(self.status()[PPID]) + + def memory_usage(self): + """return the memory usage of the process in Ko""" + try : + return int(self.status()[VSIZE]) + except IOError: + return 0 + + def lineage_memory_usage(self): + return self.memory_usage() + sum([child.lineage_memory_usage() + for child in self.children]) + + def time(self, children=0): + """return the number of jiffies that this process has been scheduled + in user and kernel mode""" + status = self.status() + time = int(status[UTIME]) + int(status[STIME]) + if children: + time += int(status[CUTIME]) + int(status[CSTIME]) + return time + + def status(self): + """return the list of fields found in /proc/<pid>/stat""" + return open(self.file).read().split() + + def name(self): + """return the process name found in /proc/<pid>/stat + """ + return self.status()[1].strip('()') + + def age(self): + """return the age of the process + """ + return os.stat(self.file)[stat.ST_MTIME] + +class ProcInfoLoader: + """manage process information""" + + def __init__(self): + self._loaded = {} + + def list_pids(self): + """return a list of existent process ids""" + for subdir in os.listdir('/proc'): + if subdir.isdigit(): + yield int(subdir) + + def load(self, pid): + """get a ProcInfo object for a given pid""" + pid = int(pid) + try: + return self._loaded[pid] + except KeyError: + procinfo = ProcInfo(pid) + procinfo.manager = self + self._loaded[pid] = procinfo + return procinfo + + + def load_all(self): + """load all processes information""" + for pid in self.list_pids(): + try: + procinfo = self.load(pid) + if procinfo.parent is None and procinfo.ppid: + pprocinfo = self.load(procinfo.ppid) + pprocinfo.append(procinfo) + except NoSuchProcess: + pass + + +try: + class ResourceError(BaseException): + """Error raise when resource limit is reached""" + limit = "Unknown Resource Limit" +except NameError: + class ResourceError(Exception): + """Error raise when resource limit is reached""" + limit = "Unknown Resource Limit" + + +class XCPUError(ResourceError): + """Error raised when CPU Time limit is reached""" + limit = "CPU Time" + +class LineageMemoryError(ResourceError): + """Error raised when the total amount of memory used by a process and + it's child is reached""" + limit = "Lineage total Memory" + +class TimeoutError(ResourceError): + """Error raised when the process is running for to much time""" + limit = "Real Time" + +# Can't use subclass because the StandardError MemoryError raised +RESOURCE_LIMIT_EXCEPTION = (ResourceError, MemoryError) + + +class MemorySentinel(Thread): + """A class checking a process don't use too much memory in a separated + daemonic thread + """ + def __init__(self, interval, memory_limit, gpid=os.getpid()): + Thread.__init__(self, target=self._run, name="Test.Sentinel") + self.memory_limit = memory_limit + self._stop = Event() + self.interval = interval + self.setDaemon(True) + self.gpid = gpid + + def stop(self): + """stop ap""" + self._stop.set() + + def _run(self): + pil = ProcInfoLoader() + while not self._stop.isSet(): + if self.memory_limit <= pil.load(self.gpid).lineage_memory_usage(): + os.killpg(self.gpid, SIGUSR1) + self._stop.wait(self.interval) + + +class ResourceController: + + def __init__(self, max_cpu_time=None, max_time=None, max_memory=None, + max_reprieve=60): + if SIGXCPU == -1: + raise RuntimeError("Unsupported platform") + self.max_time = max_time + self.max_memory = max_memory + self.max_cpu_time = max_cpu_time + self._reprieve = max_reprieve + self._timer = None + self._msentinel = None + self._old_max_memory = None + self._old_usr1_hdlr = None + self._old_max_cpu_time = None + self._old_usr2_hdlr = None + self._old_sigxcpu_hdlr = None + self._limit_set = 0 + self._abort_try = 0 + self._start_time = None + self._elapse_time = 0 + + def _hangle_sig_timeout(self, sig, frame): + raise TimeoutError() + + def _hangle_sig_memory(self, sig, frame): + if self._abort_try < self._reprieve: + self._abort_try += 1 + raise LineageMemoryError("Memory limit reached") + else: + os.killpg(os.getpid(), SIGKILL) + + def _handle_sigxcpu(self, sig, frame): + if self._abort_try < self._reprieve: + self._abort_try += 1 + raise XCPUError("Soft CPU time limit reached") + else: + os.killpg(os.getpid(), SIGKILL) + + def _time_out(self): + if self._abort_try < self._reprieve: + self._abort_try += 1 + os.killpg(os.getpid(), SIGUSR2) + if self._limit_set > 0: + self._timer = Timer(1, self._time_out) + self._timer.start() + else: + os.killpg(os.getpid(), SIGKILL) + + def setup_limit(self): + """set up the process limit""" + assert currentThread().getName() == 'MainThread' + os.setpgrp() + if self._limit_set <= 0: + if self.max_time is not None: + self._old_usr2_hdlr = signal(SIGUSR2, self._hangle_sig_timeout) + self._timer = Timer(max(1, int(self.max_time) - self._elapse_time), + self._time_out) + self._start_time = int(time()) + self._timer.start() + if self.max_cpu_time is not None: + self._old_max_cpu_time = getrlimit(RLIMIT_CPU) + cpu_limit = (int(self.max_cpu_time), self._old_max_cpu_time[1]) + self._old_sigxcpu_hdlr = signal(SIGXCPU, self._handle_sigxcpu) + setrlimit(RLIMIT_CPU, cpu_limit) + if self.max_memory is not None: + self._msentinel = MemorySentinel(1, int(self.max_memory) ) + self._old_max_memory = getrlimit(RLIMIT_AS) + self._old_usr1_hdlr = signal(SIGUSR1, self._hangle_sig_memory) + as_limit = (int(self.max_memory), self._old_max_memory[1]) + setrlimit(RLIMIT_AS, as_limit) + self._msentinel.start() + self._limit_set += 1 + + def clean_limit(self): + """reinstall the old process limit""" + if self._limit_set > 0: + if self.max_time is not None: + self._timer.cancel() + self._elapse_time += int(time())-self._start_time + self._timer = None + signal(SIGUSR2, self._old_usr2_hdlr) + if self.max_cpu_time is not None: + setrlimit(RLIMIT_CPU, self._old_max_cpu_time) + signal(SIGXCPU, self._old_sigxcpu_hdlr) + if self.max_memory is not None: + self._msentinel.stop() + self._msentinel = None + setrlimit(RLIMIT_AS, self._old_max_memory) + signal(SIGUSR1, self._old_usr1_hdlr) + self._limit_set -= 1 diff --git a/logilab/common/pyro_ext.py b/logilab/common/pyro_ext.py new file mode 100644 index 0000000..5204b1b --- /dev/null +++ b/logilab/common/pyro_ext.py @@ -0,0 +1,180 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Python Remote Object utilities + +Main functions available: + +* `register_object` to expose arbitrary object through pyro using delegation + approach and register it in the nameserver. +* `ns_unregister` unregister an object identifier from the nameserver. +* `ns_get_proxy` get a pyro proxy from a nameserver object identifier. +""" + +__docformat__ = "restructuredtext en" + +import logging +import tempfile + +from Pyro import core, naming, errors, util, config + +_LOGGER = logging.getLogger('pyro') +_MARKER = object() + +config.PYRO_STORAGE = tempfile.gettempdir() + +def ns_group_and_id(idstr, defaultnsgroup=_MARKER): + try: + nsgroup, nsid = idstr.rsplit('.', 1) + except ValueError: + if defaultnsgroup is _MARKER: + nsgroup = config.PYRO_NS_DEFAULTGROUP + else: + nsgroup = defaultnsgroup + nsid = idstr + if nsgroup is not None and not nsgroup.startswith(':'): + nsgroup = ':' + nsgroup + return nsgroup, nsid + +def host_and_port(hoststr): + if not hoststr: + return None, None + try: + hoststr, port = hoststr.split(':') + except ValueError: + port = None + else: + port = int(port) + return hoststr, port + +_DAEMONS = {} +_PYRO_OBJS = {} +def _get_daemon(daemonhost, start=True): + if not daemonhost in _DAEMONS: + if not start: + raise Exception('no daemon for %s' % daemonhost) + if not _DAEMONS: + core.initServer(banner=0) + host, port = host_and_port(daemonhost) + daemon = core.Daemon(host=host, port=port) + _DAEMONS[daemonhost] = daemon + return _DAEMONS[daemonhost] + + +def locate_ns(nshost): + """locate and return the pyro name server to the daemon""" + core.initClient(banner=False) + return naming.NameServerLocator().getNS(*host_and_port(nshost)) + + +def register_object(object, nsid, defaultnsgroup=_MARKER, + daemonhost=None, nshost=None, use_pyrons=True): + """expose the object as a pyro object and register it in the name-server + + if use_pyrons is False, then the object is exposed, but no + attempt to register it to a pyro nameserver is made. + + return the pyro daemon object + """ + nsgroup, nsid = ns_group_and_id(nsid, defaultnsgroup) + daemon = _get_daemon(daemonhost) + if use_pyrons: + nsd = locate_ns(nshost) + # make sure our namespace group exists + try: + nsd.createGroup(nsgroup) + except errors.NamingError: + pass + daemon.useNameServer(nsd) + # use Delegation approach + impl = core.ObjBase() + impl.delegateTo(object) + qnsid = '%s.%s' % (nsgroup, nsid) + uri = daemon.connect(impl, qnsid) + _PYRO_OBJS[qnsid] = str(uri) + _LOGGER.info('registered %s a pyro object using group %s and id %s', + object, nsgroup, nsid) + return daemon + +def get_object_uri(qnsid): + return _PYRO_OBJS[qnsid] + +def ns_unregister(nsid, defaultnsgroup=_MARKER, nshost=None): + """unregister the object with the given nsid from the pyro name server""" + nsgroup, nsid = ns_group_and_id(nsid, defaultnsgroup) + try: + nsd = locate_ns(nshost) + except errors.PyroError as ex: + # name server not responding + _LOGGER.error('can\'t locate pyro name server: %s', ex) + else: + try: + nsd.unregister('%s.%s' % (nsgroup, nsid)) + _LOGGER.info('%s unregistered from pyro name server', nsid) + except errors.NamingError: + _LOGGER.warning('%s not registered in pyro name server', nsid) + + +def ns_reregister(nsid, defaultnsgroup=_MARKER, nshost=None): + """reregister a pyro object into the name server. You only have to specify + the name-server id of the object (though you MUST have gone through + `register_object` for the given object previously). + + This is especially useful for long running server while the name server may + have been restarted, and its records lost. + """ + nsgroup, nsid = ns_group_and_id(nsid, defaultnsgroup) + qnsid = '%s.%s' % (nsgroup, nsid) + nsd = locate_ns(nshost) + try: + nsd.unregister(qnsid) + except errors.NamingError: + # make sure our namespace group exists + try: + nsd.createGroup(nsgroup) + except errors.NamingError: + pass + nsd.register(qnsid, _PYRO_OBJS[qnsid]) + +def ns_get_proxy(nsid, defaultnsgroup=_MARKER, nshost=None): + """ + if nshost is None, the nameserver is found by a broadcast. + """ + # resolve the Pyro object + nsgroup, nsid = ns_group_and_id(nsid, defaultnsgroup) + try: + nsd = locate_ns(nshost) + pyrouri = nsd.resolve('%s.%s' % (nsgroup, nsid)) + except errors.ProtocolError as ex: + raise errors.PyroError( + 'Could not connect to the Pyro name server (host: %s)' % nshost) + except errors.NamingError: + raise errors.PyroError( + 'Could not get proxy for %s (not registered in Pyro), ' + 'you may have to restart your server-side application' % nsid) + return core.getProxyForURI(pyrouri) + +def get_proxy(pyro_uri): + """get a proxy for the passed pyro uri without using a nameserver + """ + return core.getProxyForURI(pyro_uri) + +def set_pyro_log_threshold(level): + pyrologger = logging.getLogger('Pyro.%s' % str(id(util.Log))) + # remove handlers so only the root handler is used + pyrologger.handlers = [] + pyrologger.setLevel(level) diff --git a/logilab/common/pytest.py b/logilab/common/pytest.py new file mode 100644 index 0000000..58515a9 --- /dev/null +++ b/logilab/common/pytest.py @@ -0,0 +1,1199 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""pytest is a tool that eases test running and debugging. + +To be able to use pytest, you should either write tests using +the logilab.common.testlib's framework or the unittest module of the +Python's standard library. + +You can customize pytest's behaviour by defining a ``pytestconf.py`` file +somewhere in your test directory. In this file, you can add options or +change the way tests are run. + +To add command line options, you must define a ``update_parser`` function in +your ``pytestconf.py`` file. The function must accept a single parameter +that will be the OptionParser's instance to customize. + +If you wish to customize the tester, you'll have to define a class named +``CustomPyTester``. This class should extend the default `PyTester` class +defined in the pytest module. Take a look at the `PyTester` and `DjangoTester` +classes for more information about what can be done. + +For instance, if you wish to add a custom -l option to specify a loglevel, you +could define the following ``pytestconf.py`` file :: + + import logging + from logilab.common.pytest import PyTester + + def update_parser(parser): + parser.add_option('-l', '--loglevel', dest='loglevel', action='store', + choices=('debug', 'info', 'warning', 'error', 'critical'), + default='critical', help="the default log level possible choices are " + "('debug', 'info', 'warning', 'error', 'critical')") + return parser + + + class CustomPyTester(PyTester): + def __init__(self, cvg, options): + super(CustomPyTester, self).__init__(cvg, options) + loglevel = options.loglevel.upper() + logger = logging.getLogger('erudi') + logger.setLevel(logging.getLevelName(loglevel)) + + +In your TestCase class you can then get the value of a specific option with +the ``optval`` method:: + + class MyTestCase(TestCase): + def test_foo(self): + loglevel = self.optval('loglevel') + # ... + + +You can also tag your tag your test for fine filtering + +With those tag:: + + from logilab.common.testlib import tag, TestCase + + class Exemple(TestCase): + + @tag('rouge', 'carre') + def toto(self): + pass + + @tag('carre', 'vert') + def tata(self): + pass + + @tag('rouge') + def titi(test): + pass + +you can filter the function with a simple python expression + + * ``toto`` and ``titi`` match ``rouge`` + * ``toto``, ``tata`` and ``titi``, match ``rouge or carre`` + * ``tata`` and ``titi`` match``rouge ^ carre`` + * ``titi`` match ``rouge and not carre`` +""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +PYTEST_DOC = """%prog [OPTIONS] [testfile [testpattern]] + +examples: + +pytest path/to/mytests.py +pytest path/to/mytests.py TheseTests +pytest path/to/mytests.py TheseTests.test_thisone +pytest path/to/mytests.py -m '(not long and database) or regr' + +pytest one (will run both test_thisone and test_thatone) +pytest path/to/mytests.py -s not (will skip test_notthisone) +""" + +ENABLE_DBC = False +FILE_RESTART = ".pytest.restart" + +import os, sys, re +import os.path as osp +from time import time, clock +import warnings +import types +from inspect import isgeneratorfunction, isclass +from contextlib import contextmanager + +from logilab.common.fileutils import abspath_listdir +from logilab.common import textutils +from logilab.common import testlib, STD_BLACKLIST +# use the same unittest module as testlib +from logilab.common.testlib import unittest, start_interactive_mode +from logilab.common.deprecation import deprecated +import doctest + +import unittest as unittest_legacy +if not getattr(unittest_legacy, "__package__", None): + try: + import unittest2.suite as unittest_suite + except ImportError: + sys.exit("You have to install python-unittest2 to use this module") +else: + import unittest.suite as unittest_suite + +try: + import django + from logilab.common.modutils import modpath_from_file, load_module_from_modpath + DJANGO_FOUND = True +except ImportError: + DJANGO_FOUND = False + +CONF_FILE = 'pytestconf.py' + +## coverage pausing tools + +@contextmanager +def replace_trace(trace=None): + """A context manager that temporary replaces the trace function""" + oldtrace = sys.gettrace() + sys.settrace(trace) + try: + yield + finally: + # specific hack to work around a bug in pycoverage, see + # https://bitbucket.org/ned/coveragepy/issue/123 + if (oldtrace is not None and not callable(oldtrace) and + hasattr(oldtrace, 'pytrace')): + oldtrace = oldtrace.pytrace + sys.settrace(oldtrace) + + +def pause_trace(): + """A context manager that temporary pauses any tracing""" + return replace_trace() + +class TraceController(object): + ctx_stack = [] + + @classmethod + @deprecated('[lgc 0.63.1] Use the pause_trace() context manager') + def pause_tracing(cls): + cls.ctx_stack.append(pause_trace()) + cls.ctx_stack[-1].__enter__() + + @classmethod + @deprecated('[lgc 0.63.1] Use the pause_trace() context manager') + def resume_tracing(cls): + cls.ctx_stack.pop().__exit__(None, None, None) + + +pause_tracing = TraceController.pause_tracing +resume_tracing = TraceController.resume_tracing + + +def nocoverage(func): + """Function decorator that pauses tracing functions""" + if hasattr(func, 'uncovered'): + return func + func.uncovered = True + + def not_covered(*args, **kwargs): + with pause_trace(): + return func(*args, **kwargs) + not_covered.uncovered = True + return not_covered + +## end of coverage pausing tools + + +TESTFILE_RE = re.compile("^((unit)?test.*|smoketest)\.py$") +def this_is_a_testfile(filename): + """returns True if `filename` seems to be a test file""" + return TESTFILE_RE.match(osp.basename(filename)) + +TESTDIR_RE = re.compile("^(unit)?tests?$") +def this_is_a_testdir(dirpath): + """returns True if `filename` seems to be a test directory""" + return TESTDIR_RE.match(osp.basename(dirpath)) + + +def load_pytest_conf(path, parser): + """loads a ``pytestconf.py`` file and update default parser + and / or tester. + """ + namespace = {} + exec(open(path, 'rb').read(), namespace) + if 'update_parser' in namespace: + namespace['update_parser'](parser) + return namespace.get('CustomPyTester', PyTester) + + +def project_root(parser, projdir=os.getcwd()): + """try to find project's root and add it to sys.path""" + previousdir = curdir = osp.abspath(projdir) + testercls = PyTester + conf_file_path = osp.join(curdir, CONF_FILE) + if osp.isfile(conf_file_path): + testercls = load_pytest_conf(conf_file_path, parser) + while this_is_a_testdir(curdir) or \ + osp.isfile(osp.join(curdir, '__init__.py')): + newdir = osp.normpath(osp.join(curdir, os.pardir)) + if newdir == curdir: + break + previousdir = curdir + curdir = newdir + conf_file_path = osp.join(curdir, CONF_FILE) + if osp.isfile(conf_file_path): + testercls = load_pytest_conf(conf_file_path, parser) + return previousdir, testercls + + +class GlobalTestReport(object): + """this class holds global test statistics""" + def __init__(self): + self.ran = 0 + self.skipped = 0 + self.failures = 0 + self.errors = 0 + self.ttime = 0 + self.ctime = 0 + self.modulescount = 0 + self.errmodules = [] + + def feed(self, filename, testresult, ttime, ctime): + """integrates new test information into internal statistics""" + ran = testresult.testsRun + self.ran += ran + self.skipped += len(getattr(testresult, 'skipped', ())) + self.failures += len(testresult.failures) + self.errors += len(testresult.errors) + self.ttime += ttime + self.ctime += ctime + self.modulescount += 1 + if not testresult.wasSuccessful(): + problems = len(testresult.failures) + len(testresult.errors) + self.errmodules.append((filename[:-3], problems, ran)) + + def failed_to_test_module(self, filename): + """called when the test module could not be imported by unittest + """ + self.errors += 1 + self.modulescount += 1 + self.ran += 1 + self.errmodules.append((filename[:-3], 1, 1)) + + def skip_module(self, filename): + self.modulescount += 1 + self.ran += 1 + self.errmodules.append((filename[:-3], 0, 0)) + + def __str__(self): + """this is just presentation stuff""" + line1 = ['Ran %s test cases in %.2fs (%.2fs CPU)' + % (self.ran, self.ttime, self.ctime)] + if self.errors: + line1.append('%s errors' % self.errors) + if self.failures: + line1.append('%s failures' % self.failures) + if self.skipped: + line1.append('%s skipped' % self.skipped) + modulesok = self.modulescount - len(self.errmodules) + if self.errors or self.failures: + line2 = '%s modules OK (%s failed)' % (modulesok, + len(self.errmodules)) + descr = ', '.join(['%s [%s/%s]' % info for info in self.errmodules]) + line3 = '\nfailures: %s' % descr + elif modulesok: + line2 = 'All %s modules OK' % modulesok + line3 = '' + else: + return '' + return '%s\n%s%s' % (', '.join(line1), line2, line3) + + + +def remove_local_modules_from_sys(testdir): + """remove all modules from cache that come from `testdir` + + This is used to avoid strange side-effects when using the + testall() mode of pytest. + For instance, if we run pytest on this tree:: + + A/test/test_utils.py + B/test/test_utils.py + + we **have** to clean sys.modules to make sure the correct test_utils + module is ran in B + """ + for modname, mod in list(sys.modules.items()): + if mod is None: + continue + if not hasattr(mod, '__file__'): + # this is the case of some built-in modules like sys, imp, marshal + continue + modfile = mod.__file__ + # if modfile is not an absolute path, it was probably loaded locally + # during the tests + if not osp.isabs(modfile) or modfile.startswith(testdir): + del sys.modules[modname] + + + +class PyTester(object): + """encapsulates testrun logic""" + + def __init__(self, cvg, options): + self.report = GlobalTestReport() + self.cvg = cvg + self.options = options + self.firstwrite = True + self._errcode = None + + def show_report(self): + """prints the report and returns appropriate exitcode""" + # everything has been ran, print report + print("*" * 79) + print(self.report) + + def get_errcode(self): + # errcode set explicitly + if self._errcode is not None: + return self._errcode + return self.report.failures + self.report.errors + + def set_errcode(self, errcode): + self._errcode = errcode + errcode = property(get_errcode, set_errcode) + + def testall(self, exitfirst=False): + """walks through current working directory, finds something + which can be considered as a testdir and runs every test there + """ + here = os.getcwd() + for dirname, dirs, _ in os.walk(here): + for skipped in STD_BLACKLIST: + if skipped in dirs: + dirs.remove(skipped) + basename = osp.basename(dirname) + if this_is_a_testdir(basename): + print("going into", dirname) + # we found a testdir, let's explore it ! + if not self.testonedir(dirname, exitfirst): + break + dirs[:] = [] + if self.report.ran == 0: + print("no test dir found testing here:", here) + # if no test was found during the visit, consider + # the local directory as a test directory even if + # it doesn't have a traditional test directory name + self.testonedir(here) + + def testonedir(self, testdir, exitfirst=False): + """finds each testfile in the `testdir` and runs it + + return true when all tests has been executed, false if exitfirst and + some test has failed. + """ + for filename in abspath_listdir(testdir): + if this_is_a_testfile(filename): + if self.options.exitfirst and not self.options.restart: + # overwrite restart file + try: + restartfile = open(FILE_RESTART, "w") + restartfile.close() + except Exception: + print("Error while overwriting succeeded test file :", + osp.join(os.getcwd(), FILE_RESTART), + file=sys.__stderr__) + raise + # run test and collect information + prog = self.testfile(filename, batchmode=True) + if exitfirst and (prog is None or not prog.result.wasSuccessful()): + return False + self.firstwrite = True + # clean local modules + remove_local_modules_from_sys(testdir) + return True + + def testfile(self, filename, batchmode=False): + """runs every test in `filename` + + :param filename: an absolute path pointing to a unittest file + """ + here = os.getcwd() + dirname = osp.dirname(filename) + if dirname: + os.chdir(dirname) + # overwrite restart file if it has not been done already + if self.options.exitfirst and not self.options.restart and self.firstwrite: + try: + restartfile = open(FILE_RESTART, "w") + restartfile.close() + except Exception: + print("Error while overwriting succeeded test file :", + osp.join(os.getcwd(), FILE_RESTART), file=sys.__stderr__) + raise + modname = osp.basename(filename)[:-3] + print((' %s ' % osp.basename(filename)).center(70, '='), + file=sys.__stderr__) + try: + tstart, cstart = time(), clock() + try: + testprog = SkipAwareTestProgram(modname, batchmode=batchmode, cvg=self.cvg, + options=self.options, outstream=sys.stderr) + except KeyboardInterrupt: + raise + except SystemExit as exc: + self.errcode = exc.code + raise + except testlib.SkipTest: + print("Module skipped:", filename) + self.report.skip_module(filename) + return None + except Exception: + self.report.failed_to_test_module(filename) + print('unhandled exception occurred while testing', modname, + file=sys.stderr) + import traceback + traceback.print_exc(file=sys.stderr) + return None + + tend, cend = time(), clock() + ttime, ctime = (tend - tstart), (cend - cstart) + self.report.feed(filename, testprog.result, ttime, ctime) + return testprog + finally: + if dirname: + os.chdir(here) + + + +class DjangoTester(PyTester): + + def load_django_settings(self, dirname): + """try to find project's setting and load it""" + curdir = osp.abspath(dirname) + previousdir = curdir + while not osp.isfile(osp.join(curdir, 'settings.py')) and \ + osp.isfile(osp.join(curdir, '__init__.py')): + newdir = osp.normpath(osp.join(curdir, os.pardir)) + if newdir == curdir: + raise AssertionError('could not find settings.py') + previousdir = curdir + curdir = newdir + # late django initialization + settings = load_module_from_modpath(modpath_from_file(osp.join(curdir, 'settings.py'))) + from django.core.management import setup_environ + setup_environ(settings) + settings.DEBUG = False + self.settings = settings + # add settings dir to pythonpath since it's the project's root + if curdir not in sys.path: + sys.path.insert(1, curdir) + + def before_testfile(self): + # Those imports must be done **after** setup_environ was called + from django.test.utils import setup_test_environment + from django.test.utils import create_test_db + setup_test_environment() + create_test_db(verbosity=0) + self.dbname = self.settings.TEST_DATABASE_NAME + + def after_testfile(self): + # Those imports must be done **after** setup_environ was called + from django.test.utils import teardown_test_environment + from django.test.utils import destroy_test_db + teardown_test_environment() + print('destroying', self.dbname) + destroy_test_db(self.dbname, verbosity=0) + + def testall(self, exitfirst=False): + """walks through current working directory, finds something + which can be considered as a testdir and runs every test there + """ + for dirname, dirs, files in os.walk(os.getcwd()): + for skipped in ('CVS', '.svn', '.hg'): + if skipped in dirs: + dirs.remove(skipped) + if 'tests.py' in files: + if not self.testonedir(dirname, exitfirst): + break + dirs[:] = [] + else: + basename = osp.basename(dirname) + if basename in ('test', 'tests'): + print("going into", dirname) + # we found a testdir, let's explore it ! + if not self.testonedir(dirname, exitfirst): + break + dirs[:] = [] + + def testonedir(self, testdir, exitfirst=False): + """finds each testfile in the `testdir` and runs it + + return true when all tests has been executed, false if exitfirst and + some test has failed. + """ + # special django behaviour : if tests are splitted in several files, + # remove the main tests.py file and tests each test file separately + testfiles = [fpath for fpath in abspath_listdir(testdir) + if this_is_a_testfile(fpath)] + if len(testfiles) > 1: + try: + testfiles.remove(osp.join(testdir, 'tests.py')) + except ValueError: + pass + for filename in testfiles: + # run test and collect information + prog = self.testfile(filename, batchmode=True) + if exitfirst and (prog is None or not prog.result.wasSuccessful()): + return False + # clean local modules + remove_local_modules_from_sys(testdir) + return True + + def testfile(self, filename, batchmode=False): + """runs every test in `filename` + + :param filename: an absolute path pointing to a unittest file + """ + here = os.getcwd() + dirname = osp.dirname(filename) + if dirname: + os.chdir(dirname) + self.load_django_settings(dirname) + modname = osp.basename(filename)[:-3] + print((' %s ' % osp.basename(filename)).center(70, '='), + file=sys.stderr) + try: + try: + tstart, cstart = time(), clock() + self.before_testfile() + testprog = SkipAwareTestProgram(modname, batchmode=batchmode, cvg=self.cvg) + tend, cend = time(), clock() + ttime, ctime = (tend - tstart), (cend - cstart) + self.report.feed(filename, testprog.result, ttime, ctime) + return testprog + except SystemExit: + raise + except Exception as exc: + import traceback + traceback.print_exc() + self.report.failed_to_test_module(filename) + print('unhandled exception occurred while testing', modname) + print('error: %s' % exc) + return None + finally: + self.after_testfile() + if dirname: + os.chdir(here) + + +def make_parser(): + """creates the OptionParser instance + """ + from optparse import OptionParser + parser = OptionParser(usage=PYTEST_DOC) + + parser.newargs = [] + def rebuild_cmdline(option, opt, value, parser): + """carry the option to unittest_main""" + parser.newargs.append(opt) + + def rebuild_and_store(option, opt, value, parser): + """carry the option to unittest_main and store + the value on current parser + """ + parser.newargs.append(opt) + setattr(parser.values, option.dest, True) + + def capture_and_rebuild(option, opt, value, parser): + warnings.simplefilter('ignore', DeprecationWarning) + rebuild_cmdline(option, opt, value, parser) + + # pytest options + parser.add_option('-t', dest='testdir', default=None, + help="directory where the tests will be found") + parser.add_option('-d', dest='dbc', default=False, + action="store_true", help="enable design-by-contract") + # unittest_main options provided and passed through pytest + parser.add_option('-v', '--verbose', callback=rebuild_cmdline, + action="callback", help="Verbose output") + parser.add_option('-i', '--pdb', callback=rebuild_and_store, + dest="pdb", action="callback", + help="Enable test failure inspection") + parser.add_option('-x', '--exitfirst', callback=rebuild_and_store, + dest="exitfirst", default=False, + action="callback", help="Exit on first failure " + "(only make sense when pytest run one test file)") + parser.add_option('-R', '--restart', callback=rebuild_and_store, + dest="restart", default=False, + action="callback", + help="Restart tests from where it failed (implies exitfirst) " + "(only make sense if tests previously ran with exitfirst only)") + parser.add_option('--color', callback=rebuild_cmdline, + action="callback", + help="colorize tracebacks") + parser.add_option('-s', '--skip', + # XXX: I wish I could use the callback action but it + # doesn't seem to be able to get the value + # associated to the option + action="store", dest="skipped", default=None, + help="test names matching this name will be skipped " + "to skip several patterns, use commas") + parser.add_option('-q', '--quiet', callback=rebuild_cmdline, + action="callback", help="Minimal output") + parser.add_option('-P', '--profile', default=None, dest='profile', + help="Profile execution and store data in the given file") + parser.add_option('-m', '--match', default=None, dest='tags_pattern', + help="only execute test whose tag match the current pattern") + + if DJANGO_FOUND: + parser.add_option('-J', '--django', dest='django', default=False, + action="store_true", + help='use pytest for django test cases') + return parser + + +def parseargs(parser): + """Parse the command line and return (options processed), (options to pass to + unittest_main()), (explicitfile or None). + """ + # parse the command line + options, args = parser.parse_args() + filenames = [arg for arg in args if arg.endswith('.py')] + if filenames: + if len(filenames) > 1: + parser.error("only one filename is acceptable") + explicitfile = filenames[0] + args.remove(explicitfile) + else: + explicitfile = None + # someone wants DBC + testlib.ENABLE_DBC = options.dbc + newargs = parser.newargs + if options.skipped: + newargs.extend(['--skip', options.skipped]) + # restart implies exitfirst + if options.restart: + options.exitfirst = True + # append additional args to the new sys.argv and let unittest_main + # do the rest + newargs += args + return options, explicitfile + + + +def run(): + parser = make_parser() + rootdir, testercls = project_root(parser) + options, explicitfile = parseargs(parser) + # mock a new command line + sys.argv[1:] = parser.newargs + cvg = None + if not '' in sys.path: + sys.path.insert(0, '') + if DJANGO_FOUND and options.django: + tester = DjangoTester(cvg, options) + else: + tester = testercls(cvg, options) + if explicitfile: + cmd, args = tester.testfile, (explicitfile,) + elif options.testdir: + cmd, args = tester.testonedir, (options.testdir, options.exitfirst) + else: + cmd, args = tester.testall, (options.exitfirst,) + try: + try: + if options.profile: + import hotshot + prof = hotshot.Profile(options.profile) + prof.runcall(cmd, *args) + prof.close() + print('profile data saved in', options.profile) + else: + cmd(*args) + except SystemExit: + raise + except: + import traceback + traceback.print_exc() + finally: + tester.show_report() + sys.exit(tester.errcode) + +class SkipAwareTestProgram(unittest.TestProgram): + # XXX: don't try to stay close to unittest.py, use optparse + USAGE = """\ +Usage: %(progName)s [options] [test] [...] + +Options: + -h, --help Show this message + -v, --verbose Verbose output + -i, --pdb Enable test failure inspection + -x, --exitfirst Exit on first failure + -s, --skip skip test matching this pattern (no regexp for now) + -q, --quiet Minimal output + --color colorize tracebacks + + -m, --match Run only test whose tag match this pattern + + -P, --profile FILE: Run the tests using cProfile and saving results + in FILE + +Examples: + %(progName)s - run default set of tests + %(progName)s MyTestSuite - run suite 'MyTestSuite' + %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething + %(progName)s MyTestCase - run all 'test*' test methods + in MyTestCase +""" + def __init__(self, module='__main__', defaultTest=None, batchmode=False, + cvg=None, options=None, outstream=sys.stderr): + self.batchmode = batchmode + self.cvg = cvg + self.options = options + self.outstream = outstream + super(SkipAwareTestProgram, self).__init__( + module=module, defaultTest=defaultTest, + testLoader=NonStrictTestLoader()) + + def parseArgs(self, argv): + self.pdbmode = False + self.exitfirst = False + self.skipped_patterns = [] + self.test_pattern = None + self.tags_pattern = None + self.colorize = False + self.profile_name = None + import getopt + try: + options, args = getopt.getopt(argv[1:], 'hHvixrqcp:s:m:P:', + ['help', 'verbose', 'quiet', 'pdb', + 'exitfirst', 'restart', + 'skip=', 'color', 'match=', 'profile=']) + for opt, value in options: + if opt in ('-h', '-H', '--help'): + self.usageExit() + if opt in ('-i', '--pdb'): + self.pdbmode = True + if opt in ('-x', '--exitfirst'): + self.exitfirst = True + if opt in ('-r', '--restart'): + self.restart = True + self.exitfirst = True + if opt in ('-q', '--quiet'): + self.verbosity = 0 + if opt in ('-v', '--verbose'): + self.verbosity = 2 + if opt in ('-s', '--skip'): + self.skipped_patterns = [pat.strip() for pat in + value.split(', ')] + if opt == '--color': + self.colorize = True + if opt in ('-m', '--match'): + #self.tags_pattern = value + self.options["tag_pattern"] = value + if opt in ('-P', '--profile'): + self.profile_name = value + self.testLoader.skipped_patterns = self.skipped_patterns + if len(args) == 0 and self.defaultTest is None: + suitefunc = getattr(self.module, 'suite', None) + if isinstance(suitefunc, (types.FunctionType, + types.MethodType)): + self.test = self.module.suite() + else: + self.test = self.testLoader.loadTestsFromModule(self.module) + return + if len(args) > 0: + self.test_pattern = args[0] + self.testNames = args + else: + self.testNames = (self.defaultTest, ) + self.createTests() + except getopt.error as msg: + self.usageExit(msg) + + def runTests(self): + if self.profile_name: + import cProfile + cProfile.runctx('self._runTests()', globals(), locals(), self.profile_name ) + else: + return self._runTests() + + def _runTests(self): + self.testRunner = SkipAwareTextTestRunner(verbosity=self.verbosity, + stream=self.outstream, + exitfirst=self.exitfirst, + pdbmode=self.pdbmode, + cvg=self.cvg, + test_pattern=self.test_pattern, + skipped_patterns=self.skipped_patterns, + colorize=self.colorize, + batchmode=self.batchmode, + options=self.options) + + def removeSucceededTests(obj, succTests): + """ Recursive function that removes succTests from + a TestSuite or TestCase + """ + if isinstance(obj, unittest.TestSuite): + removeSucceededTests(obj._tests, succTests) + if isinstance(obj, list): + for el in obj[:]: + if isinstance(el, unittest.TestSuite): + removeSucceededTests(el, succTests) + elif isinstance(el, unittest.TestCase): + descr = '.'.join((el.__class__.__module__, + el.__class__.__name__, + el._testMethodName)) + if descr in succTests: + obj.remove(el) + # take care, self.options may be None + if getattr(self.options, 'restart', False): + # retrieve succeeded tests from FILE_RESTART + try: + restartfile = open(FILE_RESTART, 'r') + try: + succeededtests = list(elem.rstrip('\n\r') for elem in + restartfile.readlines()) + removeSucceededTests(self.test, succeededtests) + finally: + restartfile.close() + except Exception as ex: + raise Exception("Error while reading succeeded tests into %s: %s" + % (osp.join(os.getcwd(), FILE_RESTART), ex)) + + result = self.testRunner.run(self.test) + # help garbage collection: we want TestSuite, which hold refs to every + # executed TestCase, to be gc'ed + del self.test + if getattr(result, "debuggers", None) and \ + getattr(self, "pdbmode", None): + start_interactive_mode(result) + if not getattr(self, "batchmode", None): + sys.exit(not result.wasSuccessful()) + self.result = result + + +class SkipAwareTextTestRunner(unittest.TextTestRunner): + + def __init__(self, stream=sys.stderr, verbosity=1, + exitfirst=False, pdbmode=False, cvg=None, test_pattern=None, + skipped_patterns=(), colorize=False, batchmode=False, + options=None): + super(SkipAwareTextTestRunner, self).__init__(stream=stream, + verbosity=verbosity) + self.exitfirst = exitfirst + self.pdbmode = pdbmode + self.cvg = cvg + self.test_pattern = test_pattern + self.skipped_patterns = skipped_patterns + self.colorize = colorize + self.batchmode = batchmode + self.options = options + + def _this_is_skipped(self, testedname): + return any([(pat in testedname) for pat in self.skipped_patterns]) + + def _runcondition(self, test, skipgenerator=True): + if isinstance(test, testlib.InnerTest): + testname = test.name + else: + if isinstance(test, testlib.TestCase): + meth = test._get_test_method() + testname = '%s.%s' % (test.__name__, meth.__name__) + elif isinstance(test, types.FunctionType): + func = test + testname = func.__name__ + elif isinstance(test, types.MethodType): + cls = test.__self__.__class__ + testname = '%s.%s' % (cls.__name__, test.__name__) + else: + return True # Not sure when this happens + if isgeneratorfunction(test) and skipgenerator: + return self.does_match_tags(test) # Let inner tests decide at run time + if self._this_is_skipped(testname): + return False # this was explicitly skipped + if self.test_pattern is not None: + try: + classpattern, testpattern = self.test_pattern.split('.') + klass, name = testname.split('.') + if classpattern not in klass or testpattern not in name: + return False + except ValueError: + if self.test_pattern not in testname: + return False + + return self.does_match_tags(test) + + def does_match_tags(self, test): + if self.options is not None: + tags_pattern = getattr(self.options, 'tags_pattern', None) + if tags_pattern is not None: + tags = getattr(test, 'tags', testlib.Tags()) + if tags.inherit and isinstance(test, types.MethodType): + tags = tags | getattr(test.im_class, 'tags', testlib.Tags()) + return tags.match(tags_pattern) + return True # no pattern + + def _makeResult(self): + return testlib.SkipAwareTestResult(self.stream, self.descriptions, + self.verbosity, self.exitfirst, + self.pdbmode, self.cvg, self.colorize) + + def run(self, test): + "Run the given test case or test suite." + result = self._makeResult() + startTime = time() + test(result, runcondition=self._runcondition, options=self.options) + stopTime = time() + timeTaken = stopTime - startTime + result.printErrors() + if not self.batchmode: + self.stream.writeln(result.separator2) + run = result.testsRun + self.stream.writeln("Ran %d test%s in %.3fs" % + (run, run != 1 and "s" or "", timeTaken)) + self.stream.writeln() + if not result.wasSuccessful(): + if self.colorize: + self.stream.write(textutils.colorize_ansi("FAILED", color='red')) + else: + self.stream.write("FAILED") + else: + if self.colorize: + self.stream.write(textutils.colorize_ansi("OK", color='green')) + else: + self.stream.write("OK") + failed, errored, skipped = map(len, (result.failures, + result.errors, + result.skipped)) + + det_results = [] + for name, value in (("failures", result.failures), + ("errors",result.errors), + ("skipped", result.skipped)): + if value: + det_results.append("%s=%i" % (name, len(value))) + if det_results: + self.stream.write(" (") + self.stream.write(', '.join(det_results)) + self.stream.write(")") + self.stream.writeln("") + return result + +class NonStrictTestLoader(unittest.TestLoader): + """ + Overrides default testloader to be able to omit classname when + specifying tests to run on command line. + + For example, if the file test_foo.py contains :: + + class FooTC(TestCase): + def test_foo1(self): # ... + def test_foo2(self): # ... + def test_bar1(self): # ... + + class BarTC(TestCase): + def test_bar2(self): # ... + + 'python test_foo.py' will run the 3 tests in FooTC + 'python test_foo.py FooTC' will run the 3 tests in FooTC + 'python test_foo.py test_foo' will run test_foo1 and test_foo2 + 'python test_foo.py test_foo1' will run test_foo1 + 'python test_foo.py test_bar' will run FooTC.test_bar1 and BarTC.test_bar2 + """ + + def __init__(self): + self.skipped_patterns = () + + # some magic here to accept empty list by extending + # and to provide callable capability + def loadTestsFromNames(self, names, module=None): + suites = [] + for name in names: + suites.extend(self.loadTestsFromName(name, module)) + return self.suiteClass(suites) + + def _collect_tests(self, module): + tests = {} + for obj in vars(module).values(): + if isclass(obj) and issubclass(obj, unittest.TestCase): + classname = obj.__name__ + if classname[0] == '_' or self._this_is_skipped(classname): + continue + methodnames = [] + # obj is a TestCase class + for attrname in dir(obj): + if attrname.startswith(self.testMethodPrefix): + attr = getattr(obj, attrname) + if callable(attr): + methodnames.append(attrname) + # keep track of class (obj) for convenience + tests[classname] = (obj, methodnames) + return tests + + def loadTestsFromSuite(self, module, suitename): + try: + suite = getattr(module, suitename)() + except AttributeError: + return [] + assert hasattr(suite, '_tests'), \ + "%s.%s is not a valid TestSuite" % (module.__name__, suitename) + # python2.3 does not implement __iter__ on suites, we need to return + # _tests explicitly + return suite._tests + + def loadTestsFromName(self, name, module=None): + parts = name.split('.') + if module is None or len(parts) > 2: + # let the base class do its job here + return [super(NonStrictTestLoader, self).loadTestsFromName(name)] + tests = self._collect_tests(module) + collected = [] + if len(parts) == 1: + pattern = parts[0] + if callable(getattr(module, pattern, None) + ) and pattern not in tests: + # consider it as a suite + return self.loadTestsFromSuite(module, pattern) + if pattern in tests: + # case python unittest_foo.py MyTestTC + klass, methodnames = tests[pattern] + for methodname in methodnames: + collected = [klass(methodname) + for methodname in methodnames] + else: + # case python unittest_foo.py something + for klass, methodnames in tests.values(): + # skip methodname if matched by skipped_patterns + for skip_pattern in self.skipped_patterns: + methodnames = [methodname + for methodname in methodnames + if skip_pattern not in methodname] + collected += [klass(methodname) + for methodname in methodnames + if pattern in methodname] + elif len(parts) == 2: + # case "MyClass.test_1" + classname, pattern = parts + klass, methodnames = tests.get(classname, (None, [])) + for methodname in methodnames: + collected = [klass(methodname) for methodname in methodnames + if pattern in methodname] + return collected + + def _this_is_skipped(self, testedname): + return any([(pat in testedname) for pat in self.skipped_patterns]) + + def getTestCaseNames(self, testCaseClass): + """Return a sorted sequence of method names found within testCaseClass + """ + is_skipped = self._this_is_skipped + classname = testCaseClass.__name__ + if classname[0] == '_' or is_skipped(classname): + return [] + testnames = super(NonStrictTestLoader, self).getTestCaseNames( + testCaseClass) + return [testname for testname in testnames if not is_skipped(testname)] + + +# The 2 functions below are modified versions of the TestSuite.run method +# that is provided with unittest2 for python 2.6, in unittest2/suite.py +# It is used to monkeypatch the original implementation to support +# extra runcondition and options arguments (see in testlib.py) + +def _ts_run(self, result, runcondition=None, options=None): + self._wrapped_run(result, runcondition=runcondition, options=options) + self._tearDownPreviousClass(None, result) + self._handleModuleTearDown(result) + return result + +def _ts_wrapped_run(self, result, debug=False, runcondition=None, options=None): + for test in self: + if result.shouldStop: + break + if unittest_suite._isnotsuite(test): + self._tearDownPreviousClass(test, result) + self._handleModuleFixture(test, result) + self._handleClassSetUp(test, result) + result._previousTestClass = test.__class__ + if (getattr(test.__class__, '_classSetupFailed', False) or + getattr(result, '_moduleSetUpFailed', False)): + continue + + # --- modifications to deal with _wrapped_run --- + # original code is: + # + # if not debug: + # test(result) + # else: + # test.debug() + if hasattr(test, '_wrapped_run'): + try: + test._wrapped_run(result, debug, runcondition=runcondition, options=options) + except TypeError: + test._wrapped_run(result, debug) + elif not debug: + try: + test(result, runcondition, options) + except TypeError: + test(result) + else: + test.debug() + # --- end of modifications to deal with _wrapped_run --- + return result + +if sys.version_info >= (2, 7): + # The function below implements a modified version of the + # TestSuite.run method that is provided with python 2.7, in + # unittest/suite.py + def _ts_run(self, result, debug=False, runcondition=None, options=None): + topLevel = False + if getattr(result, '_testRunEntered', False) is False: + result._testRunEntered = topLevel = True + + self._wrapped_run(result, debug, runcondition, options) + + if topLevel: + self._tearDownPreviousClass(None, result) + self._handleModuleTearDown(result) + result._testRunEntered = False + return result + + +def enable_dbc(*args): + """ + Without arguments, return True if contracts can be enabled and should be + enabled (see option -d), return False otherwise. + + With arguments, return False if contracts can't or shouldn't be enabled, + otherwise weave ContractAspect with items passed as arguments. + """ + if not ENABLE_DBC: + return False + try: + from logilab.aspects.weaver import weaver + from logilab.aspects.lib.contracts import ContractAspect + except ImportError: + sys.stderr.write( + 'Warning: logilab.aspects is not available. Contracts disabled.') + return False + for arg in args: + weaver.weave_module(arg, ContractAspect) + return True + + +# monkeypatch unittest and doctest (ouch !) +unittest._TextTestResult = testlib.SkipAwareTestResult +unittest.TextTestRunner = SkipAwareTextTestRunner +unittest.TestLoader = NonStrictTestLoader +unittest.TestProgram = SkipAwareTestProgram + +if sys.version_info >= (2, 4): + doctest.DocTestCase.__bases__ = (testlib.TestCase,) + # XXX check python2.6 compatibility + #doctest.DocTestCase._cleanups = [] + #doctest.DocTestCase._out = [] +else: + unittest.FunctionTestCase.__bases__ = (testlib.TestCase,) +unittest.TestSuite.run = _ts_run +unittest.TestSuite._wrapped_run = _ts_wrapped_run diff --git a/logilab/common/registry.py b/logilab/common/registry.py new file mode 100644 index 0000000..a52b2eb --- /dev/null +++ b/logilab/common/registry.py @@ -0,0 +1,1119 @@ +# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of Logilab-common. +# +# Logilab-common is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 2.1 of the License, or (at your +# option) any later version. +# +# Logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with Logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""This module provides bases for predicates dispatching (the pattern in use +here is similar to what's refered as multi-dispatch or predicate-dispatch in the +literature, though a bit different since the idea is to select across different +implementation 'e.g. classes), not to dispatch a message to a function or +method. It contains the following classes: + +* :class:`RegistryStore`, the top level object which loads implementation + objects and stores them into registries. You'll usually use it to access + registries and their contained objects; + +* :class:`Registry`, the base class which contains objects semantically grouped + (for instance, sharing a same API, hence the 'implementation' name). You'll + use it to select the proper implementation according to a context. Notice you + may use registries on their own without using the store. + +.. Note:: + + implementation objects are usually designed to be accessed through the + registry and not by direct instantiation, besides to use it as base classe. + +The selection procedure is delegated to a selector, which is responsible for +scoring the object according to some context. At the end of the selection, if an +implementation has been found, an instance of this class is returned. A selector +is built from one or more predicates combined together using AND, OR, NOT +operators (actually `&`, `|` and `~`). You'll thus find some base classes to +build predicates: + +* :class:`Predicate`, the abstract base predicate class + +* :class:`AndPredicate`, :class:`OrPredicate`, :class:`NotPredicate`, which you + shouldn't have to use directly. You'll use `&`, `|` and '~' operators between + predicates directly + +* :func:`objectify_predicate` + +You'll eventually find one concrete predicate: :class:`yes` + +.. autoclass:: RegistryStore +.. autoclass:: Registry + +Predicates +---------- +.. autoclass:: Predicate +.. autofunc:: objectify_predicate +.. autoclass:: yes + +Debugging +--------- +.. autoclass:: traced_selection + +Exceptions +---------- +.. autoclass:: RegistryException +.. autoclass:: RegistryNotFound +.. autoclass:: ObjectNotFound +.. autoclass:: NoSelectableObject +""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +import sys +import types +import weakref +import traceback as tb +from os import listdir, stat +from os.path import join, isdir, exists +from logging import getLogger +from warnings import warn + +from six import string_types, add_metaclass + +from logilab.common.modutils import modpath_from_file +from logilab.common.logging_ext import set_log_methods +from logilab.common.decorators import classproperty + + +class RegistryException(Exception): + """Base class for registry exception.""" + +class RegistryNotFound(RegistryException): + """Raised when an unknown registry is requested. + + This is usually a programming/typo error. + """ + +class ObjectNotFound(RegistryException): + """Raised when an unregistered object is requested. + + This may be a programming/typo or a misconfiguration error. + """ + +class NoSelectableObject(RegistryException): + """Raised when no object is selectable for a given context.""" + def __init__(self, args, kwargs, objects): + self.args = args + self.kwargs = kwargs + self.objects = objects + + def __str__(self): + return ('args: %s, kwargs: %s\ncandidates: %s' + % (self.args, self.kwargs.keys(), self.objects)) + + +def _modname_from_path(path, extrapath=None): + modpath = modpath_from_file(path, extrapath) + # omit '__init__' from package's name to avoid loading that module + # once for each name when it is imported by some other object + # module. This supposes import in modules are done as:: + # + # from package import something + # + # not:: + # + # from package.__init__ import something + # + # which seems quite correct. + if modpath[-1] == '__init__': + modpath.pop() + return '.'.join(modpath) + + +def _toload_info(path, extrapath, _toload=None): + """Return a dictionary of <modname>: <modpath> and an ordered list of + (file, module name) to load + """ + if _toload is None: + assert isinstance(path, list) + _toload = {}, [] + for fileordir in path: + if isdir(fileordir) and exists(join(fileordir, '__init__.py')): + subfiles = [join(fileordir, fname) for fname in listdir(fileordir)] + _toload_info(subfiles, extrapath, _toload) + elif fileordir[-3:] == '.py': + modname = _modname_from_path(fileordir, extrapath) + _toload[0][modname] = fileordir + _toload[1].append((fileordir, modname)) + return _toload + + +class RegistrableObject(object): + """This is the base class for registrable objects which are selected + according to a context. + + :attr:`__registry__` + name of the registry for this object (string like 'views', + 'templates'...). You may want to define `__registries__` directly if your + object should be registered in several registries. + + :attr:`__regid__` + object's identifier in the registry (string like 'main', + 'primary', 'folder_box') + + :attr:`__select__` + class'selector + + Moreover, the `__abstract__` attribute may be set to True to indicate that a + class is abstract and should not be registered. + + You don't have to inherit from this class to put it in a registry (having + `__regid__` and `__select__` is enough), though this is needed for classes + that should be automatically registered. + """ + + __registry__ = None + __regid__ = None + __select__ = None + __abstract__ = True # see doc snipppets below (in Registry class) + + @classproperty + def __registries__(cls): + if cls.__registry__ is None: + return () + return (cls.__registry__,) + + +class RegistrableInstance(RegistrableObject): + """Inherit this class if you want instances of the classes to be + automatically registered. + """ + + def __new__(cls, *args, **kwargs): + """Add a __module__ attribute telling the module where the instance was + created, for automatic registration. + """ + obj = super(RegistrableInstance, cls).__new__(cls) + # XXX subclass must no override __new__ + filepath = tb.extract_stack(limit=2)[0][0] + obj.__module__ = _modname_from_path(filepath) + return obj + + +class Registry(dict): + """The registry store a set of implementations associated to identifier: + + * to each identifier are associated a list of implementations + + * to select an implementation of a given identifier, you should use one of the + :meth:`select` or :meth:`select_or_none` method + + * to select a list of implementations for a context, you should use the + :meth:`possible_objects` method + + * dictionary like access to an identifier will return the bare list of + implementations for this identifier. + + To be usable in a registry, the only requirement is to have a `__select__` + attribute. + + At the end of the registration process, the :meth:`__registered__` + method is called on each registered object which have them, given the + registry in which it's registered as argument. + + Registration methods: + + .. automethod: register + .. automethod: unregister + + Selection methods: + + .. automethod: select + .. automethod: select_or_none + .. automethod: possible_objects + .. automethod: object_by_id + """ + def __init__(self, debugmode): + super(Registry, self).__init__() + self.debugmode = debugmode + + def __getitem__(self, name): + """return the registry (list of implementation objects) associated to + this name + """ + try: + return super(Registry, self).__getitem__(name) + except KeyError: + exc = ObjectNotFound(name) + exc.__traceback__ = sys.exc_info()[-1] + raise exc + + @classmethod + def objid(cls, obj): + """returns a unique identifier for an object stored in the registry""" + return '%s.%s' % (obj.__module__, cls.objname(obj)) + + @classmethod + def objname(cls, obj): + """returns a readable name for an object stored in the registry""" + return getattr(obj, '__name__', id(obj)) + + def initialization_completed(self): + """call method __registered__() on registered objects when the callback + is defined""" + for objects in self.values(): + for objectcls in objects: + registered = getattr(objectcls, '__registered__', None) + if registered: + registered(self) + if self.debugmode: + wrap_predicates(_lltrace) + + def register(self, obj, oid=None, clear=False): + """base method to add an object in the registry""" + assert not '__abstract__' in obj.__dict__, obj + assert obj.__select__, obj + oid = oid or obj.__regid__ + assert oid, ('no explicit name supplied to register object %s, ' + 'which has no __regid__ set' % obj) + if clear: + objects = self[oid] = [] + else: + objects = self.setdefault(oid, []) + assert not obj in objects, 'object %s is already registered' % obj + objects.append(obj) + + def register_and_replace(self, obj, replaced): + """remove <replaced> and register <obj>""" + # XXXFIXME this is a duplication of unregister() + # remove register_and_replace in favor of unregister + register + # or simplify by calling unregister then register here + if not isinstance(replaced, string_types): + replaced = self.objid(replaced) + # prevent from misspelling + assert obj is not replaced, 'replacing an object by itself: %s' % obj + registered_objs = self.get(obj.__regid__, ()) + for index, registered in enumerate(registered_objs): + if self.objid(registered) == replaced: + del registered_objs[index] + break + else: + self.warning('trying to replace %s that is not registered with %s', + replaced, obj) + self.register(obj) + + def unregister(self, obj): + """remove object <obj> from this registry""" + objid = self.objid(obj) + oid = obj.__regid__ + for registered in self.get(oid, ()): + # use self.objid() to compare objects because vreg will probably + # have its own version of the object, loaded through execfile + if self.objid(registered) == objid: + self[oid].remove(registered) + break + else: + self.warning('can\'t remove %s, no id %s in the registry', + objid, oid) + + def all_objects(self): + """return a list containing all objects in this registry. + """ + result = [] + for objs in self.values(): + result += objs + return result + + # dynamic selection methods ################################################ + + def object_by_id(self, oid, *args, **kwargs): + """return object with the `oid` identifier. Only one object is expected + to be found. + + raise :exc:`ObjectNotFound` if there are no object with id `oid` in this + registry + + raise :exc:`AssertionError` if there is more than one object there + """ + objects = self[oid] + assert len(objects) == 1, objects + return objects[0](*args, **kwargs) + + def select(self, __oid, *args, **kwargs): + """return the most specific object among those with the given oid + according to the given context. + + raise :exc:`ObjectNotFound` if there are no object with id `oid` in this + registry + + raise :exc:`NoSelectableObject` if no object can be selected + """ + obj = self._select_best(self[__oid], *args, **kwargs) + if obj is None: + raise NoSelectableObject(args, kwargs, self[__oid] ) + return obj + + def select_or_none(self, __oid, *args, **kwargs): + """return the most specific object among those with the given oid + according to the given context, or None if no object applies. + """ + try: + return self._select_best(self[__oid], *args, **kwargs) + except ObjectNotFound: + return None + + def possible_objects(self, *args, **kwargs): + """return an iterator on possible objects in this registry for the given + context + """ + for objects in self.values(): + obj = self._select_best(objects, *args, **kwargs) + if obj is None: + continue + yield obj + + def _select_best(self, objects, *args, **kwargs): + """return an instance of the most specific object according + to parameters + + return None if not object apply (don't raise `NoSelectableObject` since + it's costly when searching objects using `possible_objects` + (e.g. searching for hooks). + """ + score, winners = 0, None + for obj in objects: + objectscore = obj.__select__(obj, *args, **kwargs) + if objectscore > score: + score, winners = objectscore, [obj] + elif objectscore > 0 and objectscore == score: + winners.append(obj) + if winners is None: + return None + if len(winners) > 1: + # log in production environement / test, error while debugging + msg = 'select ambiguity: %s\n(args: %s, kwargs: %s)' + if self.debugmode: + # raise bare exception in debug mode + raise Exception(msg % (winners, args, kwargs.keys())) + self.error(msg, winners, args, kwargs.keys()) + # return the result of calling the object + return self.selected(winners[0], args, kwargs) + + def selected(self, winner, args, kwargs): + """override here if for instance you don't want "instanciation" + """ + return winner(*args, **kwargs) + + # these are overridden by set_log_methods below + # only defining here to prevent pylint from complaining + info = warning = error = critical = exception = debug = lambda msg, *a, **kw: None + + +def obj_registries(cls, registryname=None): + """return a tuple of registry names (see __registries__)""" + if registryname: + return (registryname,) + return cls.__registries__ + + +class RegistryStore(dict): + """This class is responsible for loading objects and storing them + in their registry which is created on the fly as needed. + + It handles dynamic registration of objects and provides a + convenient api to access them. To be recognized as an object that + should be stored into one of the store's registry + (:class:`Registry`), an object must provide the following + attributes, used control how they interact with the registry: + + :attr:`__registries__` + list of registry names (string like 'views', 'templates'...) into which + the object should be registered + + :attr:`__regid__` + object identifier in the registry (string like 'main', + 'primary', 'folder_box') + + :attr:`__select__` + the object predicate selectors + + Moreover, the :attr:`__abstract__` attribute may be set to `True` + to indicate that an object is abstract and should not be registered + (such inherited attributes not considered). + + .. Note:: + + When using the store to load objects dynamically, you *always* have + to use **super()** to get the methods and attributes of the + superclasses, and not use the class identifier. If not, you'll get into + trouble at reload time. + + For example, instead of writing:: + + class Thing(Parent): + __regid__ = 'athing' + __select__ = yes() + + def f(self, arg1): + Parent.f(self, arg1) + + You must write:: + + class Thing(Parent): + __regid__ = 'athing' + __select__ = yes() + + def f(self, arg1): + super(Thing, self).f(arg1) + + Controlling object registration + ------------------------------- + + Dynamic loading is triggered by calling the + :meth:`register_objects` method, given a list of directories to + inspect for python modules. + + .. automethod: register_objects + + For each module, by default, all compatible objects are registered + automatically. However if some objects come as replacement of + other objects, or have to be included only if some condition is + met, you'll have to define a `registration_callback(vreg)` + function in the module and explicitly register **all objects** in + this module, using the api defined below. + + + .. automethod:: RegistryStore.register_all + .. automethod:: RegistryStore.register_and_replace + .. automethod:: RegistryStore.register + .. automethod:: RegistryStore.unregister + + .. Note:: + Once the function `registration_callback(vreg)` is implemented in a + module, all the objects from this module have to be explicitly + registered as it disables the automatic object registration. + + + Examples: + + .. sourcecode:: python + + def registration_callback(store): + # register everything in the module except BabarClass + store.register_all(globals().values(), __name__, (BabarClass,)) + + # conditionally register BabarClass + if 'babar_relation' in store.schema: + store.register(BabarClass) + + In this example, we register all application object classes defined in the module + except `BabarClass`. This class is then registered only if the 'babar_relation' + relation type is defined in the instance schema. + + .. sourcecode:: python + + def registration_callback(store): + store.register(Elephant) + # replace Babar by Celeste + store.register_and_replace(Celeste, Babar) + + In this example, we explicitly register classes one by one: + + * the `Elephant` class + * the `Celeste` to replace `Babar` + + If at some point we register a new appobject class in this module, it won't be + registered at all without modification to the `registration_callback` + implementation. The first example will register it though, thanks to the call + to the `register_all` method. + + Controlling registry instantiation + ---------------------------------- + + The `REGISTRY_FACTORY` class dictionary allows to specify which class should + be instantiated for a given registry name. The class associated to `None` + key will be the class used when there is no specific class for a name. + """ + + def __init__(self, debugmode=False): + super(RegistryStore, self).__init__() + self.debugmode = debugmode + + def reset(self): + """clear all registries managed by this store""" + # don't use self.clear, we want to keep existing subdictionaries + for subdict in self.values(): + subdict.clear() + self._lastmodifs = {} + + def __getitem__(self, name): + """return the registry (dictionary of class objects) associated to + this name + """ + try: + return super(RegistryStore, self).__getitem__(name) + except KeyError: + exc = RegistryNotFound(name) + exc.__traceback__ = sys.exc_info()[-1] + raise exc + + # methods for explicit (un)registration ################################### + + # default class, when no specific class set + REGISTRY_FACTORY = {None: Registry} + + def registry_class(self, regid): + """return existing registry named regid or use factory to create one and + return it""" + try: + return self.REGISTRY_FACTORY[regid] + except KeyError: + return self.REGISTRY_FACTORY[None] + + def setdefault(self, regid): + try: + return self[regid] + except RegistryNotFound: + self[regid] = self.registry_class(regid)(self.debugmode) + return self[regid] + + def register_all(self, objects, modname, butclasses=()): + """register registrable objects into `objects`. + + Registrable objects are properly configured subclasses of + :class:`RegistrableObject`. Objects which are not defined in the module + `modname` or which are in `butclasses` won't be registered. + + Typical usage is: + + .. sourcecode:: python + + store.register_all(globals().values(), __name__, (ClassIWantToRegisterExplicitly,)) + + So you get partially automatic registration, keeping manual registration + for some object (to use + :meth:`~logilab.common.registry.RegistryStore.register_and_replace` for + instance). + """ + assert isinstance(modname, string_types), \ + 'modname expected to be a module name (ie string), got %r' % modname + for obj in objects: + if self.is_registrable(obj) and obj.__module__ == modname and not obj in butclasses: + if isinstance(obj, type): + self._load_ancestors_then_object(modname, obj, butclasses) + else: + self.register(obj) + + def register(self, obj, registryname=None, oid=None, clear=False): + """register `obj` implementation into `registryname` or + `obj.__registries__` if not specified, with identifier `oid` or + `obj.__regid__` if not specified. + + If `clear` is true, all objects with the same identifier will be + previously unregistered. + """ + assert not obj.__dict__.get('__abstract__'), obj + for registryname in obj_registries(obj, registryname): + registry = self.setdefault(registryname) + registry.register(obj, oid=oid, clear=clear) + self.debug("register %s in %s['%s']", + registry.objname(obj), registryname, oid or obj.__regid__) + self._loadedmods.setdefault(obj.__module__, {})[registry.objid(obj)] = obj + + def unregister(self, obj, registryname=None): + """unregister `obj` object from the registry `registryname` or + `obj.__registries__` if not specified. + """ + for registryname in obj_registries(obj, registryname): + registry = self[registryname] + registry.unregister(obj) + self.debug("unregister %s from %s['%s']", + registry.objname(obj), registryname, obj.__regid__) + + def register_and_replace(self, obj, replaced, registryname=None): + """register `obj` object into `registryname` or + `obj.__registries__` if not specified. If found, the `replaced` object + will be unregistered first (else a warning will be issued as it is + generally unexpected). + """ + for registryname in obj_registries(obj, registryname): + registry = self[registryname] + registry.register_and_replace(obj, replaced) + self.debug("register %s in %s['%s'] instead of %s", + registry.objname(obj), registryname, obj.__regid__, + registry.objname(replaced)) + + # initialization methods ################################################### + + def init_registration(self, path, extrapath=None): + """reset registry and walk down path to return list of (path, name) + file modules to be loaded""" + # XXX make this private by renaming it to _init_registration ? + self.reset() + # compute list of all modules that have to be loaded + self._toloadmods, filemods = _toload_info(path, extrapath) + # XXX is _loadedmods still necessary ? It seems like it's useful + # to avoid loading same module twice, especially with the + # _load_ancestors_then_object logic but this needs to be checked + self._loadedmods = {} + return filemods + + def register_objects(self, path, extrapath=None): + """register all objects found walking down <path>""" + # load views from each directory in the instance's path + # XXX inline init_registration ? + filemods = self.init_registration(path, extrapath) + for filepath, modname in filemods: + self.load_file(filepath, modname) + self.initialization_completed() + + def initialization_completed(self): + """call initialization_completed() on all known registries""" + for reg in self.values(): + reg.initialization_completed() + + def _mdate(self, filepath): + """ return the modification date of a file path """ + try: + return stat(filepath)[-2] + except OSError: + # this typically happens on emacs backup files (.#foo.py) + self.warning('Unable to load %s. It is likely to be a backup file', + filepath) + return None + + def is_reload_needed(self, path): + """return True if something module changed and the registry should be + reloaded + """ + lastmodifs = self._lastmodifs + for fileordir in path: + if isdir(fileordir) and exists(join(fileordir, '__init__.py')): + if self.is_reload_needed([join(fileordir, fname) + for fname in listdir(fileordir)]): + return True + elif fileordir[-3:] == '.py': + mdate = self._mdate(fileordir) + if mdate is None: + continue # backup file, see _mdate implementation + elif "flymake" in fileordir: + # flymake + pylint in use, don't consider these they will corrupt the registry + continue + if fileordir not in lastmodifs or lastmodifs[fileordir] < mdate: + self.info('File %s changed since last visit', fileordir) + return True + return False + + def load_file(self, filepath, modname): + """ load registrable objects (if any) from a python file """ + from logilab.common.modutils import load_module_from_name + if modname in self._loadedmods: + return + self._loadedmods[modname] = {} + mdate = self._mdate(filepath) + if mdate is None: + return # backup file, see _mdate implementation + elif "flymake" in filepath: + # flymake + pylint in use, don't consider these they will corrupt the registry + return + # set update time before module loading, else we get some reloading + # weirdness in case of syntax error or other error while importing the + # module + self._lastmodifs[filepath] = mdate + # load the module + module = load_module_from_name(modname) + self.load_module(module) + + def load_module(self, module): + """Automatically handle module objects registration. + + Instances are registered as soon as they are hashable and have the + following attributes: + + * __regid__ (a string) + * __select__ (a callable) + * __registries__ (a tuple/list of string) + + For classes this is a bit more complicated : + + - first ensure parent classes are already registered + + - class with __abstract__ == True in their local dictionary are skipped + + - object class needs to have registries and identifier properly set to a + non empty string to be registered. + """ + self.info('loading %s from %s', module.__name__, module.__file__) + if hasattr(module, 'registration_callback'): + module.registration_callback(self) + else: + self.register_all(vars(module).values(), module.__name__) + + def _load_ancestors_then_object(self, modname, objectcls, butclasses=()): + """handle class registration according to rules defined in + :meth:`load_module` + """ + # backward compat, we used to allow whatever else than classes + if not isinstance(objectcls, type): + if self.is_registrable(objectcls) and objectcls.__module__ == modname: + self.register(objectcls) + return + # imported classes + objmodname = objectcls.__module__ + if objmodname != modname: + # The module of the object is not the same as the currently + # worked on module, or this is actually an instance, which + # has no module at all + if objmodname in self._toloadmods: + # if this is still scheduled for loading, let's proceed immediately, + # but using the object module + self.load_file(self._toloadmods[objmodname], objmodname) + return + # ensure object hasn't been already processed + clsid = '%s.%s' % (modname, objectcls.__name__) + if clsid in self._loadedmods[modname]: + return + self._loadedmods[modname][clsid] = objectcls + # ensure ancestors are registered + for parent in objectcls.__bases__: + self._load_ancestors_then_object(modname, parent, butclasses) + # ensure object is registrable + if objectcls in butclasses or not self.is_registrable(objectcls): + return + # backward compat + reg = self.setdefault(obj_registries(objectcls)[0]) + if reg.objname(objectcls)[0] == '_': + warn("[lgc 0.59] object whose name start with '_' won't be " + "skipped anymore at some point, use __abstract__ = True " + "instead (%s)" % objectcls, DeprecationWarning) + return + # register, finally + self.register(objectcls) + + @classmethod + def is_registrable(cls, obj): + """ensure `obj` should be registered + + as arbitrary stuff may be registered, do a lot of check and warn about + weird cases (think to dumb proxy objects) + """ + if isinstance(obj, type): + if not issubclass(obj, RegistrableObject): + # ducktyping backward compat + if not (getattr(obj, '__registries__', None) + and getattr(obj, '__regid__', None) + and getattr(obj, '__select__', None)): + return False + elif issubclass(obj, RegistrableInstance): + return False + elif not isinstance(obj, RegistrableInstance): + return False + if not obj.__regid__: + return False # no regid + registries = obj.__registries__ + if not registries: + return False # no registries + selector = obj.__select__ + if not selector: + return False # no selector + if obj.__dict__.get('__abstract__', False): + return False + # then detect potential problems that should be warned + if not isinstance(registries, (tuple, list)): + cls.warning('%s has __registries__ which is not a list or tuple', obj) + return False + if not callable(selector): + cls.warning('%s has not callable __select__', obj) + return False + return True + + # these are overridden by set_log_methods below + # only defining here to prevent pylint from complaining + info = warning = error = critical = exception = debug = lambda msg, *a, **kw: None + + +# init logging +set_log_methods(RegistryStore, getLogger('registry.store')) +set_log_methods(Registry, getLogger('registry')) + + +# helpers for debugging selectors +TRACED_OIDS = None + +def _trace_selector(cls, selector, args, ret): + vobj = args[0] + if TRACED_OIDS == 'all' or vobj.__regid__ in TRACED_OIDS: + print('%s -> %s for %s(%s)' % (cls, ret, vobj, vobj.__regid__)) + +def _lltrace(selector): + """use this decorator on your predicates so they become traceable with + :class:`traced_selection` + """ + def traced(cls, *args, **kwargs): + ret = selector(cls, *args, **kwargs) + if TRACED_OIDS is not None: + _trace_selector(cls, selector, args, ret) + return ret + traced.__name__ = selector.__name__ + traced.__doc__ = selector.__doc__ + return traced + +class traced_selection(object): # pylint: disable=C0103 + """ + Typical usage is : + + .. sourcecode:: python + + >>> from logilab.common.registry import traced_selection + >>> with traced_selection(): + ... # some code in which you want to debug selectors + ... # for all objects + + This will yield lines like this in the logs:: + + selector one_line_rset returned 0 for <class 'elephant.Babar'> + + You can also give to :class:`traced_selection` the identifiers of objects on + which you want to debug selection ('oid1' and 'oid2' in the example above). + + .. sourcecode:: python + + >>> with traced_selection( ('regid1', 'regid2') ): + ... # some code in which you want to debug selectors + ... # for objects with __regid__ 'regid1' and 'regid2' + + A potentially useful point to set up such a tracing function is + the `logilab.common.registry.Registry.select` method body. + """ + + def __init__(self, traced='all'): + self.traced = traced + + def __enter__(self): + global TRACED_OIDS + TRACED_OIDS = self.traced + + def __exit__(self, exctype, exc, traceback): + global TRACED_OIDS + TRACED_OIDS = None + return traceback is None + +# selector base classes and operations ######################################## + +def objectify_predicate(selector_func): + """Most of the time, a simple score function is enough to build a selector. + The :func:`objectify_predicate` decorator turn it into a proper selector + class:: + + @objectify_predicate + def one(cls, req, rset=None, **kwargs): + return 1 + + class MyView(View): + __select__ = View.__select__ & one() + + """ + return type(selector_func.__name__, (Predicate,), + {'__doc__': selector_func.__doc__, + '__call__': lambda self, *a, **kw: selector_func(*a, **kw)}) + + +_PREDICATES = {} + +def wrap_predicates(decorator): + for predicate in _PREDICATES.values(): + if not '_decorators' in predicate.__dict__: + predicate._decorators = set() + if decorator in predicate._decorators: + continue + predicate._decorators.add(decorator) + predicate.__call__ = decorator(predicate.__call__) + +class PredicateMetaClass(type): + def __new__(mcs, *args, **kwargs): + # use __new__ so subclasses doesn't have to call Predicate.__init__ + inst = type.__new__(mcs, *args, **kwargs) + proxy = weakref.proxy(inst, lambda p: _PREDICATES.pop(id(p))) + _PREDICATES[id(proxy)] = proxy + return inst + + +@add_metaclass(PredicateMetaClass) +class Predicate(object): + """base class for selector classes providing implementation + for operators ``&``, ``|`` and ``~`` + + This class is only here to give access to binary operators, the selector + logic itself should be implemented in the :meth:`__call__` method. Notice it + should usually accept any arbitrary arguments (the context), though that may + vary depending on your usage of the registry. + + a selector is called to help choosing the correct object for a + particular context by returning a score (`int`) telling how well + the implementation given as first argument fit to the given context. + + 0 score means that the class doesn't apply. + """ + + @property + def func_name(self): + # backward compatibility + return self.__class__.__name__ + + def search_selector(self, selector): + """search for the given selector, selector instance or tuple of + selectors in the selectors tree. Return None if not found. + """ + if self is selector: + return self + if (isinstance(selector, type) or isinstance(selector, tuple)) and \ + isinstance(self, selector): + return self + return None + + def __str__(self): + return self.__class__.__name__ + + def __and__(self, other): + return AndPredicate(self, other) + def __rand__(self, other): + return AndPredicate(other, self) + def __iand__(self, other): + return AndPredicate(self, other) + def __or__(self, other): + return OrPredicate(self, other) + def __ror__(self, other): + return OrPredicate(other, self) + def __ior__(self, other): + return OrPredicate(self, other) + + def __invert__(self): + return NotPredicate(self) + + # XXX (function | function) or (function & function) not managed yet + + def __call__(self, cls, *args, **kwargs): + return NotImplementedError("selector %s must implement its logic " + "in its __call__ method" % self.__class__) + + def __repr__(self): + return u'<Predicate %s at %x>' % (self.__class__.__name__, id(self)) + + +class MultiPredicate(Predicate): + """base class for compound selector classes""" + + def __init__(self, *selectors): + self.selectors = self.merge_selectors(selectors) + + def __str__(self): + return '%s(%s)' % (self.__class__.__name__, + ','.join(str(s) for s in self.selectors)) + + @classmethod + def merge_selectors(cls, selectors): + """deal with selector instanciation when necessary and merge + multi-selectors if possible: + + AndPredicate(AndPredicate(sel1, sel2), AndPredicate(sel3, sel4)) + ==> AndPredicate(sel1, sel2, sel3, sel4) + """ + merged_selectors = [] + for selector in selectors: + # XXX do we really want magic-transformations below? + # if so, wanna warn about them? + if isinstance(selector, types.FunctionType): + selector = objectify_predicate(selector)() + if isinstance(selector, type) and issubclass(selector, Predicate): + selector = selector() + assert isinstance(selector, Predicate), selector + if isinstance(selector, cls): + merged_selectors += selector.selectors + else: + merged_selectors.append(selector) + return merged_selectors + + def search_selector(self, selector): + """search for the given selector or selector instance (or tuple of + selectors) in the selectors tree. Return None if not found + """ + for childselector in self.selectors: + if childselector is selector: + return childselector + found = childselector.search_selector(selector) + if found is not None: + return found + # if not found in children, maybe we are looking for self? + return super(MultiPredicate, self).search_selector(selector) + + +class AndPredicate(MultiPredicate): + """and-chained selectors""" + def __call__(self, cls, *args, **kwargs): + score = 0 + for selector in self.selectors: + partscore = selector(cls, *args, **kwargs) + if not partscore: + return 0 + score += partscore + return score + + +class OrPredicate(MultiPredicate): + """or-chained selectors""" + def __call__(self, cls, *args, **kwargs): + for selector in self.selectors: + partscore = selector(cls, *args, **kwargs) + if partscore: + return partscore + return 0 + +class NotPredicate(Predicate): + """negation selector""" + def __init__(self, selector): + self.selector = selector + + def __call__(self, cls, *args, **kwargs): + score = self.selector(cls, *args, **kwargs) + return int(not score) + + def __str__(self): + return 'NOT(%s)' % self.selector + + +class yes(Predicate): # pylint: disable=C0103 + """Return the score given as parameter, with a default score of 0.5 so any + other selector take precedence. + + Usually used for objects which can be selected whatever the context, or + also sometimes to add arbitrary points to a score. + + Take care, `yes(0)` could be named 'no'... + """ + def __init__(self, score=0.5): + self.score = score + + def __call__(self, *args, **kwargs): + return self.score + + +# deprecated stuff ############################################################# + +from logilab.common.deprecation import deprecated + +@deprecated('[lgc 0.59] use Registry.objid class method instead') +def classid(cls): + return '%s.%s' % (cls.__module__, cls.__name__) + +@deprecated('[lgc 0.59] use obj_registries function instead') +def class_registries(cls, registryname): + return obj_registries(cls, registryname) + diff --git a/logilab/common/shellutils.py b/logilab/common/shellutils.py new file mode 100644 index 0000000..4e68956 --- /dev/null +++ b/logilab/common/shellutils.py @@ -0,0 +1,462 @@ +# copyright 2003-2014 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""shell/term utilities, useful to write some python scripts instead of shell +scripts. +""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +import os +import glob +import shutil +import stat +import sys +import tempfile +import time +import fnmatch +import errno +import string +import random +import subprocess +from os.path import exists, isdir, islink, basename, join + +from six import string_types +from six.moves import range, input as raw_input + +from logilab.common import STD_BLACKLIST, _handle_blacklist +from logilab.common.compat import str_to_bytes +from logilab.common.deprecation import deprecated + +try: + from logilab.common.proc import ProcInfo, NoSuchProcess +except ImportError: + # windows platform + class NoSuchProcess(Exception): pass + + def ProcInfo(pid): + raise NoSuchProcess() + + +class tempdir(object): + + def __enter__(self): + self.path = tempfile.mkdtemp() + return self.path + + def __exit__(self, exctype, value, traceback): + # rmtree in all cases + shutil.rmtree(self.path) + return traceback is None + + +class pushd(object): + def __init__(self, directory): + self.directory = directory + + def __enter__(self): + self.cwd = os.getcwd() + os.chdir(self.directory) + return self.directory + + def __exit__(self, exctype, value, traceback): + os.chdir(self.cwd) + + +def chown(path, login=None, group=None): + """Same as `os.chown` function but accepting user login or group name as + argument. If login or group is omitted, it's left unchanged. + + Note: you must own the file to chown it (or be root). Otherwise OSError is raised. + """ + if login is None: + uid = -1 + else: + try: + uid = int(login) + except ValueError: + import pwd # Platforms: Unix + uid = pwd.getpwnam(login).pw_uid + if group is None: + gid = -1 + else: + try: + gid = int(group) + except ValueError: + import grp + gid = grp.getgrnam(group).gr_gid + os.chown(path, uid, gid) + +def mv(source, destination, _action=shutil.move): + """A shell-like mv, supporting wildcards. + """ + sources = glob.glob(source) + if len(sources) > 1: + assert isdir(destination) + for filename in sources: + _action(filename, join(destination, basename(filename))) + else: + try: + source = sources[0] + except IndexError: + raise OSError('No file matching %s' % source) + if isdir(destination) and exists(destination): + destination = join(destination, basename(source)) + try: + _action(source, destination) + except OSError as ex: + raise OSError('Unable to move %r to %r (%s)' % ( + source, destination, ex)) + +def rm(*files): + """A shell-like rm, supporting wildcards. + """ + for wfile in files: + for filename in glob.glob(wfile): + if islink(filename): + os.remove(filename) + elif isdir(filename): + shutil.rmtree(filename) + else: + os.remove(filename) + +def cp(source, destination): + """A shell-like cp, supporting wildcards. + """ + mv(source, destination, _action=shutil.copy) + +def find(directory, exts, exclude=False, blacklist=STD_BLACKLIST): + """Recursively find files ending with the given extensions from the directory. + + :type directory: str + :param directory: + directory where the search should start + + :type exts: basestring or list or tuple + :param exts: + extensions or lists or extensions to search + + :type exclude: boolean + :param exts: + if this argument is True, returning files NOT ending with the given + extensions + + :type blacklist: list or tuple + :param blacklist: + optional list of files or directory to ignore, default to the value of + `logilab.common.STD_BLACKLIST` + + :rtype: list + :return: + the list of all matching files + """ + if isinstance(exts, string_types): + exts = (exts,) + if exclude: + def match(filename, exts): + for ext in exts: + if filename.endswith(ext): + return False + return True + else: + def match(filename, exts): + for ext in exts: + if filename.endswith(ext): + return True + return False + files = [] + for dirpath, dirnames, filenames in os.walk(directory): + _handle_blacklist(blacklist, dirnames, filenames) + # don't append files if the directory is blacklisted + dirname = basename(dirpath) + if dirname in blacklist: + continue + files.extend([join(dirpath, f) for f in filenames if match(f, exts)]) + return files + + +def globfind(directory, pattern, blacklist=STD_BLACKLIST): + """Recursively finds files matching glob `pattern` under `directory`. + + This is an alternative to `logilab.common.shellutils.find`. + + :type directory: str + :param directory: + directory where the search should start + + :type pattern: basestring + :param pattern: + the glob pattern (e.g *.py, foo*.py, etc.) + + :type blacklist: list or tuple + :param blacklist: + optional list of files or directory to ignore, default to the value of + `logilab.common.STD_BLACKLIST` + + :rtype: iterator + :return: + iterator over the list of all matching files + """ + for curdir, dirnames, filenames in os.walk(directory): + _handle_blacklist(blacklist, dirnames, filenames) + for fname in fnmatch.filter(filenames, pattern): + yield join(curdir, fname) + +def unzip(archive, destdir): + import zipfile + if not exists(destdir): + os.mkdir(destdir) + zfobj = zipfile.ZipFile(archive) + for name in zfobj.namelist(): + if name.endswith('/'): + os.mkdir(join(destdir, name)) + else: + outfile = open(join(destdir, name), 'wb') + outfile.write(zfobj.read(name)) + outfile.close() + + +class Execute: + """This is a deadlock safe version of popen2 (no stdin), that returns + an object with errorlevel, out and err. + """ + + def __init__(self, command): + cmd = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + self.out, self.err = cmd.communicate() + self.status = os.WEXITSTATUS(cmd.returncode) + +Execute = deprecated('Use subprocess.Popen instead')(Execute) + + +def acquire_lock(lock_file, max_try=10, delay=10, max_delay=3600): + """Acquire a lock represented by a file on the file system + + If the process written in lock file doesn't exist anymore, we remove the + lock file immediately + If age of the lock_file is greater than max_delay, then we raise a UserWarning + """ + count = abs(max_try) + while count: + try: + fd = os.open(lock_file, os.O_EXCL | os.O_RDWR | os.O_CREAT) + os.write(fd, str_to_bytes(str(os.getpid())) ) + os.close(fd) + return True + except OSError as e: + if e.errno == errno.EEXIST: + try: + fd = open(lock_file, "r") + pid = int(fd.readline()) + pi = ProcInfo(pid) + age = (time.time() - os.stat(lock_file)[stat.ST_MTIME]) + if age / max_delay > 1 : + raise UserWarning("Command '%s' (pid %s) has locked the " + "file '%s' for %s minutes" + % (pi.name(), pid, lock_file, age/60)) + except UserWarning: + raise + except NoSuchProcess: + os.remove(lock_file) + except Exception: + # The try block is not essential. can be skipped. + # Note: ProcInfo object is only available for linux + # process information are not accessible... + # or lock_file is no more present... + pass + else: + raise + count -= 1 + time.sleep(delay) + else: + raise Exception('Unable to acquire %s' % lock_file) + +def release_lock(lock_file): + """Release a lock represented by a file on the file system.""" + os.remove(lock_file) + + +class ProgressBar(object): + """A simple text progression bar.""" + + def __init__(self, nbops, size=20, stream=sys.stdout, title=''): + if title: + self._fstr = '\r%s [%%-%ss]' % (title, int(size)) + else: + self._fstr = '\r[%%-%ss]' % int(size) + self._stream = stream + self._total = nbops + self._size = size + self._current = 0 + self._progress = 0 + self._current_text = None + self._last_text_write_size = 0 + + def _get_text(self): + return self._current_text + + def _set_text(self, text=None): + if text != self._current_text: + self._current_text = text + self.refresh() + + def _del_text(self): + self.text = None + + text = property(_get_text, _set_text, _del_text) + + def update(self, offset=1, exact=False): + """Move FORWARD to new cursor position (cursor will never go backward). + + :offset: fraction of ``size`` + + :exact: + + - False: offset relative to current cursor position if True + - True: offset as an asbsolute position + + """ + if exact: + self._current = offset + else: + self._current += offset + + progress = int((float(self._current)/float(self._total))*self._size) + if progress > self._progress: + self._progress = progress + self.refresh() + + def refresh(self): + """Refresh the progression bar display.""" + self._stream.write(self._fstr % ('=' * min(self._progress, self._size)) ) + if self._last_text_write_size or self._current_text: + template = ' %%-%is' % (self._last_text_write_size) + text = self._current_text + if text is None: + text = '' + self._stream.write(template % text) + self._last_text_write_size = len(text.rstrip()) + self._stream.flush() + + def finish(self): + self._stream.write('\n') + self._stream.flush() + + +class DummyProgressBar(object): + __slot__ = ('text',) + + def refresh(self): + pass + def update(self): + pass + def finish(self): + pass + + +_MARKER = object() +class progress(object): + + def __init__(self, nbops=_MARKER, size=_MARKER, stream=_MARKER, title=_MARKER, enabled=True): + self.nbops = nbops + self.size = size + self.stream = stream + self.title = title + self.enabled = enabled + + def __enter__(self): + if self.enabled: + kwargs = {} + for attr in ('nbops', 'size', 'stream', 'title'): + value = getattr(self, attr) + if value is not _MARKER: + kwargs[attr] = value + self.pb = ProgressBar(**kwargs) + else: + self.pb = DummyProgressBar() + return self.pb + + def __exit__(self, exc_type, exc_val, exc_tb): + self.pb.finish() + +class RawInput(object): + + def __init__(self, input=None, printer=None): + self._input = input or raw_input + self._print = printer + + def ask(self, question, options, default): + assert default in options + choices = [] + for option in options: + if option == default: + label = option[0].upper() + else: + label = option[0].lower() + if len(option) > 1: + label += '(%s)' % option[1:].lower() + choices.append((option, label)) + prompt = "%s [%s]: " % (question, + '/'.join([opt[1] for opt in choices])) + tries = 3 + while tries > 0: + answer = self._input(prompt).strip().lower() + if not answer: + return default + possible = [option for option, label in choices + if option.lower().startswith(answer)] + if len(possible) == 1: + return possible[0] + elif len(possible) == 0: + msg = '%s is not an option.' % answer + else: + msg = ('%s is an ambiguous answer, do you mean %s ?' % ( + answer, ' or '.join(possible))) + if self._print: + self._print(msg) + else: + print(msg) + tries -= 1 + raise Exception('unable to get a sensible answer') + + def confirm(self, question, default_is_yes=True): + default = default_is_yes and 'y' or 'n' + answer = self.ask(question, ('y', 'n'), default) + return answer == 'y' + +ASK = RawInput() + + +def getlogin(): + """avoid using os.getlogin() because of strange tty / stdin problems + (man 3 getlogin) + Another solution would be to use $LOGNAME, $USER or $USERNAME + """ + if sys.platform != 'win32': + import pwd # Platforms: Unix + return pwd.getpwuid(os.getuid())[0] + else: + return os.environ['USERNAME'] + +def generate_password(length=8, vocab=string.ascii_letters + string.digits): + """dumb password generation function""" + pwd = '' + for i in range(length): + pwd += random.choice(vocab) + return pwd diff --git a/logilab/common/sphinx_ext.py b/logilab/common/sphinx_ext.py new file mode 100644 index 0000000..a24608c --- /dev/null +++ b/logilab/common/sphinx_ext.py @@ -0,0 +1,87 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +from logilab.common.decorators import monkeypatch + +from sphinx.ext import autodoc + +class DocstringOnlyModuleDocumenter(autodoc.ModuleDocumenter): + objtype = 'docstring' + def format_signature(self): + pass + def add_directive_header(self, sig): + pass + def document_members(self, all_members=False): + pass + + def resolve_name(self, modname, parents, path, base): + if modname is not None: + return modname, parents + [base] + return (path or '') + base, [] + + +#autodoc.add_documenter(DocstringOnlyModuleDocumenter) + +def setup(app): + app.add_autodocumenter(DocstringOnlyModuleDocumenter) + + + +from sphinx.ext.autodoc import (ViewList, Options, AutodocReporter, nodes, + assemble_option_dict, nested_parse_with_titles) + +@monkeypatch(autodoc.AutoDirective) +def run(self): + self.filename_set = set() # a set of dependent filenames + self.reporter = self.state.document.reporter + self.env = self.state.document.settings.env + self.warnings = [] + self.result = ViewList() + + # find out what documenter to call + objtype = self.name[4:] + doc_class = self._registry[objtype] + # process the options with the selected documenter's option_spec + self.genopt = Options(assemble_option_dict( + self.options.items(), doc_class.option_spec)) + # generate the output + documenter = doc_class(self, self.arguments[0]) + documenter.generate(more_content=self.content) + if not self.result: + return self.warnings + + # record all filenames as dependencies -- this will at least + # partially make automatic invalidation possible + for fn in self.filename_set: + self.env.note_dependency(fn) + + # use a custom reporter that correctly assigns lines to source + # filename/description and lineno + old_reporter = self.state.memo.reporter + self.state.memo.reporter = AutodocReporter(self.result, + self.state.memo.reporter) + if self.name in ('automodule', 'autodocstring'): + node = nodes.section() + # necessary so that the child nodes get the right source/line set + node.document = self.state.document + nested_parse_with_titles(self.state, self.result, node) + else: + node = nodes.paragraph() + node.document = self.state.document + self.state.nested_parse(self.result, 0, node) + self.state.memo.reporter = old_reporter + return self.warnings + node.children diff --git a/logilab/common/sphinxutils.py b/logilab/common/sphinxutils.py new file mode 100644 index 0000000..ab6e8a1 --- /dev/null +++ b/logilab/common/sphinxutils.py @@ -0,0 +1,122 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Sphinx utils + +ModuleGenerator: Generate a file that lists all the modules of a list of +packages in order to pull all the docstring. +This should not be used in a makefile to systematically generate sphinx +documentation! + +Typical usage: + +>>> from logilab.common.sphinxutils import ModuleGenerator +>>> mgen = ModuleGenerator('logilab common', '/home/adim/src/logilab/common') +>>> mgen.generate('api_logilab_common.rst', exclude_dirs=('test',)) +""" + +import os, sys +import os.path as osp +import inspect + +from logilab.common import STD_BLACKLIST +from logilab.common.shellutils import globfind +from logilab.common.modutils import load_module_from_file, modpath_from_file + +def module_members(module): + members = [] + for name, value in inspect.getmembers(module): + if getattr(value, '__module__', None) == module.__name__: + members.append( (name, value) ) + return sorted(members) + + +def class_members(klass): + return sorted([name for name in vars(klass) + if name not in ('__doc__', '__module__', + '__dict__', '__weakref__')]) + +class ModuleGenerator: + file_header = """.. -*- coding: utf-8 -*-\n\n%s\n""" + module_def = """ +:mod:`%s` +=======%s + +.. automodule:: %s + :members: %s +""" + class_def = """ + +.. autoclass:: %s + :members: %s + +""" + + def __init__(self, project_title, code_dir): + self.title = project_title + self.code_dir = osp.abspath(code_dir) + + def generate(self, dest_file, exclude_dirs=STD_BLACKLIST): + """make the module file""" + self.fn = open(dest_file, 'w') + num = len(self.title) + 6 + title = "=" * num + "\n %s API\n" % self.title + "=" * num + self.fn.write(self.file_header % title) + self.gen_modules(exclude_dirs=exclude_dirs) + self.fn.close() + + def gen_modules(self, exclude_dirs): + """generate all modules""" + for module in self.find_modules(exclude_dirs): + modname = module.__name__ + classes = [] + modmembers = [] + for objname, obj in module_members(module): + if inspect.isclass(obj): + classmembers = class_members(obj) + classes.append( (objname, classmembers) ) + else: + modmembers.append(objname) + self.fn.write(self.module_def % (modname, '=' * len(modname), + modname, + ', '.join(modmembers))) + for klass, members in classes: + self.fn.write(self.class_def % (klass, ', '.join(members))) + + def find_modules(self, exclude_dirs): + basepath = osp.dirname(self.code_dir) + basedir = osp.basename(basepath) + osp.sep + if basedir not in sys.path: + sys.path.insert(1, basedir) + for filepath in globfind(self.code_dir, '*.py', exclude_dirs): + if osp.basename(filepath) in ('setup.py', '__pkginfo__.py'): + continue + try: + module = load_module_from_file(filepath) + except: # module might be broken or magic + dotted_path = modpath_from_file(filepath) + module = type('.'.join(dotted_path), (), {}) # mock it + yield module + + +if __name__ == '__main__': + # example : + title, code_dir, outfile = sys.argv[1:] + generator = ModuleGenerator(title, code_dir) + # XXX modnames = ['logilab'] + generator.generate(outfile, ('test', 'tests', 'examples', + 'data', 'doc', '.hg', 'migration')) diff --git a/logilab/common/table.py b/logilab/common/table.py new file mode 100644 index 0000000..2f3df69 --- /dev/null +++ b/logilab/common/table.py @@ -0,0 +1,929 @@ +# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Table management module.""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +from six.moves import range + +class Table(object): + """Table defines a data table with column and row names. + inv: + len(self.data) <= len(self.row_names) + forall(self.data, lambda x: len(x) <= len(self.col_names)) + """ + + def __init__(self, default_value=0, col_names=None, row_names=None): + self.col_names = [] + self.row_names = [] + self.data = [] + self.default_value = default_value + if col_names: + self.create_columns(col_names) + if row_names: + self.create_rows(row_names) + + def _next_row_name(self): + return 'row%s' % (len(self.row_names)+1) + + def __iter__(self): + return iter(self.data) + + def __eq__(self, other): + if other is None: + return False + else: + return list(self) == list(other) + + __hash__ = object.__hash__ + + def __ne__(self, other): + return not self == other + + def __len__(self): + return len(self.row_names) + + ## Rows / Columns creation ################################################# + def create_rows(self, row_names): + """Appends row_names to the list of existing rows + """ + self.row_names.extend(row_names) + for row_name in row_names: + self.data.append([self.default_value]*len(self.col_names)) + + def create_columns(self, col_names): + """Appends col_names to the list of existing columns + """ + for col_name in col_names: + self.create_column(col_name) + + def create_row(self, row_name=None): + """Creates a rowname to the row_names list + """ + row_name = row_name or self._next_row_name() + self.row_names.append(row_name) + self.data.append([self.default_value]*len(self.col_names)) + + + def create_column(self, col_name): + """Creates a colname to the col_names list + """ + self.col_names.append(col_name) + for row in self.data: + row.append(self.default_value) + + ## Sort by column ########################################################## + def sort_by_column_id(self, col_id, method = 'asc'): + """Sorts the table (in-place) according to data stored in col_id + """ + try: + col_index = self.col_names.index(col_id) + self.sort_by_column_index(col_index, method) + except ValueError: + raise KeyError("Col (%s) not found in table" % (col_id)) + + + def sort_by_column_index(self, col_index, method = 'asc'): + """Sorts the table 'in-place' according to data stored in col_index + + method should be in ('asc', 'desc') + """ + sort_list = sorted([(row[col_index], row, row_name) + for row, row_name in zip(self.data, self.row_names)]) + # Sorting sort_list will sort according to col_index + # If we want reverse sort, then reverse list + if method.lower() == 'desc': + sort_list.reverse() + + # Rebuild data / row names + self.data = [] + self.row_names = [] + for val, row, row_name in sort_list: + self.data.append(row) + self.row_names.append(row_name) + + def groupby(self, colname, *others): + """builds indexes of data + :returns: nested dictionaries pointing to actual rows + """ + groups = {} + colnames = (colname,) + others + col_indexes = [self.col_names.index(col_id) for col_id in colnames] + for row in self.data: + ptr = groups + for col_index in col_indexes[:-1]: + ptr = ptr.setdefault(row[col_index], {}) + ptr = ptr.setdefault(row[col_indexes[-1]], + Table(default_value=self.default_value, + col_names=self.col_names)) + ptr.append_row(tuple(row)) + return groups + + def select(self, colname, value): + grouped = self.groupby(colname) + try: + return grouped[value] + except KeyError: + return [] + + def remove(self, colname, value): + col_index = self.col_names.index(colname) + for row in self.data[:]: + if row[col_index] == value: + self.data.remove(row) + + + ## The 'setter' part ####################################################### + def set_cell(self, row_index, col_index, data): + """sets value of cell 'row_indew', 'col_index' to data + """ + self.data[row_index][col_index] = data + + + def set_cell_by_ids(self, row_id, col_id, data): + """sets value of cell mapped by row_id and col_id to data + Raises a KeyError if row_id or col_id are not found in the table + """ + try: + row_index = self.row_names.index(row_id) + except ValueError: + raise KeyError("Row (%s) not found in table" % (row_id)) + else: + try: + col_index = self.col_names.index(col_id) + self.data[row_index][col_index] = data + except ValueError: + raise KeyError("Column (%s) not found in table" % (col_id)) + + + def set_row(self, row_index, row_data): + """sets the 'row_index' row + pre: + type(row_data) == types.ListType + len(row_data) == len(self.col_names) + """ + self.data[row_index] = row_data + + + def set_row_by_id(self, row_id, row_data): + """sets the 'row_id' column + pre: + type(row_data) == types.ListType + len(row_data) == len(self.row_names) + Raises a KeyError if row_id is not found + """ + try: + row_index = self.row_names.index(row_id) + self.set_row(row_index, row_data) + except ValueError: + raise KeyError('Row (%s) not found in table' % (row_id)) + + + def append_row(self, row_data, row_name=None): + """Appends a row to the table + pre: + type(row_data) == types.ListType + len(row_data) == len(self.col_names) + """ + row_name = row_name or self._next_row_name() + self.row_names.append(row_name) + self.data.append(row_data) + return len(self.data) - 1 + + def insert_row(self, index, row_data, row_name=None): + """Appends row_data before 'index' in the table. To make 'insert' + behave like 'list.insert', inserting in an out of range index will + insert row_data to the end of the list + pre: + type(row_data) == types.ListType + len(row_data) == len(self.col_names) + """ + row_name = row_name or self._next_row_name() + self.row_names.insert(index, row_name) + self.data.insert(index, row_data) + + + def delete_row(self, index): + """Deletes the 'index' row in the table, and returns it. + Raises an IndexError if index is out of range + """ + self.row_names.pop(index) + return self.data.pop(index) + + + def delete_row_by_id(self, row_id): + """Deletes the 'row_id' row in the table. + Raises a KeyError if row_id was not found. + """ + try: + row_index = self.row_names.index(row_id) + self.delete_row(row_index) + except ValueError: + raise KeyError('Row (%s) not found in table' % (row_id)) + + + def set_column(self, col_index, col_data): + """sets the 'col_index' column + pre: + type(col_data) == types.ListType + len(col_data) == len(self.row_names) + """ + + for row_index, cell_data in enumerate(col_data): + self.data[row_index][col_index] = cell_data + + + def set_column_by_id(self, col_id, col_data): + """sets the 'col_id' column + pre: + type(col_data) == types.ListType + len(col_data) == len(self.col_names) + Raises a KeyError if col_id is not found + """ + try: + col_index = self.col_names.index(col_id) + self.set_column(col_index, col_data) + except ValueError: + raise KeyError('Column (%s) not found in table' % (col_id)) + + + def append_column(self, col_data, col_name): + """Appends the 'col_index' column + pre: + type(col_data) == types.ListType + len(col_data) == len(self.row_names) + """ + self.col_names.append(col_name) + for row_index, cell_data in enumerate(col_data): + self.data[row_index].append(cell_data) + + + def insert_column(self, index, col_data, col_name): + """Appends col_data before 'index' in the table. To make 'insert' + behave like 'list.insert', inserting in an out of range index will + insert col_data to the end of the list + pre: + type(col_data) == types.ListType + len(col_data) == len(self.row_names) + """ + self.col_names.insert(index, col_name) + for row_index, cell_data in enumerate(col_data): + self.data[row_index].insert(index, cell_data) + + + def delete_column(self, index): + """Deletes the 'index' column in the table, and returns it. + Raises an IndexError if index is out of range + """ + self.col_names.pop(index) + return [row.pop(index) for row in self.data] + + + def delete_column_by_id(self, col_id): + """Deletes the 'col_id' col in the table. + Raises a KeyError if col_id was not found. + """ + try: + col_index = self.col_names.index(col_id) + self.delete_column(col_index) + except ValueError: + raise KeyError('Column (%s) not found in table' % (col_id)) + + + ## The 'getter' part ####################################################### + + def get_shape(self): + """Returns a tuple which represents the table's shape + """ + return len(self.row_names), len(self.col_names) + shape = property(get_shape) + + def __getitem__(self, indices): + """provided for convenience""" + rows, multirows = None, False + cols, multicols = None, False + if isinstance(indices, tuple): + rows = indices[0] + if len(indices) > 1: + cols = indices[1] + else: + rows = indices + # define row slice + if isinstance(rows, str): + try: + rows = self.row_names.index(rows) + except ValueError: + raise KeyError("Row (%s) not found in table" % (rows)) + if isinstance(rows, int): + rows = slice(rows, rows+1) + multirows = False + else: + rows = slice(None) + multirows = True + # define col slice + if isinstance(cols, str): + try: + cols = self.col_names.index(cols) + except ValueError: + raise KeyError("Column (%s) not found in table" % (cols)) + if isinstance(cols, int): + cols = slice(cols, cols+1) + multicols = False + else: + cols = slice(None) + multicols = True + # get sub-table + tab = Table() + tab.default_value = self.default_value + tab.create_rows(self.row_names[rows]) + tab.create_columns(self.col_names[cols]) + for idx, row in enumerate(self.data[rows]): + tab.set_row(idx, row[cols]) + if multirows : + if multicols: + return tab + else: + return [item[0] for item in tab.data] + else: + if multicols: + return tab.data[0] + else: + return tab.data[0][0] + + def get_cell_by_ids(self, row_id, col_id): + """Returns the element at [row_id][col_id] + """ + try: + row_index = self.row_names.index(row_id) + except ValueError: + raise KeyError("Row (%s) not found in table" % (row_id)) + else: + try: + col_index = self.col_names.index(col_id) + except ValueError: + raise KeyError("Column (%s) not found in table" % (col_id)) + return self.data[row_index][col_index] + + def get_row_by_id(self, row_id): + """Returns the 'row_id' row + """ + try: + row_index = self.row_names.index(row_id) + except ValueError: + raise KeyError("Row (%s) not found in table" % (row_id)) + return self.data[row_index] + + def get_column_by_id(self, col_id, distinct=False): + """Returns the 'col_id' col + """ + try: + col_index = self.col_names.index(col_id) + except ValueError: + raise KeyError("Column (%s) not found in table" % (col_id)) + return self.get_column(col_index, distinct) + + def get_columns(self): + """Returns all the columns in the table + """ + return [self[:, index] for index in range(len(self.col_names))] + + def get_column(self, col_index, distinct=False): + """get a column by index""" + col = [row[col_index] for row in self.data] + if distinct: + col = list(set(col)) + return col + + def apply_stylesheet(self, stylesheet): + """Applies the stylesheet to this table + """ + for instruction in stylesheet.instructions: + eval(instruction) + + + def transpose(self): + """Keeps the self object intact, and returns the transposed (rotated) + table. + """ + transposed = Table() + transposed.create_rows(self.col_names) + transposed.create_columns(self.row_names) + for col_index, column in enumerate(self.get_columns()): + transposed.set_row(col_index, column) + return transposed + + + def pprint(self): + """returns a string representing the table in a pretty + printed 'text' format. + """ + # The maximum row name (to know the start_index of the first col) + max_row_name = 0 + for row_name in self.row_names: + if len(row_name) > max_row_name: + max_row_name = len(row_name) + col_start = max_row_name + 5 + + lines = [] + # Build the 'first' line <=> the col_names one + # The first cell <=> an empty one + col_names_line = [' '*col_start] + for col_name in self.col_names: + col_names_line.append(col_name + ' '*5) + lines.append('|' + '|'.join(col_names_line) + '|') + max_line_length = len(lines[0]) + + # Build the table + for row_index, row in enumerate(self.data): + line = [] + # First, build the row_name's cell + row_name = self.row_names[row_index] + line.append(row_name + ' '*(col_start-len(row_name))) + + # Then, build all the table's cell for this line. + for col_index, cell in enumerate(row): + col_name_length = len(self.col_names[col_index]) + 5 + data = str(cell) + line.append(data + ' '*(col_name_length - len(data))) + lines.append('|' + '|'.join(line) + '|') + if len(lines[-1]) > max_line_length: + max_line_length = len(lines[-1]) + + # Wrap the table with '-' to make a frame + lines.insert(0, '-'*max_line_length) + lines.append('-'*max_line_length) + return '\n'.join(lines) + + + def __repr__(self): + return repr(self.data) + + def as_text(self): + data = [] + # We must convert cells into strings before joining them + for row in self.data: + data.append([str(cell) for cell in row]) + lines = ['\t'.join(row) for row in data] + return '\n'.join(lines) + + + +class TableStyle: + """Defines a table's style + """ + + def __init__(self, table): + + self._table = table + self.size = dict([(col_name, '1*') for col_name in table.col_names]) + # __row_column__ is a special key to define the first column which + # actually has no name (<=> left most column <=> row names column) + self.size['__row_column__'] = '1*' + self.alignment = dict([(col_name, 'right') + for col_name in table.col_names]) + self.alignment['__row_column__'] = 'right' + + # We shouldn't have to create an entry for + # the 1st col (the row_column one) + self.units = dict([(col_name, '') for col_name in table.col_names]) + self.units['__row_column__'] = '' + + # XXX FIXME : params order should be reversed for all set() methods + def set_size(self, value, col_id): + """sets the size of the specified col_id to value + """ + self.size[col_id] = value + + def set_size_by_index(self, value, col_index): + """Allows to set the size according to the column index rather than + using the column's id. + BE CAREFUL : the '0' column is the '__row_column__' one ! + """ + if col_index == 0: + col_id = '__row_column__' + else: + col_id = self._table.col_names[col_index-1] + + self.size[col_id] = value + + + def set_alignment(self, value, col_id): + """sets the alignment of the specified col_id to value + """ + self.alignment[col_id] = value + + + def set_alignment_by_index(self, value, col_index): + """Allows to set the alignment according to the column index rather than + using the column's id. + BE CAREFUL : the '0' column is the '__row_column__' one ! + """ + if col_index == 0: + col_id = '__row_column__' + else: + col_id = self._table.col_names[col_index-1] + + self.alignment[col_id] = value + + + def set_unit(self, value, col_id): + """sets the unit of the specified col_id to value + """ + self.units[col_id] = value + + + def set_unit_by_index(self, value, col_index): + """Allows to set the unit according to the column index rather than + using the column's id. + BE CAREFUL : the '0' column is the '__row_column__' one ! + (Note that in the 'unit' case, you shouldn't have to set a unit + for the 1st column (the __row__column__ one)) + """ + if col_index == 0: + col_id = '__row_column__' + else: + col_id = self._table.col_names[col_index-1] + + self.units[col_id] = value + + + def get_size(self, col_id): + """Returns the size of the specified col_id + """ + return self.size[col_id] + + + def get_size_by_index(self, col_index): + """Allows to get the size according to the column index rather than + using the column's id. + BE CAREFUL : the '0' column is the '__row_column__' one ! + """ + if col_index == 0: + col_id = '__row_column__' + else: + col_id = self._table.col_names[col_index-1] + + return self.size[col_id] + + + def get_alignment(self, col_id): + """Returns the alignment of the specified col_id + """ + return self.alignment[col_id] + + + def get_alignment_by_index(self, col_index): + """Allors to get the alignment according to the column index rather than + using the column's id. + BE CAREFUL : the '0' column is the '__row_column__' one ! + """ + if col_index == 0: + col_id = '__row_column__' + else: + col_id = self._table.col_names[col_index-1] + + return self.alignment[col_id] + + + def get_unit(self, col_id): + """Returns the unit of the specified col_id + """ + return self.units[col_id] + + + def get_unit_by_index(self, col_index): + """Allors to get the unit according to the column index rather than + using the column's id. + BE CAREFUL : the '0' column is the '__row_column__' one ! + """ + if col_index == 0: + col_id = '__row_column__' + else: + col_id = self._table.col_names[col_index-1] + + return self.units[col_id] + + +import re +CELL_PROG = re.compile("([0-9]+)_([0-9]+)") + +class TableStyleSheet: + """A simple Table stylesheet + Rules are expressions where cells are defined by the row_index + and col_index separated by an underscore ('_'). + For example, suppose you want to say that the (2,5) cell must be + the sum of its two preceding cells in the row, you would create + the following rule : + 2_5 = 2_3 + 2_4 + You can also use all the math.* operations you want. For example: + 2_5 = sqrt(2_3**2 + 2_4**2) + """ + + def __init__(self, rules = None): + rules = rules or [] + self.rules = [] + self.instructions = [] + for rule in rules: + self.add_rule(rule) + + + def add_rule(self, rule): + """Adds a rule to the stylesheet rules + """ + try: + source_code = ['from math import *'] + source_code.append(CELL_PROG.sub(r'self.data[\1][\2]', rule)) + self.instructions.append(compile('\n'.join(source_code), + 'table.py', 'exec')) + self.rules.append(rule) + except SyntaxError: + print("Bad Stylesheet Rule : %s [skipped]" % rule) + + + def add_rowsum_rule(self, dest_cell, row_index, start_col, end_col): + """Creates and adds a rule to sum over the row at row_index from + start_col to end_col. + dest_cell is a tuple of two elements (x,y) of the destination cell + No check is done for indexes ranges. + pre: + start_col >= 0 + end_col > start_col + """ + cell_list = ['%d_%d'%(row_index, index) for index in range(start_col, + end_col + 1)] + rule = '%d_%d=' % dest_cell + '+'.join(cell_list) + self.add_rule(rule) + + + def add_rowavg_rule(self, dest_cell, row_index, start_col, end_col): + """Creates and adds a rule to make the row average (from start_col + to end_col) + dest_cell is a tuple of two elements (x,y) of the destination cell + No check is done for indexes ranges. + pre: + start_col >= 0 + end_col > start_col + """ + cell_list = ['%d_%d'%(row_index, index) for index in range(start_col, + end_col + 1)] + num = (end_col - start_col + 1) + rule = '%d_%d=' % dest_cell + '('+'+'.join(cell_list)+')/%f'%num + self.add_rule(rule) + + + def add_colsum_rule(self, dest_cell, col_index, start_row, end_row): + """Creates and adds a rule to sum over the col at col_index from + start_row to end_row. + dest_cell is a tuple of two elements (x,y) of the destination cell + No check is done for indexes ranges. + pre: + start_row >= 0 + end_row > start_row + """ + cell_list = ['%d_%d'%(index, col_index) for index in range(start_row, + end_row + 1)] + rule = '%d_%d=' % dest_cell + '+'.join(cell_list) + self.add_rule(rule) + + + def add_colavg_rule(self, dest_cell, col_index, start_row, end_row): + """Creates and adds a rule to make the col average (from start_row + to end_row) + dest_cell is a tuple of two elements (x,y) of the destination cell + No check is done for indexes ranges. + pre: + start_row >= 0 + end_row > start_row + """ + cell_list = ['%d_%d'%(index, col_index) for index in range(start_row, + end_row + 1)] + num = (end_row - start_row + 1) + rule = '%d_%d=' % dest_cell + '('+'+'.join(cell_list)+')/%f'%num + self.add_rule(rule) + + + +class TableCellRenderer: + """Defines a simple text renderer + """ + + def __init__(self, **properties): + """keywords should be properties with an associated boolean as value. + For example : + renderer = TableCellRenderer(units = True, alignment = False) + An unspecified property will have a 'False' value by default. + Possible properties are : + alignment, unit + """ + self.properties = properties + + + def render_cell(self, cell_coord, table, table_style): + """Renders the cell at 'cell_coord' in the table, using table_style + """ + row_index, col_index = cell_coord + cell_value = table.data[row_index][col_index] + final_content = self._make_cell_content(cell_value, + table_style, col_index +1) + return self._render_cell_content(final_content, + table_style, col_index + 1) + + + def render_row_cell(self, row_name, table, table_style): + """Renders the cell for 'row_id' row + """ + cell_value = row_name + return self._render_cell_content(cell_value, table_style, 0) + + + def render_col_cell(self, col_name, table, table_style): + """Renders the cell for 'col_id' row + """ + cell_value = col_name + col_index = table.col_names.index(col_name) + return self._render_cell_content(cell_value, table_style, col_index +1) + + + + def _render_cell_content(self, content, table_style, col_index): + """Makes the appropriate rendering for this cell content. + Rendering properties will be searched using the + *table_style.get_xxx_by_index(col_index)' methods + + **This method should be overridden in the derived renderer classes.** + """ + return content + + + def _make_cell_content(self, cell_content, table_style, col_index): + """Makes the cell content (adds decoration data, like units for + example) + """ + final_content = cell_content + if 'skip_zero' in self.properties: + replacement_char = self.properties['skip_zero'] + else: + replacement_char = 0 + if replacement_char and final_content == 0: + return replacement_char + + try: + units_on = self.properties['units'] + if units_on: + final_content = self._add_unit( + cell_content, table_style, col_index) + except KeyError: + pass + + return final_content + + + def _add_unit(self, cell_content, table_style, col_index): + """Adds unit to the cell_content if needed + """ + unit = table_style.get_unit_by_index(col_index) + return str(cell_content) + " " + unit + + + +class DocbookRenderer(TableCellRenderer): + """Defines how to render a cell for a docboook table + """ + + def define_col_header(self, col_index, table_style): + """Computes the colspec element according to the style + """ + size = table_style.get_size_by_index(col_index) + return '<colspec colname="c%d" colwidth="%s"/>\n' % \ + (col_index, size) + + + def _render_cell_content(self, cell_content, table_style, col_index): + """Makes the appropriate rendering for this cell content. + Rendering properties will be searched using the + table_style.get_xxx_by_index(col_index)' methods. + """ + try: + align_on = self.properties['alignment'] + alignment = table_style.get_alignment_by_index(col_index) + if align_on: + return "<entry align='%s'>%s</entry>\n" % \ + (alignment, cell_content) + except KeyError: + # KeyError <=> Default alignment + return "<entry>%s</entry>\n" % cell_content + + +class TableWriter: + """A class to write tables + """ + + def __init__(self, stream, table, style, **properties): + self._stream = stream + self.style = style or TableStyle(table) + self._table = table + self.properties = properties + self.renderer = None + + + def set_style(self, style): + """sets the table's associated style + """ + self.style = style + + + def set_renderer(self, renderer): + """sets the way to render cell + """ + self.renderer = renderer + + + def update_properties(self, **properties): + """Updates writer's properties (for cell rendering) + """ + self.properties.update(properties) + + + def write_table(self, title = ""): + """Writes the table + """ + raise NotImplementedError("write_table must be implemented !") + + + +class DocbookTableWriter(TableWriter): + """Defines an implementation of TableWriter to write a table in Docbook + """ + + def _write_headers(self): + """Writes col headers + """ + # Define col_headers (colstpec elements) + for col_index in range(len(self._table.col_names)+1): + self._stream.write(self.renderer.define_col_header(col_index, + self.style)) + + self._stream.write("<thead>\n<row>\n") + # XXX FIXME : write an empty entry <=> the first (__row_column) column + self._stream.write('<entry></entry>\n') + for col_name in self._table.col_names: + self._stream.write(self.renderer.render_col_cell( + col_name, self._table, + self.style)) + + self._stream.write("</row>\n</thead>\n") + + + def _write_body(self): + """Writes the table body + """ + self._stream.write('<tbody>\n') + + for row_index, row in enumerate(self._table.data): + self._stream.write('<row>\n') + row_name = self._table.row_names[row_index] + # Write the first entry (row_name) + self._stream.write(self.renderer.render_row_cell(row_name, + self._table, + self.style)) + + for col_index, cell in enumerate(row): + self._stream.write(self.renderer.render_cell( + (row_index, col_index), + self._table, self.style)) + + self._stream.write('</row>\n') + + self._stream.write('</tbody>\n') + + + def write_table(self, title = ""): + """Writes the table + """ + self._stream.write('<table>\n<title>%s></title>\n'%(title)) + self._stream.write( + '<tgroup cols="%d" align="left" colsep="1" rowsep="1">\n'% + (len(self._table.col_names)+1)) + self._write_headers() + self._write_body() + + self._stream.write('</tgroup>\n</table>\n') + + diff --git a/logilab/common/tasksqueue.py b/logilab/common/tasksqueue.py new file mode 100644 index 0000000..ed74cf5 --- /dev/null +++ b/logilab/common/tasksqueue.py @@ -0,0 +1,101 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Prioritized tasks queue""" + +__docformat__ = "restructuredtext en" + +from bisect import insort_left + +from six.moves import queue + +LOW = 0 +MEDIUM = 10 +HIGH = 100 + +PRIORITY = { + 'LOW': LOW, + 'MEDIUM': MEDIUM, + 'HIGH': HIGH, + } +REVERSE_PRIORITY = dict((values, key) for key, values in PRIORITY.items()) + + + +class PrioritizedTasksQueue(queue.Queue): + + def _init(self, maxsize): + """Initialize the queue representation""" + self.maxsize = maxsize + # ordered list of task, from the lowest to the highest priority + self.queue = [] + + def _put(self, item): + """Put a new item in the queue""" + for i, task in enumerate(self.queue): + # equivalent task + if task == item: + # if new task has a higher priority, remove the one already + # queued so the new priority will be considered + if task < item: + item.merge(task) + del self.queue[i] + break + # else keep it so current order is kept + task.merge(item) + return + insort_left(self.queue, item) + + def _get(self): + """Get an item from the queue""" + return self.queue.pop() + + def __iter__(self): + return iter(self.queue) + + def remove(self, tid): + """remove a specific task from the queue""" + # XXX acquire lock + for i, task in enumerate(self): + if task.id == tid: + self.queue.pop(i) + return + raise ValueError('not task of id %s in queue' % tid) + +class Task(object): + def __init__(self, tid, priority=LOW): + # task id + self.id = tid + # task priority + self.priority = priority + + def __repr__(self): + return '<Task %s @%#x>' % (self.id, id(self)) + + def __cmp__(self, other): + return cmp(self.priority, other.priority) + + def __lt__(self, other): + return self.priority < other.priority + + def __eq__(self, other): + return self.id == other.id + + __hash__ = object.__hash__ + + def merge(self, other): + pass diff --git a/logilab/common/testlib.py b/logilab/common/testlib.py new file mode 100644 index 0000000..31efe56 --- /dev/null +++ b/logilab/common/testlib.py @@ -0,0 +1,1392 @@ +# -*- coding: utf-8 -*- +# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Run tests. + +This will find all modules whose name match a given prefix in the test +directory, and run them. Various command line options provide +additional facilities. + +Command line options: + + -v verbose -- run tests in verbose mode with output to stdout + -q quiet -- don't print anything except if a test fails + -t testdir -- directory where the tests will be found + -x exclude -- add a test to exclude + -p profile -- profiled execution + -d dbc -- enable design-by-contract + -m match -- only run test matching the tag pattern which follow + +If no non-option arguments are present, prefixes used are 'test', +'regrtest', 'smoketest' and 'unittest'. + +""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" +# modified copy of some functions from test/regrtest.py from PyXml +# disable camel case warning +# pylint: disable=C0103 + +import sys +import os, os.path as osp +import re +import traceback +import inspect +import difflib +import tempfile +import math +import warnings +from shutil import rmtree +from operator import itemgetter +from itertools import dropwhile +from inspect import isgeneratorfunction + +from six import string_types +from six.moves import builtins, range, configparser, input + +from logilab.common.deprecation import deprecated + +import unittest as unittest_legacy +if not getattr(unittest_legacy, "__package__", None): + try: + import unittest2 as unittest + from unittest2 import SkipTest + except ImportError: + raise ImportError("You have to install python-unittest2 to use %s" % __name__) +else: + import unittest + from unittest import SkipTest + +from functools import wraps + +from logilab.common.debugger import Debugger, colorize_source +from logilab.common.decorators import cached, classproperty +from logilab.common import textutils + + +__all__ = ['main', 'unittest_main', 'find_tests', 'run_test', 'spawn'] + +DEFAULT_PREFIXES = ('test', 'regrtest', 'smoketest', 'unittest', + 'func', 'validation') + +is_generator = deprecated('[lgc 0.63] use inspect.isgeneratorfunction')(isgeneratorfunction) + +# used by unittest to count the number of relevant levels in the traceback +__unittest = 1 + + +def with_tempdir(callable): + """A decorator ensuring no temporary file left when the function return + Work only for temporary file create with the tempfile module""" + if isgeneratorfunction(callable): + def proxy(*args, **kwargs): + old_tmpdir = tempfile.gettempdir() + new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-") + tempfile.tempdir = new_tmpdir + try: + for x in callable(*args, **kwargs): + yield x + finally: + try: + rmtree(new_tmpdir, ignore_errors=True) + finally: + tempfile.tempdir = old_tmpdir + return proxy + + @wraps(callable) + def proxy(*args, **kargs): + + old_tmpdir = tempfile.gettempdir() + new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-") + tempfile.tempdir = new_tmpdir + try: + return callable(*args, **kargs) + finally: + try: + rmtree(new_tmpdir, ignore_errors=True) + finally: + tempfile.tempdir = old_tmpdir + return proxy + +def in_tempdir(callable): + """A decorator moving the enclosed function inside the tempfile.tempfdir + """ + @wraps(callable) + def proxy(*args, **kargs): + + old_cwd = os.getcwd() + os.chdir(tempfile.tempdir) + try: + return callable(*args, **kargs) + finally: + os.chdir(old_cwd) + return proxy + +def within_tempdir(callable): + """A decorator run the enclosed function inside a tmpdir removed after execution + """ + proxy = with_tempdir(in_tempdir(callable)) + proxy.__name__ = callable.__name__ + return proxy + +def find_tests(testdir, + prefixes=DEFAULT_PREFIXES, suffix=".py", + excludes=(), + remove_suffix=True): + """ + Return a list of all applicable test modules. + """ + tests = [] + for name in os.listdir(testdir): + if not suffix or name.endswith(suffix): + for prefix in prefixes: + if name.startswith(prefix): + if remove_suffix and name.endswith(suffix): + name = name[:-len(suffix)] + if name not in excludes: + tests.append(name) + tests.sort() + return tests + + +## PostMortem Debug facilities ##### +def start_interactive_mode(result): + """starts an interactive shell so that the user can inspect errors + """ + debuggers = result.debuggers + descrs = result.error_descrs + result.fail_descrs + if len(debuggers) == 1: + # don't ask for test name if there's only one failure + debuggers[0].start() + else: + while True: + testindex = 0 + print("Choose a test to debug:") + # order debuggers in the same way than errors were printed + print("\n".join(['\t%s : %s' % (i, descr) for i, (_, descr) + in enumerate(descrs)])) + print("Type 'exit' (or ^D) to quit") + print() + try: + todebug = input('Enter a test name: ') + if todebug.strip().lower() == 'exit': + print() + break + else: + try: + testindex = int(todebug) + debugger = debuggers[descrs[testindex][0]] + except (ValueError, IndexError): + print("ERROR: invalid test number %r" % (todebug, )) + else: + debugger.start() + except (EOFError, KeyboardInterrupt): + print() + break + + +# test utils ################################################################## + +class SkipAwareTestResult(unittest._TextTestResult): + + def __init__(self, stream, descriptions, verbosity, + exitfirst=False, pdbmode=False, cvg=None, colorize=False): + super(SkipAwareTestResult, self).__init__(stream, + descriptions, verbosity) + self.skipped = [] + self.debuggers = [] + self.fail_descrs = [] + self.error_descrs = [] + self.exitfirst = exitfirst + self.pdbmode = pdbmode + self.cvg = cvg + self.colorize = colorize + self.pdbclass = Debugger + self.verbose = verbosity > 1 + + def descrs_for(self, flavour): + return getattr(self, '%s_descrs' % flavour.lower()) + + def _create_pdb(self, test_descr, flavour): + self.descrs_for(flavour).append( (len(self.debuggers), test_descr) ) + if self.pdbmode: + self.debuggers.append(self.pdbclass(sys.exc_info()[2])) + + def _iter_valid_frames(self, frames): + """only consider non-testlib frames when formatting traceback""" + lgc_testlib = osp.abspath(__file__) + std_testlib = osp.abspath(unittest.__file__) + invalid = lambda fi: osp.abspath(fi[1]) in (lgc_testlib, std_testlib) + for frameinfo in dropwhile(invalid, frames): + yield frameinfo + + def _exc_info_to_string(self, err, test): + """Converts a sys.exc_info()-style tuple of values into a string. + + This method is overridden here because we want to colorize + lines if --color is passed, and display local variables if + --verbose is passed + """ + exctype, exc, tb = err + output = ['Traceback (most recent call last)'] + frames = inspect.getinnerframes(tb) + colorize = self.colorize + frames = enumerate(self._iter_valid_frames(frames)) + for index, (frame, filename, lineno, funcname, ctx, ctxindex) in frames: + filename = osp.abspath(filename) + if ctx is None: # pyc files or C extensions for instance + source = '<no source available>' + else: + source = ''.join(ctx) + if colorize: + filename = textutils.colorize_ansi(filename, 'magenta') + source = colorize_source(source) + output.append(' File "%s", line %s, in %s' % (filename, lineno, funcname)) + output.append(' %s' % source.strip()) + if self.verbose: + output.append('%r == %r' % (dir(frame), test.__module__)) + output.append('') + output.append(' ' + ' local variables '.center(66, '-')) + for varname, value in sorted(frame.f_locals.items()): + output.append(' %s: %r' % (varname, value)) + if varname == 'self': # special handy processing for self + for varname, value in sorted(vars(value).items()): + output.append(' self.%s: %r' % (varname, value)) + output.append(' ' + '-' * 66) + output.append('') + output.append(''.join(traceback.format_exception_only(exctype, exc))) + return '\n'.join(output) + + def addError(self, test, err): + """err -> (exc_type, exc, tcbk)""" + exc_type, exc, _ = err + if isinstance(exc, SkipTest): + assert exc_type == SkipTest + self.addSkip(test, exc) + else: + if self.exitfirst: + self.shouldStop = True + descr = self.getDescription(test) + super(SkipAwareTestResult, self).addError(test, err) + self._create_pdb(descr, 'error') + + def addFailure(self, test, err): + if self.exitfirst: + self.shouldStop = True + descr = self.getDescription(test) + super(SkipAwareTestResult, self).addFailure(test, err) + self._create_pdb(descr, 'fail') + + def addSkip(self, test, reason): + self.skipped.append((test, reason)) + if self.showAll: + self.stream.writeln("SKIPPED") + elif self.dots: + self.stream.write('S') + + def printErrors(self): + super(SkipAwareTestResult, self).printErrors() + self.printSkippedList() + + def printSkippedList(self): + # format (test, err) compatible with unittest2 + for test, err in self.skipped: + descr = self.getDescription(test) + self.stream.writeln(self.separator1) + self.stream.writeln("%s: %s" % ('SKIPPED', descr)) + self.stream.writeln("\t%s" % err) + + def printErrorList(self, flavour, errors): + for (_, descr), (test, err) in zip(self.descrs_for(flavour), errors): + self.stream.writeln(self.separator1) + self.stream.writeln("%s: %s" % (flavour, descr)) + self.stream.writeln(self.separator2) + self.stream.writeln(err) + self.stream.writeln('no stdout'.center(len(self.separator2))) + self.stream.writeln('no stderr'.center(len(self.separator2))) + +# Add deprecation warnings about new api used by module level fixtures in unittest2 +# http://www.voidspace.org.uk/python/articles/unittest2.shtml#setupmodule-and-teardownmodule +class _DebugResult(object): # simplify import statement among unittest flavors.. + "Used by the TestSuite to hold previous class when running in debug." + _previousTestClass = None + _moduleSetUpFailed = False + shouldStop = False + +from logilab.common.decorators import monkeypatch +@monkeypatch(unittest.TestSuite) +def _handleModuleTearDown(self, result): + previousModule = self._get_previous_module(result) + if previousModule is None: + return + if result._moduleSetUpFailed: + return + try: + module = sys.modules[previousModule] + except KeyError: + return + # add testlib specific deprecation warning and switch to new api + if hasattr(module, 'teardown_module'): + warnings.warn('Please rename teardown_module() to tearDownModule() instead.', + DeprecationWarning) + setattr(module, 'tearDownModule', module.teardown_module) + # end of monkey-patching + tearDownModule = getattr(module, 'tearDownModule', None) + if tearDownModule is not None: + try: + tearDownModule() + except Exception as e: + if isinstance(result, _DebugResult): + raise + errorName = 'tearDownModule (%s)' % previousModule + self._addClassOrModuleLevelException(result, e, errorName) + +@monkeypatch(unittest.TestSuite) +def _handleModuleFixture(self, test, result): + previousModule = self._get_previous_module(result) + currentModule = test.__class__.__module__ + if currentModule == previousModule: + return + self._handleModuleTearDown(result) + result._moduleSetUpFailed = False + try: + module = sys.modules[currentModule] + except KeyError: + return + # add testlib specific deprecation warning and switch to new api + if hasattr(module, 'setup_module'): + warnings.warn('Please rename setup_module() to setUpModule() instead.', + DeprecationWarning) + setattr(module, 'setUpModule', module.setup_module) + # end of monkey-patching + setUpModule = getattr(module, 'setUpModule', None) + if setUpModule is not None: + try: + setUpModule() + except Exception as e: + if isinstance(result, _DebugResult): + raise + result._moduleSetUpFailed = True + errorName = 'setUpModule (%s)' % currentModule + self._addClassOrModuleLevelException(result, e, errorName) + +# backward compatibility: TestSuite might be imported from lgc.testlib +TestSuite = unittest.TestSuite + +class keywords(dict): + """Keyword args (**kwargs) support for generative tests.""" + +class starargs(tuple): + """Variable arguments (*args) for generative tests.""" + def __new__(cls, *args): + return tuple.__new__(cls, args) + +unittest_main = unittest.main + + +class InnerTestSkipped(SkipTest): + """raised when a test is skipped""" + pass + +def parse_generative_args(params): + args = [] + varargs = () + kwargs = {} + flags = 0 # 2 <=> starargs, 4 <=> kwargs + for param in params: + if isinstance(param, starargs): + varargs = param + if flags: + raise TypeError('found starargs after keywords !') + flags |= 2 + args += list(varargs) + elif isinstance(param, keywords): + kwargs = param + if flags & 4: + raise TypeError('got multiple keywords parameters') + flags |= 4 + elif flags & 2 or flags & 4: + raise TypeError('found parameters after kwargs or args') + else: + args.append(param) + + return args, kwargs + + +class InnerTest(tuple): + def __new__(cls, name, *data): + instance = tuple.__new__(cls, data) + instance.name = name + return instance + +class Tags(set): + """A set of tag able validate an expression""" + + def __init__(self, *tags, **kwargs): + self.inherit = kwargs.pop('inherit', True) + if kwargs: + raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys()) + + if len(tags) == 1 and not isinstance(tags[0], string_types): + tags = tags[0] + super(Tags, self).__init__(tags, **kwargs) + + def __getitem__(self, key): + return key in self + + def match(self, exp): + return eval(exp, {}, self) + + +# duplicate definition from unittest2 of the _deprecate decorator +def _deprecate(original_func): + def deprecated_func(*args, **kwargs): + warnings.warn( + ('Please use %s instead.' % original_func.__name__), + DeprecationWarning, 2) + return original_func(*args, **kwargs) + return deprecated_func + +class TestCase(unittest.TestCase): + """A unittest.TestCase extension with some additional methods.""" + maxDiff = None + pdbclass = Debugger + tags = Tags() + + def __init__(self, methodName='runTest'): + super(TestCase, self).__init__(methodName) + self.__exc_info = sys.exc_info + self.__testMethodName = self._testMethodName + self._current_test_descr = None + self._options_ = None + + @classproperty + @cached + def datadir(cls): # pylint: disable=E0213 + """helper attribute holding the standard test's data directory + + NOTE: this is a logilab's standard + """ + mod = __import__(cls.__module__) + return osp.join(osp.dirname(osp.abspath(mod.__file__)), 'data') + # cache it (use a class method to cache on class since TestCase is + # instantiated for each test run) + + @classmethod + def datapath(cls, *fname): + """joins the object's datadir and `fname`""" + return osp.join(cls.datadir, *fname) + + def set_description(self, descr): + """sets the current test's description. + This can be useful for generative tests because it allows to specify + a description per yield + """ + self._current_test_descr = descr + + # override default's unittest.py feature + def shortDescription(self): + """override default unittest shortDescription to handle correctly + generative tests + """ + if self._current_test_descr is not None: + return self._current_test_descr + return super(TestCase, self).shortDescription() + + def quiet_run(self, result, func, *args, **kwargs): + try: + func(*args, **kwargs) + except (KeyboardInterrupt, SystemExit): + raise + except unittest.SkipTest as e: + if hasattr(result, 'addSkip'): + result.addSkip(self, str(e)) + else: + warnings.warn("TestResult has no addSkip method, skips not reported", + RuntimeWarning, 2) + result.addSuccess(self) + return False + except: + result.addError(self, self.__exc_info()) + return False + return True + + def _get_test_method(self): + """return the test method""" + return getattr(self, self._testMethodName) + + def optval(self, option, default=None): + """return the option value or default if the option is not define""" + return getattr(self._options_, option, default) + + def __call__(self, result=None, runcondition=None, options=None): + """rewrite TestCase.__call__ to support generative tests + This is mostly a copy/paste from unittest.py (i.e same + variable names, same logic, except for the generative tests part) + """ + from logilab.common.pytest import FILE_RESTART + if result is None: + result = self.defaultTestResult() + result.pdbclass = self.pdbclass + self._options_ = options + # if result.cvg: + # result.cvg.start() + testMethod = self._get_test_method() + if (getattr(self.__class__, "__unittest_skip__", False) or + getattr(testMethod, "__unittest_skip__", False)): + # If the class or method was skipped. + try: + skip_why = (getattr(self.__class__, '__unittest_skip_why__', '') + or getattr(testMethod, '__unittest_skip_why__', '')) + self._addSkip(result, skip_why) + finally: + result.stopTest(self) + return + if runcondition and not runcondition(testMethod): + return # test is skipped + result.startTest(self) + try: + if not self.quiet_run(result, self.setUp): + return + generative = isgeneratorfunction(testMethod) + # generative tests + if generative: + self._proceed_generative(result, testMethod, + runcondition) + else: + status = self._proceed(result, testMethod) + success = (status == 0) + if not self.quiet_run(result, self.tearDown): + return + if not generative and success: + if hasattr(options, "exitfirst") and options.exitfirst: + # add this test to restart file + try: + restartfile = open(FILE_RESTART, 'a') + try: + descr = '.'.join((self.__class__.__module__, + self.__class__.__name__, + self._testMethodName)) + restartfile.write(descr+os.linesep) + finally: + restartfile.close() + except Exception: + print("Error while saving succeeded test into", + osp.join(os.getcwd(), FILE_RESTART), + file=sys.__stderr__) + raise + result.addSuccess(self) + finally: + # if result.cvg: + # result.cvg.stop() + result.stopTest(self) + + def _proceed_generative(self, result, testfunc, runcondition=None): + # cancel startTest()'s increment + result.testsRun -= 1 + success = True + try: + for params in testfunc(): + if runcondition and not runcondition(testfunc, + skipgenerator=False): + if not (isinstance(params, InnerTest) + and runcondition(params)): + continue + if not isinstance(params, (tuple, list)): + params = (params, ) + func = params[0] + args, kwargs = parse_generative_args(params[1:]) + # increment test counter manually + result.testsRun += 1 + status = self._proceed(result, func, args, kwargs) + if status == 0: + result.addSuccess(self) + success = True + else: + success = False + # XXX Don't stop anymore if an error occured + #if status == 2: + # result.shouldStop = True + if result.shouldStop: # either on error or on exitfirst + error + break + except: + # if an error occurs between two yield + result.addError(self, self.__exc_info()) + success = False + return success + + def _proceed(self, result, testfunc, args=(), kwargs=None): + """proceed the actual test + returns 0 on success, 1 on failure, 2 on error + + Note: addSuccess can't be called here because we have to wait + for tearDown to be successfully executed to declare the test as + successful + """ + kwargs = kwargs or {} + try: + testfunc(*args, **kwargs) + except self.failureException: + result.addFailure(self, self.__exc_info()) + return 1 + except KeyboardInterrupt: + raise + except InnerTestSkipped as e: + result.addSkip(self, e) + return 1 + except SkipTest as e: + result.addSkip(self, e) + return 0 + except: + result.addError(self, self.__exc_info()) + return 2 + return 0 + + def defaultTestResult(self): + """return a new instance of the defaultTestResult""" + return SkipAwareTestResult() + + skip = _deprecate(unittest.TestCase.skipTest) + assertEquals = _deprecate(unittest.TestCase.assertEqual) + assertNotEquals = _deprecate(unittest.TestCase.assertNotEqual) + assertAlmostEquals = _deprecate(unittest.TestCase.assertAlmostEqual) + assertNotAlmostEquals = _deprecate(unittest.TestCase.assertNotAlmostEqual) + + def innerSkip(self, msg=None): + """mark a generative test as skipped for the <msg> reason""" + msg = msg or 'test was skipped' + raise InnerTestSkipped(msg) + + @deprecated('Please use assertDictEqual instead.') + def assertDictEquals(self, dict1, dict2, msg=None, context=None): + """compares two dicts + + If the two dict differ, the first difference is shown in the error + message + :param dict1: a Python Dictionary + :param dict2: a Python Dictionary + :param msg: custom message (String) in case of failure + """ + dict1 = dict(dict1) + msgs = [] + for key, value in dict2.items(): + try: + if dict1[key] != value: + msgs.append('%r != %r for key %r' % (dict1[key], value, + key)) + del dict1[key] + except KeyError: + msgs.append('missing %r key' % key) + if dict1: + msgs.append('dict2 is lacking %r' % dict1) + if msg: + self.failureException(msg) + elif msgs: + if context is not None: + base = '%s\n' % context + else: + base = '' + self.fail(base + '\n'.join(msgs)) + + @deprecated('Please use assertCountEqual instead.') + def assertUnorderedIterableEquals(self, got, expected, msg=None): + """compares two iterable and shows difference between both + + :param got: the unordered Iterable that we found + :param expected: the expected unordered Iterable + :param msg: custom message (String) in case of failure + """ + got, expected = list(got), list(expected) + self.assertSetEqual(set(got), set(expected), msg) + if len(got) != len(expected): + if msg is None: + msg = ['Iterable have the same elements but not the same number', + '\t<element>\t<expected>i\t<got>'] + got_count = {} + expected_count = {} + for element in got: + got_count[element] = got_count.get(element, 0) + 1 + for element in expected: + expected_count[element] = expected_count.get(element, 0) + 1 + # we know that got_count.key() == expected_count.key() + # because of assertSetEqual + for element, count in got_count.iteritems(): + other_count = expected_count[element] + if other_count != count: + msg.append('\t%s\t%s\t%s' % (element, other_count, count)) + + self.fail(msg) + + assertUnorderedIterableEqual = assertUnorderedIterableEquals + assertUnordIterEquals = assertUnordIterEqual = assertUnorderedIterableEqual + + @deprecated('Please use assertSetEqual instead.') + def assertSetEquals(self,got,expected, msg=None): + """compares two sets and shows difference between both + + Don't use it for iterables other than sets. + + :param got: the Set that we found + :param expected: the second Set to be compared to the first one + :param msg: custom message (String) in case of failure + """ + + if not(isinstance(got, set) and isinstance(expected, set)): + warnings.warn("the assertSetEquals function if now intended for set only."\ + "use assertUnorderedIterableEquals instead.", + DeprecationWarning, 2) + return self.assertUnorderedIterableEquals(got, expected, msg) + + items={} + items['missing'] = expected - got + items['unexpected'] = got - expected + if any(items.itervalues()): + if msg is None: + msg = '\n'.join('%s:\n\t%s' % (key, "\n\t".join(str(value) for value in values)) + for key, values in items.iteritems() if values) + self.fail(msg) + + @deprecated('Please use assertListEqual instead.') + def assertListEquals(self, list_1, list_2, msg=None): + """compares two lists + + If the two list differ, the first difference is shown in the error + message + + :param list_1: a Python List + :param list_2: a second Python List + :param msg: custom message (String) in case of failure + """ + _l1 = list_1[:] + for i, value in enumerate(list_2): + try: + if _l1[0] != value: + from pprint import pprint + pprint(list_1) + pprint(list_2) + self.fail('%r != %r for index %d' % (_l1[0], value, i)) + del _l1[0] + except IndexError: + if msg is None: + msg = 'list_1 has only %d elements, not %s '\ + '(at least %r missing)'% (i, len(list_2), value) + self.fail(msg) + if _l1: + if msg is None: + msg = 'list_2 is lacking %r' % _l1 + self.fail(msg) + + @deprecated('Non-standard. Please use assertMultiLineEqual instead.') + def assertLinesEquals(self, string1, string2, msg=None, striplines=False): + """compare two strings and assert that the text lines of the strings + are equal. + + :param string1: a String + :param string2: a String + :param msg: custom message (String) in case of failure + :param striplines: Boolean to trigger line stripping before comparing + """ + lines1 = string1.splitlines() + lines2 = string2.splitlines() + if striplines: + lines1 = [l.strip() for l in lines1] + lines2 = [l.strip() for l in lines2] + self.assertListEqual(lines1, lines2, msg) + assertLineEqual = assertLinesEquals + + @deprecated('Non-standard: please copy test method to your TestCase class') + def assertXMLWellFormed(self, stream, msg=None, context=2): + """asserts the XML stream is well-formed (no DTD conformance check) + + :param context: number of context lines in standard message + (show all data if negative). + Only available with element tree + """ + try: + from xml.etree.ElementTree import parse + self._assertETXMLWellFormed(stream, parse, msg) + except ImportError: + from xml.sax import make_parser, SAXParseException + parser = make_parser() + try: + parser.parse(stream) + except SAXParseException as ex: + if msg is None: + stream.seek(0) + for _ in range(ex.getLineNumber()): + line = stream.readline() + pointer = ('' * (ex.getLineNumber() - 1)) + '^' + msg = 'XML stream not well formed: %s\n%s%s' % (ex, line, pointer) + self.fail(msg) + + @deprecated('Non-standard: please copy test method to your TestCase class') + def assertXMLStringWellFormed(self, xml_string, msg=None, context=2): + """asserts the XML string is well-formed (no DTD conformance check) + + :param context: number of context lines in standard message + (show all data if negative). + Only available with element tree + """ + try: + from xml.etree.ElementTree import fromstring + except ImportError: + from elementtree.ElementTree import fromstring + self._assertETXMLWellFormed(xml_string, fromstring, msg) + + def _assertETXMLWellFormed(self, data, parse, msg=None, context=2): + """internal function used by /assertXML(String)?WellFormed/ functions + + :param data: xml_data + :param parse: appropriate parser function for this data + :param msg: error message + :param context: number of context lines in standard message + (show all data if negative). + Only available with element tree + """ + from xml.parsers.expat import ExpatError + try: + from xml.etree.ElementTree import ParseError + except ImportError: + # compatibility for <python2.7 + ParseError = ExpatError + try: + parse(data) + except (ExpatError, ParseError) as ex: + if msg is None: + if hasattr(data, 'readlines'): #file like object + data.seek(0) + lines = data.readlines() + else: + lines = data.splitlines(True) + nb_lines = len(lines) + context_lines = [] + + # catch when ParseError doesn't set valid lineno + if ex.lineno is not None: + if context < 0: + start = 1 + end = nb_lines + else: + start = max(ex.lineno-context, 1) + end = min(ex.lineno+context, nb_lines) + line_number_length = len('%i' % end) + line_pattern = " %%%ii: %%s" % line_number_length + + for line_no in range(start, ex.lineno): + context_lines.append(line_pattern % (line_no, lines[line_no-1])) + context_lines.append(line_pattern % (ex.lineno, lines[ex.lineno-1])) + context_lines.append('%s^\n' % (' ' * (1 + line_number_length + 2 +ex.offset))) + for line_no in range(ex.lineno+1, end+1): + context_lines.append(line_pattern % (line_no, lines[line_no-1])) + + rich_context = ''.join(context_lines) + msg = 'XML stream not well formed: %s\n%s' % (ex, rich_context) + self.fail(msg) + + @deprecated('Non-standard: please copy test method to your TestCase class') + def assertXMLEqualsTuple(self, element, tup): + """compare an ElementTree Element to a tuple formatted as follow: + (tagname, [attrib[, children[, text[, tail]]]])""" + # check tag + self.assertTextEquals(element.tag, tup[0]) + # check attrib + if len(element.attrib) or len(tup)>1: + if len(tup)<=1: + self.fail( "tuple %s has no attributes (%s expected)"%(tup, + dict(element.attrib))) + self.assertDictEqual(element.attrib, tup[1]) + # check children + if len(element) or len(tup)>2: + if len(tup)<=2: + self.fail( "tuple %s has no children (%i expected)"%(tup, + len(element))) + if len(element) != len(tup[2]): + self.fail( "tuple %s has %i children%s (%i expected)"%(tup, + len(tup[2]), + ('', 's')[len(tup[2])>1], len(element))) + for index in range(len(tup[2])): + self.assertXMLEqualsTuple(element[index], tup[2][index]) + #check text + if element.text or len(tup)>3: + if len(tup)<=3: + self.fail( "tuple %s has no text value (%r expected)"%(tup, + element.text)) + self.assertTextEquals(element.text, tup[3]) + #check tail + if element.tail or len(tup)>4: + if len(tup)<=4: + self.fail( "tuple %s has no tail value (%r expected)"%(tup, + element.tail)) + self.assertTextEquals(element.tail, tup[4]) + + def _difftext(self, lines1, lines2, junk=None, msg_prefix='Texts differ'): + junk = junk or (' ', '\t') + # result is a generator + result = difflib.ndiff(lines1, lines2, charjunk=lambda x: x in junk) + read = [] + for line in result: + read.append(line) + # lines that don't start with a ' ' are diff ones + if not line.startswith(' '): + self.fail('\n'.join(['%s\n'%msg_prefix]+read + list(result))) + + @deprecated('Non-standard. Please use assertMultiLineEqual instead.') + def assertTextEquals(self, text1, text2, junk=None, + msg_prefix='Text differ', striplines=False): + """compare two multiline strings (using difflib and splitlines()) + + :param text1: a Python BaseString + :param text2: a second Python Basestring + :param junk: List of Caracters + :param msg_prefix: String (message prefix) + :param striplines: Boolean to trigger line stripping before comparing + """ + msg = [] + if not isinstance(text1, string_types): + msg.append('text1 is not a string (%s)'%(type(text1))) + if not isinstance(text2, string_types): + msg.append('text2 is not a string (%s)'%(type(text2))) + if msg: + self.fail('\n'.join(msg)) + lines1 = text1.strip().splitlines(True) + lines2 = text2.strip().splitlines(True) + if striplines: + lines1 = [line.strip() for line in lines1] + lines2 = [line.strip() for line in lines2] + self._difftext(lines1, lines2, junk, msg_prefix) + assertTextEqual = assertTextEquals + + @deprecated('Non-standard: please copy test method to your TestCase class') + def assertStreamEquals(self, stream1, stream2, junk=None, + msg_prefix='Stream differ'): + """compare two streams (using difflib and readlines())""" + # if stream2 is stream2, readlines() on stream1 will also read lines + # in stream2, so they'll appear different, although they're not + if stream1 is stream2: + return + # make sure we compare from the beginning of the stream + stream1.seek(0) + stream2.seek(0) + # compare + self._difftext(stream1.readlines(), stream2.readlines(), junk, + msg_prefix) + + assertStreamEqual = assertStreamEquals + + @deprecated('Non-standard: please copy test method to your TestCase class') + def assertFileEquals(self, fname1, fname2, junk=(' ', '\t')): + """compares two files using difflib""" + self.assertStreamEqual(open(fname1), open(fname2), junk, + msg_prefix='Files differs\n-:%s\n+:%s\n'%(fname1, fname2)) + + assertFileEqual = assertFileEquals + + @deprecated('Non-standard: please copy test method to your TestCase class') + def assertDirEquals(self, path_a, path_b): + """compares two files using difflib""" + assert osp.exists(path_a), "%s doesn't exists" % path_a + assert osp.exists(path_b), "%s doesn't exists" % path_b + + all_a = [ (ipath[len(path_a):].lstrip('/'), idirs, ifiles) + for ipath, idirs, ifiles in os.walk(path_a)] + all_a.sort(key=itemgetter(0)) + + all_b = [ (ipath[len(path_b):].lstrip('/'), idirs, ifiles) + for ipath, idirs, ifiles in os.walk(path_b)] + all_b.sort(key=itemgetter(0)) + + iter_a, iter_b = iter(all_a), iter(all_b) + partial_iter = True + ipath_a, idirs_a, ifiles_a = data_a = None, None, None + while True: + try: + ipath_a, idirs_a, ifiles_a = datas_a = next(iter_a) + partial_iter = False + ipath_b, idirs_b, ifiles_b = datas_b = next(iter_b) + partial_iter = True + + + self.assertTrue(ipath_a == ipath_b, + "unexpected %s in %s while looking %s from %s" % + (ipath_a, path_a, ipath_b, path_b)) + + + errors = {} + sdirs_a = set(idirs_a) + sdirs_b = set(idirs_b) + errors["unexpected directories"] = sdirs_a - sdirs_b + errors["missing directories"] = sdirs_b - sdirs_a + + sfiles_a = set(ifiles_a) + sfiles_b = set(ifiles_b) + errors["unexpected files"] = sfiles_a - sfiles_b + errors["missing files"] = sfiles_b - sfiles_a + + + msgs = [ "%s: %s"% (name, items) + for name, items in errors.items() if items] + + if msgs: + msgs.insert(0, "%s and %s differ :" % ( + osp.join(path_a, ipath_a), + osp.join(path_b, ipath_b), + )) + self.fail("\n".join(msgs)) + + for files in (ifiles_a, ifiles_b): + files.sort() + + for index, path in enumerate(ifiles_a): + self.assertFileEquals(osp.join(path_a, ipath_a, path), + osp.join(path_b, ipath_b, ifiles_b[index])) + + except StopIteration: + break + + assertDirEqual = assertDirEquals + + def assertIsInstance(self, obj, klass, msg=None, strict=False): + """check if an object is an instance of a class + + :param obj: the Python Object to be checked + :param klass: the target class + :param msg: a String for a custom message + :param strict: if True, check that the class of <obj> is <klass>; + else check with 'isinstance' + """ + if strict: + warnings.warn('[API] Non-standard. Strict parameter has vanished', + DeprecationWarning, stacklevel=2) + if msg is None: + if strict: + msg = '%r is not of class %s but of %s' + else: + msg = '%r is not an instance of %s but of %s' + msg = msg % (obj, klass, type(obj)) + if strict: + self.assertTrue(obj.__class__ is klass, msg) + else: + self.assertTrue(isinstance(obj, klass), msg) + + @deprecated('Please use assertIsNone instead.') + def assertNone(self, obj, msg=None): + """assert obj is None + + :param obj: Python Object to be tested + """ + if msg is None: + msg = "reference to %r when None expected"%(obj,) + self.assertTrue( obj is None, msg ) + + @deprecated('Please use assertIsNotNone instead.') + def assertNotNone(self, obj, msg=None): + """assert obj is not None""" + if msg is None: + msg = "unexpected reference to None" + self.assertTrue( obj is not None, msg ) + + @deprecated('Non-standard. Please use assertAlmostEqual instead.') + def assertFloatAlmostEquals(self, obj, other, prec=1e-5, + relative=False, msg=None): + """compares if two floats have a distance smaller than expected + precision. + + :param obj: a Float + :param other: another Float to be comparted to <obj> + :param prec: a Float describing the precision + :param relative: boolean switching to relative/absolute precision + :param msg: a String for a custom message + """ + if msg is None: + msg = "%r != %r" % (obj, other) + if relative: + prec = prec*math.fabs(obj) + self.assertTrue(math.fabs(obj - other) < prec, msg) + + def failUnlessRaises(self, excClass, callableObj=None, *args, **kwargs): + """override default failUnlessRaises method to return the raised + exception instance. + + Fail unless an exception of class excClass is thrown + by callableObj when invoked with arguments args and keyword + arguments kwargs. If a different type of exception is + thrown, it will not be caught, and the test case will be + deemed to have suffered an error, exactly as for an + unexpected exception. + + CAUTION! There are subtle differences between Logilab and unittest2 + - exc is not returned in standard version + - context capabilities in standard version + - try/except/else construction (minor) + + :param excClass: the Exception to be raised + :param callableObj: a callable Object which should raise <excClass> + :param args: a List of arguments for <callableObj> + :param kwargs: a List of keyword arguments for <callableObj> + """ + # XXX cube vcslib : test_branches_from_app + if callableObj is None: + _assert = super(TestCase, self).assertRaises + return _assert(excClass, callableObj, *args, **kwargs) + try: + callableObj(*args, **kwargs) + except excClass as exc: + class ProxyException: + def __init__(self, obj): + self._obj = obj + def __getattr__(self, attr): + warn_msg = ("This exception was retrieved with the old testlib way " + "`exc = self.assertRaises(Exc, callable)`, please use " + "the context manager instead'") + warnings.warn(warn_msg, DeprecationWarning, 2) + return self._obj.__getattribute__(attr) + return ProxyException(exc) + else: + if hasattr(excClass, '__name__'): + excName = excClass.__name__ + else: + excName = str(excClass) + raise self.failureException("%s not raised" % excName) + + assertRaises = failUnlessRaises + + if sys.version_info >= (3,2): + assertItemsEqual = unittest.TestCase.assertCountEqual + else: + assertCountEqual = unittest.TestCase.assertItemsEqual + if sys.version_info < (2,7): + def assertIsNotNone(self, value, *args, **kwargs): + self.assertNotEqual(None, value, *args, **kwargs) + +TestCase.assertItemsEqual = deprecated('assertItemsEqual is deprecated, use assertCountEqual')( + TestCase.assertItemsEqual) + +import doctest + +class SkippedSuite(unittest.TestSuite): + def test(self): + """just there to trigger test execution""" + self.skipped_test('doctest module has no DocTestSuite class') + + +class DocTestFinder(doctest.DocTestFinder): + + def __init__(self, *args, **kwargs): + self.skipped = kwargs.pop('skipped', ()) + doctest.DocTestFinder.__init__(self, *args, **kwargs) + + def _get_test(self, obj, name, module, globs, source_lines): + """override default _get_test method to be able to skip tests + according to skipped attribute's value + """ + if getattr(obj, '__name__', '') in self.skipped: + return None + return doctest.DocTestFinder._get_test(self, obj, name, module, + globs, source_lines) + + +class DocTest(TestCase): + """trigger module doctest + I don't know how to make unittest.main consider the DocTestSuite instance + without this hack + """ + skipped = () + def __call__(self, result=None, runcondition=None, options=None):\ + # pylint: disable=W0613 + try: + finder = DocTestFinder(skipped=self.skipped) + suite = doctest.DocTestSuite(self.module, test_finder=finder) + # XXX iirk + doctest.DocTestCase._TestCase__exc_info = sys.exc_info + except AttributeError: + suite = SkippedSuite() + # doctest may gork the builtins dictionnary + # This happen to the "_" entry used by gettext + old_builtins = builtins.__dict__.copy() + try: + return suite.run(result) + finally: + builtins.__dict__.clear() + builtins.__dict__.update(old_builtins) + run = __call__ + + def test(self): + """just there to trigger test execution""" + +MAILBOX = None + +class MockSMTP: + """fake smtplib.SMTP""" + + def __init__(self, host, port): + self.host = host + self.port = port + global MAILBOX + self.reveived = MAILBOX = [] + + def set_debuglevel(self, debuglevel): + """ignore debug level""" + + def sendmail(self, fromaddr, toaddres, body): + """push sent mail in the mailbox""" + self.reveived.append((fromaddr, toaddres, body)) + + def quit(self): + """ignore quit""" + + +class MockConfigParser(configparser.ConfigParser): + """fake ConfigParser.ConfigParser""" + + def __init__(self, options): + configparser.ConfigParser.__init__(self) + for section, pairs in options.iteritems(): + self.add_section(section) + for key, value in pairs.iteritems(): + self.set(section, key, value) + def write(self, _): + raise NotImplementedError() + + +class MockConnection: + """fake DB-API 2.0 connexion AND cursor (i.e. cursor() return self)""" + + def __init__(self, results): + self.received = [] + self.states = [] + self.results = results + + def cursor(self): + """Mock cursor method""" + return self + def execute(self, query, args=None): + """Mock execute method""" + self.received.append( (query, args) ) + def fetchone(self): + """Mock fetchone method""" + return self.results[0] + def fetchall(self): + """Mock fetchall method""" + return self.results + def commit(self): + """Mock commiy method""" + self.states.append( ('commit', len(self.received)) ) + def rollback(self): + """Mock rollback method""" + self.states.append( ('rollback', len(self.received)) ) + def close(self): + """Mock close method""" + pass + + +def mock_object(**params): + """creates an object using params to set attributes + >>> option = mock_object(verbose=False, index=range(5)) + >>> option.verbose + False + >>> option.index + [0, 1, 2, 3, 4] + """ + return type('Mock', (), params)() + + +def create_files(paths, chroot): + """Creates directories and files found in <path>. + + :param paths: list of relative paths to files or directories + :param chroot: the root directory in which paths will be created + + >>> from os.path import isdir, isfile + >>> isdir('/tmp/a') + False + >>> create_files(['a/b/foo.py', 'a/b/c/', 'a/b/c/d/e.py'], '/tmp') + >>> isdir('/tmp/a') + True + >>> isdir('/tmp/a/b/c') + True + >>> isfile('/tmp/a/b/c/d/e.py') + True + >>> isfile('/tmp/a/b/foo.py') + True + """ + dirs, files = set(), set() + for path in paths: + path = osp.join(chroot, path) + filename = osp.basename(path) + # path is a directory path + if filename == '': + dirs.add(path) + # path is a filename path + else: + dirs.add(osp.dirname(path)) + files.add(path) + for dirpath in dirs: + if not osp.isdir(dirpath): + os.makedirs(dirpath) + for filepath in files: + open(filepath, 'w').close() + + +class AttrObject: # XXX cf mock_object + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + +def tag(*args, **kwargs): + """descriptor adding tag to a function""" + def desc(func): + assert not hasattr(func, 'tags') + func.tags = Tags(*args, **kwargs) + return func + return desc + +def require_version(version): + """ Compare version of python interpreter to the given one. Skip the test + if older. + """ + def check_require_version(f): + version_elements = version.split('.') + try: + compare = tuple([int(v) for v in version_elements]) + except ValueError: + raise ValueError('%s is not a correct version : should be X.Y[.Z].' % version) + current = sys.version_info[:3] + if current < compare: + def new_f(self, *args, **kwargs): + self.skipTest('Need at least %s version of python. Current version is %s.' % (version, '.'.join([str(element) for element in current]))) + new_f.__name__ = f.__name__ + return new_f + else: + return f + return check_require_version + +def require_module(module): + """ Check if the given module is loaded. Skip the test if not. + """ + def check_require_module(f): + try: + __import__(module) + return f + except ImportError: + def new_f(self, *args, **kwargs): + self.skipTest('%s can not be imported.' % module) + new_f.__name__ = f.__name__ + return new_f + return check_require_module + diff --git a/logilab/common/textutils.py b/logilab/common/textutils.py new file mode 100644 index 0000000..9046f97 --- /dev/null +++ b/logilab/common/textutils.py @@ -0,0 +1,537 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Some text manipulation utility functions. + + +:group text formatting: normalize_text, normalize_paragraph, pretty_match,\ +unquote, colorize_ansi +:group text manipulation: searchall, splitstrip +:sort: text formatting, text manipulation + +:type ANSI_STYLES: dict(str) +:var ANSI_STYLES: dictionary mapping style identifier to ANSI terminal code + +:type ANSI_COLORS: dict(str) +:var ANSI_COLORS: dictionary mapping color identifier to ANSI terminal code + +:type ANSI_PREFIX: str +:var ANSI_PREFIX: + ANSI terminal code notifying the start of an ANSI escape sequence + +:type ANSI_END: str +:var ANSI_END: + ANSI terminal code notifying the end of an ANSI escape sequence + +:type ANSI_RESET: str +:var ANSI_RESET: + ANSI terminal code resetting format defined by a previous ANSI escape sequence +""" +__docformat__ = "restructuredtext en" + +import sys +import re +import os.path as osp +from warnings import warn +from unicodedata import normalize as _uninormalize +try: + from os import linesep +except ImportError: + linesep = '\n' # gae + +from logilab.common.deprecation import deprecated + +MANUAL_UNICODE_MAP = { + u'\xa1': u'!', # INVERTED EXCLAMATION MARK + u'\u0142': u'l', # LATIN SMALL LETTER L WITH STROKE + u'\u2044': u'/', # FRACTION SLASH + u'\xc6': u'AE', # LATIN CAPITAL LETTER AE + u'\xa9': u'(c)', # COPYRIGHT SIGN + u'\xab': u'"', # LEFT-POINTING DOUBLE ANGLE QUOTATION MARK + u'\xe6': u'ae', # LATIN SMALL LETTER AE + u'\xae': u'(r)', # REGISTERED SIGN + u'\u0153': u'oe', # LATIN SMALL LIGATURE OE + u'\u0152': u'OE', # LATIN CAPITAL LIGATURE OE + u'\xd8': u'O', # LATIN CAPITAL LETTER O WITH STROKE + u'\xf8': u'o', # LATIN SMALL LETTER O WITH STROKE + u'\xbb': u'"', # RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK + u'\xdf': u'ss', # LATIN SMALL LETTER SHARP S + } + +def unormalize(ustring, ignorenonascii=None, substitute=None): + """replace diacritical characters with their corresponding ascii characters + + Convert the unicode string to its long normalized form (unicode character + will be transform into several characters) and keep the first one only. + The normal form KD (NFKD) will apply the compatibility decomposition, i.e. + replace all compatibility characters with their equivalents. + + :type substitute: str + :param substitute: replacement character to use if decomposition fails + + :see: Another project about ASCII transliterations of Unicode text + http://pypi.python.org/pypi/Unidecode + """ + # backward compatibility, ignorenonascii was a boolean + if ignorenonascii is not None: + warn("ignorenonascii is deprecated, use substitute named parameter instead", + DeprecationWarning, stacklevel=2) + if ignorenonascii: + substitute = '' + res = [] + for letter in ustring[:]: + try: + replacement = MANUAL_UNICODE_MAP[letter] + except KeyError: + replacement = _uninormalize('NFKD', letter)[0] + if ord(replacement) >= 2 ** 7: + if substitute is None: + raise ValueError("can't deal with non-ascii based characters") + replacement = substitute + res.append(replacement) + return u''.join(res) + +def unquote(string): + """remove optional quotes (simple or double) from the string + + :type string: str or unicode + :param string: an optionally quoted string + + :rtype: str or unicode + :return: the unquoted string (or the input string if it wasn't quoted) + """ + if not string: + return string + if string[0] in '"\'': + string = string[1:] + if string[-1] in '"\'': + string = string[:-1] + return string + + +_BLANKLINES_RGX = re.compile('\r?\n\r?\n') +_NORM_SPACES_RGX = re.compile('\s+') + +def normalize_text(text, line_len=80, indent='', rest=False): + """normalize a text to display it with a maximum line size and + optionally arbitrary indentation. Line jumps are normalized but blank + lines are kept. The indentation string may be used to insert a + comment (#) or a quoting (>) mark for instance. + + :type text: str or unicode + :param text: the input text to normalize + + :type line_len: int + :param line_len: expected maximum line's length, default to 80 + + :type indent: str or unicode + :param indent: optional string to use as indentation + + :rtype: str or unicode + :return: + the input text normalized to fit on lines with a maximized size + inferior to `line_len`, and optionally prefixed by an + indentation string + """ + if rest: + normp = normalize_rest_paragraph + else: + normp = normalize_paragraph + result = [] + for text in _BLANKLINES_RGX.split(text): + result.append(normp(text, line_len, indent)) + return ('%s%s%s' % (linesep, indent, linesep)).join(result) + + +def normalize_paragraph(text, line_len=80, indent=''): + """normalize a text to display it with a maximum line size and + optionally arbitrary indentation. Line jumps are normalized. The + indentation string may be used top insert a comment mark for + instance. + + :type text: str or unicode + :param text: the input text to normalize + + :type line_len: int + :param line_len: expected maximum line's length, default to 80 + + :type indent: str or unicode + :param indent: optional string to use as indentation + + :rtype: str or unicode + :return: + the input text normalized to fit on lines with a maximized size + inferior to `line_len`, and optionally prefixed by an + indentation string + """ + text = _NORM_SPACES_RGX.sub(' ', text) + line_len = line_len - len(indent) + lines = [] + while text: + aline, text = splittext(text.strip(), line_len) + lines.append(indent + aline) + return linesep.join(lines) + +def normalize_rest_paragraph(text, line_len=80, indent=''): + """normalize a ReST text to display it with a maximum line size and + optionally arbitrary indentation. Line jumps are normalized. The + indentation string may be used top insert a comment mark for + instance. + + :type text: str or unicode + :param text: the input text to normalize + + :type line_len: int + :param line_len: expected maximum line's length, default to 80 + + :type indent: str or unicode + :param indent: optional string to use as indentation + + :rtype: str or unicode + :return: + the input text normalized to fit on lines with a maximized size + inferior to `line_len`, and optionally prefixed by an + indentation string + """ + toreport = '' + lines = [] + line_len = line_len - len(indent) + for line in text.splitlines(): + line = toreport + _NORM_SPACES_RGX.sub(' ', line.strip()) + toreport = '' + while len(line) > line_len: + # too long line, need split + line, toreport = splittext(line, line_len) + lines.append(indent + line) + if toreport: + line = toreport + ' ' + toreport = '' + else: + line = '' + if line: + lines.append(indent + line.strip()) + return linesep.join(lines) + + +def splittext(text, line_len): + """split the given text on space according to the given max line size + + return a 2-uple: + * a line <= line_len if possible + * the rest of the text which has to be reported on another line + """ + if len(text) <= line_len: + return text, '' + pos = min(len(text)-1, line_len) + while pos > 0 and text[pos] != ' ': + pos -= 1 + if pos == 0: + pos = min(len(text), line_len) + while len(text) > pos and text[pos] != ' ': + pos += 1 + return text[:pos], text[pos+1:].strip() + + +def splitstrip(string, sep=','): + """return a list of stripped string by splitting the string given as + argument on `sep` (',' by default). Empty string are discarded. + + >>> splitstrip('a, b, c , 4,,') + ['a', 'b', 'c', '4'] + >>> splitstrip('a') + ['a'] + >>> + + :type string: str or unicode + :param string: a csv line + + :type sep: str or unicode + :param sep: field separator, default to the comma (',') + + :rtype: str or unicode + :return: the unquoted string (or the input string if it wasn't quoted) + """ + return [word.strip() for word in string.split(sep) if word.strip()] + +get_csv = deprecated('get_csv is deprecated, use splitstrip')(splitstrip) + + +def split_url_or_path(url_or_path): + """return the latest component of a string containing either an url of the + form <scheme>://<path> or a local file system path + """ + if '://' in url_or_path: + return url_or_path.rstrip('/').rsplit('/', 1) + return osp.split(url_or_path.rstrip(osp.sep)) + + +def text_to_dict(text): + """parse multilines text containing simple 'key=value' lines and return a + dict of {'key': 'value'}. When the same key is encountered multiple time, + value is turned into a list containing all values. + + >>> d = text_to_dict('''multiple=1 + ... multiple= 2 + ... single =3 + ... ''') + >>> d['single'] + '3' + >>> d['multiple'] + ['1', '2'] + + """ + res = {} + if not text: + return res + for line in text.splitlines(): + line = line.strip() + if line and not line.startswith('#'): + key, value = [w.strip() for w in line.split('=', 1)] + if key in res: + try: + res[key].append(value) + except AttributeError: + res[key] = [res[key], value] + else: + res[key] = value + return res + + +_BLANK_URE = r'(\s|,)+' +_BLANK_RE = re.compile(_BLANK_URE) +__VALUE_URE = r'-?(([0-9]+\.[0-9]*)|((0x?)?[0-9]+))' +__UNITS_URE = r'[a-zA-Z]+' +_VALUE_RE = re.compile(r'(?P<value>%s)(?P<unit>%s)?'%(__VALUE_URE, __UNITS_URE)) +_VALIDATION_RE = re.compile(r'^((%s)(%s))*(%s)?$' % (__VALUE_URE, __UNITS_URE, + __VALUE_URE)) + +BYTE_UNITS = { + "b": 1, + "kb": 1024, + "mb": 1024 ** 2, + "gb": 1024 ** 3, + "tb": 1024 ** 4, +} + +TIME_UNITS = { + "ms": 0.0001, + "s": 1, + "min": 60, + "h": 60 * 60, + "d": 60 * 60 *24, +} + +def apply_units(string, units, inter=None, final=float, blank_reg=_BLANK_RE, + value_reg=_VALUE_RE): + """Parse the string applying the units defined in units + (e.g.: "1.5m",{'m',60} -> 80). + + :type string: str or unicode + :param string: the string to parse + + :type units: dict (or any object with __getitem__ using basestring key) + :param units: a dict mapping a unit string repr to its value + + :type inter: type + :param inter: used to parse every intermediate value (need __sum__) + + :type blank_reg: regexp + :param blank_reg: should match every blank char to ignore. + + :type value_reg: regexp with "value" and optional "unit" group + :param value_reg: match a value and it's unit into the + """ + if inter is None: + inter = final + fstring = _BLANK_RE.sub('', string) + if not (fstring and _VALIDATION_RE.match(fstring)): + raise ValueError("Invalid unit string: %r." % string) + values = [] + for match in value_reg.finditer(fstring): + dic = match.groupdict() + lit, unit = dic["value"], dic.get("unit") + value = inter(lit) + if unit is not None: + try: + value *= units[unit.lower()] + except KeyError: + raise KeyError('invalid unit %s. valid units are %s' % + (unit, units.keys())) + values.append(value) + return final(sum(values)) + + +_LINE_RGX = re.compile('\r\n|\r+|\n') + +def pretty_match(match, string, underline_char='^'): + """return a string with the match location underlined: + + >>> import re + >>> print(pretty_match(re.search('mange', 'il mange du bacon'), 'il mange du bacon')) + il mange du bacon + ^^^^^ + >>> + + :type match: _sre.SRE_match + :param match: object returned by re.match, re.search or re.finditer + + :type string: str or unicode + :param string: + the string on which the regular expression has been applied to + obtain the `match` object + + :type underline_char: str or unicode + :param underline_char: + character to use to underline the matched section, default to the + carret '^' + + :rtype: str or unicode + :return: + the original string with an inserted line to underline the match + location + """ + start = match.start() + end = match.end() + string = _LINE_RGX.sub(linesep, string) + start_line_pos = string.rfind(linesep, 0, start) + if start_line_pos == -1: + start_line_pos = 0 + result = [] + else: + result = [string[:start_line_pos]] + start_line_pos += len(linesep) + offset = start - start_line_pos + underline = ' ' * offset + underline_char * (end - start) + end_line_pos = string.find(linesep, end) + if end_line_pos == -1: + string = string[start_line_pos:] + result.append(string) + result.append(underline) + else: + end = string[end_line_pos + len(linesep):] + string = string[start_line_pos:end_line_pos] + result.append(string) + result.append(underline) + result.append(end) + return linesep.join(result).rstrip() + + +# Ansi colorization ########################################################### + +ANSI_PREFIX = '\033[' +ANSI_END = 'm' +ANSI_RESET = '\033[0m' +ANSI_STYLES = { + 'reset': "0", + 'bold': "1", + 'italic': "3", + 'underline': "4", + 'blink': "5", + 'inverse': "7", + 'strike': "9", +} +ANSI_COLORS = { + 'reset': "0", + 'black': "30", + 'red': "31", + 'green': "32", + 'yellow': "33", + 'blue': "34", + 'magenta': "35", + 'cyan': "36", + 'white': "37", +} + +def _get_ansi_code(color=None, style=None): + """return ansi escape code corresponding to color and style + + :type color: str or None + :param color: + the color name (see `ANSI_COLORS` for available values) + or the color number when 256 colors are available + + :type style: str or None + :param style: + style string (see `ANSI_COLORS` for available values). To get + several style effects at the same time, use a coma as separator. + + :raise KeyError: if an unexistent color or style identifier is given + + :rtype: str + :return: the built escape code + """ + ansi_code = [] + if style: + style_attrs = splitstrip(style) + for effect in style_attrs: + ansi_code.append(ANSI_STYLES[effect]) + if color: + if color.isdigit(): + ansi_code.extend(['38', '5']) + ansi_code.append(color) + else: + ansi_code.append(ANSI_COLORS[color]) + if ansi_code: + return ANSI_PREFIX + ';'.join(ansi_code) + ANSI_END + return '' + +def colorize_ansi(msg, color=None, style=None): + """colorize message by wrapping it with ansi escape codes + + :type msg: str or unicode + :param msg: the message string to colorize + + :type color: str or None + :param color: + the color identifier (see `ANSI_COLORS` for available values) + + :type style: str or None + :param style: + style string (see `ANSI_COLORS` for available values). To get + several style effects at the same time, use a coma as separator. + + :raise KeyError: if an unexistent color or style identifier is given + + :rtype: str or unicode + :return: the ansi escaped string + """ + # If both color and style are not defined, then leave the text as is + if color is None and style is None: + return msg + escape_code = _get_ansi_code(color, style) + # If invalid (or unknown) color, don't wrap msg with ansi codes + if escape_code: + return '%s%s%s' % (escape_code, msg, ANSI_RESET) + return msg + +DIFF_STYLE = {'separator': 'cyan', 'remove': 'red', 'add': 'green'} + +def diff_colorize_ansi(lines, out=sys.stdout, style=DIFF_STYLE): + for line in lines: + if line[:4] in ('--- ', '+++ '): + out.write(colorize_ansi(line, style['separator'])) + elif line[0] == '-': + out.write(colorize_ansi(line, style['remove'])) + elif line[0] == '+': + out.write(colorize_ansi(line, style['add'])) + elif line[:4] == '--- ': + out.write(colorize_ansi(line, style['separator'])) + elif line[:4] == '+++ ': + out.write(colorize_ansi(line, style['separator'])) + else: + out.write(line) + diff --git a/logilab/common/tree.py b/logilab/common/tree.py new file mode 100644 index 0000000..885eb0f --- /dev/null +++ b/logilab/common/tree.py @@ -0,0 +1,369 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Base class to represent a tree structure. + + + + +""" +__docformat__ = "restructuredtext en" + +import sys + +from logilab.common import flatten +from logilab.common.visitor import VisitedMixIn, FilteredIterator, no_filter + +## Exceptions ################################################################# + +class NodeNotFound(Exception): + """raised when a node has not been found""" + +EX_SIBLING_NOT_FOUND = "No such sibling as '%s'" +EX_CHILD_NOT_FOUND = "No such child as '%s'" +EX_NODE_NOT_FOUND = "No such node as '%s'" + + +# Base node ################################################################### + +class Node(object): + """a basic tree node, characterized by an id""" + + def __init__(self, nid=None) : + self.id = nid + # navigation + self.parent = None + self.children = [] + + def __iter__(self): + return iter(self.children) + + def __str__(self, indent=0): + s = ['%s%s %s' % (' '*indent, self.__class__.__name__, self.id)] + indent += 2 + for child in self.children: + try: + s.append(child.__str__(indent)) + except TypeError: + s.append(child.__str__()) + return '\n'.join(s) + + def is_leaf(self): + return not self.children + + def append(self, child): + """add a node to children""" + self.children.append(child) + child.parent = self + + def remove(self, child): + """remove a child node""" + self.children.remove(child) + child.parent = None + + def insert(self, index, child): + """insert a child node""" + self.children.insert(index, child) + child.parent = self + + def replace(self, old_child, new_child): + """replace a child node with another""" + i = self.children.index(old_child) + self.children.pop(i) + self.children.insert(i, new_child) + new_child.parent = self + + def get_sibling(self, nid): + """return the sibling node that has given id""" + try: + return self.parent.get_child_by_id(nid) + except NodeNotFound : + raise NodeNotFound(EX_SIBLING_NOT_FOUND % nid) + + def next_sibling(self): + """ + return the next sibling for this node if any + """ + parent = self.parent + if parent is None: + # root node has no sibling + return None + index = parent.children.index(self) + try: + return parent.children[index+1] + except IndexError: + return None + + def previous_sibling(self): + """ + return the previous sibling for this node if any + """ + parent = self.parent + if parent is None: + # root node has no sibling + return None + index = parent.children.index(self) + if index > 0: + return parent.children[index-1] + return None + + def get_node_by_id(self, nid): + """ + return node in whole hierarchy that has given id + """ + root = self.root() + try: + return root.get_child_by_id(nid, 1) + except NodeNotFound : + raise NodeNotFound(EX_NODE_NOT_FOUND % nid) + + def get_child_by_id(self, nid, recurse=None): + """ + return child of given id + """ + if self.id == nid: + return self + for c in self.children : + if recurse: + try: + return c.get_child_by_id(nid, 1) + except NodeNotFound : + continue + if c.id == nid : + return c + raise NodeNotFound(EX_CHILD_NOT_FOUND % nid) + + def get_child_by_path(self, path): + """ + return child of given path (path is a list of ids) + """ + if len(path) > 0 and path[0] == self.id: + if len(path) == 1 : + return self + else : + for c in self.children : + try: + return c.get_child_by_path(path[1:]) + except NodeNotFound : + pass + raise NodeNotFound(EX_CHILD_NOT_FOUND % path) + + def depth(self): + """ + return depth of this node in the tree + """ + if self.parent is not None: + return 1 + self.parent.depth() + else : + return 0 + + def depth_down(self): + """ + return depth of the tree from this node + """ + if self.children: + return 1 + max([c.depth_down() for c in self.children]) + return 1 + + def width(self): + """ + return the width of the tree from this node + """ + return len(self.leaves()) + + def root(self): + """ + return the root node of the tree + """ + if self.parent is not None: + return self.parent.root() + return self + + def leaves(self): + """ + return a list with all the leaves nodes descendant from this node + """ + leaves = [] + if self.children: + for child in self.children: + leaves += child.leaves() + return leaves + else: + return [self] + + def flatten(self, _list=None): + """ + return a list with all the nodes descendant from this node + """ + if _list is None: + _list = [] + _list.append(self) + for c in self.children: + c.flatten(_list) + return _list + + def lineage(self): + """ + return list of parents up to root node + """ + lst = [self] + if self.parent is not None: + lst.extend(self.parent.lineage()) + return lst + +class VNode(Node, VisitedMixIn): + """a visitable node + """ + pass + + +class BinaryNode(VNode): + """a binary node (i.e. only two children + """ + def __init__(self, lhs=None, rhs=None) : + VNode.__init__(self) + if lhs is not None or rhs is not None: + assert lhs and rhs + self.append(lhs) + self.append(rhs) + + def remove(self, child): + """remove the child and replace this node with the other child + """ + self.children.remove(child) + self.parent.replace(self, self.children[0]) + + def get_parts(self): + """ + return the left hand side and the right hand side of this node + """ + return self.children[0], self.children[1] + + + +if sys.version_info[0:2] >= (2, 2): + list_class = list +else: + from UserList import UserList + list_class = UserList + +class ListNode(VNode, list_class): + """Used to manipulate Nodes as Lists + """ + def __init__(self): + list_class.__init__(self) + VNode.__init__(self) + self.children = self + + def __str__(self, indent=0): + return '%s%s %s' % (indent*' ', self.__class__.__name__, + ', '.join([str(v) for v in self])) + + def append(self, child): + """add a node to children""" + list_class.append(self, child) + child.parent = self + + def insert(self, index, child): + """add a node to children""" + list_class.insert(self, index, child) + child.parent = self + + def remove(self, child): + """add a node to children""" + list_class.remove(self, child) + child.parent = None + + def pop(self, index): + """add a node to children""" + child = list_class.pop(self, index) + child.parent = None + + def __iter__(self): + return list_class.__iter__(self) + +# construct list from tree #################################################### + +def post_order_list(node, filter_func=no_filter): + """ + create a list with tree nodes for which the <filter> function returned true + in a post order fashion + """ + l, stack = [], [] + poped, index = 0, 0 + while node: + if filter_func(node): + if node.children and not poped: + stack.append((node, index)) + index = 0 + node = node.children[0] + else: + l.append(node) + index += 1 + try: + node = stack[-1][0].children[index] + except IndexError: + node = None + else: + node = None + poped = 0 + if node is None and stack: + node, index = stack.pop() + poped = 1 + return l + +def pre_order_list(node, filter_func=no_filter): + """ + create a list with tree nodes for which the <filter> function returned true + in a pre order fashion + """ + l, stack = [], [] + poped, index = 0, 0 + while node: + if filter_func(node): + if not poped: + l.append(node) + if node.children and not poped: + stack.append((node, index)) + index = 0 + node = node.children[0] + else: + index += 1 + try: + node = stack[-1][0].children[index] + except IndexError: + node = None + else: + node = None + poped = 0 + if node is None and len(stack) > 1: + node, index = stack.pop() + poped = 1 + return l + +class PostfixedDepthFirstIterator(FilteredIterator): + """a postfixed depth first iterator, designed to be used with visitors + """ + def __init__(self, node, filter_func=None): + FilteredIterator.__init__(self, node, post_order_list, filter_func) + +class PrefixedDepthFirstIterator(FilteredIterator): + """a prefixed depth first iterator, designed to be used with visitors + """ + def __init__(self, node, filter_func=None): + FilteredIterator.__init__(self, node, pre_order_list, filter_func) + diff --git a/logilab/common/umessage.py b/logilab/common/umessage.py new file mode 100644 index 0000000..a5e4799 --- /dev/null +++ b/logilab/common/umessage.py @@ -0,0 +1,194 @@ +# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Unicode email support (extends email from stdlib)""" + +__docformat__ = "restructuredtext en" + +import email +from encodings import search_function +import sys +if sys.version_info >= (2, 5): + from email.utils import parseaddr, parsedate + from email.header import decode_header +else: + from email.Utils import parseaddr, parsedate + from email.Header import decode_header + +from datetime import datetime + +try: + from mx.DateTime import DateTime +except ImportError: + DateTime = datetime + +import logilab.common as lgc + + +def decode_QP(string): + parts = [] + for decoded, charset in decode_header(string): + if not charset : + charset = 'iso-8859-15' + parts.append(decoded.decode(charset, 'replace')) + + if sys.version_info < (3, 3): + # decoding was non-RFC compliant wrt to whitespace handling + # see http://bugs.python.org/issue1079 + return u' '.join(parts) + return u''.join(parts) + +def message_from_file(fd): + try: + return UMessage(email.message_from_file(fd)) + except email.Errors.MessageParseError: + return '' + +def message_from_string(string): + try: + return UMessage(email.message_from_string(string)) + except email.Errors.MessageParseError: + return '' + +class UMessage: + """Encapsulates an email.Message instance and returns only unicode objects. + """ + + def __init__(self, message): + self.message = message + + # email.Message interface ################################################# + + def get(self, header, default=None): + value = self.message.get(header, default) + if value: + return decode_QP(value) + return value + + def __getitem__(self, header): + return self.get(header) + + def get_all(self, header, default=()): + return [decode_QP(val) for val in self.message.get_all(header, default) + if val is not None] + + def is_multipart(self): + return self.message.is_multipart() + + def get_boundary(self): + return self.message.get_boundary() + + def walk(self): + for part in self.message.walk(): + yield UMessage(part) + + if sys.version_info < (3, 0): + + def get_payload(self, index=None, decode=False): + message = self.message + if index is None: + payload = message.get_payload(index, decode) + if isinstance(payload, list): + return [UMessage(msg) for msg in payload] + if message.get_content_maintype() != 'text': + return payload + + charset = message.get_content_charset() or 'iso-8859-1' + if search_function(charset) is None: + charset = 'iso-8859-1' + return unicode(payload or '', charset, "replace") + else: + payload = UMessage(message.get_payload(index, decode)) + return payload + + def get_content_maintype(self): + return unicode(self.message.get_content_maintype()) + + def get_content_type(self): + return unicode(self.message.get_content_type()) + + def get_filename(self, failobj=None): + value = self.message.get_filename(failobj) + if value is failobj: + return value + try: + return unicode(value) + except UnicodeDecodeError: + return u'error decoding filename' + + else: + + def get_payload(self, index=None, decode=False): + message = self.message + if index is None: + payload = message.get_payload(index, decode) + if isinstance(payload, list): + return [UMessage(msg) for msg in payload] + return payload + else: + payload = UMessage(message.get_payload(index, decode)) + return payload + + def get_content_maintype(self): + return self.message.get_content_maintype() + + def get_content_type(self): + return self.message.get_content_type() + + def get_filename(self, failobj=None): + return self.message.get_filename(failobj) + + # other convenience methods ############################################### + + def headers(self): + """return an unicode string containing all the message's headers""" + values = [] + for header in self.message.keys(): + values.append(u'%s: %s' % (header, self.get(header))) + return '\n'.join(values) + + def multi_addrs(self, header): + """return a list of 2-uple (name, address) for the given address (which + is expected to be an header containing address such as from, to, cc...) + """ + persons = [] + for person in self.get_all(header, ()): + name, mail = parseaddr(person) + persons.append((name, mail)) + return persons + + def date(self, alternative_source=False, return_str=False): + """return a datetime object for the email's date or None if no date is + set or if it can't be parsed + """ + value = self.get('date') + if value is None and alternative_source: + unix_from = self.message.get_unixfrom() + if unix_from is not None: + try: + value = unix_from.split(" ", 2)[2] + except IndexError: + pass + if value is not None: + datetuple = parsedate(value) + if datetuple: + if lgc.USE_MX_DATETIME: + return DateTime(*datetuple[:6]) + return datetime(*datetuple[:6]) + elif not return_str: + return None + return value diff --git a/logilab/common/ureports/__init__.py b/logilab/common/ureports/__init__.py new file mode 100644 index 0000000..d76ebe5 --- /dev/null +++ b/logilab/common/ureports/__init__.py @@ -0,0 +1,172 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Universal report objects and some formatting drivers. + +A way to create simple reports using python objects, primarily designed to be +formatted as text and html. +""" +__docformat__ = "restructuredtext en" + +import sys + +from logilab.common.compat import StringIO +from logilab.common.textutils import linesep + + +def get_nodes(node, klass): + """return an iterator on all children node of the given klass""" + for child in node.children: + if isinstance(child, klass): + yield child + # recurse (FIXME: recursion controled by an option) + for grandchild in get_nodes(child, klass): + yield grandchild + +def layout_title(layout): + """try to return the layout's title as string, return None if not found + """ + for child in layout.children: + if isinstance(child, Title): + return u' '.join([node.data for node in get_nodes(child, Text)]) + +def build_summary(layout, level=1): + """make a summary for the report, including X level""" + assert level > 0 + level -= 1 + summary = List(klass=u'summary') + for child in layout.children: + if not isinstance(child, Section): + continue + label = layout_title(child) + if not label and not child.id: + continue + if not child.id: + child.id = label.replace(' ', '-') + node = Link(u'#'+child.id, label=label or child.id) + # FIXME: Three following lines produce not very compliant + # docbook: there are some useless <para><para>. They might be + # replaced by the three commented lines but this then produces + # a bug in html display... + if level and [n for n in child.children if isinstance(n, Section)]: + node = Paragraph([node, build_summary(child, level)]) + summary.append(node) +# summary.append(node) +# if level and [n for n in child.children if isinstance(n, Section)]: +# summary.append(build_summary(child, level)) + return summary + + +class BaseWriter(object): + """base class for ureport writers""" + + def format(self, layout, stream=None, encoding=None): + """format and write the given layout into the stream object + + unicode policy: unicode strings may be found in the layout; + try to call stream.write with it, but give it back encoded using + the given encoding if it fails + """ + if stream is None: + stream = sys.stdout + if not encoding: + encoding = getattr(stream, 'encoding', 'UTF-8') + self.encoding = encoding or 'UTF-8' + self.__compute_funcs = [] + self.out = stream + self.begin_format(layout) + layout.accept(self) + self.end_format(layout) + + def format_children(self, layout): + """recurse on the layout children and call their accept method + (see the Visitor pattern) + """ + for child in getattr(layout, 'children', ()): + child.accept(self) + + def writeln(self, string=u''): + """write a line in the output buffer""" + self.write(string + linesep) + + def write(self, string): + """write a string in the output buffer""" + try: + self.out.write(string) + except UnicodeEncodeError: + self.out.write(string.encode(self.encoding)) + + def begin_format(self, layout): + """begin to format a layout""" + self.section = 0 + + def end_format(self, layout): + """finished to format a layout""" + + def get_table_content(self, table): + """trick to get table content without actually writing it + + return an aligned list of lists containing table cells values as string + """ + result = [[]] + cols = table.cols + for cell in self.compute_content(table): + if cols == 0: + result.append([]) + cols = table.cols + cols -= 1 + result[-1].append(cell) + # fill missing cells + while len(result[-1]) < cols: + result[-1].append(u'') + return result + + def compute_content(self, layout): + """trick to compute the formatting of children layout before actually + writing it + + return an iterator on strings (one for each child element) + """ + # use cells ! + def write(data): + try: + stream.write(data) + except UnicodeEncodeError: + stream.write(data.encode(self.encoding)) + def writeln(data=u''): + try: + stream.write(data+linesep) + except UnicodeEncodeError: + stream.write(data.encode(self.encoding)+linesep) + self.write = write + self.writeln = writeln + self.__compute_funcs.append((write, writeln)) + for child in layout.children: + stream = StringIO() + child.accept(self) + yield stream.getvalue() + self.__compute_funcs.pop() + try: + self.write, self.writeln = self.__compute_funcs[-1] + except IndexError: + del self.write + del self.writeln + + +from logilab.common.ureports.nodes import * +from logilab.common.ureports.text_writer import TextWriter +from logilab.common.ureports.html_writer import HTMLWriter diff --git a/logilab/common/ureports/docbook_writer.py b/logilab/common/ureports/docbook_writer.py new file mode 100644 index 0000000..857068c --- /dev/null +++ b/logilab/common/ureports/docbook_writer.py @@ -0,0 +1,140 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""HTML formatting drivers for ureports""" +__docformat__ = "restructuredtext en" + +from six.moves import range + +from logilab.common.ureports import HTMLWriter + +class DocbookWriter(HTMLWriter): + """format layouts as HTML""" + + def begin_format(self, layout): + """begin to format a layout""" + super(HTMLWriter, self).begin_format(layout) + if self.snippet is None: + self.writeln('<?xml version="1.0" encoding="ISO-8859-1"?>') + self.writeln(""" +<book xmlns:xi='http://www.w3.org/2001/XInclude' + lang='fr'> +""") + + def end_format(self, layout): + """finished to format a layout""" + if self.snippet is None: + self.writeln('</book>') + + def visit_section(self, layout): + """display a section (using <chapter> (level 0) or <section>)""" + if self.section == 0: + tag = "chapter" + else: + tag = "section" + self.section += 1 + self.writeln(self._indent('<%s%s>' % (tag, self.handle_attrs(layout)))) + self.format_children(layout) + self.writeln(self._indent('</%s>'% tag)) + self.section -= 1 + + def visit_title(self, layout): + """display a title using <title>""" + self.write(self._indent(' <title%s>' % self.handle_attrs(layout))) + self.format_children(layout) + self.writeln('</title>') + + def visit_table(self, layout): + """display a table as html""" + self.writeln(self._indent(' <table%s><title>%s</title>' \ + % (self.handle_attrs(layout), layout.title))) + self.writeln(self._indent(' <tgroup cols="%s">'% layout.cols)) + for i in range(layout.cols): + self.writeln(self._indent(' <colspec colname="c%s" colwidth="1*"/>' % i)) + + table_content = self.get_table_content(layout) + # write headers + if layout.cheaders: + self.writeln(self._indent(' <thead>')) + self._write_row(table_content[0]) + self.writeln(self._indent(' </thead>')) + table_content = table_content[1:] + elif layout.rcheaders: + self.writeln(self._indent(' <thead>')) + self._write_row(table_content[-1]) + self.writeln(self._indent(' </thead>')) + table_content = table_content[:-1] + # write body + self.writeln(self._indent(' <tbody>')) + for i in range(len(table_content)): + row = table_content[i] + self.writeln(self._indent(' <row>')) + for j in range(len(row)): + cell = row[j] or ' ' + self.writeln(self._indent(' <entry>%s</entry>' % cell)) + self.writeln(self._indent(' </row>')) + self.writeln(self._indent(' </tbody>')) + self.writeln(self._indent(' </tgroup>')) + self.writeln(self._indent(' </table>')) + + def _write_row(self, row): + """write content of row (using <row> <entry>)""" + self.writeln(' <row>') + for j in range(len(row)): + cell = row[j] or ' ' + self.writeln(' <entry>%s</entry>' % cell) + self.writeln(self._indent(' </row>')) + + def visit_list(self, layout): + """display a list (using <itemizedlist>)""" + self.writeln(self._indent(' <itemizedlist%s>' % self.handle_attrs(layout))) + for row in list(self.compute_content(layout)): + self.writeln(' <listitem><para>%s</para></listitem>' % row) + self.writeln(self._indent(' </itemizedlist>')) + + def visit_paragraph(self, layout): + """display links (using <para>)""" + self.write(self._indent(' <para>')) + self.format_children(layout) + self.writeln('</para>') + + def visit_span(self, layout): + """display links (using <p>)""" + #TODO: translate in docbook + self.write('<literal %s>' % self.handle_attrs(layout)) + self.format_children(layout) + self.write('</literal>') + + def visit_link(self, layout): + """display links (using <ulink>)""" + self.write('<ulink url="%s"%s>%s</ulink>' % (layout.url, + self.handle_attrs(layout), + layout.label)) + + def visit_verbatimtext(self, layout): + """display verbatim text (using <programlisting>)""" + self.writeln(self._indent(' <programlisting>')) + self.write(layout.data.replace('&', '&').replace('<', '<')) + self.writeln(self._indent(' </programlisting>')) + + def visit_text(self, layout): + """add some text""" + self.write(layout.data.replace('&', '&').replace('<', '<')) + + def _indent(self, string): + """correctly indent string according to section""" + return ' ' * 2*(self.section) + string diff --git a/logilab/common/ureports/html_writer.py b/logilab/common/ureports/html_writer.py new file mode 100644 index 0000000..eba34ea --- /dev/null +++ b/logilab/common/ureports/html_writer.py @@ -0,0 +1,133 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""HTML formatting drivers for ureports""" +__docformat__ = "restructuredtext en" + +from cgi import escape + +from six.moves import range + +from logilab.common.ureports import BaseWriter + + +class HTMLWriter(BaseWriter): + """format layouts as HTML""" + + def __init__(self, snippet=None): + super(HTMLWriter, self).__init__() + self.snippet = snippet + + def handle_attrs(self, layout): + """get an attribute string from layout member attributes""" + attrs = u'' + klass = getattr(layout, 'klass', None) + if klass: + attrs += u' class="%s"' % klass + nid = getattr(layout, 'id', None) + if nid: + attrs += u' id="%s"' % nid + return attrs + + def begin_format(self, layout): + """begin to format a layout""" + super(HTMLWriter, self).begin_format(layout) + if self.snippet is None: + self.writeln(u'<html>') + self.writeln(u'<body>') + + def end_format(self, layout): + """finished to format a layout""" + if self.snippet is None: + self.writeln(u'</body>') + self.writeln(u'</html>') + + + def visit_section(self, layout): + """display a section as html, using div + h[section level]""" + self.section += 1 + self.writeln(u'<div%s>' % self.handle_attrs(layout)) + self.format_children(layout) + self.writeln(u'</div>') + self.section -= 1 + + def visit_title(self, layout): + """display a title using <hX>""" + self.write(u'<h%s%s>' % (self.section, self.handle_attrs(layout))) + self.format_children(layout) + self.writeln(u'</h%s>' % self.section) + + def visit_table(self, layout): + """display a table as html""" + self.writeln(u'<table%s>' % self.handle_attrs(layout)) + table_content = self.get_table_content(layout) + for i in range(len(table_content)): + row = table_content[i] + if i == 0 and layout.rheaders: + self.writeln(u'<tr class="header">') + elif i+1 == len(table_content) and layout.rrheaders: + self.writeln(u'<tr class="header">') + else: + self.writeln(u'<tr class="%s">' % (i%2 and 'even' or 'odd')) + for j in range(len(row)): + cell = row[j] or u' ' + if (layout.rheaders and i == 0) or \ + (layout.cheaders and j == 0) or \ + (layout.rrheaders and i+1 == len(table_content)) or \ + (layout.rcheaders and j+1 == len(row)): + self.writeln(u'<th>%s</th>' % cell) + else: + self.writeln(u'<td>%s</td>' % cell) + self.writeln(u'</tr>') + self.writeln(u'</table>') + + def visit_list(self, layout): + """display a list as html""" + self.writeln(u'<ul%s>' % self.handle_attrs(layout)) + for row in list(self.compute_content(layout)): + self.writeln(u'<li>%s</li>' % row) + self.writeln(u'</ul>') + + def visit_paragraph(self, layout): + """display links (using <p>)""" + self.write(u'<p>') + self.format_children(layout) + self.write(u'</p>') + + def visit_span(self, layout): + """display links (using <p>)""" + self.write(u'<span%s>' % self.handle_attrs(layout)) + self.format_children(layout) + self.write(u'</span>') + + def visit_link(self, layout): + """display links (using <a>)""" + self.write(u' <a href="%s"%s>%s</a>' % (layout.url, + self.handle_attrs(layout), + layout.label)) + def visit_verbatimtext(self, layout): + """display verbatim text (using <pre>)""" + self.write(u'<pre>') + self.write(layout.data.replace(u'&', u'&').replace(u'<', u'<')) + self.write(u'</pre>') + + def visit_text(self, layout): + """add some text""" + data = layout.data + if layout.escaped: + data = data.replace(u'&', u'&').replace(u'<', u'<') + self.write(data) diff --git a/logilab/common/ureports/nodes.py b/logilab/common/ureports/nodes.py new file mode 100644 index 0000000..a9585b3 --- /dev/null +++ b/logilab/common/ureports/nodes.py @@ -0,0 +1,203 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Micro reports objects. + +A micro report is a tree of layout and content objects. +""" +__docformat__ = "restructuredtext en" + +from logilab.common.tree import VNode + +from six import string_types + +class BaseComponent(VNode): + """base report component + + attributes + * id : the component's optional id + * klass : the component's optional klass + """ + def __init__(self, id=None, klass=None): + VNode.__init__(self, id) + self.klass = klass + +class BaseLayout(BaseComponent): + """base container node + + attributes + * BaseComponent attributes + * children : components in this table (i.e. the table's cells) + """ + def __init__(self, children=(), **kwargs): + super(BaseLayout, self).__init__(**kwargs) + for child in children: + if isinstance(child, BaseComponent): + self.append(child) + else: + self.add_text(child) + + def append(self, child): + """overridden to detect problems easily""" + assert child not in self.parents() + VNode.append(self, child) + + def parents(self): + """return the ancestor nodes""" + assert self.parent is not self + if self.parent is None: + return [] + return [self.parent] + self.parent.parents() + + def add_text(self, text): + """shortcut to add text data""" + self.children.append(Text(text)) + + +# non container nodes ######################################################### + +class Text(BaseComponent): + """a text portion + + attributes : + * BaseComponent attributes + * data : the text value as an encoded or unicode string + """ + def __init__(self, data, escaped=True, **kwargs): + super(Text, self).__init__(**kwargs) + #if isinstance(data, unicode): + # data = data.encode('ascii') + assert isinstance(data, string_types), data.__class__ + self.escaped = escaped + self.data = data + +class VerbatimText(Text): + """a verbatim text, display the raw data + + attributes : + * BaseComponent attributes + * data : the text value as an encoded or unicode string + """ + +class Link(BaseComponent): + """a labelled link + + attributes : + * BaseComponent attributes + * url : the link's target (REQUIRED) + * label : the link's label as a string (use the url by default) + """ + def __init__(self, url, label=None, **kwargs): + super(Link, self).__init__(**kwargs) + assert url + self.url = url + self.label = label or url + + +class Image(BaseComponent): + """an embedded or a single image + + attributes : + * BaseComponent attributes + * filename : the image's filename (REQUIRED) + * stream : the stream object containing the image data (REQUIRED) + * title : the image's optional title + """ + def __init__(self, filename, stream, title=None, **kwargs): + super(Image, self).__init__(**kwargs) + assert filename + assert stream + self.filename = filename + self.stream = stream + self.title = title + + +# container nodes ############################################################# + +class Section(BaseLayout): + """a section + + attributes : + * BaseLayout attributes + + a title may also be given to the constructor, it'll be added + as a first element + a description may also be given to the constructor, it'll be added + as a first paragraph + """ + def __init__(self, title=None, description=None, **kwargs): + super(Section, self).__init__(**kwargs) + if description: + self.insert(0, Paragraph([Text(description)])) + if title: + self.insert(0, Title(children=(title,))) + +class Title(BaseLayout): + """a title + + attributes : + * BaseLayout attributes + + A title must not contains a section nor a paragraph! + """ + +class Span(BaseLayout): + """a title + + attributes : + * BaseLayout attributes + + A span should only contains Text and Link nodes (in-line elements) + """ + +class Paragraph(BaseLayout): + """a simple text paragraph + + attributes : + * BaseLayout attributes + + A paragraph must not contains a section ! + """ + +class Table(BaseLayout): + """some tabular data + + attributes : + * BaseLayout attributes + * cols : the number of columns of the table (REQUIRED) + * rheaders : the first row's elements are table's header + * cheaders : the first col's elements are table's header + * title : the table's optional title + """ + def __init__(self, cols, title=None, + rheaders=0, cheaders=0, rrheaders=0, rcheaders=0, + **kwargs): + super(Table, self).__init__(**kwargs) + assert isinstance(cols, int) + self.cols = cols + self.title = title + self.rheaders = rheaders + self.cheaders = cheaders + self.rrheaders = rrheaders + self.rcheaders = rcheaders + +class List(BaseLayout): + """some list data + + attributes : + * BaseLayout attributes + """ diff --git a/logilab/common/ureports/text_writer.py b/logilab/common/ureports/text_writer.py new file mode 100644 index 0000000..c87613c --- /dev/null +++ b/logilab/common/ureports/text_writer.py @@ -0,0 +1,145 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Text formatting drivers for ureports""" + +from __future__ import print_function + +__docformat__ = "restructuredtext en" + +from six.moves import range + +from logilab.common.textutils import linesep +from logilab.common.ureports import BaseWriter + + +TITLE_UNDERLINES = [u'', u'=', u'-', u'`', u'.', u'~', u'^'] +BULLETS = [u'*', u'-'] + +class TextWriter(BaseWriter): + """format layouts as text + (ReStructured inspiration but not totally handled yet) + """ + def begin_format(self, layout): + super(TextWriter, self).begin_format(layout) + self.list_level = 0 + self.pending_urls = [] + + def visit_section(self, layout): + """display a section as text + """ + self.section += 1 + self.writeln() + self.format_children(layout) + if self.pending_urls: + self.writeln() + for label, url in self.pending_urls: + self.writeln(u'.. _`%s`: %s' % (label, url)) + self.pending_urls = [] + self.section -= 1 + self.writeln() + + def visit_title(self, layout): + title = u''.join(list(self.compute_content(layout))) + self.writeln(title) + try: + self.writeln(TITLE_UNDERLINES[self.section] * len(title)) + except IndexError: + print("FIXME TITLE TOO DEEP. TURNING TITLE INTO TEXT") + + def visit_paragraph(self, layout): + """enter a paragraph""" + self.format_children(layout) + self.writeln() + + def visit_span(self, layout): + """enter a span""" + self.format_children(layout) + + def visit_table(self, layout): + """display a table as text""" + table_content = self.get_table_content(layout) + # get columns width + cols_width = [0]*len(table_content[0]) + for row in table_content: + for index in range(len(row)): + col = row[index] + cols_width[index] = max(cols_width[index], len(col)) + if layout.klass == 'field': + self.field_table(layout, table_content, cols_width) + else: + self.default_table(layout, table_content, cols_width) + self.writeln() + + def default_table(self, layout, table_content, cols_width): + """format a table""" + cols_width = [size+1 for size in cols_width] + format_strings = u' '.join([u'%%-%ss'] * len(cols_width)) + format_strings = format_strings % tuple(cols_width) + format_strings = format_strings.split(' ') + table_linesep = u'\n+' + u'+'.join([u'-'*w for w in cols_width]) + u'+\n' + headsep = u'\n+' + u'+'.join([u'='*w for w in cols_width]) + u'+\n' + # FIXME: layout.cheaders + self.write(table_linesep) + for i in range(len(table_content)): + self.write(u'|') + line = table_content[i] + for j in range(len(line)): + self.write(format_strings[j] % line[j]) + self.write(u'|') + if i == 0 and layout.rheaders: + self.write(headsep) + else: + self.write(table_linesep) + + def field_table(self, layout, table_content, cols_width): + """special case for field table""" + assert layout.cols == 2 + format_string = u'%s%%-%ss: %%s' % (linesep, cols_width[0]) + for field, value in table_content: + self.write(format_string % (field, value)) + + + def visit_list(self, layout): + """display a list layout as text""" + bullet = BULLETS[self.list_level % len(BULLETS)] + indent = ' ' * self.list_level + self.list_level += 1 + for child in layout.children: + self.write(u'%s%s%s ' % (linesep, indent, bullet)) + child.accept(self) + self.list_level -= 1 + + def visit_link(self, layout): + """add a hyperlink""" + if layout.label != layout.url: + self.write(u'`%s`_' % layout.label) + self.pending_urls.append( (layout.label, layout.url) ) + else: + self.write(layout.url) + + def visit_verbatimtext(self, layout): + """display a verbatim layout as text (so difficult ;) + """ + self.writeln(u'::\n') + for line in layout.data.splitlines(): + self.writeln(u' ' + line) + self.writeln() + + def visit_text(self, layout): + """add some text""" + self.write(u'%s' % layout.data) diff --git a/logilab/common/urllib2ext.py b/logilab/common/urllib2ext.py new file mode 100644 index 0000000..339aec0 --- /dev/null +++ b/logilab/common/urllib2ext.py @@ -0,0 +1,89 @@ +from __future__ import print_function + +import logging +import urllib2 + +import kerberos as krb + +class GssapiAuthError(Exception): + """raised on error during authentication process""" + +import re +RGX = re.compile('(?:.*,)*\s*Negotiate\s*([^,]*),?', re.I) + +def get_negociate_value(headers): + for authreq in headers.getheaders('www-authenticate'): + match = RGX.search(authreq) + if match: + return match.group(1) + +class HTTPGssapiAuthHandler(urllib2.BaseHandler): + """Negotiate HTTP authentication using context from GSSAPI""" + + handler_order = 400 # before Digest Auth + + def __init__(self): + self._reset() + + def _reset(self): + self._retried = 0 + self._context = None + + def clean_context(self): + if self._context is not None: + krb.authGSSClientClean(self._context) + + def http_error_401(self, req, fp, code, msg, headers): + try: + if self._retried > 5: + raise urllib2.HTTPError(req.get_full_url(), 401, + "negotiate auth failed", headers, None) + self._retried += 1 + logging.debug('gssapi handler, try %s' % self._retried) + negotiate = get_negociate_value(headers) + if negotiate is None: + logging.debug('no negociate found in a www-authenticate header') + return None + logging.debug('HTTPGssapiAuthHandler: negotiate 1 is %r' % negotiate) + result, self._context = krb.authGSSClientInit("HTTP@%s" % req.get_host()) + if result < 1: + raise GssapiAuthError("HTTPGssapiAuthHandler: init failed with %d" % result) + result = krb.authGSSClientStep(self._context, negotiate) + if result < 0: + raise GssapiAuthError("HTTPGssapiAuthHandler: step 1 failed with %d" % result) + client_response = krb.authGSSClientResponse(self._context) + logging.debug('HTTPGssapiAuthHandler: client response is %s...' % client_response[:10]) + req.add_unredirected_header('Authorization', "Negotiate %s" % client_response) + server_response = self.parent.open(req) + negotiate = get_negociate_value(server_response.info()) + if negotiate is None: + logging.warning('HTTPGssapiAuthHandler: failed to authenticate server') + else: + logging.debug('HTTPGssapiAuthHandler negotiate 2: %s' % negotiate) + result = krb.authGSSClientStep(self._context, negotiate) + if result < 1: + raise GssapiAuthError("HTTPGssapiAuthHandler: step 2 failed with %d" % result) + return server_response + except GssapiAuthError as exc: + logging.error(repr(exc)) + finally: + self.clean_context() + self._reset() + +if __name__ == '__main__': + import sys + # debug + import httplib + httplib.HTTPConnection.debuglevel = 1 + httplib.HTTPSConnection.debuglevel = 1 + # debug + import logging + logging.basicConfig(level=logging.DEBUG) + # handle cookies + import cookielib + cj = cookielib.CookieJar() + ch = urllib2.HTTPCookieProcessor(cj) + # test with url sys.argv[1] + h = HTTPGssapiAuthHandler() + response = urllib2.build_opener(h, ch).open(sys.argv[1]) + print('\nresponse: %s\n--------------\n' % response.code, response.info()) diff --git a/logilab/common/vcgutils.py b/logilab/common/vcgutils.py new file mode 100644 index 0000000..9cd2acd --- /dev/null +++ b/logilab/common/vcgutils.py @@ -0,0 +1,216 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""Functions to generate files readable with Georg Sander's vcg +(Visualization of Compiler Graphs). + +You can download vcg at http://rw4.cs.uni-sb.de/~sander/html/gshome.html +Note that vcg exists as a debian package. + +See vcg's documentation for explanation about the different values that +maybe used for the functions parameters. + + + + +""" +__docformat__ = "restructuredtext en" + +import string + +ATTRS_VAL = { + 'algos': ('dfs', 'tree', 'minbackward', + 'left_to_right', 'right_to_left', + 'top_to_bottom', 'bottom_to_top', + 'maxdepth', 'maxdepthslow', 'mindepth', 'mindepthslow', + 'mindegree', 'minindegree', 'minoutdegree', + 'maxdegree', 'maxindegree', 'maxoutdegree'), + 'booleans': ('yes', 'no'), + 'colors': ('black', 'white', 'blue', 'red', 'green', 'yellow', + 'magenta', 'lightgrey', + 'cyan', 'darkgrey', 'darkblue', 'darkred', 'darkgreen', + 'darkyellow', 'darkmagenta', 'darkcyan', 'gold', + 'lightblue', 'lightred', 'lightgreen', 'lightyellow', + 'lightmagenta', 'lightcyan', 'lilac', 'turquoise', + 'aquamarine', 'khaki', 'purple', 'yellowgreen', 'pink', + 'orange', 'orchid'), + 'shapes': ('box', 'ellipse', 'rhomb', 'triangle'), + 'textmodes': ('center', 'left_justify', 'right_justify'), + 'arrowstyles': ('solid', 'line', 'none'), + 'linestyles': ('continuous', 'dashed', 'dotted', 'invisible'), + } + +# meaning of possible values: +# O -> string +# 1 -> int +# list -> value in list +GRAPH_ATTRS = { + 'title': 0, + 'label': 0, + 'color': ATTRS_VAL['colors'], + 'textcolor': ATTRS_VAL['colors'], + 'bordercolor': ATTRS_VAL['colors'], + 'width': 1, + 'height': 1, + 'borderwidth': 1, + 'textmode': ATTRS_VAL['textmodes'], + 'shape': ATTRS_VAL['shapes'], + 'shrink': 1, + 'stretch': 1, + 'orientation': ATTRS_VAL['algos'], + 'vertical_order': 1, + 'horizontal_order': 1, + 'xspace': 1, + 'yspace': 1, + 'layoutalgorithm': ATTRS_VAL['algos'], + 'late_edge_labels': ATTRS_VAL['booleans'], + 'display_edge_labels': ATTRS_VAL['booleans'], + 'dirty_edge_labels': ATTRS_VAL['booleans'], + 'finetuning': ATTRS_VAL['booleans'], + 'manhattan_edges': ATTRS_VAL['booleans'], + 'smanhattan_edges': ATTRS_VAL['booleans'], + 'port_sharing': ATTRS_VAL['booleans'], + 'edges': ATTRS_VAL['booleans'], + 'nodes': ATTRS_VAL['booleans'], + 'splines': ATTRS_VAL['booleans'], + } +NODE_ATTRS = { + 'title': 0, + 'label': 0, + 'color': ATTRS_VAL['colors'], + 'textcolor': ATTRS_VAL['colors'], + 'bordercolor': ATTRS_VAL['colors'], + 'width': 1, + 'height': 1, + 'borderwidth': 1, + 'textmode': ATTRS_VAL['textmodes'], + 'shape': ATTRS_VAL['shapes'], + 'shrink': 1, + 'stretch': 1, + 'vertical_order': 1, + 'horizontal_order': 1, + } +EDGE_ATTRS = { + 'sourcename': 0, + 'targetname': 0, + 'label': 0, + 'linestyle': ATTRS_VAL['linestyles'], + 'class': 1, + 'thickness': 0, + 'color': ATTRS_VAL['colors'], + 'textcolor': ATTRS_VAL['colors'], + 'arrowcolor': ATTRS_VAL['colors'], + 'backarrowcolor': ATTRS_VAL['colors'], + 'arrowsize': 1, + 'backarrowsize': 1, + 'arrowstyle': ATTRS_VAL['arrowstyles'], + 'backarrowstyle': ATTRS_VAL['arrowstyles'], + 'textmode': ATTRS_VAL['textmodes'], + 'priority': 1, + 'anchor': 1, + 'horizontal_order': 1, + } + + +# Misc utilities ############################################################### + +def latin_to_vcg(st): + """Convert latin characters using vcg escape sequence. + """ + for char in st: + if char not in string.ascii_letters: + try: + num = ord(char) + if num >= 192: + st = st.replace(char, r'\fi%d'%ord(char)) + except: + pass + return st + + +class VCGPrinter: + """A vcg graph writer. + """ + + def __init__(self, output_stream): + self._stream = output_stream + self._indent = '' + + def open_graph(self, **args): + """open a vcg graph + """ + self._stream.write('%sgraph:{\n'%self._indent) + self._inc_indent() + self._write_attributes(GRAPH_ATTRS, **args) + + def close_graph(self): + """close a vcg graph + """ + self._dec_indent() + self._stream.write('%s}\n'%self._indent) + + + def node(self, title, **args): + """draw a node + """ + self._stream.write('%snode: {title:"%s"' % (self._indent, title)) + self._write_attributes(NODE_ATTRS, **args) + self._stream.write('}\n') + + + def edge(self, from_node, to_node, edge_type='', **args): + """draw an edge from a node to another. + """ + self._stream.write( + '%s%sedge: {sourcename:"%s" targetname:"%s"' % ( + self._indent, edge_type, from_node, to_node)) + self._write_attributes(EDGE_ATTRS, **args) + self._stream.write('}\n') + + + # private ################################################################## + + def _write_attributes(self, attributes_dict, **args): + """write graph, node or edge attributes + """ + for key, value in args.items(): + try: + _type = attributes_dict[key] + except KeyError: + raise Exception('''no such attribute %s +possible attributes are %s''' % (key, attributes_dict.keys())) + + if not _type: + self._stream.write('%s%s:"%s"\n' % (self._indent, key, value)) + elif _type == 1: + self._stream.write('%s%s:%s\n' % (self._indent, key, + int(value))) + elif value in _type: + self._stream.write('%s%s:%s\n' % (self._indent, key, value)) + else: + raise Exception('''value %s isn\'t correct for attribute %s +correct values are %s''' % (value, key, _type)) + + def _inc_indent(self): + """increment indentation + """ + self._indent = ' %s' % self._indent + + def _dec_indent(self): + """decrement indentation + """ + self._indent = self._indent[:-2] diff --git a/logilab/common/visitor.py b/logilab/common/visitor.py new file mode 100644 index 0000000..ed2b70f --- /dev/null +++ b/logilab/common/visitor.py @@ -0,0 +1,109 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""A generic visitor abstract implementation. + + + + +""" +__docformat__ = "restructuredtext en" + +def no_filter(_): + return 1 + +# Iterators ################################################################### +class FilteredIterator(object): + + def __init__(self, node, list_func, filter_func=None): + self._next = [(node, 0)] + if filter_func is None: + filter_func = no_filter + self._list = list_func(node, filter_func) + + def __next__(self): + try: + return self._list.pop(0) + except : + return None + + next = __next__ + +# Base Visitor ################################################################ +class Visitor(object): + + def __init__(self, iterator_class, filter_func=None): + self._iter_class = iterator_class + self.filter = filter_func + + def visit(self, node, *args, **kargs): + """ + launch the visit on a given node + + call 'open_visit' before the beginning of the visit, with extra args + given + when all nodes have been visited, call the 'close_visit' method + """ + self.open_visit(node, *args, **kargs) + return self.close_visit(self._visit(node)) + + def _visit(self, node): + iterator = self._get_iterator(node) + n = next(iterator) + while n: + result = n.accept(self) + n = next(iterator) + return result + + def _get_iterator(self, node): + return self._iter_class(node, self.filter) + + def open_visit(self, *args, **kargs): + """ + method called at the beginning of the visit + """ + pass + + def close_visit(self, result): + """ + method called at the end of the visit + """ + return result + +# standard visited mixin ###################################################### +class VisitedMixIn(object): + """ + Visited interface allow node visitors to use the node + """ + def get_visit_name(self): + """ + return the visit name for the mixed class. When calling 'accept', the + method <'visit_' + name returned by this method> will be called on the + visitor + """ + try: + return self.TYPE.replace('-', '_') + except: + return self.__class__.__name__.lower() + + def accept(self, visitor, *args, **kwargs): + func = getattr(visitor, 'visit_%s' % self.get_visit_name()) + return func(self, *args, **kwargs) + + def leave(self, visitor, *args, **kwargs): + func = getattr(visitor, 'leave_%s' % self.get_visit_name()) + return func(self, *args, **kwargs) diff --git a/logilab/common/xmlrpcutils.py b/logilab/common/xmlrpcutils.py new file mode 100644 index 0000000..1d30d82 --- /dev/null +++ b/logilab/common/xmlrpcutils.py @@ -0,0 +1,131 @@ +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""XML-RPC utilities.""" +__docformat__ = "restructuredtext en" + +import xmlrpclib +from base64 import encodestring +#from cStringIO import StringIO + +ProtocolError = xmlrpclib.ProtocolError + +## class BasicAuthTransport(xmlrpclib.Transport): +## def __init__(self, username=None, password=None): +## self.username = username +## self.password = password +## self.verbose = None +## self.has_ssl = httplib.__dict__.has_key("HTTPConnection") + +## def request(self, host, handler, request_body, verbose=None): +## # issue XML-RPC request +## if self.has_ssl: +## if host.startswith("https:"): h = httplib.HTTPSConnection(host) +## else: h = httplib.HTTPConnection(host) +## else: h = httplib.HTTP(host) + +## h.putrequest("POST", handler) + +## # required by HTTP/1.1 +## if not self.has_ssl: # HTTPConnection already does 1.1 +## h.putheader("Host", host) +## h.putheader("Connection", "close") + +## if request_body: h.send(request_body) +## if self.has_ssl: +## response = h.getresponse() +## if response.status != 200: +## raise xmlrpclib.ProtocolError(host + handler, +## response.status, +## response.reason, +## response.msg) +## file = response.fp +## else: +## errcode, errmsg, headers = h.getreply() +## if errcode != 200: +## raise xmlrpclib.ProtocolError(host + handler, errcode, +## errmsg, headers) + +## file = h.getfile() + +## return self.parse_response(file) + + + +class AuthMixin: + """basic http authentication mixin for xmlrpc transports""" + + def __init__(self, username, password, encoding): + self.verbose = 0 + self.username = username + self.password = password + self.encoding = encoding + + def request(self, host, handler, request_body, verbose=0): + """issue XML-RPC request""" + h = self.make_connection(host) + h.putrequest("POST", handler) + # required by XML-RPC + h.putheader("User-Agent", self.user_agent) + h.putheader("Content-Type", "text/xml") + h.putheader("Content-Length", str(len(request_body))) + h.putheader("Host", host) + h.putheader("Connection", "close") + # basic auth + if self.username is not None and self.password is not None: + h.putheader("AUTHORIZATION", "Basic %s" % encodestring( + "%s:%s" % (self.username, self.password)).replace("\012", "")) + h.endheaders() + # send body + if request_body: + h.send(request_body) + # get and check reply + errcode, errmsg, headers = h.getreply() + if errcode != 200: + raise ProtocolError(host + handler, errcode, errmsg, headers) + file = h.getfile() +## # FIXME: encoding ??? iirc, this fix a bug in xmlrpclib but... +## data = h.getfile().read() +## if self.encoding != 'UTF-8': +## data = data.replace("version='1.0'", +## "version='1.0' encoding='%s'" % self.encoding) +## result = StringIO() +## result.write(data) +## result.seek(0) +## return self.parse_response(result) + return self.parse_response(file) + +class BasicAuthTransport(AuthMixin, xmlrpclib.Transport): + """basic http authentication transport""" + +class BasicAuthSafeTransport(AuthMixin, xmlrpclib.SafeTransport): + """basic https authentication transport""" + + +def connect(url, user=None, passwd=None, encoding='ISO-8859-1'): + """return an xml rpc server on <url>, using user / password if specified + """ + if user or passwd: + assert user and passwd is not None + if url.startswith('https://'): + transport = BasicAuthSafeTransport(user, passwd, encoding) + else: + transport = BasicAuthTransport(user, passwd, encoding) + else: + transport = None + server = xmlrpclib.ServerProxy(url, transport, encoding=encoding) + return server diff --git a/logilab/common/xmlutils.py b/logilab/common/xmlutils.py new file mode 100644 index 0000000..d383b9d --- /dev/null +++ b/logilab/common/xmlutils.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved. +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This file is part of logilab-common. +# +# logilab-common is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the Free +# Software Foundation, either version 2.1 of the License, or (at your option) any +# later version. +# +# logilab-common is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License along +# with logilab-common. If not, see <http://www.gnu.org/licenses/>. +"""XML utilities. + +This module contains useful functions for parsing and using XML data. For the +moment, there is only one function that can parse the data inside a processing +instruction and return a Python dictionary. + + + + +""" +__docformat__ = "restructuredtext en" + +import re + +RE_DOUBLE_QUOTE = re.compile('([\w\-\.]+)="([^"]+)"') +RE_SIMPLE_QUOTE = re.compile("([\w\-\.]+)='([^']+)'") + +def parse_pi_data(pi_data): + """ + Utility function that parses the data contained in an XML + processing instruction and returns a dictionary of keywords and their + associated values (most of the time, the processing instructions contain + data like ``keyword="value"``, if a keyword is not associated to a value, + for example ``keyword``, it will be associated to ``None``). + + :param pi_data: data contained in an XML processing instruction. + :type pi_data: unicode + + :returns: Dictionary of the keywords (Unicode strings) associated to + their values (Unicode strings) as they were defined in the + data. + :rtype: dict + """ + results = {} + for elt in pi_data.split(): + if RE_DOUBLE_QUOTE.match(elt): + kwd, val = RE_DOUBLE_QUOTE.match(elt).groups() + elif RE_SIMPLE_QUOTE.match(elt): + kwd, val = RE_SIMPLE_QUOTE.match(elt).groups() + else: + kwd, val = elt, None + results[kwd] = val + return results |