summaryrefslogtreecommitdiff
path: root/logilab/common
diff options
context:
space:
mode:
authorDavid Douard <david.douard@logilab.fr>2015-03-13 15:18:12 +0100
committerDavid Douard <david.douard@logilab.fr>2015-03-13 15:18:12 +0100
commit84ba0c13c480f1e0fb3853caa6bc8ee48dd13178 (patch)
tree61ef71cc521fdba98a5b496029caa009e346ec88 /logilab/common
parentb95ae183478e43f8a2229d6cbdfe79e389c0f6e3 (diff)
downloadlogilab-common-84ba0c13c480f1e0fb3853caa6bc8ee48dd13178.tar.gz
[layout] change the source directory layout
The logilab.common package now lives in a logilab/common directory to make it pip compliant. Related to #294479.
Diffstat (limited to 'logilab/common')
-rw-r--r--logilab/common/__init__.py175
-rw-r--r--logilab/common/__pkginfo__.py57
-rw-r--r--logilab/common/cache.py114
-rw-r--r--logilab/common/changelog.py238
-rw-r--r--logilab/common/clcommands.py334
-rw-r--r--logilab/common/cli.py211
-rw-r--r--logilab/common/compat.py78
-rw-r--r--logilab/common/configuration.py1105
-rw-r--r--logilab/common/contexts.py5
-rw-r--r--logilab/common/corbautils.py117
-rw-r--r--logilab/common/daemon.py101
-rw-r--r--logilab/common/date.py335
-rw-r--r--logilab/common/dbf.py231
-rw-r--r--logilab/common/debugger.py214
-rw-r--r--logilab/common/decorators.py281
-rw-r--r--logilab/common/deprecation.py189
-rw-r--r--logilab/common/fileutils.py404
-rw-r--r--logilab/common/graph.py282
-rw-r--r--logilab/common/interface.py71
-rw-r--r--logilab/common/logging_ext.py195
-rw-r--r--logilab/common/modutils.py702
-rw-r--r--logilab/common/optik_ext.py392
-rw-r--r--logilab/common/optparser.py92
-rw-r--r--logilab/common/proc.py277
-rw-r--r--logilab/common/pyro_ext.py180
-rw-r--r--logilab/common/pytest.py1199
-rw-r--r--logilab/common/registry.py1119
-rw-r--r--logilab/common/shellutils.py462
-rw-r--r--logilab/common/sphinx_ext.py87
-rw-r--r--logilab/common/sphinxutils.py122
-rw-r--r--logilab/common/table.py929
-rw-r--r--logilab/common/tasksqueue.py101
-rw-r--r--logilab/common/testlib.py1392
-rw-r--r--logilab/common/textutils.py537
-rw-r--r--logilab/common/tree.py369
-rw-r--r--logilab/common/umessage.py194
-rw-r--r--logilab/common/ureports/__init__.py172
-rw-r--r--logilab/common/ureports/docbook_writer.py140
-rw-r--r--logilab/common/ureports/html_writer.py133
-rw-r--r--logilab/common/ureports/nodes.py203
-rw-r--r--logilab/common/ureports/text_writer.py145
-rw-r--r--logilab/common/urllib2ext.py89
-rw-r--r--logilab/common/vcgutils.py216
-rw-r--r--logilab/common/visitor.py109
-rw-r--r--logilab/common/xmlrpcutils.py131
-rw-r--r--logilab/common/xmlutils.py61
46 files changed, 14290 insertions, 0 deletions
diff --git a/logilab/common/__init__.py b/logilab/common/__init__.py
new file mode 100644
index 0000000..6a7c8b4
--- /dev/null
+++ b/logilab/common/__init__.py
@@ -0,0 +1,175 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Logilab common library (aka Logilab's extension to the standard library).
+
+:type STD_BLACKLIST: tuple
+:var STD_BLACKLIST: directories ignored by default by the functions in
+ this package which have to recurse into directories
+
+:type IGNORED_EXTENSIONS: tuple
+:var IGNORED_EXTENSIONS: file extensions that may usually be ignored
+"""
+__docformat__ = "restructuredtext en"
+
+
+from logilab.common.__pkginfo__ import version as __version__
+
+STD_BLACKLIST = ('CVS', '.svn', '.hg', 'debian', 'dist', 'build')
+
+IGNORED_EXTENSIONS = ('.pyc', '.pyo', '.elc', '~', '.swp', '.orig')
+
+# set this to False if you've mx DateTime installed but you don't want your db
+# adapter to use it (should be set before you got a connection)
+USE_MX_DATETIME = True
+
+
+class attrdict(dict):
+ """A dictionary for which keys are also accessible as attributes."""
+ def __getattr__(self, attr):
+ try:
+ return self[attr]
+ except KeyError:
+ raise AttributeError(attr)
+
+class dictattr(dict):
+ def __init__(self, proxy):
+ self.__proxy = proxy
+
+ def __getitem__(self, attr):
+ try:
+ return getattr(self.__proxy, attr)
+ except AttributeError:
+ raise KeyError(attr)
+
+class nullobject(object):
+ def __repr__(self):
+ return '<nullobject>'
+ def __bool__(self):
+ return False
+ __nonzero__ = __bool__
+
+class tempattr(object):
+ def __init__(self, obj, attr, value):
+ self.obj = obj
+ self.attr = attr
+ self.value = value
+
+ def __enter__(self):
+ self.oldvalue = getattr(self.obj, self.attr)
+ setattr(self.obj, self.attr, self.value)
+ return self.obj
+
+ def __exit__(self, exctype, value, traceback):
+ setattr(self.obj, self.attr, self.oldvalue)
+
+
+
+# flatten -----
+# XXX move in a specific module and use yield instead
+# do not mix flatten and translate
+#
+# def iterable(obj):
+# try: iter(obj)
+# except: return False
+# return True
+#
+# def is_string_like(obj):
+# try: obj +''
+# except (TypeError, ValueError): return False
+# return True
+#
+#def is_scalar(obj):
+# return is_string_like(obj) or not iterable(obj)
+#
+#def flatten(seq):
+# for item in seq:
+# if is_scalar(item):
+# yield item
+# else:
+# for subitem in flatten(item):
+# yield subitem
+
+def flatten(iterable, tr_func=None, results=None):
+ """Flatten a list of list with any level.
+
+ If tr_func is not None, it should be a one argument function that'll be called
+ on each final element.
+
+ :rtype: list
+
+ >>> flatten([1, [2, 3]])
+ [1, 2, 3]
+ """
+ if results is None:
+ results = []
+ for val in iterable:
+ if isinstance(val, (list, tuple)):
+ flatten(val, tr_func, results)
+ elif tr_func is None:
+ results.append(val)
+ else:
+ results.append(tr_func(val))
+ return results
+
+
+# XXX is function below still used ?
+
+def make_domains(lists):
+ """
+ Given a list of lists, return a list of domain for each list to produce all
+ combinations of possibles values.
+
+ :rtype: list
+
+ Example:
+
+ >>> make_domains(['a', 'b'], ['c','d', 'e'])
+ [['a', 'b', 'a', 'b', 'a', 'b'], ['c', 'c', 'd', 'd', 'e', 'e']]
+ """
+ from six.moves import range
+ domains = []
+ for iterable in lists:
+ new_domain = iterable[:]
+ for i in range(len(domains)):
+ domains[i] = domains[i]*len(iterable)
+ if domains:
+ missing = (len(domains[0]) - len(iterable)) / len(iterable)
+ i = 0
+ for j in range(len(iterable)):
+ value = iterable[j]
+ for dummy in range(missing):
+ new_domain.insert(i, value)
+ i += 1
+ i += 1
+ domains.append(new_domain)
+ return domains
+
+
+# private stuff ################################################################
+
+def _handle_blacklist(blacklist, dirnames, filenames):
+ """remove files/directories in the black list
+
+ dirnames/filenames are usually from os.walk
+ """
+ for norecurs in blacklist:
+ if norecurs in dirnames:
+ dirnames.remove(norecurs)
+ elif norecurs in filenames:
+ filenames.remove(norecurs)
+
diff --git a/logilab/common/__pkginfo__.py b/logilab/common/__pkginfo__.py
new file mode 100644
index 0000000..55a2cc3
--- /dev/null
+++ b/logilab/common/__pkginfo__.py
@@ -0,0 +1,57 @@
+# copyright 2003-2014 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""logilab.common packaging information"""
+__docformat__ = "restructuredtext en"
+import sys
+import os
+
+distname = 'logilab-common'
+modname = 'common'
+subpackage_of = 'logilab'
+subpackage_master = True
+
+numversion = (0, 63, 2)
+version = '.'.join([str(num) for num in numversion])
+
+license = 'LGPL' # 2.1 or later
+description = "collection of low-level Python packages and modules used by Logilab projects"
+web = "http://www.logilab.org/project/%s" % distname
+mailinglist = "mailto://python-projects@lists.logilab.org"
+author = "Logilab"
+author_email = "contact@logilab.fr"
+
+
+from os.path import join
+scripts = [join('bin', 'pytest')]
+include_dirs = [join('test', 'data')]
+
+install_requires = [
+ 'six >= 1.4.0',
+ ]
+test_require = ['pytz']
+
+if sys.version_info < (2, 7):
+ install_requires.append('unittest2 >= 0.5.1')
+if os.name == 'nt':
+ install_requires.append('colorama')
+
+classifiers = ["Topic :: Utilities",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 2",
+ "Programming Language :: Python :: 3",
+ ]
diff --git a/logilab/common/cache.py b/logilab/common/cache.py
new file mode 100644
index 0000000..11ed137
--- /dev/null
+++ b/logilab/common/cache.py
@@ -0,0 +1,114 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Cache module, with a least recently used algorithm for the management of the
+deletion of entries.
+
+
+
+
+"""
+__docformat__ = "restructuredtext en"
+
+from threading import Lock
+
+from logilab.common.decorators import locked
+
+_marker = object()
+
+class Cache(dict):
+ """A dictionary like cache.
+
+ inv:
+ len(self._usage) <= self.size
+ len(self.data) <= self.size
+ """
+
+ def __init__(self, size=100):
+ """ Warning : Cache.__init__() != dict.__init__().
+ Constructor does not take any arguments beside size.
+ """
+ assert size >= 0, 'cache size must be >= 0 (0 meaning no caching)'
+ self.size = size
+ self._usage = []
+ self._lock = Lock()
+ super(Cache, self).__init__()
+
+ def _acquire(self):
+ self._lock.acquire()
+
+ def _release(self):
+ self._lock.release()
+
+ def _update_usage(self, key):
+ if not self._usage:
+ self._usage.append(key)
+ elif self._usage[-1] != key:
+ try:
+ self._usage.remove(key)
+ except ValueError:
+ # we are inserting a new key
+ # check the size of the dictionary
+ # and remove the oldest item in the cache
+ if self.size and len(self._usage) >= self.size:
+ super(Cache, self).__delitem__(self._usage[0])
+ del self._usage[0]
+ self._usage.append(key)
+ else:
+ pass # key is already the most recently used key
+
+ def __getitem__(self, key):
+ value = super(Cache, self).__getitem__(key)
+ self._update_usage(key)
+ return value
+ __getitem__ = locked(_acquire, _release)(__getitem__)
+
+ def __setitem__(self, key, item):
+ # Just make sure that size > 0 before inserting a new item in the cache
+ if self.size > 0:
+ super(Cache, self).__setitem__(key, item)
+ self._update_usage(key)
+ __setitem__ = locked(_acquire, _release)(__setitem__)
+
+ def __delitem__(self, key):
+ super(Cache, self).__delitem__(key)
+ self._usage.remove(key)
+ __delitem__ = locked(_acquire, _release)(__delitem__)
+
+ def clear(self):
+ super(Cache, self).clear()
+ self._usage = []
+ clear = locked(_acquire, _release)(clear)
+
+ def pop(self, key, default=_marker):
+ if key in self:
+ self._usage.remove(key)
+ #if default is _marker:
+ # return super(Cache, self).pop(key)
+ return super(Cache, self).pop(key, default)
+ pop = locked(_acquire, _release)(pop)
+
+ def popitem(self):
+ raise NotImplementedError()
+
+ def setdefault(self, key, default=None):
+ raise NotImplementedError()
+
+ def update(self, other):
+ raise NotImplementedError()
+
+
diff --git a/logilab/common/changelog.py b/logilab/common/changelog.py
new file mode 100644
index 0000000..2fff2ed
--- /dev/null
+++ b/logilab/common/changelog.py
@@ -0,0 +1,238 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Manipulation of upstream change log files.
+
+The upstream change log files format handled is simpler than the one
+often used such as those generated by the default Emacs changelog mode.
+
+Sample ChangeLog format::
+
+ Change log for project Yoo
+ ==========================
+
+ --
+ * add a new functionality
+
+ 2002-02-01 -- 0.1.1
+ * fix bug #435454
+ * fix bug #434356
+
+ 2002-01-01 -- 0.1
+ * initial release
+
+
+There is 3 entries in this change log, one for each released version and one
+for the next version (i.e. the current entry).
+Each entry contains a set of messages corresponding to changes done in this
+release.
+All the non empty lines before the first entry are considered as the change
+log title.
+"""
+
+__docformat__ = "restructuredtext en"
+
+import sys
+from stat import S_IWRITE
+
+from six import string_types
+
+BULLET = '*'
+SUBBULLET = '-'
+INDENT = ' ' * 4
+
+class NoEntry(Exception):
+ """raised when we are unable to find an entry"""
+
+class EntryNotFound(Exception):
+ """raised when we are unable to find a given entry"""
+
+class Version(tuple):
+ """simple class to handle soft version number has a tuple while
+ correctly printing it as X.Y.Z
+ """
+ def __new__(cls, versionstr):
+ if isinstance(versionstr, string_types):
+ versionstr = versionstr.strip(' :') # XXX (syt) duh?
+ parsed = cls.parse(versionstr)
+ else:
+ parsed = versionstr
+ return tuple.__new__(cls, parsed)
+
+ @classmethod
+ def parse(cls, versionstr):
+ versionstr = versionstr.strip(' :')
+ try:
+ return [int(i) for i in versionstr.split('.')]
+ except ValueError as ex:
+ raise ValueError("invalid literal for version '%s' (%s)"%(versionstr, ex))
+
+ def __str__(self):
+ return '.'.join([str(i) for i in self])
+
+# upstream change log #########################################################
+
+class ChangeLogEntry(object):
+ """a change log entry, i.e. a set of messages associated to a version and
+ its release date
+ """
+ version_class = Version
+
+ def __init__(self, date=None, version=None, **kwargs):
+ self.__dict__.update(kwargs)
+ if version:
+ self.version = self.version_class(version)
+ else:
+ self.version = None
+ self.date = date
+ self.messages = []
+
+ def add_message(self, msg):
+ """add a new message"""
+ self.messages.append(([msg], []))
+
+ def complete_latest_message(self, msg_suite):
+ """complete the latest added message
+ """
+ if not self.messages:
+ raise ValueError('unable to complete last message as there is no previous message)')
+ if self.messages[-1][1]: # sub messages
+ self.messages[-1][1][-1].append(msg_suite)
+ else: # message
+ self.messages[-1][0].append(msg_suite)
+
+ def add_sub_message(self, sub_msg, key=None):
+ if not self.messages:
+ raise ValueError('unable to complete last message as there is no previous message)')
+ if key is None:
+ self.messages[-1][1].append([sub_msg])
+ else:
+ raise NotImplementedError("sub message to specific key are not implemented yet")
+
+ def write(self, stream=sys.stdout):
+ """write the entry to file """
+ stream.write('%s -- %s\n' % (self.date or '', self.version or ''))
+ for msg, sub_msgs in self.messages:
+ stream.write('%s%s %s\n' % (INDENT, BULLET, msg[0]))
+ stream.write(''.join(msg[1:]))
+ if sub_msgs:
+ stream.write('\n')
+ for sub_msg in sub_msgs:
+ stream.write('%s%s %s\n' % (INDENT * 2, SUBBULLET, sub_msg[0]))
+ stream.write(''.join(sub_msg[1:]))
+ stream.write('\n')
+
+ stream.write('\n\n')
+
+class ChangeLog(object):
+ """object representation of a whole ChangeLog file"""
+
+ entry_class = ChangeLogEntry
+
+ def __init__(self, changelog_file, title=''):
+ self.file = changelog_file
+ self.title = title
+ self.additional_content = ''
+ self.entries = []
+ self.load()
+
+ def __repr__(self):
+ return '<ChangeLog %s at %s (%s entries)>' % (self.file, id(self),
+ len(self.entries))
+
+ def add_entry(self, entry):
+ """add a new entry to the change log"""
+ self.entries.append(entry)
+
+ def get_entry(self, version='', create=None):
+ """ return a given changelog entry
+ if version is omitted, return the current entry
+ """
+ if not self.entries:
+ if version or not create:
+ raise NoEntry()
+ self.entries.append(self.entry_class())
+ if not version:
+ if self.entries[0].version and create is not None:
+ self.entries.insert(0, self.entry_class())
+ return self.entries[0]
+ version = self.version_class(version)
+ for entry in self.entries:
+ if entry.version == version:
+ return entry
+ raise EntryNotFound()
+
+ def add(self, msg, create=None):
+ """add a new message to the latest opened entry"""
+ entry = self.get_entry(create=create)
+ entry.add_message(msg)
+
+ def load(self):
+ """ read a logilab's ChangeLog from file """
+ try:
+ stream = open(self.file)
+ except IOError:
+ return
+ last = None
+ expect_sub = False
+ for line in stream.readlines():
+ sline = line.strip()
+ words = sline.split()
+ # if new entry
+ if len(words) == 1 and words[0] == '--':
+ expect_sub = False
+ last = self.entry_class()
+ self.add_entry(last)
+ # if old entry
+ elif len(words) == 3 and words[1] == '--':
+ expect_sub = False
+ last = self.entry_class(words[0], words[2])
+ self.add_entry(last)
+ # if title
+ elif sline and last is None:
+ self.title = '%s%s' % (self.title, line)
+ # if new entry
+ elif sline and sline[0] == BULLET:
+ expect_sub = False
+ last.add_message(sline[1:].strip())
+ # if new sub_entry
+ elif expect_sub and sline and sline[0] == SUBBULLET:
+ last.add_sub_message(sline[1:].strip())
+ # if new line for current entry
+ elif sline and last.messages:
+ last.complete_latest_message(line)
+ else:
+ expect_sub = True
+ self.additional_content += line
+ stream.close()
+
+ def format_title(self):
+ return '%s\n\n' % self.title.strip()
+
+ def save(self):
+ """write back change log"""
+ # filetutils isn't importable in appengine, so import locally
+ from logilab.common.fileutils import ensure_fs_mode
+ ensure_fs_mode(self.file, S_IWRITE)
+ self.write(open(self.file, 'w'))
+
+ def write(self, stream=sys.stdout):
+ """write changelog to stream"""
+ stream.write(self.format_title())
+ for entry in self.entries:
+ entry.write(stream)
+
diff --git a/logilab/common/clcommands.py b/logilab/common/clcommands.py
new file mode 100644
index 0000000..4778b99
--- /dev/null
+++ b/logilab/common/clcommands.py
@@ -0,0 +1,334 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Helper functions to support command line tools providing more than
+one command.
+
+e.g called as "tool command [options] args..." where <options> and <args> are
+command'specific
+"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+import sys
+import logging
+from os.path import basename
+
+from logilab.common.configuration import Configuration
+from logilab.common.logging_ext import init_log, get_threshold
+from logilab.common.deprecation import deprecated
+
+
+class BadCommandUsage(Exception):
+ """Raised when an unknown command is used or when a command is not
+ correctly used (bad options, too much / missing arguments...).
+
+ Trigger display of command usage.
+ """
+
+class CommandError(Exception):
+ """Raised when a command can't be processed and we want to display it and
+ exit, without traceback nor usage displayed.
+ """
+
+
+# command line access point ####################################################
+
+class CommandLine(dict):
+ """Usage:
+
+ >>> LDI = cli.CommandLine('ldi', doc='Logilab debian installer',
+ version=version, rcfile=RCFILE)
+ >>> LDI.register(MyCommandClass)
+ >>> LDI.register(MyOtherCommandClass)
+ >>> LDI.run(sys.argv[1:])
+
+ Arguments:
+
+ * `pgm`, the program name, default to `basename(sys.argv[0])`
+
+ * `doc`, a short description of the command line tool
+
+ * `copyright`, additional doc string that will be appended to the generated
+ doc
+
+ * `version`, version number of string of the tool. If specified, global
+ --version option will be available.
+
+ * `rcfile`, path to a configuration file. If specified, global --C/--rc-file
+ option will be available? self.rcfile = rcfile
+
+ * `logger`, logger to propagate to commands, default to
+ `logging.getLogger(self.pgm))`
+ """
+ def __init__(self, pgm=None, doc=None, copyright=None, version=None,
+ rcfile=None, logthreshold=logging.ERROR,
+ check_duplicated_command=True):
+ if pgm is None:
+ pgm = basename(sys.argv[0])
+ self.pgm = pgm
+ self.doc = doc
+ self.copyright = copyright
+ self.version = version
+ self.rcfile = rcfile
+ self.logger = None
+ self.logthreshold = logthreshold
+ self.check_duplicated_command = check_duplicated_command
+
+ def register(self, cls, force=False):
+ """register the given :class:`Command` subclass"""
+ assert not self.check_duplicated_command or force or not cls.name in self, \
+ 'a command %s is already defined' % cls.name
+ self[cls.name] = cls
+ return cls
+
+ def run(self, args):
+ """main command line access point:
+ * init logging
+ * handle global options (-h/--help, --version, -C/--rc-file)
+ * check command
+ * run command
+
+ Terminate by :exc:`SystemExit`
+ """
+ init_log(debug=True, # so that we use StreamHandler
+ logthreshold=self.logthreshold,
+ logformat='%(levelname)s: %(message)s')
+ try:
+ arg = args.pop(0)
+ except IndexError:
+ self.usage_and_exit(1)
+ if arg in ('-h', '--help'):
+ self.usage_and_exit(0)
+ if self.version is not None and arg in ('--version'):
+ print(self.version)
+ sys.exit(0)
+ rcfile = self.rcfile
+ if rcfile is not None and arg in ('-C', '--rc-file'):
+ try:
+ rcfile = args.pop(0)
+ arg = args.pop(0)
+ except IndexError:
+ self.usage_and_exit(1)
+ try:
+ command = self.get_command(arg)
+ except KeyError:
+ print('ERROR: no %s command' % arg)
+ print()
+ self.usage_and_exit(1)
+ try:
+ sys.exit(command.main_run(args, rcfile))
+ except KeyboardInterrupt as exc:
+ print('Interrupted', end=' ')
+ if str(exc):
+ print(': %s' % exc, end=' ')
+ print()
+ sys.exit(4)
+ except BadCommandUsage as err:
+ print('ERROR:', err)
+ print()
+ print(command.help())
+ sys.exit(1)
+
+ def create_logger(self, handler, logthreshold=None):
+ logger = logging.Logger(self.pgm)
+ logger.handlers = [handler]
+ if logthreshold is None:
+ logthreshold = get_threshold(self.logthreshold)
+ logger.setLevel(logthreshold)
+ return logger
+
+ def get_command(self, cmd, logger=None):
+ if logger is None:
+ logger = self.logger
+ if logger is None:
+ logger = self.logger = logging.getLogger(self.pgm)
+ logger.setLevel(get_threshold(self.logthreshold))
+ return self[cmd](logger)
+
+ def usage(self):
+ """display usage for the main program (i.e. when no command supplied)
+ and exit
+ """
+ print('usage:', self.pgm, end=' ')
+ if self.rcfile:
+ print('[--rc-file=<configuration file>]', end=' ')
+ print('<command> [options] <command argument>...')
+ if self.doc:
+ print('\n%s' % self.doc)
+ print('''
+Type "%(pgm)s <command> --help" for more information about a specific
+command. Available commands are :\n''' % self.__dict__)
+ max_len = max([len(cmd) for cmd in self])
+ padding = ' ' * max_len
+ for cmdname, cmd in sorted(self.items()):
+ if not cmd.hidden:
+ print(' ', (cmdname + padding)[:max_len], cmd.short_description())
+ if self.rcfile:
+ print('''
+Use --rc-file=<configuration file> / -C <configuration file> before the command
+to specify a configuration file. Default to %s.
+''' % self.rcfile)
+ print('''%(pgm)s -h/--help
+ display this usage information and exit''' % self.__dict__)
+ if self.version:
+ print('''%(pgm)s -v/--version
+ display version configuration and exit''' % self.__dict__)
+ if self.copyright:
+ print('\n', self.copyright)
+
+ def usage_and_exit(self, status):
+ self.usage()
+ sys.exit(status)
+
+
+# base command classes #########################################################
+
+class Command(Configuration):
+ """Base class for command line commands.
+
+ Class attributes:
+
+ * `name`, the name of the command
+
+ * `min_args`, minimum number of arguments, None if unspecified
+
+ * `max_args`, maximum number of arguments, None if unspecified
+
+ * `arguments`, string describing arguments, used in command usage
+
+ * `hidden`, boolean flag telling if the command should be hidden, e.g. does
+ not appear in help's commands list
+
+ * `options`, options list, as allowed by :mod:configuration
+ """
+
+ arguments = ''
+ name = ''
+ # hidden from help ?
+ hidden = False
+ # max/min args, None meaning unspecified
+ min_args = None
+ max_args = None
+
+ @classmethod
+ def description(cls):
+ return cls.__doc__.replace(' ', '')
+
+ @classmethod
+ def short_description(cls):
+ return cls.description().split('.')[0]
+
+ def __init__(self, logger):
+ usage = '%%prog %s %s\n\n%s' % (self.name, self.arguments,
+ self.description())
+ Configuration.__init__(self, usage=usage)
+ self.logger = logger
+
+ def check_args(self, args):
+ """check command's arguments are provided"""
+ if self.min_args is not None and len(args) < self.min_args:
+ raise BadCommandUsage('missing argument')
+ if self.max_args is not None and len(args) > self.max_args:
+ raise BadCommandUsage('too many arguments')
+
+ def main_run(self, args, rcfile=None):
+ """Run the command and return status 0 if everything went fine.
+
+ If :exc:`CommandError` is raised by the underlying command, simply log
+ the error and return status 2.
+
+ Any other exceptions, including :exc:`BadCommandUsage` will be
+ propagated.
+ """
+ if rcfile:
+ self.load_file_configuration(rcfile)
+ args = self.load_command_line_configuration(args)
+ try:
+ self.check_args(args)
+ self.run(args)
+ except CommandError as err:
+ self.logger.error(err)
+ return 2
+ return 0
+
+ def run(self, args):
+ """run the command with its specific arguments"""
+ raise NotImplementedError()
+
+
+class ListCommandsCommand(Command):
+ """list available commands, useful for bash completion."""
+ name = 'listcommands'
+ arguments = '[command]'
+ hidden = True
+
+ def run(self, args):
+ """run the command with its specific arguments"""
+ if args:
+ command = args.pop()
+ cmd = _COMMANDS[command]
+ for optname, optdict in cmd.options:
+ print('--help')
+ print('--' + optname)
+ else:
+ commands = sorted(_COMMANDS.keys())
+ for command in commands:
+ cmd = _COMMANDS[command]
+ if not cmd.hidden:
+ print(command)
+
+
+# deprecated stuff #############################################################
+
+_COMMANDS = CommandLine()
+
+DEFAULT_COPYRIGHT = '''\
+Copyright (c) 2004-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+http://www.logilab.fr/ -- mailto:contact@logilab.fr'''
+
+@deprecated('use cls.register(cli)')
+def register_commands(commands):
+ """register existing commands"""
+ for command_klass in commands:
+ _COMMANDS.register(command_klass)
+
+@deprecated('use args.pop(0)')
+def main_run(args, doc=None, copyright=None, version=None):
+ """command line tool: run command specified by argument list (without the
+ program name). Raise SystemExit with status 0 if everything went fine.
+
+ >>> main_run(sys.argv[1:])
+ """
+ _COMMANDS.doc = doc
+ _COMMANDS.copyright = copyright
+ _COMMANDS.version = version
+ _COMMANDS.run(args)
+
+@deprecated('use args.pop(0)')
+def pop_arg(args_list, expected_size_after=None, msg="Missing argument"):
+ """helper function to get and check command line arguments"""
+ try:
+ value = args_list.pop(0)
+ except IndexError:
+ raise BadCommandUsage(msg)
+ if expected_size_after is not None and len(args_list) > expected_size_after:
+ raise BadCommandUsage('too many arguments')
+ return value
+
diff --git a/logilab/common/cli.py b/logilab/common/cli.py
new file mode 100644
index 0000000..cdeef97
--- /dev/null
+++ b/logilab/common/cli.py
@@ -0,0 +1,211 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Command line interface helper classes.
+
+It provides some default commands, a help system, a default readline
+configuration with completion and persistent history.
+
+Example::
+
+ class BookShell(CLIHelper):
+
+ def __init__(self):
+ # quit and help are builtins
+ # CMD_MAP keys are commands, values are topics
+ self.CMD_MAP['pionce'] = _("Sommeil")
+ self.CMD_MAP['ronfle'] = _("Sommeil")
+ CLIHelper.__init__(self)
+
+ help_do_pionce = ("pionce", "pionce duree", _("met ton corps en veille"))
+ def do_pionce(self):
+ print('nap is good')
+
+ help_do_ronfle = ("ronfle", "ronfle volume", _("met les autres en veille"))
+ def do_ronfle(self):
+ print('fuuuuuuuuuuuu rhhhhhrhrhrrh')
+
+ cl = BookShell()
+"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+from six.moves import builtins, input
+
+if not hasattr(builtins, '_'):
+ builtins._ = str
+
+
+def init_readline(complete_method, histfile=None):
+ """Init the readline library if available."""
+ try:
+ import readline
+ readline.parse_and_bind("tab: complete")
+ readline.set_completer(complete_method)
+ string = readline.get_completer_delims().replace(':', '')
+ readline.set_completer_delims(string)
+ if histfile is not None:
+ try:
+ readline.read_history_file(histfile)
+ except IOError:
+ pass
+ import atexit
+ atexit.register(readline.write_history_file, histfile)
+ except:
+ print('readline is not available :-(')
+
+
+class Completer :
+ """Readline completer."""
+
+ def __init__(self, commands):
+ self.list = commands
+
+ def complete(self, text, state):
+ """Hook called by readline when <tab> is pressed."""
+ n = len(text)
+ matches = []
+ for cmd in self.list :
+ if cmd[:n] == text :
+ matches.append(cmd)
+ try:
+ return matches[state]
+ except IndexError:
+ return None
+
+
+class CLIHelper:
+ """An abstract command line interface client which recognize commands
+ and provide an help system.
+ """
+
+ CMD_MAP = {'help': _("Others"),
+ 'quit': _("Others"),
+ }
+ CMD_PREFIX = ''
+
+ def __init__(self, histfile=None) :
+ self._topics = {}
+ self.commands = None
+ self._completer = Completer(self._register_commands())
+ init_readline(self._completer.complete, histfile)
+
+ def run(self):
+ """loop on user input, exit on EOF"""
+ while True:
+ try:
+ line = input('>>> ')
+ except EOFError:
+ print
+ break
+ s_line = line.strip()
+ if not s_line:
+ continue
+ args = s_line.split()
+ if args[0] in self.commands:
+ try:
+ cmd = 'do_%s' % self.commands[args[0]]
+ getattr(self, cmd)(*args[1:])
+ except EOFError:
+ break
+ except:
+ import traceback
+ traceback.print_exc()
+ else:
+ try:
+ self.handle_line(s_line)
+ except:
+ import traceback
+ traceback.print_exc()
+
+ def handle_line(self, stripped_line):
+ """Method to overload in the concrete class (should handle
+ lines which are not commands).
+ """
+ raise NotImplementedError()
+
+
+ # private methods #########################################################
+
+ def _register_commands(self):
+ """ register available commands method and return the list of
+ commands name
+ """
+ self.commands = {}
+ self._command_help = {}
+ commands = [attr[3:] for attr in dir(self) if attr[:3] == 'do_']
+ for command in commands:
+ topic = self.CMD_MAP[command]
+ help_method = getattr(self, 'help_do_%s' % command)
+ self._topics.setdefault(topic, []).append(help_method)
+ self.commands[self.CMD_PREFIX + command] = command
+ self._command_help[command] = help_method
+ return self.commands.keys()
+
+ def _print_help(self, cmd, syntax, explanation):
+ print(_('Command %s') % cmd)
+ print(_('Syntax: %s') % syntax)
+ print('\t', explanation)
+ print()
+
+
+ # predefined commands #####################################################
+
+ def do_help(self, command=None) :
+ """base input of the help system"""
+ if command in self._command_help:
+ self._print_help(*self._command_help[command])
+ elif command is None or command not in self._topics:
+ print(_("Use help <topic> or help <command>."))
+ print(_("Available topics are:"))
+ topics = sorted(self._topics.keys())
+ for topic in topics:
+ print('\t', topic)
+ print()
+ print(_("Available commands are:"))
+ commands = self.commands.keys()
+ commands.sort()
+ for command in commands:
+ print('\t', command[len(self.CMD_PREFIX):])
+
+ else:
+ print(_('Available commands about %s:') % command)
+ print
+ for command_help_method in self._topics[command]:
+ try:
+ if callable(command_help_method):
+ self._print_help(*command_help_method())
+ else:
+ self._print_help(*command_help_method)
+ except:
+ import traceback
+ traceback.print_exc()
+ print('ERROR in help method %s'% (
+ command_help_method.__name__))
+
+ help_do_help = ("help", "help [topic|command]",
+ _("print help message for the given topic/command or \
+available topics when no argument"))
+
+ def do_quit(self):
+ """quit the CLI"""
+ raise EOFError()
+
+ def help_do_quit(self):
+ return ("quit", "quit", _("quit the application"))
diff --git a/logilab/common/compat.py b/logilab/common/compat.py
new file mode 100644
index 0000000..f2eb590
--- /dev/null
+++ b/logilab/common/compat.py
@@ -0,0 +1,78 @@
+# pylint: disable=E0601,W0622,W0611
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Wrappers around some builtins introduced in python 2.3, 2.4 and
+2.5, making them available in for earlier versions of python.
+
+See another compatibility snippets from other projects:
+
+ :mod:`lib2to3.fixes`
+ :mod:`coverage.backward`
+ :mod:`unittest2.compatibility`
+"""
+
+
+__docformat__ = "restructuredtext en"
+
+import os
+import sys
+import types
+from warnings import warn
+
+# not used here, but imported to preserve API
+from six.moves import builtins
+
+if sys.version_info < (3, 0):
+ str_to_bytes = str
+ def str_encode(string, encoding):
+ if isinstance(string, unicode):
+ return string.encode(encoding)
+ return str(string)
+else:
+ def str_to_bytes(string):
+ return str.encode(string)
+ # we have to ignore the encoding in py3k to be able to write a string into a
+ # TextIOWrapper or like object (which expect an unicode string)
+ def str_encode(string, encoding):
+ return str(string)
+
+# See also http://bugs.python.org/issue11776
+if sys.version_info[0] == 3:
+ def method_type(callable, instance, klass):
+ # api change. klass is no more considered
+ return types.MethodType(callable, instance)
+else:
+ # alias types otherwise
+ method_type = types.MethodType
+
+# Pythons 2 and 3 differ on where to get StringIO
+if sys.version_info < (3, 0):
+ from cStringIO import StringIO
+ FileIO = file
+ BytesIO = StringIO
+ reload = reload
+else:
+ from io import FileIO, BytesIO, StringIO
+ from imp import reload
+
+from logilab.common.deprecation import deprecated
+
+# Other projects import these from here, keep providing them for
+# backwards compat
+any = deprecated('use builtin "any"')(any)
+all = deprecated('use builtin "all"')(all)
diff --git a/logilab/common/configuration.py b/logilab/common/configuration.py
new file mode 100644
index 0000000..b292427
--- /dev/null
+++ b/logilab/common/configuration.py
@@ -0,0 +1,1105 @@
+# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Classes to handle advanced configuration in simple to complex applications.
+
+Allows to load the configuration from a file or from command line
+options, to generate a sample configuration file or to display
+program's usage. Fills the gap between optik/optparse and ConfigParser
+by adding data types (which are also available as a standalone optik
+extension in the `optik_ext` module).
+
+
+Quick start: simplest usage
+---------------------------
+
+.. python ::
+
+ >>> import sys
+ >>> from logilab.common.configuration import Configuration
+ >>> options = [('dothis', {'type':'yn', 'default': True, 'metavar': '<y or n>'}),
+ ... ('value', {'type': 'string', 'metavar': '<string>'}),
+ ... ('multiple', {'type': 'csv', 'default': ('yop',),
+ ... 'metavar': '<comma separated values>',
+ ... 'help': 'you can also document the option'}),
+ ... ('number', {'type': 'int', 'default':2, 'metavar':'<int>'}),
+ ... ]
+ >>> config = Configuration(options=options, name='My config')
+ >>> print config['dothis']
+ True
+ >>> print config['value']
+ None
+ >>> print config['multiple']
+ ('yop',)
+ >>> print config['number']
+ 2
+ >>> print config.help()
+ Usage: [options]
+
+ Options:
+ -h, --help show this help message and exit
+ --dothis=<y or n>
+ --value=<string>
+ --multiple=<comma separated values>
+ you can also document the option [current: none]
+ --number=<int>
+
+ >>> f = open('myconfig.ini', 'w')
+ >>> f.write('''[MY CONFIG]
+ ... number = 3
+ ... dothis = no
+ ... multiple = 1,2,3
+ ... ''')
+ >>> f.close()
+ >>> config.load_file_configuration('myconfig.ini')
+ >>> print config['dothis']
+ False
+ >>> print config['value']
+ None
+ >>> print config['multiple']
+ ['1', '2', '3']
+ >>> print config['number']
+ 3
+ >>> sys.argv = ['mon prog', '--value', 'bacon', '--multiple', '4,5,6',
+ ... 'nonoptionargument']
+ >>> print config.load_command_line_configuration()
+ ['nonoptionargument']
+ >>> print config['value']
+ bacon
+ >>> config.generate_config()
+ # class for simple configurations which don't need the
+ # manager / providers model and prefer delegation to inheritance
+ #
+ # configuration values are accessible through a dict like interface
+ #
+ [MY CONFIG]
+
+ dothis=no
+
+ value=bacon
+
+ # you can also document the option
+ multiple=4,5,6
+
+ number=3
+
+ Note : starting with Python 2.7 ConfigParser is able to take into
+ account the order of occurrences of the options into a file (by
+ using an OrderedDict). If you have two options changing some common
+ state, like a 'disable-all-stuff' and a 'enable-some-stuff-a', their
+ order of appearance will be significant : the last specified in the
+ file wins. For earlier version of python and logilab.common newer
+ than 0.61 the behaviour is unspecified.
+
+"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+__all__ = ('OptionsManagerMixIn', 'OptionsProviderMixIn',
+ 'ConfigurationMixIn', 'Configuration',
+ 'OptionsManager2ConfigurationAdapter')
+
+import os
+import sys
+import re
+from os.path import exists, expanduser
+from copy import copy
+from warnings import warn
+
+from six import string_types
+from six.moves import range, configparser as cp, input
+
+from logilab.common.compat import str_encode as _encode
+from logilab.common.deprecation import deprecated
+from logilab.common.textutils import normalize_text, unquote
+from logilab.common import optik_ext
+
+OptionError = optik_ext.OptionError
+
+REQUIRED = []
+
+class UnsupportedAction(Exception):
+ """raised by set_option when it doesn't know what to do for an action"""
+
+
+def _get_encoding(encoding, stream):
+ encoding = encoding or getattr(stream, 'encoding', None)
+ if not encoding:
+ import locale
+ encoding = locale.getpreferredencoding()
+ return encoding
+
+
+# validation functions ########################################################
+
+# validators will return the validated value or raise optparse.OptionValueError
+# XXX add to documentation
+
+def choice_validator(optdict, name, value):
+ """validate and return a converted value for option of type 'choice'
+ """
+ if not value in optdict['choices']:
+ msg = "option %s: invalid value: %r, should be in %s"
+ raise optik_ext.OptionValueError(msg % (name, value, optdict['choices']))
+ return value
+
+def multiple_choice_validator(optdict, name, value):
+ """validate and return a converted value for option of type 'choice'
+ """
+ choices = optdict['choices']
+ values = optik_ext.check_csv(None, name, value)
+ for value in values:
+ if not value in choices:
+ msg = "option %s: invalid value: %r, should be in %s"
+ raise optik_ext.OptionValueError(msg % (name, value, choices))
+ return values
+
+def csv_validator(optdict, name, value):
+ """validate and return a converted value for option of type 'csv'
+ """
+ return optik_ext.check_csv(None, name, value)
+
+def yn_validator(optdict, name, value):
+ """validate and return a converted value for option of type 'yn'
+ """
+ return optik_ext.check_yn(None, name, value)
+
+def named_validator(optdict, name, value):
+ """validate and return a converted value for option of type 'named'
+ """
+ return optik_ext.check_named(None, name, value)
+
+def file_validator(optdict, name, value):
+ """validate and return a filepath for option of type 'file'"""
+ return optik_ext.check_file(None, name, value)
+
+def color_validator(optdict, name, value):
+ """validate and return a valid color for option of type 'color'"""
+ return optik_ext.check_color(None, name, value)
+
+def password_validator(optdict, name, value):
+ """validate and return a string for option of type 'password'"""
+ return optik_ext.check_password(None, name, value)
+
+def date_validator(optdict, name, value):
+ """validate and return a mx DateTime object for option of type 'date'"""
+ return optik_ext.check_date(None, name, value)
+
+def time_validator(optdict, name, value):
+ """validate and return a time object for option of type 'time'"""
+ return optik_ext.check_time(None, name, value)
+
+def bytes_validator(optdict, name, value):
+ """validate and return an integer for option of type 'bytes'"""
+ return optik_ext.check_bytes(None, name, value)
+
+
+VALIDATORS = {'string': unquote,
+ 'int': int,
+ 'float': float,
+ 'file': file_validator,
+ 'font': unquote,
+ 'color': color_validator,
+ 'regexp': re.compile,
+ 'csv': csv_validator,
+ 'yn': yn_validator,
+ 'bool': yn_validator,
+ 'named': named_validator,
+ 'password': password_validator,
+ 'date': date_validator,
+ 'time': time_validator,
+ 'bytes': bytes_validator,
+ 'choice': choice_validator,
+ 'multiple_choice': multiple_choice_validator,
+ }
+
+def _call_validator(opttype, optdict, option, value):
+ if opttype not in VALIDATORS:
+ raise Exception('Unsupported type "%s"' % opttype)
+ try:
+ return VALIDATORS[opttype](optdict, option, value)
+ except TypeError:
+ try:
+ return VALIDATORS[opttype](value)
+ except optik_ext.OptionValueError:
+ raise
+ except:
+ raise optik_ext.OptionValueError('%s value (%r) should be of type %s' %
+ (option, value, opttype))
+
+# user input functions ########################################################
+
+# user input functions will ask the user for input on stdin then validate
+# the result and return the validated value or raise optparse.OptionValueError
+# XXX add to documentation
+
+def input_password(optdict, question='password:'):
+ from getpass import getpass
+ while True:
+ value = getpass(question)
+ value2 = getpass('confirm: ')
+ if value == value2:
+ return value
+ print('password mismatch, try again')
+
+def input_string(optdict, question):
+ value = input(question).strip()
+ return value or None
+
+def _make_input_function(opttype):
+ def input_validator(optdict, question):
+ while True:
+ value = input(question)
+ if not value.strip():
+ return None
+ try:
+ return _call_validator(opttype, optdict, None, value)
+ except optik_ext.OptionValueError as ex:
+ msg = str(ex).split(':', 1)[-1].strip()
+ print('bad value: %s' % msg)
+ return input_validator
+
+INPUT_FUNCTIONS = {
+ 'string': input_string,
+ 'password': input_password,
+ }
+
+for opttype in VALIDATORS.keys():
+ INPUT_FUNCTIONS.setdefault(opttype, _make_input_function(opttype))
+
+# utility functions ############################################################
+
+def expand_default(self, option):
+ """monkey patch OptionParser.expand_default since we have a particular
+ way to handle defaults to avoid overriding values in the configuration
+ file
+ """
+ if self.parser is None or not self.default_tag:
+ return option.help
+ optname = option._long_opts[0][2:]
+ try:
+ provider = self.parser.options_manager._all_options[optname]
+ except KeyError:
+ value = None
+ else:
+ optdict = provider.get_option_def(optname)
+ optname = provider.option_attrname(optname, optdict)
+ value = getattr(provider.config, optname, optdict)
+ value = format_option_value(optdict, value)
+ if value is optik_ext.NO_DEFAULT or not value:
+ value = self.NO_DEFAULT_VALUE
+ return option.help.replace(self.default_tag, str(value))
+
+
+def _validate(value, optdict, name=''):
+ """return a validated value for an option according to its type
+
+ optional argument name is only used for error message formatting
+ """
+ try:
+ _type = optdict['type']
+ except KeyError:
+ # FIXME
+ return value
+ return _call_validator(_type, optdict, name, value)
+convert = deprecated('[0.60] convert() was renamed _validate()')(_validate)
+
+# format and output functions ##################################################
+
+def comment(string):
+ """return string as a comment"""
+ lines = [line.strip() for line in string.splitlines()]
+ return '# ' + ('%s# ' % os.linesep).join(lines)
+
+def format_time(value):
+ if not value:
+ return '0'
+ if value != int(value):
+ return '%.2fs' % value
+ value = int(value)
+ nbmin, nbsec = divmod(value, 60)
+ if nbsec:
+ return '%ss' % value
+ nbhour, nbmin_ = divmod(nbmin, 60)
+ if nbmin_:
+ return '%smin' % nbmin
+ nbday, nbhour_ = divmod(nbhour, 24)
+ if nbhour_:
+ return '%sh' % nbhour
+ return '%sd' % nbday
+
+def format_bytes(value):
+ if not value:
+ return '0'
+ if value != int(value):
+ return '%.2fB' % value
+ value = int(value)
+ prevunit = 'B'
+ for unit in ('KB', 'MB', 'GB', 'TB'):
+ next, remain = divmod(value, 1024)
+ if remain:
+ return '%s%s' % (value, prevunit)
+ prevunit = unit
+ value = next
+ return '%s%s' % (value, unit)
+
+def format_option_value(optdict, value):
+ """return the user input's value from a 'compiled' value"""
+ if isinstance(value, (list, tuple)):
+ value = ','.join(value)
+ elif isinstance(value, dict):
+ value = ','.join(['%s:%s' % (k, v) for k, v in value.items()])
+ elif hasattr(value, 'match'): # optdict.get('type') == 'regexp'
+ # compiled regexp
+ value = value.pattern
+ elif optdict.get('type') == 'yn':
+ value = value and 'yes' or 'no'
+ elif isinstance(value, string_types) and value.isspace():
+ value = "'%s'" % value
+ elif optdict.get('type') == 'time' and isinstance(value, (float, int, long)):
+ value = format_time(value)
+ elif optdict.get('type') == 'bytes' and hasattr(value, '__int__'):
+ value = format_bytes(value)
+ return value
+
+def ini_format_section(stream, section, options, encoding=None, doc=None):
+ """format an options section using the INI format"""
+ encoding = _get_encoding(encoding, stream)
+ if doc:
+ print(_encode(comment(doc), encoding), file=stream)
+ print('[%s]' % section, file=stream)
+ ini_format(stream, options, encoding)
+
+def ini_format(stream, options, encoding):
+ """format options using the INI format"""
+ for optname, optdict, value in options:
+ value = format_option_value(optdict, value)
+ help = optdict.get('help')
+ if help:
+ help = normalize_text(help, line_len=79, indent='# ')
+ print(file=stream)
+ print(_encode(help, encoding), file=stream)
+ else:
+ print(file=stream)
+ if value is None:
+ print('#%s=' % optname, file=stream)
+ else:
+ value = _encode(value, encoding).strip()
+ print('%s=%s' % (optname, value), file=stream)
+
+format_section = ini_format_section
+
+def rest_format_section(stream, section, options, encoding=None, doc=None):
+ """format an options section using as ReST formatted output"""
+ encoding = _get_encoding(encoding, stream)
+ if section:
+ print('%s\n%s' % (section, "'"*len(section)), file=stream)
+ if doc:
+ print(_encode(normalize_text(doc, line_len=79, indent=''), encoding), file=stream)
+ print(file=stream)
+ for optname, optdict, value in options:
+ help = optdict.get('help')
+ print(':%s:' % optname, file=stream)
+ if help:
+ help = normalize_text(help, line_len=79, indent=' ')
+ print(_encode(help, encoding), file=stream)
+ if value:
+ value = _encode(format_option_value(optdict, value), encoding)
+ print(file=stream)
+ print(' Default: ``%s``' % value.replace("`` ", "```` ``"), file=stream)
+
+# Options Manager ##############################################################
+
+class OptionsManagerMixIn(object):
+ """MixIn to handle a configuration from both a configuration file and
+ command line options
+ """
+
+ def __init__(self, usage, config_file=None, version=None, quiet=0):
+ self.config_file = config_file
+ self.reset_parsers(usage, version=version)
+ # list of registered options providers
+ self.options_providers = []
+ # dictionary associating option name to checker
+ self._all_options = {}
+ self._short_options = {}
+ self._nocallback_options = {}
+ self._mygroups = dict()
+ # verbosity
+ self.quiet = quiet
+ self._maxlevel = 0
+
+ def reset_parsers(self, usage='', version=None):
+ # configuration file parser
+ self.cfgfile_parser = cp.ConfigParser()
+ # command line parser
+ self.cmdline_parser = optik_ext.OptionParser(usage=usage, version=version)
+ self.cmdline_parser.options_manager = self
+ self._optik_option_attrs = set(self.cmdline_parser.option_class.ATTRS)
+
+ def register_options_provider(self, provider, own_group=True):
+ """register an options provider"""
+ assert provider.priority <= 0, "provider's priority can't be >= 0"
+ for i in range(len(self.options_providers)):
+ if provider.priority > self.options_providers[i].priority:
+ self.options_providers.insert(i, provider)
+ break
+ else:
+ self.options_providers.append(provider)
+ non_group_spec_options = [option for option in provider.options
+ if 'group' not in option[1]]
+ groups = getattr(provider, 'option_groups', ())
+ if own_group and non_group_spec_options:
+ self.add_option_group(provider.name.upper(), provider.__doc__,
+ non_group_spec_options, provider)
+ else:
+ for opt, optdict in non_group_spec_options:
+ self.add_optik_option(provider, self.cmdline_parser, opt, optdict)
+ for gname, gdoc in groups:
+ gname = gname.upper()
+ goptions = [option for option in provider.options
+ if option[1].get('group', '').upper() == gname]
+ self.add_option_group(gname, gdoc, goptions, provider)
+
+ def add_option_group(self, group_name, doc, options, provider):
+ """add an option group including the listed options
+ """
+ assert options
+ # add option group to the command line parser
+ if group_name in self._mygroups:
+ group = self._mygroups[group_name]
+ else:
+ group = optik_ext.OptionGroup(self.cmdline_parser,
+ title=group_name.capitalize())
+ self.cmdline_parser.add_option_group(group)
+ group.level = provider.level
+ self._mygroups[group_name] = group
+ # add section to the config file
+ if group_name != "DEFAULT":
+ self.cfgfile_parser.add_section(group_name)
+ # add provider's specific options
+ for opt, optdict in options:
+ self.add_optik_option(provider, group, opt, optdict)
+
+ def add_optik_option(self, provider, optikcontainer, opt, optdict):
+ if 'inputlevel' in optdict:
+ warn('[0.50] "inputlevel" in option dictionary for %s is deprecated,'
+ ' use "level"' % opt, DeprecationWarning)
+ optdict['level'] = optdict.pop('inputlevel')
+ args, optdict = self.optik_option(provider, opt, optdict)
+ option = optikcontainer.add_option(*args, **optdict)
+ self._all_options[opt] = provider
+ self._maxlevel = max(self._maxlevel, option.level or 0)
+
+ def optik_option(self, provider, opt, optdict):
+ """get our personal option definition and return a suitable form for
+ use with optik/optparse
+ """
+ optdict = copy(optdict)
+ others = {}
+ if 'action' in optdict:
+ self._nocallback_options[provider] = opt
+ else:
+ optdict['action'] = 'callback'
+ optdict['callback'] = self.cb_set_provider_option
+ # default is handled here and *must not* be given to optik if you
+ # want the whole machinery to work
+ if 'default' in optdict:
+ if ('help' in optdict
+ and optdict.get('default') is not None
+ and not optdict['action'] in ('store_true', 'store_false')):
+ optdict['help'] += ' [current: %default]'
+ del optdict['default']
+ args = ['--' + str(opt)]
+ if 'short' in optdict:
+ self._short_options[optdict['short']] = opt
+ args.append('-' + optdict['short'])
+ del optdict['short']
+ # cleanup option definition dict before giving it to optik
+ for key in list(optdict.keys()):
+ if not key in self._optik_option_attrs:
+ optdict.pop(key)
+ return args, optdict
+
+ def cb_set_provider_option(self, option, opt, value, parser):
+ """optik callback for option setting"""
+ if opt.startswith('--'):
+ # remove -- on long option
+ opt = opt[2:]
+ else:
+ # short option, get its long equivalent
+ opt = self._short_options[opt[1:]]
+ # trick since we can't set action='store_true' on options
+ if value is None:
+ value = 1
+ self.global_set_option(opt, value)
+
+ def global_set_option(self, opt, value):
+ """set option on the correct option provider"""
+ self._all_options[opt].set_option(opt, value)
+
+ def generate_config(self, stream=None, skipsections=(), encoding=None):
+ """write a configuration file according to the current configuration
+ into the given stream or stdout
+ """
+ options_by_section = {}
+ sections = []
+ for provider in self.options_providers:
+ for section, options in provider.options_by_section():
+ if section is None:
+ section = provider.name
+ if section in skipsections:
+ continue
+ options = [(n, d, v) for (n, d, v) in options
+ if d.get('type') is not None]
+ if not options:
+ continue
+ if not section in sections:
+ sections.append(section)
+ alloptions = options_by_section.setdefault(section, [])
+ alloptions += options
+ stream = stream or sys.stdout
+ encoding = _get_encoding(encoding, stream)
+ printed = False
+ for section in sections:
+ if printed:
+ print('\n', file=stream)
+ format_section(stream, section.upper(), options_by_section[section],
+ encoding)
+ printed = True
+
+ def generate_manpage(self, pkginfo, section=1, stream=None):
+ """write a man page for the current configuration into the given
+ stream or stdout
+ """
+ self._monkeypatch_expand_default()
+ try:
+ optik_ext.generate_manpage(self.cmdline_parser, pkginfo,
+ section, stream=stream or sys.stdout,
+ level=self._maxlevel)
+ finally:
+ self._unmonkeypatch_expand_default()
+
+ # initialization methods ##################################################
+
+ def load_provider_defaults(self):
+ """initialize configuration using default values"""
+ for provider in self.options_providers:
+ provider.load_defaults()
+
+ def load_file_configuration(self, config_file=None):
+ """load the configuration from file"""
+ self.read_config_file(config_file)
+ self.load_config_file()
+
+ def read_config_file(self, config_file=None):
+ """read the configuration file but do not load it (i.e. dispatching
+ values to each options provider)
+ """
+ helplevel = 1
+ while helplevel <= self._maxlevel:
+ opt = '-'.join(['long'] * helplevel) + '-help'
+ if opt in self._all_options:
+ break # already processed
+ def helpfunc(option, opt, val, p, level=helplevel):
+ print(self.help(level))
+ sys.exit(0)
+ helpmsg = '%s verbose help.' % ' '.join(['more'] * helplevel)
+ optdict = {'action' : 'callback', 'callback' : helpfunc,
+ 'help' : helpmsg}
+ provider = self.options_providers[0]
+ self.add_optik_option(provider, self.cmdline_parser, opt, optdict)
+ provider.options += ( (opt, optdict), )
+ helplevel += 1
+ if config_file is None:
+ config_file = self.config_file
+ if config_file is not None:
+ config_file = expanduser(config_file)
+ if config_file and exists(config_file):
+ parser = self.cfgfile_parser
+ parser.read([config_file])
+ # normalize sections'title
+ for sect, values in parser._sections.items():
+ if not sect.isupper() and values:
+ parser._sections[sect.upper()] = values
+ elif not self.quiet:
+ msg = 'No config file found, using default configuration'
+ print(msg, file=sys.stderr)
+ return
+
+ def input_config(self, onlysection=None, inputlevel=0, stream=None):
+ """interactively get configuration values by asking to the user and generate
+ a configuration file
+ """
+ if onlysection is not None:
+ onlysection = onlysection.upper()
+ for provider in self.options_providers:
+ for section, option, optdict in provider.all_options():
+ if onlysection is not None and section != onlysection:
+ continue
+ if not 'type' in optdict:
+ # ignore action without type (callback, store_true...)
+ continue
+ provider.input_option(option, optdict, inputlevel)
+ # now we can generate the configuration file
+ if stream is not None:
+ self.generate_config(stream)
+
+ def load_config_file(self):
+ """dispatch values previously read from a configuration file to each
+ options provider)
+ """
+ parser = self.cfgfile_parser
+ for section in parser.sections():
+ for option, value in parser.items(section):
+ try:
+ self.global_set_option(option, value)
+ except (KeyError, OptionError):
+ # TODO handle here undeclared options appearing in the config file
+ continue
+
+ def load_configuration(self, **kwargs):
+ """override configuration according to given parameters
+ """
+ for opt, opt_value in kwargs.items():
+ opt = opt.replace('_', '-')
+ provider = self._all_options[opt]
+ provider.set_option(opt, opt_value)
+
+ def load_command_line_configuration(self, args=None):
+ """override configuration according to command line parameters
+
+ return additional arguments
+ """
+ self._monkeypatch_expand_default()
+ try:
+ if args is None:
+ args = sys.argv[1:]
+ else:
+ args = list(args)
+ (options, args) = self.cmdline_parser.parse_args(args=args)
+ for provider in self._nocallback_options.keys():
+ config = provider.config
+ for attr in config.__dict__.keys():
+ value = getattr(options, attr, None)
+ if value is None:
+ continue
+ setattr(config, attr, value)
+ return args
+ finally:
+ self._unmonkeypatch_expand_default()
+
+
+ # help methods ############################################################
+
+ def add_help_section(self, title, description, level=0):
+ """add a dummy option section for help purpose """
+ group = optik_ext.OptionGroup(self.cmdline_parser,
+ title=title.capitalize(),
+ description=description)
+ group.level = level
+ self._maxlevel = max(self._maxlevel, level)
+ self.cmdline_parser.add_option_group(group)
+
+ def _monkeypatch_expand_default(self):
+ # monkey patch optik_ext to deal with our default values
+ try:
+ self.__expand_default_backup = optik_ext.HelpFormatter.expand_default
+ optik_ext.HelpFormatter.expand_default = expand_default
+ except AttributeError:
+ # python < 2.4: nothing to be done
+ pass
+ def _unmonkeypatch_expand_default(self):
+ # remove monkey patch
+ if hasattr(optik_ext.HelpFormatter, 'expand_default'):
+ # unpatch optik_ext to avoid side effects
+ optik_ext.HelpFormatter.expand_default = self.__expand_default_backup
+
+ def help(self, level=0):
+ """return the usage string for available options """
+ self.cmdline_parser.formatter.output_level = level
+ self._monkeypatch_expand_default()
+ try:
+ return self.cmdline_parser.format_help()
+ finally:
+ self._unmonkeypatch_expand_default()
+
+
+class Method(object):
+ """used to ease late binding of default method (so you can define options
+ on the class using default methods on the configuration instance)
+ """
+ def __init__(self, methname):
+ self.method = methname
+ self._inst = None
+
+ def bind(self, instance):
+ """bind the method to its instance"""
+ if self._inst is None:
+ self._inst = instance
+
+ def __call__(self, *args, **kwargs):
+ assert self._inst, 'unbound method'
+ return getattr(self._inst, self.method)(*args, **kwargs)
+
+# Options Provider #############################################################
+
+class OptionsProviderMixIn(object):
+ """Mixin to provide options to an OptionsManager"""
+
+ # those attributes should be overridden
+ priority = -1
+ name = 'default'
+ options = ()
+ level = 0
+
+ def __init__(self):
+ self.config = optik_ext.Values()
+ for option in self.options:
+ try:
+ option, optdict = option
+ except ValueError:
+ raise Exception('Bad option: %r' % option)
+ if isinstance(optdict.get('default'), Method):
+ optdict['default'].bind(self)
+ elif isinstance(optdict.get('callback'), Method):
+ optdict['callback'].bind(self)
+ self.load_defaults()
+
+ def load_defaults(self):
+ """initialize the provider using default values"""
+ for opt, optdict in self.options:
+ action = optdict.get('action')
+ if action != 'callback':
+ # callback action have no default
+ default = self.option_default(opt, optdict)
+ if default is REQUIRED:
+ continue
+ self.set_option(opt, default, action, optdict)
+
+ def option_default(self, opt, optdict=None):
+ """return the default value for an option"""
+ if optdict is None:
+ optdict = self.get_option_def(opt)
+ default = optdict.get('default')
+ if callable(default):
+ default = default()
+ return default
+
+ def option_attrname(self, opt, optdict=None):
+ """get the config attribute corresponding to opt
+ """
+ if optdict is None:
+ optdict = self.get_option_def(opt)
+ return optdict.get('dest', opt.replace('-', '_'))
+ option_name = deprecated('[0.60] OptionsProviderMixIn.option_name() was renamed to option_attrname()')(option_attrname)
+
+ def option_value(self, opt):
+ """get the current value for the given option"""
+ return getattr(self.config, self.option_attrname(opt), None)
+
+ def set_option(self, opt, value, action=None, optdict=None):
+ """method called to set an option (registered in the options list)
+ """
+ if optdict is None:
+ optdict = self.get_option_def(opt)
+ if value is not None:
+ value = _validate(value, optdict, opt)
+ if action is None:
+ action = optdict.get('action', 'store')
+ if optdict.get('type') == 'named': # XXX need specific handling
+ optname = self.option_attrname(opt, optdict)
+ currentvalue = getattr(self.config, optname, None)
+ if currentvalue:
+ currentvalue.update(value)
+ value = currentvalue
+ if action == 'store':
+ setattr(self.config, self.option_attrname(opt, optdict), value)
+ elif action in ('store_true', 'count'):
+ setattr(self.config, self.option_attrname(opt, optdict), 0)
+ elif action == 'store_false':
+ setattr(self.config, self.option_attrname(opt, optdict), 1)
+ elif action == 'append':
+ opt = self.option_attrname(opt, optdict)
+ _list = getattr(self.config, opt, None)
+ if _list is None:
+ if isinstance(value, (list, tuple)):
+ _list = value
+ elif value is not None:
+ _list = []
+ _list.append(value)
+ setattr(self.config, opt, _list)
+ elif isinstance(_list, tuple):
+ setattr(self.config, opt, _list + (value,))
+ else:
+ _list.append(value)
+ elif action == 'callback':
+ optdict['callback'](None, opt, value, None)
+ else:
+ raise UnsupportedAction(action)
+
+ def input_option(self, option, optdict, inputlevel=99):
+ default = self.option_default(option, optdict)
+ if default is REQUIRED:
+ defaultstr = '(required): '
+ elif optdict.get('level', 0) > inputlevel:
+ return
+ elif optdict['type'] == 'password' or default is None:
+ defaultstr = ': '
+ else:
+ defaultstr = '(default: %s): ' % format_option_value(optdict, default)
+ print(':%s:' % option)
+ print(optdict.get('help') or option)
+ inputfunc = INPUT_FUNCTIONS[optdict['type']]
+ value = inputfunc(optdict, defaultstr)
+ while default is REQUIRED and not value:
+ print('please specify a value')
+ value = inputfunc(optdict, '%s: ' % option)
+ if value is None and default is not None:
+ value = default
+ self.set_option(option, value, optdict=optdict)
+
+ def get_option_def(self, opt):
+ """return the dictionary defining an option given it's name"""
+ assert self.options
+ for option in self.options:
+ if option[0] == opt:
+ return option[1]
+ raise OptionError('no such option %s in section %r'
+ % (opt, self.name), opt)
+
+
+ def all_options(self):
+ """return an iterator on available options for this provider
+ option are actually described by a 3-uple:
+ (section, option name, option dictionary)
+ """
+ for section, options in self.options_by_section():
+ if section is None:
+ if self.name is None:
+ continue
+ section = self.name.upper()
+ for option, optiondict, value in options:
+ yield section, option, optiondict
+
+ def options_by_section(self):
+ """return an iterator on options grouped by section
+
+ (section, [list of (optname, optdict, optvalue)])
+ """
+ sections = {}
+ for optname, optdict in self.options:
+ sections.setdefault(optdict.get('group'), []).append(
+ (optname, optdict, self.option_value(optname)))
+ if None in sections:
+ yield None, sections.pop(None)
+ for section, options in sections.items():
+ yield section.upper(), options
+
+ def options_and_values(self, options=None):
+ if options is None:
+ options = self.options
+ for optname, optdict in options:
+ yield (optname, optdict, self.option_value(optname))
+
+# configuration ################################################################
+
+class ConfigurationMixIn(OptionsManagerMixIn, OptionsProviderMixIn):
+ """basic mixin for simple configurations which don't need the
+ manager / providers model
+ """
+ def __init__(self, *args, **kwargs):
+ if not args:
+ kwargs.setdefault('usage', '')
+ kwargs.setdefault('quiet', 1)
+ OptionsManagerMixIn.__init__(self, *args, **kwargs)
+ OptionsProviderMixIn.__init__(self)
+ if not getattr(self, 'option_groups', None):
+ self.option_groups = []
+ for option, optdict in self.options:
+ try:
+ gdef = (optdict['group'].upper(), '')
+ except KeyError:
+ continue
+ if not gdef in self.option_groups:
+ self.option_groups.append(gdef)
+ self.register_options_provider(self, own_group=False)
+
+ def register_options(self, options):
+ """add some options to the configuration"""
+ options_by_group = {}
+ for optname, optdict in options:
+ options_by_group.setdefault(optdict.get('group', self.name.upper()), []).append((optname, optdict))
+ for group, options in options_by_group.items():
+ self.add_option_group(group, None, options, self)
+ self.options += tuple(options)
+
+ def load_defaults(self):
+ OptionsProviderMixIn.load_defaults(self)
+
+ def __iter__(self):
+ return iter(self.config.__dict__.iteritems())
+
+ def __getitem__(self, key):
+ try:
+ return getattr(self.config, self.option_attrname(key))
+ except (optik_ext.OptionValueError, AttributeError):
+ raise KeyError(key)
+
+ def __setitem__(self, key, value):
+ self.set_option(key, value)
+
+ def get(self, key, default=None):
+ try:
+ return getattr(self.config, self.option_attrname(key))
+ except (OptionError, AttributeError):
+ return default
+
+
+class Configuration(ConfigurationMixIn):
+ """class for simple configurations which don't need the
+ manager / providers model and prefer delegation to inheritance
+
+ configuration values are accessible through a dict like interface
+ """
+
+ def __init__(self, config_file=None, options=None, name=None,
+ usage=None, doc=None, version=None):
+ if options is not None:
+ self.options = options
+ if name is not None:
+ self.name = name
+ if doc is not None:
+ self.__doc__ = doc
+ super(Configuration, self).__init__(config_file=config_file, usage=usage, version=version)
+
+
+class OptionsManager2ConfigurationAdapter(object):
+ """Adapt an option manager to behave like a
+ `logilab.common.configuration.Configuration` instance
+ """
+ def __init__(self, provider):
+ self.config = provider
+
+ def __getattr__(self, key):
+ return getattr(self.config, key)
+
+ def __getitem__(self, key):
+ provider = self.config._all_options[key]
+ try:
+ return getattr(provider.config, provider.option_attrname(key))
+ except AttributeError:
+ raise KeyError(key)
+
+ def __setitem__(self, key, value):
+ self.config.global_set_option(self.config.option_attrname(key), value)
+
+ def get(self, key, default=None):
+ provider = self.config._all_options[key]
+ try:
+ return getattr(provider.config, provider.option_attrname(key))
+ except AttributeError:
+ return default
+
+# other functions ##############################################################
+
+def read_old_config(newconfig, changes, configfile):
+ """initialize newconfig from a deprecated configuration file
+
+ possible changes:
+ * ('renamed', oldname, newname)
+ * ('moved', option, oldgroup, newgroup)
+ * ('typechanged', option, oldtype, newvalue)
+ """
+ # build an index of changes
+ changesindex = {}
+ for action in changes:
+ if action[0] == 'moved':
+ option, oldgroup, newgroup = action[1:]
+ changesindex.setdefault(option, []).append((action[0], oldgroup, newgroup))
+ continue
+ if action[0] == 'renamed':
+ oldname, newname = action[1:]
+ changesindex.setdefault(newname, []).append((action[0], oldname))
+ continue
+ if action[0] == 'typechanged':
+ option, oldtype, newvalue = action[1:]
+ changesindex.setdefault(option, []).append((action[0], oldtype, newvalue))
+ continue
+ if action[1] in ('added', 'removed'):
+ continue # nothing to do here
+ raise Exception('unknown change %s' % action[0])
+ # build a config object able to read the old config
+ options = []
+ for optname, optdef in newconfig.options:
+ for action in changesindex.pop(optname, ()):
+ if action[0] == 'moved':
+ oldgroup, newgroup = action[1:]
+ optdef = optdef.copy()
+ optdef['group'] = oldgroup
+ elif action[0] == 'renamed':
+ optname = action[1]
+ elif action[0] == 'typechanged':
+ oldtype = action[1]
+ optdef = optdef.copy()
+ optdef['type'] = oldtype
+ options.append((optname, optdef))
+ if changesindex:
+ raise Exception('unapplied changes: %s' % changesindex)
+ oldconfig = Configuration(options=options, name=newconfig.name)
+ # read the old config
+ oldconfig.load_file_configuration(configfile)
+ # apply values reverting changes
+ changes.reverse()
+ done = set()
+ for action in changes:
+ if action[0] == 'renamed':
+ oldname, newname = action[1:]
+ newconfig[newname] = oldconfig[oldname]
+ done.add(newname)
+ elif action[0] == 'typechanged':
+ optname, oldtype, newvalue = action[1:]
+ newconfig[optname] = newvalue
+ done.add(optname)
+ for optname, optdef in newconfig.options:
+ if optdef.get('type') and not optname in done:
+ newconfig.set_option(optname, oldconfig[optname], optdict=optdef)
+
+
+def merge_options(options, optgroup=None):
+ """preprocess a list of options and remove duplicates, returning a new list
+ (tuple actually) of options.
+
+ Options dictionaries are copied to avoid later side-effect. Also, if
+ `otpgroup` argument is specified, ensure all options are in the given group.
+ """
+ alloptions = {}
+ options = list(options)
+ for i in range(len(options)-1, -1, -1):
+ optname, optdict = options[i]
+ if optname in alloptions:
+ options.pop(i)
+ alloptions[optname].update(optdict)
+ else:
+ optdict = optdict.copy()
+ options[i] = (optname, optdict)
+ alloptions[optname] = optdict
+ if optgroup is not None:
+ alloptions[optname]['group'] = optgroup
+ return tuple(options)
diff --git a/logilab/common/contexts.py b/logilab/common/contexts.py
new file mode 100644
index 0000000..d78c327
--- /dev/null
+++ b/logilab/common/contexts.py
@@ -0,0 +1,5 @@
+from warnings import warn
+warn('logilab.common.contexts module is deprecated, use logilab.common.shellutils instead',
+ DeprecationWarning, stacklevel=1)
+
+from logilab.common.shellutils import tempfile, pushd
diff --git a/logilab/common/corbautils.py b/logilab/common/corbautils.py
new file mode 100644
index 0000000..65c301d
--- /dev/null
+++ b/logilab/common/corbautils.py
@@ -0,0 +1,117 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""A set of utility function to ease the use of OmniORBpy.
+
+
+
+
+"""
+__docformat__ = "restructuredtext en"
+
+from omniORB import CORBA, PortableServer
+import CosNaming
+
+orb = None
+
+def get_orb():
+ """
+ returns a reference to the ORB.
+ The first call to the method initialized the ORB
+ This method is mainly used internally in the module.
+ """
+
+ global orb
+ if orb is None:
+ import sys
+ orb = CORBA.ORB_init(sys.argv, CORBA.ORB_ID)
+ return orb
+
+def get_root_context():
+ """
+ returns a reference to the NameService object.
+ This method is mainly used internally in the module.
+ """
+
+ orb = get_orb()
+ nss = orb.resolve_initial_references("NameService")
+ rootContext = nss._narrow(CosNaming.NamingContext)
+ assert rootContext is not None, "Failed to narrow root naming context"
+ return rootContext
+
+def register_object_name(object, namepath):
+ """
+ Registers a object in the NamingService.
+ The name path is a list of 2-uples (id,kind) giving the path.
+
+ For instance if the path of an object is [('foo',''),('bar','')],
+ it is possible to get a reference to the object using the URL
+ 'corbaname::hostname#foo/bar'.
+ [('logilab','rootmodule'),('chatbot','application'),('chatter','server')]
+ is mapped to
+ 'corbaname::hostname#logilab.rootmodule/chatbot.application/chatter.server'
+
+ The get_object_reference() function can be used to resolve such a URL.
+ """
+ context = get_root_context()
+ for id, kind in namepath[:-1]:
+ name = [CosNaming.NameComponent(id, kind)]
+ try:
+ context = context.bind_new_context(name)
+ except CosNaming.NamingContext.AlreadyBound as ex:
+ context = context.resolve(name)._narrow(CosNaming.NamingContext)
+ assert context is not None, \
+ 'test context exists but is not a NamingContext'
+
+ id, kind = namepath[-1]
+ name = [CosNaming.NameComponent(id, kind)]
+ try:
+ context.bind(name, object._this())
+ except CosNaming.NamingContext.AlreadyBound as ex:
+ context.rebind(name, object._this())
+
+def activate_POA():
+ """
+ This methods activates the Portable Object Adapter.
+ You need to call it to enable the reception of messages in your code,
+ on both the client and the server.
+ """
+ orb = get_orb()
+ poa = orb.resolve_initial_references('RootPOA')
+ poaManager = poa._get_the_POAManager()
+ poaManager.activate()
+
+def run_orb():
+ """
+ Enters the ORB mainloop on the server.
+ You should not call this method on the client.
+ """
+ get_orb().run()
+
+def get_object_reference(url):
+ """
+ Resolves a corbaname URL to an object proxy.
+ See register_object_name() for examples URLs
+ """
+ return get_orb().string_to_object(url)
+
+def get_object_string(host, namepath):
+ """given an host name and a name path as described in register_object_name,
+ return a corba string identifier
+ """
+ strname = '/'.join(['.'.join(path_elt) for path_elt in namepath])
+ return 'corbaname::%s#%s' % (host, strname)
diff --git a/logilab/common/daemon.py b/logilab/common/daemon.py
new file mode 100644
index 0000000..40319a4
--- /dev/null
+++ b/logilab/common/daemon.py
@@ -0,0 +1,101 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""A daemonize function (for Unices)"""
+
+__docformat__ = "restructuredtext en"
+
+import os
+import errno
+import signal
+import sys
+import time
+import warnings
+
+from six.moves import range
+
+def setugid(user):
+ """Change process user and group ID
+
+ Argument is a numeric user id or a user name"""
+ try:
+ from pwd import getpwuid
+ passwd = getpwuid(int(user))
+ except ValueError:
+ from pwd import getpwnam
+ passwd = getpwnam(user)
+
+ if hasattr(os, 'initgroups'): # python >= 2.7
+ os.initgroups(passwd.pw_name, passwd.pw_gid)
+ else:
+ import ctypes
+ if ctypes.CDLL(None).initgroups(passwd.pw_name, passwd.pw_gid) < 0:
+ err = ctypes.c_int.in_dll(ctypes.pythonapi,"errno").value
+ raise OSError(err, os.strerror(err), 'initgroups')
+ os.setgid(passwd.pw_gid)
+ os.setuid(passwd.pw_uid)
+ os.environ['HOME'] = passwd.pw_dir
+
+
+def daemonize(pidfile=None, uid=None, umask=0o77):
+ """daemonize a Unix process. Set paranoid umask by default.
+
+ Return 1 in the original process, 2 in the first fork, and None for the
+ second fork (eg daemon process).
+ """
+ # http://www.faqs.org/faqs/unix-faq/programmer/faq/
+ #
+ # fork so the parent can exit
+ if os.fork(): # launch child and...
+ return 1
+ # disconnect from tty and create a new session
+ os.setsid()
+ # fork again so the parent, (the session group leader), can exit.
+ # as a non-session group leader, we can never regain a controlling
+ # terminal.
+ if os.fork(): # launch child again.
+ return 2
+ # move to the root to avoit mount pb
+ os.chdir('/')
+ # redirect standard descriptors
+ null = os.open('/dev/null', os.O_RDWR)
+ for i in range(3):
+ try:
+ os.dup2(null, i)
+ except OSError as e:
+ if e.errno != errno.EBADF:
+ raise
+ os.close(null)
+ # filter warnings
+ warnings.filterwarnings('ignore')
+ # write pid in a file
+ if pidfile:
+ # ensure the directory where the pid-file should be set exists (for
+ # instance /var/run/cubicweb may be deleted on computer restart)
+ piddir = os.path.dirname(pidfile)
+ if not os.path.exists(piddir):
+ os.makedirs(piddir)
+ f = file(pidfile, 'w')
+ f.write(str(os.getpid()))
+ f.close()
+ # set umask if specified
+ if umask is not None:
+ os.umask(umask)
+ # change process uid
+ if uid:
+ setugid(uid)
+ return None
diff --git a/logilab/common/date.py b/logilab/common/date.py
new file mode 100644
index 0000000..a093a8a
--- /dev/null
+++ b/logilab/common/date.py
@@ -0,0 +1,335 @@
+# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Date manipulation helper functions."""
+from __future__ import division
+
+__docformat__ = "restructuredtext en"
+
+import math
+import re
+import sys
+from locale import getlocale, LC_TIME
+from datetime import date, time, datetime, timedelta
+from time import strptime as time_strptime
+from calendar import monthrange, timegm
+
+from six.moves import range
+
+try:
+ from mx.DateTime import RelativeDateTime, Date, DateTimeType
+except ImportError:
+ endOfMonth = None
+ DateTimeType = datetime
+else:
+ endOfMonth = RelativeDateTime(months=1, day=-1)
+
+# NOTE: should we implement a compatibility layer between date representations
+# as we have in lgc.db ?
+
+FRENCH_FIXED_HOLIDAYS = {
+ 'jour_an': '%s-01-01',
+ 'fete_travail': '%s-05-01',
+ 'armistice1945': '%s-05-08',
+ 'fete_nat': '%s-07-14',
+ 'assomption': '%s-08-15',
+ 'toussaint': '%s-11-01',
+ 'armistice1918': '%s-11-11',
+ 'noel': '%s-12-25',
+ }
+
+FRENCH_MOBILE_HOLIDAYS = {
+ 'paques2004': '2004-04-12',
+ 'ascension2004': '2004-05-20',
+ 'pentecote2004': '2004-05-31',
+
+ 'paques2005': '2005-03-28',
+ 'ascension2005': '2005-05-05',
+ 'pentecote2005': '2005-05-16',
+
+ 'paques2006': '2006-04-17',
+ 'ascension2006': '2006-05-25',
+ 'pentecote2006': '2006-06-05',
+
+ 'paques2007': '2007-04-09',
+ 'ascension2007': '2007-05-17',
+ 'pentecote2007': '2007-05-28',
+
+ 'paques2008': '2008-03-24',
+ 'ascension2008': '2008-05-01',
+ 'pentecote2008': '2008-05-12',
+
+ 'paques2009': '2009-04-13',
+ 'ascension2009': '2009-05-21',
+ 'pentecote2009': '2009-06-01',
+
+ 'paques2010': '2010-04-05',
+ 'ascension2010': '2010-05-13',
+ 'pentecote2010': '2010-05-24',
+
+ 'paques2011': '2011-04-25',
+ 'ascension2011': '2011-06-02',
+ 'pentecote2011': '2011-06-13',
+
+ 'paques2012': '2012-04-09',
+ 'ascension2012': '2012-05-17',
+ 'pentecote2012': '2012-05-28',
+ }
+
+# XXX this implementation cries for multimethod dispatching
+
+def get_step(dateobj, nbdays=1):
+ # assume date is either a python datetime or a mx.DateTime object
+ if isinstance(dateobj, date):
+ return ONEDAY * nbdays
+ return nbdays # mx.DateTime is ok with integers
+
+def datefactory(year, month, day, sampledate):
+ # assume date is either a python datetime or a mx.DateTime object
+ if isinstance(sampledate, datetime):
+ return datetime(year, month, day)
+ if isinstance(sampledate, date):
+ return date(year, month, day)
+ return Date(year, month, day)
+
+def weekday(dateobj):
+ # assume date is either a python datetime or a mx.DateTime object
+ if isinstance(dateobj, date):
+ return dateobj.weekday()
+ return dateobj.day_of_week
+
+def str2date(datestr, sampledate):
+ # NOTE: datetime.strptime is not an option until we drop py2.4 compat
+ year, month, day = [int(chunk) for chunk in datestr.split('-')]
+ return datefactory(year, month, day, sampledate)
+
+def days_between(start, end):
+ if isinstance(start, date):
+ delta = end - start
+ # datetime.timedelta.days is always an integer (floored)
+ if delta.seconds:
+ return delta.days + 1
+ return delta.days
+ else:
+ return int(math.ceil((end - start).days))
+
+def get_national_holidays(begin, end):
+ """return french national days off between begin and end"""
+ begin = datefactory(begin.year, begin.month, begin.day, begin)
+ end = datefactory(end.year, end.month, end.day, end)
+ holidays = [str2date(datestr, begin)
+ for datestr in FRENCH_MOBILE_HOLIDAYS.values()]
+ for year in range(begin.year, end.year+1):
+ for datestr in FRENCH_FIXED_HOLIDAYS.values():
+ date = str2date(datestr % year, begin)
+ if date not in holidays:
+ holidays.append(date)
+ return [day for day in holidays if begin <= day < end]
+
+def add_days_worked(start, days):
+ """adds date but try to only take days worked into account"""
+ step = get_step(start)
+ weeks, plus = divmod(days, 5)
+ end = start + ((weeks * 7) + plus) * step
+ if weekday(end) >= 5: # saturday or sunday
+ end += (2 * step)
+ end += len([x for x in get_national_holidays(start, end + step)
+ if weekday(x) < 5]) * step
+ if weekday(end) >= 5: # saturday or sunday
+ end += (2 * step)
+ return end
+
+def nb_open_days(start, end):
+ assert start <= end
+ step = get_step(start)
+ days = days_between(start, end)
+ weeks, plus = divmod(days, 7)
+ if weekday(start) > weekday(end):
+ plus -= 2
+ elif weekday(end) == 6:
+ plus -= 1
+ open_days = weeks * 5 + plus
+ nb_week_holidays = len([x for x in get_national_holidays(start, end+step)
+ if weekday(x) < 5 and x < end])
+ open_days -= nb_week_holidays
+ if open_days < 0:
+ return 0
+ return open_days
+
+def date_range(begin, end, incday=None, incmonth=None):
+ """yields each date between begin and end
+
+ :param begin: the start date
+ :param end: the end date
+ :param incr: the step to use to iterate over dates. Default is
+ one day.
+ :param include: None (means no exclusion) or a function taking a
+ date as parameter, and returning True if the date
+ should be included.
+
+ When using mx datetime, you should *NOT* use incmonth argument, use instead
+ oneDay, oneHour, oneMinute, oneSecond, oneWeek or endOfMonth (to enumerate
+ months) as `incday` argument
+ """
+ assert not (incday and incmonth)
+ begin = todate(begin)
+ end = todate(end)
+ if incmonth:
+ while begin < end:
+ yield begin
+ begin = next_month(begin, incmonth)
+ else:
+ incr = get_step(begin, incday or 1)
+ while begin < end:
+ yield begin
+ begin += incr
+
+# makes py datetime usable #####################################################
+
+ONEDAY = timedelta(days=1)
+ONEWEEK = timedelta(days=7)
+
+try:
+ strptime = datetime.strptime
+except AttributeError: # py < 2.5
+ from time import strptime as time_strptime
+ def strptime(value, format):
+ return datetime(*time_strptime(value, format)[:6])
+
+def strptime_time(value, format='%H:%M'):
+ return time(*time_strptime(value, format)[3:6])
+
+def todate(somedate):
+ """return a date from a date (leaving unchanged) or a datetime"""
+ if isinstance(somedate, datetime):
+ return date(somedate.year, somedate.month, somedate.day)
+ assert isinstance(somedate, (date, DateTimeType)), repr(somedate)
+ return somedate
+
+def totime(somedate):
+ """return a time from a time (leaving unchanged), date or datetime"""
+ # XXX mx compat
+ if not isinstance(somedate, time):
+ return time(somedate.hour, somedate.minute, somedate.second)
+ assert isinstance(somedate, (time)), repr(somedate)
+ return somedate
+
+def todatetime(somedate):
+ """return a date from a date (leaving unchanged) or a datetime"""
+ # take care, datetime is a subclass of date
+ if isinstance(somedate, datetime):
+ return somedate
+ assert isinstance(somedate, (date, DateTimeType)), repr(somedate)
+ return datetime(somedate.year, somedate.month, somedate.day)
+
+def datetime2ticks(somedate):
+ return timegm(somedate.timetuple()) * 1000
+
+def ticks2datetime(ticks):
+ miliseconds, microseconds = divmod(ticks, 1000)
+ try:
+ return datetime.fromtimestamp(miliseconds)
+ except (ValueError, OverflowError):
+ epoch = datetime.fromtimestamp(0)
+ nb_days, seconds = divmod(int(miliseconds), 86400)
+ delta = timedelta(nb_days, seconds=seconds, microseconds=microseconds)
+ try:
+ return epoch + delta
+ except (ValueError, OverflowError):
+ raise
+
+def days_in_month(somedate):
+ return monthrange(somedate.year, somedate.month)[1]
+
+def days_in_year(somedate):
+ feb = date(somedate.year, 2, 1)
+ if days_in_month(feb) == 29:
+ return 366
+ else:
+ return 365
+
+def previous_month(somedate, nbmonth=1):
+ while nbmonth:
+ somedate = first_day(somedate) - ONEDAY
+ nbmonth -= 1
+ return somedate
+
+def next_month(somedate, nbmonth=1):
+ while nbmonth:
+ somedate = last_day(somedate) + ONEDAY
+ nbmonth -= 1
+ return somedate
+
+def first_day(somedate):
+ return date(somedate.year, somedate.month, 1)
+
+def last_day(somedate):
+ return date(somedate.year, somedate.month, days_in_month(somedate))
+
+def ustrftime(somedate, fmt='%Y-%m-%d'):
+ """like strftime, but returns a unicode string instead of an encoded
+ string which may be problematic with localized date.
+ """
+ if sys.version_info >= (3, 3):
+ # datetime.date.strftime() supports dates since year 1 in Python >=3.3.
+ return somedate.strftime(fmt)
+ else:
+ try:
+ if sys.version_info < (3, 0):
+ encoding = getlocale(LC_TIME)[1] or 'ascii'
+ return unicode(somedate.strftime(str(fmt)), encoding)
+ else:
+ return somedate.strftime(fmt)
+ except ValueError:
+ if somedate.year >= 1900:
+ raise
+ # datetime is not happy with dates before 1900
+ # we try to work around this, assuming a simple
+ # format string
+ fields = {'Y': somedate.year,
+ 'm': somedate.month,
+ 'd': somedate.day,
+ }
+ if isinstance(somedate, datetime):
+ fields.update({'H': somedate.hour,
+ 'M': somedate.minute,
+ 'S': somedate.second})
+ fmt = re.sub('%([YmdHMS])', r'%(\1)02d', fmt)
+ return unicode(fmt) % fields
+
+def utcdatetime(dt):
+ if dt.tzinfo is None:
+ return dt
+ return (dt.replace(tzinfo=None) - dt.utcoffset())
+
+def utctime(dt):
+ if dt.tzinfo is None:
+ return dt
+ return (dt + dt.utcoffset() + dt.dst()).replace(tzinfo=None)
+
+def datetime_to_seconds(date):
+ """return the number of seconds since the begining of the day for that date
+ """
+ return date.second+60*date.minute + 3600*date.hour
+
+def timedelta_to_days(delta):
+ """return the time delta as a number of seconds"""
+ return delta.days + delta.seconds / (3600*24)
+
+def timedelta_to_seconds(delta):
+ """return the time delta as a fraction of days"""
+ return delta.days*(3600*24) + delta.seconds
diff --git a/logilab/common/dbf.py b/logilab/common/dbf.py
new file mode 100644
index 0000000..ab142b2
--- /dev/null
+++ b/logilab/common/dbf.py
@@ -0,0 +1,231 @@
+# -*- coding: utf-8 -*-
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""This is a DBF reader which reads Visual Fox Pro DBF format with Memo field
+
+Usage:
+
+>>> rec = readDbf('test.dbf')
+>>> for line in rec:
+>>> print line['name']
+
+
+:date: 13/07/2007
+
+http://www.physics.ox.ac.uk/users/santoso/Software.Repository.html
+page says code is "available as is without any warranty or support".
+"""
+from __future__ import print_function
+
+import struct
+import os, os.path
+import sys
+import csv
+import tempfile
+
+from six.moves import range
+
+class Dbase:
+ def __init__(self):
+ self.fdb = None
+ self.fmemo = None
+ self.db_data = None
+ self.memo_data = None
+ self.fields = None
+ self.num_records = 0
+ self.header = None
+ self.memo_file = ''
+ self.memo_header = None
+ self.memo_block_size = 0
+ self.memo_header_len = 0
+
+ def _drop_after_NULL(self, txt):
+ for i in range(0, len(txt)):
+ if ord(struct.unpack('c', txt[i])[0])==0:
+ return txt[:i]
+ return txt
+
+ def _reverse_endian(self, num):
+ if not len(num):
+ return 0
+ val = struct.unpack('<L', num)
+ val = struct.pack('>L', val[0])
+ val = struct.unpack('>L', val)
+ return val[0]
+
+ def _assign_ids(self, lst, ids):
+ result = {}
+ idx = 0
+ for item in lst:
+ id = ids[idx]
+ result[id] = item
+ idx += 1
+ return result
+
+ def open(self, db_name):
+ filesize = os.path.getsize(db_name)
+ if filesize <= 68:
+ raise IOError('The file is not large enough to be a dbf file')
+
+ self.fdb = open(db_name, 'rb')
+
+ self.memo_file = ''
+ if os.path.isfile(db_name[0:-1] + 't'):
+ self.memo_file = db_name[0:-1] + 't'
+ elif os.path.isfile(db_name[0:-3] + 'fpt'):
+ self.memo_file = db_name[0:-3] + 'fpt'
+
+ if self.memo_file:
+ #Read memo file
+ self.fmemo = open(self.memo_file, 'rb')
+ self.memo_data = self.fmemo.read()
+ self.memo_header = self._assign_ids(struct.unpack('>6x1H', self.memo_data[:8]), ['Block size'])
+ block_size = self.memo_header['Block size']
+ if not block_size:
+ block_size = 512
+ self.memo_block_size = block_size
+ self.memo_header_len = block_size
+ memo_size = os.path.getsize(self.memo_file)
+
+ #Start reading data file
+ data = self.fdb.read(32)
+ self.header = self._assign_ids(struct.unpack('<B 3B L 2H 20x', data), ['id', 'Year', 'Month', 'Day', '# of Records', 'Header Size', 'Record Size'])
+ self.header['id'] = hex(self.header['id'])
+
+ self.num_records = self.header['# of Records']
+ data = self.fdb.read(self.header['Header Size']-34)
+ self.fields = {}
+ x = 0
+ header_pattern = '<11s c 4x B B 14x'
+ ids = ['Field Name', 'Field Type', 'Field Length', 'Field Precision']
+ pattern_len = 32
+ for offset in range(0, len(data), 32):
+ if ord(data[offset])==0x0d:
+ break
+ x += 1
+ data_subset = data[offset: offset+pattern_len]
+ if len(data_subset) < pattern_len:
+ data_subset += ' '*(pattern_len-len(data_subset))
+ self.fields[x] = self._assign_ids(struct.unpack(header_pattern, data_subset), ids)
+ self.fields[x]['Field Name'] = self._drop_after_NULL(self.fields[x]['Field Name'])
+
+ self.fdb.read(3)
+ if self.header['# of Records']:
+ data_size = (self.header['# of Records'] * self.header['Record Size']) - 1
+ self.db_data = self.fdb.read(data_size)
+ else:
+ self.db_data = ''
+ self.row_format = '<'
+ self.row_ids = []
+ self.row_len = 0
+ for key in self.fields:
+ field = self.fields[key]
+ self.row_format += '%ds ' % (field['Field Length'])
+ self.row_ids.append(field['Field Name'])
+ self.row_len += field['Field Length']
+
+ def close(self):
+ if self.fdb:
+ self.fdb.close()
+ if self.fmemo:
+ self.fmemo.close()
+
+ def get_numrecords(self):
+ return self.num_records
+
+ def get_record_with_names(self, rec_no):
+ """
+ This function accept record number from 0 to N-1
+ """
+ if rec_no < 0 or rec_no > self.num_records:
+ raise Exception('Unable to extract data outside the range')
+
+ offset = self.header['Record Size'] * rec_no
+ data = self.db_data[offset:offset+self.row_len]
+ record = self._assign_ids(struct.unpack(self.row_format, data), self.row_ids)
+
+ if self.memo_file:
+ for key in self.fields:
+ field = self.fields[key]
+ f_type = field['Field Type']
+ f_name = field['Field Name']
+ c_data = record[f_name]
+
+ if f_type=='M' or f_type=='G' or f_type=='B' or f_type=='P':
+ c_data = self._reverse_endian(c_data)
+ if c_data:
+ record[f_name] = self.read_memo(c_data-1).strip()
+ else:
+ record[f_name] = c_data.strip()
+ return record
+
+ def read_memo_record(self, num, in_length):
+ """
+ Read the record of given number. The second parameter is the length of
+ the record to read. It can be undefined, meaning read the whole record,
+ and it can be negative, meaning at most the length
+ """
+ if in_length < 0:
+ in_length = -self.memo_block_size
+
+ offset = self.memo_header_len + num * self.memo_block_size
+ self.fmemo.seek(offset)
+ if in_length<0:
+ in_length = -in_length
+ if in_length==0:
+ return ''
+ return self.fmemo.read(in_length)
+
+ def read_memo(self, num):
+ result = ''
+ buffer = self.read_memo_record(num, -1)
+ if len(buffer)<=0:
+ return ''
+ length = struct.unpack('>L', buffer[4:4+4])[0] + 8
+
+ block_size = self.memo_block_size
+ if length < block_size:
+ return buffer[8:length]
+ rest_length = length - block_size
+ rest_data = self.read_memo_record(num+1, rest_length)
+ if len(rest_data)<=0:
+ return ''
+ return buffer[8:] + rest_data
+
+def readDbf(filename):
+ """
+ Read the DBF file specified by the filename and
+ return the records as a list of dictionary.
+
+ :param: filename File name of the DBF
+ :return: List of rows
+ """
+ db = Dbase()
+ db.open(filename)
+ num = db.get_numrecords()
+ rec = []
+ for i in range(0, num):
+ record = db.get_record_with_names(i)
+ rec.append(record)
+ db.close()
+ return rec
+
+if __name__=='__main__':
+ rec = readDbf('dbf/sptable.dbf')
+ for line in rec:
+ print('%s %s' % (line['GENUS'].strip(), line['SPECIES'].strip()))
diff --git a/logilab/common/debugger.py b/logilab/common/debugger.py
new file mode 100644
index 0000000..1f540a1
--- /dev/null
+++ b/logilab/common/debugger.py
@@ -0,0 +1,214 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Customized version of pdb's default debugger.
+
+- sets up a history file
+- uses ipython if available to colorize lines of code
+- overrides list command to search for current block instead
+ of using 5 lines of context
+
+
+
+
+"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+try:
+ import readline
+except ImportError:
+ readline = None
+import os
+import os.path as osp
+import sys
+from pdb import Pdb
+import inspect
+
+from logilab.common.compat import StringIO
+
+try:
+ from IPython import PyColorize
+except ImportError:
+ def colorize(source, *args):
+ """fallback colorize function"""
+ return source
+ def colorize_source(source, *args):
+ return source
+else:
+ def colorize(source, start_lineno, curlineno):
+ """colorize and annotate source with linenos
+ (as in pdb's list command)
+ """
+ parser = PyColorize.Parser()
+ output = StringIO()
+ parser.format(source, output)
+ annotated = []
+ for index, line in enumerate(output.getvalue().splitlines()):
+ lineno = index + start_lineno
+ if lineno == curlineno:
+ annotated.append('%4s\t->\t%s' % (lineno, line))
+ else:
+ annotated.append('%4s\t\t%s' % (lineno, line))
+ return '\n'.join(annotated)
+
+ def colorize_source(source):
+ """colorize given source"""
+ parser = PyColorize.Parser()
+ output = StringIO()
+ parser.format(source, output)
+ return output.getvalue()
+
+
+def getsource(obj):
+ """Return the text of the source code for an object.
+
+ The argument may be a module, class, method, function, traceback, frame,
+ or code object. The source code is returned as a single string. An
+ IOError is raised if the source code cannot be retrieved."""
+ lines, lnum = inspect.getsourcelines(obj)
+ return ''.join(lines), lnum
+
+
+################################################################
+class Debugger(Pdb):
+ """custom debugger
+
+ - sets up a history file
+ - uses ipython if available to colorize lines of code
+ - overrides list command to search for current block instead
+ of using 5 lines of context
+ """
+ def __init__(self, tcbk=None):
+ Pdb.__init__(self)
+ self.reset()
+ if tcbk:
+ while tcbk.tb_next is not None:
+ tcbk = tcbk.tb_next
+ self._tcbk = tcbk
+ self._histfile = os.path.expanduser("~/.pdbhist")
+
+ def setup_history_file(self):
+ """if readline is available, read pdb history file
+ """
+ if readline is not None:
+ try:
+ # XXX try..except shouldn't be necessary
+ # read_history_file() can accept None
+ readline.read_history_file(self._histfile)
+ except IOError:
+ pass
+
+ def start(self):
+ """starts the interactive mode"""
+ self.interaction(self._tcbk.tb_frame, self._tcbk)
+
+ def setup(self, frame, tcbk):
+ """setup hook: set up history file"""
+ self.setup_history_file()
+ Pdb.setup(self, frame, tcbk)
+
+ def set_quit(self):
+ """quit hook: save commands in the history file"""
+ if readline is not None:
+ readline.write_history_file(self._histfile)
+ Pdb.set_quit(self)
+
+ def complete_p(self, text, line, begin_idx, end_idx):
+ """provide variable names completion for the ``p`` command"""
+ namespace = dict(self.curframe.f_globals)
+ namespace.update(self.curframe.f_locals)
+ if '.' in text:
+ return self.attr_matches(text, namespace)
+ return [varname for varname in namespace if varname.startswith(text)]
+
+
+ def attr_matches(self, text, namespace):
+ """implementation coming from rlcompleter.Completer.attr_matches
+ Compute matches when text contains a dot.
+
+ Assuming the text is of the form NAME.NAME....[NAME], and is
+ evaluatable in self.namespace, it will be evaluated and its attributes
+ (as revealed by dir()) are used as possible completions. (For class
+ instances, class members are also considered.)
+
+ WARNING: this can still invoke arbitrary C code, if an object
+ with a __getattr__ hook is evaluated.
+
+ """
+ import re
+ m = re.match(r"(\w+(\.\w+)*)\.(\w*)", text)
+ if not m:
+ return
+ expr, attr = m.group(1, 3)
+ object = eval(expr, namespace)
+ words = dir(object)
+ if hasattr(object, '__class__'):
+ words.append('__class__')
+ words = words + self.get_class_members(object.__class__)
+ matches = []
+ n = len(attr)
+ for word in words:
+ if word[:n] == attr and word != "__builtins__":
+ matches.append("%s.%s" % (expr, word))
+ return matches
+
+ def get_class_members(self, klass):
+ """implementation coming from rlcompleter.get_class_members"""
+ ret = dir(klass)
+ if hasattr(klass, '__bases__'):
+ for base in klass.__bases__:
+ ret = ret + self.get_class_members(base)
+ return ret
+
+ ## specific / overridden commands
+ def do_list(self, arg):
+ """overrides default list command to display the surrounding block
+ instead of 5 lines of context
+ """
+ self.lastcmd = 'list'
+ if not arg:
+ try:
+ source, start_lineno = getsource(self.curframe)
+ print(colorize(''.join(source), start_lineno,
+ self.curframe.f_lineno))
+ except KeyboardInterrupt:
+ pass
+ except IOError:
+ Pdb.do_list(self, arg)
+ else:
+ Pdb.do_list(self, arg)
+ do_l = do_list
+
+ def do_open(self, arg):
+ """opens source file corresponding to the current stack level"""
+ filename = self.curframe.f_code.co_filename
+ lineno = self.curframe.f_lineno
+ cmd = 'emacsclient --no-wait +%s %s' % (lineno, filename)
+ os.system(cmd)
+
+ do_o = do_open
+
+def pm():
+ """use our custom debugger"""
+ dbg = Debugger(sys.last_traceback)
+ dbg.start()
+
+def set_trace():
+ Debugger().set_trace(sys._getframe().f_back)
diff --git a/logilab/common/decorators.py b/logilab/common/decorators.py
new file mode 100644
index 0000000..beafa20
--- /dev/null
+++ b/logilab/common/decorators.py
@@ -0,0 +1,281 @@
+# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+""" A few useful function/method decorators. """
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+import sys
+import types
+from time import clock, time
+from inspect import isgeneratorfunction, getargspec
+
+from logilab.common.compat import method_type
+
+# XXX rewrite so we can use the decorator syntax when keyarg has to be specified
+
+class cached_decorator(object):
+ def __init__(self, cacheattr=None, keyarg=None):
+ self.cacheattr = cacheattr
+ self.keyarg = keyarg
+ def __call__(self, callableobj=None):
+ assert not isgeneratorfunction(callableobj), \
+ 'cannot cache generator function: %s' % callableobj
+ if len(getargspec(callableobj).args) == 1 or self.keyarg == 0:
+ cache = _SingleValueCache(callableobj, self.cacheattr)
+ elif self.keyarg:
+ cache = _MultiValuesKeyArgCache(callableobj, self.keyarg, self.cacheattr)
+ else:
+ cache = _MultiValuesCache(callableobj, self.cacheattr)
+ return cache.closure()
+
+class _SingleValueCache(object):
+ def __init__(self, callableobj, cacheattr=None):
+ self.callable = callableobj
+ if cacheattr is None:
+ self.cacheattr = '_%s_cache_' % callableobj.__name__
+ else:
+ assert cacheattr != callableobj.__name__
+ self.cacheattr = cacheattr
+
+ def __call__(__me, self, *args):
+ try:
+ return self.__dict__[__me.cacheattr]
+ except KeyError:
+ value = __me.callable(self, *args)
+ setattr(self, __me.cacheattr, value)
+ return value
+
+ def closure(self):
+ def wrapped(*args, **kwargs):
+ return self.__call__(*args, **kwargs)
+ wrapped.cache_obj = self
+ try:
+ wrapped.__doc__ = self.callable.__doc__
+ wrapped.__name__ = self.callable.__name__
+ except:
+ pass
+ return wrapped
+
+ def clear(self, holder):
+ holder.__dict__.pop(self.cacheattr, None)
+
+
+class _MultiValuesCache(_SingleValueCache):
+ def _get_cache(self, holder):
+ try:
+ _cache = holder.__dict__[self.cacheattr]
+ except KeyError:
+ _cache = {}
+ setattr(holder, self.cacheattr, _cache)
+ return _cache
+
+ def __call__(__me, self, *args, **kwargs):
+ _cache = __me._get_cache(self)
+ try:
+ return _cache[args]
+ except KeyError:
+ _cache[args] = __me.callable(self, *args)
+ return _cache[args]
+
+class _MultiValuesKeyArgCache(_MultiValuesCache):
+ def __init__(self, callableobj, keyarg, cacheattr=None):
+ super(_MultiValuesKeyArgCache, self).__init__(callableobj, cacheattr)
+ self.keyarg = keyarg
+
+ def __call__(__me, self, *args, **kwargs):
+ _cache = __me._get_cache(self)
+ key = args[__me.keyarg-1]
+ try:
+ return _cache[key]
+ except KeyError:
+ _cache[key] = __me.callable(self, *args, **kwargs)
+ return _cache[key]
+
+
+def cached(callableobj=None, keyarg=None, **kwargs):
+ """Simple decorator to cache result of method call."""
+ kwargs['keyarg'] = keyarg
+ decorator = cached_decorator(**kwargs)
+ if callableobj is None:
+ return decorator
+ else:
+ return decorator(callableobj)
+
+
+class cachedproperty(object):
+ """ Provides a cached property equivalent to the stacking of
+ @cached and @property, but more efficient.
+
+ After first usage, the <property_name> becomes part of the object's
+ __dict__. Doing:
+
+ del obj.<property_name> empties the cache.
+
+ Idea taken from the pyramid_ framework and the mercurial_ project.
+
+ .. _pyramid: http://pypi.python.org/pypi/pyramid
+ .. _mercurial: http://pypi.python.org/pypi/Mercurial
+ """
+ __slots__ = ('wrapped',)
+
+ def __init__(self, wrapped):
+ try:
+ wrapped.__name__
+ except AttributeError:
+ raise TypeError('%s must have a __name__ attribute' %
+ wrapped)
+ self.wrapped = wrapped
+
+ @property
+ def __doc__(self):
+ doc = getattr(self.wrapped, '__doc__', None)
+ return ('<wrapped by the cachedproperty decorator>%s'
+ % ('\n%s' % doc if doc else ''))
+
+ def __get__(self, inst, objtype=None):
+ if inst is None:
+ return self
+ val = self.wrapped(inst)
+ setattr(inst, self.wrapped.__name__, val)
+ return val
+
+
+def get_cache_impl(obj, funcname):
+ cls = obj.__class__
+ member = getattr(cls, funcname)
+ if isinstance(member, property):
+ member = member.fget
+ return member.cache_obj
+
+def clear_cache(obj, funcname):
+ """Clear a cache handled by the :func:`cached` decorator. If 'x' class has
+ @cached on its method `foo`, type
+
+ >>> clear_cache(x, 'foo')
+
+ to purge this method's cache on the instance.
+ """
+ get_cache_impl(obj, funcname).clear(obj)
+
+def copy_cache(obj, funcname, cacheobj):
+ """Copy cache for <funcname> from cacheobj to obj."""
+ cacheattr = get_cache_impl(obj, funcname).cacheattr
+ try:
+ setattr(obj, cacheattr, cacheobj.__dict__[cacheattr])
+ except KeyError:
+ pass
+
+
+class wproperty(object):
+ """Simple descriptor expecting to take a modifier function as first argument
+ and looking for a _<function name> to retrieve the attribute.
+ """
+ def __init__(self, setfunc):
+ self.setfunc = setfunc
+ self.attrname = '_%s' % setfunc.__name__
+
+ def __set__(self, obj, value):
+ self.setfunc(obj, value)
+
+ def __get__(self, obj, cls):
+ assert obj is not None
+ return getattr(obj, self.attrname)
+
+
+class classproperty(object):
+ """this is a simple property-like class but for class attributes.
+ """
+ def __init__(self, get):
+ self.get = get
+ def __get__(self, inst, cls):
+ return self.get(cls)
+
+
+class iclassmethod(object):
+ '''Descriptor for method which should be available as class method if called
+ on the class or instance method if called on an instance.
+ '''
+ def __init__(self, func):
+ self.func = func
+ def __get__(self, instance, objtype):
+ if instance is None:
+ return method_type(self.func, objtype, objtype.__class__)
+ return method_type(self.func, instance, objtype)
+ def __set__(self, instance, value):
+ raise AttributeError("can't set attribute")
+
+
+def timed(f):
+ def wrap(*args, **kwargs):
+ t = time()
+ c = clock()
+ res = f(*args, **kwargs)
+ print('%s clock: %.9f / time: %.9f' % (f.__name__,
+ clock() - c, time() - t))
+ return res
+ return wrap
+
+
+def locked(acquire, release):
+ """Decorator taking two methods to acquire/release a lock as argument,
+ returning a decorator function which will call the inner method after
+ having called acquire(self) et will call release(self) afterwards.
+ """
+ def decorator(f):
+ def wrapper(self, *args, **kwargs):
+ acquire(self)
+ try:
+ return f(self, *args, **kwargs)
+ finally:
+ release(self)
+ return wrapper
+ return decorator
+
+
+def monkeypatch(klass, methodname=None):
+ """Decorator extending class with the decorated callable. This is basically
+ a syntactic sugar vs class assignment.
+
+ >>> class A:
+ ... pass
+ >>> @monkeypatch(A)
+ ... def meth(self):
+ ... return 12
+ ...
+ >>> a = A()
+ >>> a.meth()
+ 12
+ >>> @monkeypatch(A, 'foo')
+ ... def meth(self):
+ ... return 12
+ ...
+ >>> a.foo()
+ 12
+ """
+ def decorator(func):
+ try:
+ name = methodname or func.__name__
+ except AttributeError:
+ raise AttributeError('%s has no __name__ attribute: '
+ 'you should provide an explicit `methodname`'
+ % func)
+ setattr(klass, name, func)
+ return func
+ return decorator
diff --git a/logilab/common/deprecation.py b/logilab/common/deprecation.py
new file mode 100644
index 0000000..1c81b63
--- /dev/null
+++ b/logilab/common/deprecation.py
@@ -0,0 +1,189 @@
+# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Deprecation utilities."""
+
+__docformat__ = "restructuredtext en"
+
+import sys
+from warnings import warn
+
+from logilab.common.changelog import Version
+
+
+class DeprecationWrapper(object):
+ """proxy to print a warning on access to any attribute of the wrapped object
+ """
+ def __init__(self, proxied, msg=None):
+ self._proxied = proxied
+ self._msg = msg
+
+ def __getattr__(self, attr):
+ warn(self._msg, DeprecationWarning, stacklevel=2)
+ return getattr(self._proxied, attr)
+
+ def __setattr__(self, attr, value):
+ if attr in ('_proxied', '_msg'):
+ self.__dict__[attr] = value
+ else:
+ warn(self._msg, DeprecationWarning, stacklevel=2)
+ setattr(self._proxied, attr, value)
+
+
+class DeprecationManager(object):
+ """Manage the deprecation message handling. Messages are dropped for
+ versions more recent than the 'compatible' version. Example::
+
+ deprecator = deprecation.DeprecationManager("module_name")
+ deprecator.compatibility('1.3')
+
+ deprecator.warn('1.2', "message.")
+
+ @deprecator.deprecated('1.2', 'Message')
+ def any_func():
+ pass
+
+ class AnyClass(object):
+ __metaclass__ = deprecator.class_deprecated('1.2')
+ """
+ def __init__(self, module_name=None):
+ """
+ """
+ self.module_name = module_name
+ self.compatible_version = None
+
+ def compatibility(self, compatible_version):
+ """Set the compatible version.
+ """
+ self.compatible_version = Version(compatible_version)
+
+ def deprecated(self, version=None, reason=None, stacklevel=2, name=None, doc=None):
+ """Display a deprecation message only if the version is older than the
+ compatible version.
+ """
+ def decorator(func):
+ message = reason or 'The function "%s" is deprecated'
+ if '%s' in message:
+ message %= func.__name__
+ def wrapped(*args, **kwargs):
+ self.warn(version, message, stacklevel+1)
+ return func(*args, **kwargs)
+ return wrapped
+ return decorator
+
+ def class_deprecated(self, version=None):
+ class metaclass(type):
+ """metaclass to print a warning on instantiation of a deprecated class"""
+
+ def __call__(cls, *args, **kwargs):
+ msg = getattr(cls, "__deprecation_warning__",
+ "%(cls)s is deprecated") % {'cls': cls.__name__}
+ self.warn(version, msg, stacklevel=3)
+ return type.__call__(cls, *args, **kwargs)
+ return metaclass
+
+ def moved(self, version, modpath, objname):
+ """use to tell that a callable has been moved to a new module.
+
+ It returns a callable wrapper, so that when its called a warning is printed
+ telling where the object can be found, import is done (and not before) and
+ the actual object is called.
+
+ NOTE: the usage is somewhat limited on classes since it will fail if the
+ wrapper is use in a class ancestors list, use the `class_moved` function
+ instead (which has no lazy import feature though).
+ """
+ def callnew(*args, **kwargs):
+ from logilab.common.modutils import load_module_from_name
+ message = "object %s has been moved to module %s" % (objname, modpath)
+ self.warn(version, message)
+ m = load_module_from_name(modpath)
+ return getattr(m, objname)(*args, **kwargs)
+ return callnew
+
+ def class_renamed(self, version, old_name, new_class, message=None):
+ clsdict = {}
+ if message is None:
+ message = '%s is deprecated, use %s' % (old_name, new_class.__name__)
+ clsdict['__deprecation_warning__'] = message
+ try:
+ # new-style class
+ return self.class_deprecated(version)(old_name, (new_class,), clsdict)
+ except (NameError, TypeError):
+ # old-style class
+ warn = self.warn
+ class DeprecatedClass(new_class):
+ """FIXME: There might be a better way to handle old/new-style class
+ """
+ def __init__(self, *args, **kwargs):
+ warn(version, message, stacklevel=3)
+ new_class.__init__(self, *args, **kwargs)
+ return DeprecatedClass
+
+ def class_moved(self, version, new_class, old_name=None, message=None):
+ """nice wrapper around class_renamed when a class has been moved into
+ another module
+ """
+ if old_name is None:
+ old_name = new_class.__name__
+ if message is None:
+ message = 'class %s is now available as %s.%s' % (
+ old_name, new_class.__module__, new_class.__name__)
+ return self.class_renamed(version, old_name, new_class, message)
+
+ def warn(self, version=None, reason="", stacklevel=2):
+ """Display a deprecation message only if the version is older than the
+ compatible version.
+ """
+ if (self.compatible_version is None
+ or version is None
+ or Version(version) < self.compatible_version):
+ if self.module_name and version:
+ reason = '[%s %s] %s' % (self.module_name, version, reason)
+ elif self.module_name:
+ reason = '[%s] %s' % (self.module_name, reason)
+ elif version:
+ reason = '[%s] %s' % (version, reason)
+ warn(reason, DeprecationWarning, stacklevel=stacklevel)
+
+_defaultdeprecator = DeprecationManager()
+
+def deprecated(reason=None, stacklevel=2, name=None, doc=None):
+ return _defaultdeprecator.deprecated(None, reason, stacklevel, name, doc)
+
+class_deprecated = _defaultdeprecator.class_deprecated()
+
+def moved(modpath, objname):
+ return _defaultdeprecator.moved(None, modpath, objname)
+moved.__doc__ = _defaultdeprecator.moved.__doc__
+
+def class_renamed(old_name, new_class, message=None):
+ """automatically creates a class which fires a DeprecationWarning
+ when instantiated.
+
+ >>> Set = class_renamed('Set', set, 'Set is now replaced by set')
+ >>> s = Set()
+ sample.py:57: DeprecationWarning: Set is now replaced by set
+ s = Set()
+ >>>
+ """
+ return _defaultdeprecator.class_renamed(None, old_name, new_class, message)
+
+def class_moved(new_class, old_name=None, message=None):
+ return _defaultdeprecator.class_moved(None, new_class, old_name, message)
+class_moved.__doc__ = _defaultdeprecator.class_moved.__doc__
+
diff --git a/logilab/common/fileutils.py b/logilab/common/fileutils.py
new file mode 100644
index 0000000..b30cf5f
--- /dev/null
+++ b/logilab/common/fileutils.py
@@ -0,0 +1,404 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""File and file-path manipulation utilities.
+
+:group path manipulation: first_level_directory, relative_path, is_binary,\
+get_by_ext, remove_dead_links
+:group file manipulation: norm_read, norm_open, lines, stream_lines, lines,\
+write_open_mode, ensure_fs_mode, export
+:sort: path manipulation, file manipulation
+"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+import sys
+import shutil
+import mimetypes
+from os.path import isabs, isdir, islink, split, exists, normpath, join
+from os.path import abspath
+from os import sep, mkdir, remove, listdir, stat, chmod, walk
+from stat import ST_MODE, S_IWRITE
+
+from logilab.common import STD_BLACKLIST as BASE_BLACKLIST, IGNORED_EXTENSIONS
+from logilab.common.shellutils import find
+from logilab.common.deprecation import deprecated
+from logilab.common.compat import FileIO
+
+def first_level_directory(path):
+ """Return the first level directory of a path.
+
+ >>> first_level_directory('home/syt/work')
+ 'home'
+ >>> first_level_directory('/home/syt/work')
+ '/'
+ >>> first_level_directory('work')
+ 'work'
+ >>>
+
+ :type path: str
+ :param path: the path for which we want the first level directory
+
+ :rtype: str
+ :return: the first level directory appearing in `path`
+ """
+ head, tail = split(path)
+ while head and tail:
+ head, tail = split(head)
+ if tail:
+ return tail
+ # path was absolute, head is the fs root
+ return head
+
+def abspath_listdir(path):
+ """Lists path's content using absolute paths.
+
+ >>> os.listdir('/home')
+ ['adim', 'alf', 'arthur', 'auc']
+ >>> abspath_listdir('/home')
+ ['/home/adim', '/home/alf', '/home/arthur', '/home/auc']
+ """
+ path = abspath(path)
+ return [join(path, filename) for filename in listdir(path)]
+
+
+def is_binary(filename):
+ """Return true if filename may be a binary file, according to it's
+ extension.
+
+ :type filename: str
+ :param filename: the name of the file
+
+ :rtype: bool
+ :return:
+ true if the file is a binary file (actually if it's mime type
+ isn't beginning by text/)
+ """
+ try:
+ return not mimetypes.guess_type(filename)[0].startswith('text')
+ except AttributeError:
+ return 1
+
+
+def write_open_mode(filename):
+ """Return the write mode that should used to open file.
+
+ :type filename: str
+ :param filename: the name of the file
+
+ :rtype: str
+ :return: the mode that should be use to open the file ('w' or 'wb')
+ """
+ if is_binary(filename):
+ return 'wb'
+ return 'w'
+
+
+def ensure_fs_mode(filepath, desired_mode=S_IWRITE):
+ """Check that the given file has the given mode(s) set, else try to
+ set it.
+
+ :type filepath: str
+ :param filepath: path of the file
+
+ :type desired_mode: int
+ :param desired_mode:
+ ORed flags describing the desired mode. Use constants from the
+ `stat` module for file permission's modes
+ """
+ mode = stat(filepath)[ST_MODE]
+ if not mode & desired_mode:
+ chmod(filepath, mode | desired_mode)
+
+
+# XXX (syt) unused? kill?
+class ProtectedFile(FileIO):
+ """A special file-object class that automatically does a 'chmod +w' when
+ needed.
+
+ XXX: for now, the way it is done allows 'normal file-objects' to be
+ created during the ProtectedFile object lifetime.
+ One way to circumvent this would be to chmod / unchmod on each
+ write operation.
+
+ One other way would be to :
+
+ - catch the IOError in the __init__
+
+ - if IOError, then create a StringIO object
+
+ - each write operation writes in this StringIO object
+
+ - on close()/del(), write/append the StringIO content to the file and
+ do the chmod only once
+ """
+ def __init__(self, filepath, mode):
+ self.original_mode = stat(filepath)[ST_MODE]
+ self.mode_changed = False
+ if mode in ('w', 'a', 'wb', 'ab'):
+ if not self.original_mode & S_IWRITE:
+ chmod(filepath, self.original_mode | S_IWRITE)
+ self.mode_changed = True
+ FileIO.__init__(self, filepath, mode)
+
+ def _restore_mode(self):
+ """restores the original mode if needed"""
+ if self.mode_changed:
+ chmod(self.name, self.original_mode)
+ # Don't re-chmod in case of several restore
+ self.mode_changed = False
+
+ def close(self):
+ """restore mode before closing"""
+ self._restore_mode()
+ FileIO.close(self)
+
+ def __del__(self):
+ if not self.closed:
+ self.close()
+
+
+class UnresolvableError(Exception):
+ """Exception raised by relative path when it's unable to compute relative
+ path between two paths.
+ """
+
+def relative_path(from_file, to_file):
+ """Try to get a relative path from `from_file` to `to_file`
+ (path will be absolute if to_file is an absolute file). This function
+ is useful to create link in `from_file` to `to_file`. This typical use
+ case is used in this function description.
+
+ If both files are relative, they're expected to be relative to the same
+ directory.
+
+ >>> relative_path( from_file='toto/index.html', to_file='index.html')
+ '../index.html'
+ >>> relative_path( from_file='index.html', to_file='toto/index.html')
+ 'toto/index.html'
+ >>> relative_path( from_file='tutu/index.html', to_file='toto/index.html')
+ '../toto/index.html'
+ >>> relative_path( from_file='toto/index.html', to_file='/index.html')
+ '/index.html'
+ >>> relative_path( from_file='/toto/index.html', to_file='/index.html')
+ '../index.html'
+ >>> relative_path( from_file='/toto/index.html', to_file='/toto/summary.html')
+ 'summary.html'
+ >>> relative_path( from_file='index.html', to_file='index.html')
+ ''
+ >>> relative_path( from_file='/index.html', to_file='toto/index.html')
+ Traceback (most recent call last):
+ File "<string>", line 1, in ?
+ File "<stdin>", line 37, in relative_path
+ UnresolvableError
+ >>> relative_path( from_file='/index.html', to_file='/index.html')
+ ''
+ >>>
+
+ :type from_file: str
+ :param from_file: source file (where links will be inserted)
+
+ :type to_file: str
+ :param to_file: target file (on which links point)
+
+ :raise UnresolvableError: if it has been unable to guess a correct path
+
+ :rtype: str
+ :return: the relative path of `to_file` from `from_file`
+ """
+ from_file = normpath(from_file)
+ to_file = normpath(to_file)
+ if from_file == to_file:
+ return ''
+ if isabs(to_file):
+ if not isabs(from_file):
+ return to_file
+ elif isabs(from_file):
+ raise UnresolvableError()
+ from_parts = from_file.split(sep)
+ to_parts = to_file.split(sep)
+ idem = 1
+ result = []
+ while len(from_parts) > 1:
+ dirname = from_parts.pop(0)
+ if idem and len(to_parts) > 1 and dirname == to_parts[0]:
+ to_parts.pop(0)
+ else:
+ idem = 0
+ result.append('..')
+ result += to_parts
+ return sep.join(result)
+
+
+def norm_read(path):
+ """Return the content of the file with normalized line feeds.
+
+ :type path: str
+ :param path: path to the file to read
+
+ :rtype: str
+ :return: the content of the file with normalized line feeds
+ """
+ return open(path, 'U').read()
+norm_read = deprecated("use \"open(path, 'U').read()\"")(norm_read)
+
+def norm_open(path):
+ """Return a stream for a file with content with normalized line feeds.
+
+ :type path: str
+ :param path: path to the file to open
+
+ :rtype: file or StringIO
+ :return: the opened file with normalized line feeds
+ """
+ return open(path, 'U')
+norm_open = deprecated("use \"open(path, 'U')\"")(norm_open)
+
+def lines(path, comments=None):
+ """Return a list of non empty lines in the file located at `path`.
+
+ :type path: str
+ :param path: path to the file
+
+ :type comments: str or None
+ :param comments:
+ optional string which can be used to comment a line in the file
+ (i.e. lines starting with this string won't be returned)
+
+ :rtype: list
+ :return:
+ a list of stripped line in the file, without empty and commented
+ lines
+
+ :warning: at some point this function will probably return an iterator
+ """
+ stream = open(path, 'U')
+ result = stream_lines(stream, comments)
+ stream.close()
+ return result
+
+
+def stream_lines(stream, comments=None):
+ """Return a list of non empty lines in the given `stream`.
+
+ :type stream: object implementing 'xreadlines' or 'readlines'
+ :param stream: file like object
+
+ :type comments: str or None
+ :param comments:
+ optional string which can be used to comment a line in the file
+ (i.e. lines starting with this string won't be returned)
+
+ :rtype: list
+ :return:
+ a list of stripped line in the file, without empty and commented
+ lines
+
+ :warning: at some point this function will probably return an iterator
+ """
+ try:
+ readlines = stream.xreadlines
+ except AttributeError:
+ readlines = stream.readlines
+ result = []
+ for line in readlines():
+ line = line.strip()
+ if line and (comments is None or not line.startswith(comments)):
+ result.append(line)
+ return result
+
+
+def export(from_dir, to_dir,
+ blacklist=BASE_BLACKLIST, ignore_ext=IGNORED_EXTENSIONS,
+ verbose=0):
+ """Make a mirror of `from_dir` in `to_dir`, omitting directories and
+ files listed in the black list or ending with one of the given
+ extensions.
+
+ :type from_dir: str
+ :param from_dir: directory to export
+
+ :type to_dir: str
+ :param to_dir: destination directory
+
+ :type blacklist: list or tuple
+ :param blacklist:
+ list of files or directories to ignore, default to the content of
+ `BASE_BLACKLIST`
+
+ :type ignore_ext: list or tuple
+ :param ignore_ext:
+ list of extensions to ignore, default to the content of
+ `IGNORED_EXTENSIONS`
+
+ :type verbose: bool
+ :param verbose:
+ flag indicating whether information about exported files should be
+ printed to stderr, default to False
+ """
+ try:
+ mkdir(to_dir)
+ except OSError:
+ pass # FIXME we should use "exists" if the point is about existing dir
+ # else (permission problems?) shouldn't return / raise ?
+ for directory, dirnames, filenames in walk(from_dir):
+ for norecurs in blacklist:
+ try:
+ dirnames.remove(norecurs)
+ except ValueError:
+ continue
+ for dirname in dirnames:
+ src = join(directory, dirname)
+ dest = to_dir + src[len(from_dir):]
+ if isdir(src):
+ if not exists(dest):
+ mkdir(dest)
+ for filename in filenames:
+ # don't include binary files
+ # endswith does not accept tuple in 2.4
+ if any([filename.endswith(ext) for ext in ignore_ext]):
+ continue
+ src = join(directory, filename)
+ dest = to_dir + src[len(from_dir):]
+ if verbose:
+ print(src, '->', dest, file=sys.stderr)
+ if exists(dest):
+ remove(dest)
+ shutil.copy2(src, dest)
+
+
+def remove_dead_links(directory, verbose=0):
+ """Recursively traverse directory and remove all dead links.
+
+ :type directory: str
+ :param directory: directory to cleanup
+
+ :type verbose: bool
+ :param verbose:
+ flag indicating whether information about deleted links should be
+ printed to stderr, default to False
+ """
+ for dirpath, dirname, filenames in walk(directory):
+ for filename in dirnames + filenames:
+ src = join(dirpath, filename)
+ if islink(src) and not exists(src):
+ if verbose:
+ print('remove dead link', src)
+ remove(src)
+
diff --git a/logilab/common/graph.py b/logilab/common/graph.py
new file mode 100644
index 0000000..cef1c98
--- /dev/null
+++ b/logilab/common/graph.py
@@ -0,0 +1,282 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Graph manipulation utilities.
+
+(dot generation adapted from pypy/translator/tool/make_dot.py)
+"""
+
+__docformat__ = "restructuredtext en"
+
+__metaclass__ = type
+
+import os.path as osp
+import os
+import sys
+import tempfile
+import codecs
+import errno
+
+def escape(value):
+ """Make <value> usable in a dot file."""
+ lines = [line.replace('"', '\\"') for line in value.split('\n')]
+ data = '\\l'.join(lines)
+ return '\\n' + data
+
+def target_info_from_filename(filename):
+ """Transforms /some/path/foo.png into ('/some/path', 'foo.png', 'png')."""
+ basename = osp.basename(filename)
+ storedir = osp.dirname(osp.abspath(filename))
+ target = filename.split('.')[-1]
+ return storedir, basename, target
+
+
+class DotBackend:
+ """Dot File backend."""
+ def __init__(self, graphname, rankdir=None, size=None, ratio=None,
+ charset='utf-8', renderer='dot', additionnal_param={}):
+ self.graphname = graphname
+ self.renderer = renderer
+ self.lines = []
+ self._source = None
+ self.emit("digraph %s {" % normalize_node_id(graphname))
+ if rankdir:
+ self.emit('rankdir=%s' % rankdir)
+ if ratio:
+ self.emit('ratio=%s' % ratio)
+ if size:
+ self.emit('size="%s"' % size)
+ if charset:
+ assert charset.lower() in ('utf-8', 'iso-8859-1', 'latin1'), \
+ 'unsupported charset %s' % charset
+ self.emit('charset="%s"' % charset)
+ for param in sorted(additionnal_param.items()):
+ self.emit('='.join(param))
+
+ def get_source(self):
+ """returns self._source"""
+ if self._source is None:
+ self.emit("}\n")
+ self._source = '\n'.join(self.lines)
+ del self.lines
+ return self._source
+
+ source = property(get_source)
+
+ def generate(self, outputfile=None, dotfile=None, mapfile=None):
+ """Generates a graph file.
+
+ :param outputfile: filename and path [defaults to graphname.png]
+ :param dotfile: filename and path [defaults to graphname.dot]
+
+ :rtype: str
+ :return: a path to the generated file
+ """
+ import subprocess # introduced in py 2.4
+ name = self.graphname
+ if not dotfile:
+ # if 'outputfile' is a dot file use it as 'dotfile'
+ if outputfile and outputfile.endswith(".dot"):
+ dotfile = outputfile
+ else:
+ dotfile = '%s.dot' % name
+ if outputfile is not None:
+ storedir, basename, target = target_info_from_filename(outputfile)
+ if target != "dot":
+ pdot, dot_sourcepath = tempfile.mkstemp(".dot", name)
+ os.close(pdot)
+ else:
+ dot_sourcepath = osp.join(storedir, dotfile)
+ else:
+ target = 'png'
+ pdot, dot_sourcepath = tempfile.mkstemp(".dot", name)
+ ppng, outputfile = tempfile.mkstemp(".png", name)
+ os.close(pdot)
+ os.close(ppng)
+ pdot = codecs.open(dot_sourcepath, 'w', encoding='utf8')
+ pdot.write(self.source)
+ pdot.close()
+ if target != 'dot':
+ if sys.platform == 'win32':
+ use_shell = True
+ else:
+ use_shell = False
+ try:
+ if mapfile:
+ subprocess.call([self.renderer, '-Tcmapx', '-o', mapfile, '-T', target, dot_sourcepath, '-o', outputfile],
+ shell=use_shell)
+ else:
+ subprocess.call([self.renderer, '-T', target,
+ dot_sourcepath, '-o', outputfile],
+ shell=use_shell)
+ except OSError as e:
+ if e.errno == errno.ENOENT:
+ e.strerror = 'File not found: {0}'.format(self.renderer)
+ raise
+ os.unlink(dot_sourcepath)
+ return outputfile
+
+ def emit(self, line):
+ """Adds <line> to final output."""
+ self.lines.append(line)
+
+ def emit_edge(self, name1, name2, **props):
+ """emit an edge from <name1> to <name2>.
+ edge properties: see http://www.graphviz.org/doc/info/attrs.html
+ """
+ attrs = ['%s="%s"' % (prop, value) for prop, value in props.items()]
+ n_from, n_to = normalize_node_id(name1), normalize_node_id(name2)
+ self.emit('%s -> %s [%s];' % (n_from, n_to, ', '.join(sorted(attrs))) )
+
+ def emit_node(self, name, **props):
+ """emit a node with given properties.
+ node properties: see http://www.graphviz.org/doc/info/attrs.html
+ """
+ attrs = ['%s="%s"' % (prop, value) for prop, value in props.items()]
+ self.emit('%s [%s];' % (normalize_node_id(name), ', '.join(sorted(attrs))))
+
+def normalize_node_id(nid):
+ """Returns a suitable DOT node id for `nid`."""
+ return '"%s"' % nid
+
+class GraphGenerator:
+ def __init__(self, backend):
+ # the backend is responsible to output the graph in a particular format
+ self.backend = backend
+
+ # XXX doesn't like space in outpufile / mapfile
+ def generate(self, visitor, propshdlr, outputfile=None, mapfile=None):
+ # the visitor
+ # the property handler is used to get node and edge properties
+ # according to the graph and to the backend
+ self.propshdlr = propshdlr
+ for nodeid, node in visitor.nodes():
+ props = propshdlr.node_properties(node)
+ self.backend.emit_node(nodeid, **props)
+ for subjnode, objnode, edge in visitor.edges():
+ props = propshdlr.edge_properties(edge, subjnode, objnode)
+ self.backend.emit_edge(subjnode, objnode, **props)
+ return self.backend.generate(outputfile=outputfile, mapfile=mapfile)
+
+
+class UnorderableGraph(Exception):
+ pass
+
+def ordered_nodes(graph):
+ """takes a dependency graph dict as arguments and return an ordered tuple of
+ nodes starting with nodes without dependencies and up to the outermost node.
+
+ If there is some cycle in the graph, :exc:`UnorderableGraph` will be raised.
+
+ Also the given graph dict will be emptied.
+ """
+ # check graph consistency
+ cycles = get_cycles(graph)
+ if cycles:
+ cycles = '\n'.join([' -> '.join(cycle) for cycle in cycles])
+ raise UnorderableGraph('cycles in graph: %s' % cycles)
+ vertices = set(graph)
+ to_vertices = set()
+ for edges in graph.values():
+ to_vertices |= set(edges)
+ missing_vertices = to_vertices - vertices
+ if missing_vertices:
+ raise UnorderableGraph('missing vertices: %s' % ', '.join(missing_vertices))
+ # order vertices
+ order = []
+ order_set = set()
+ old_len = None
+ while graph:
+ if old_len == len(graph):
+ raise UnorderableGraph('unknown problem with %s' % graph)
+ old_len = len(graph)
+ deps_ok = []
+ for node, node_deps in graph.items():
+ for dep in node_deps:
+ if dep not in order_set:
+ break
+ else:
+ deps_ok.append(node)
+ order.append(deps_ok)
+ order_set |= set(deps_ok)
+ for node in deps_ok:
+ del graph[node]
+ result = []
+ for grp in reversed(order):
+ result.extend(sorted(grp))
+ return tuple(result)
+
+
+def get_cycles(graph_dict, vertices=None):
+ '''given a dictionary representing an ordered graph (i.e. key are vertices
+ and values is a list of destination vertices representing edges), return a
+ list of detected cycles
+ '''
+ if not graph_dict:
+ return ()
+ result = []
+ if vertices is None:
+ vertices = graph_dict.keys()
+ for vertice in vertices:
+ _get_cycles(graph_dict, [], set(), result, vertice)
+ return result
+
+def _get_cycles(graph_dict, path, visited, result, vertice):
+ """recursive function doing the real work for get_cycles"""
+ if vertice in path:
+ cycle = [vertice]
+ for node in path[::-1]:
+ if node == vertice:
+ break
+ cycle.insert(0, node)
+ # make a canonical representation
+ start_from = min(cycle)
+ index = cycle.index(start_from)
+ cycle = cycle[index:] + cycle[0:index]
+ # append it to result if not already in
+ if not cycle in result:
+ result.append(cycle)
+ return
+ path.append(vertice)
+ try:
+ for node in graph_dict[vertice]:
+ # don't check already visited nodes again
+ if node not in visited:
+ _get_cycles(graph_dict, path, visited, result, node)
+ visited.add(node)
+ except KeyError:
+ pass
+ path.pop()
+
+def has_path(graph_dict, fromnode, tonode, path=None):
+ """generic function taking a simple graph definition as a dictionary, with
+ node has key associated to a list of nodes directly reachable from it.
+
+ Return None if no path exists to go from `fromnode` to `tonode`, else the
+ first path found (as a list including the destination node at last)
+ """
+ if path is None:
+ path = []
+ elif fromnode in path:
+ return None
+ path.append(fromnode)
+ for destnode in graph_dict[fromnode]:
+ if destnode == tonode or has_path(graph_dict, destnode, tonode, path):
+ return path[1:] + [tonode]
+ path.pop()
+ return None
+
diff --git a/logilab/common/interface.py b/logilab/common/interface.py
new file mode 100644
index 0000000..3ea4ab7
--- /dev/null
+++ b/logilab/common/interface.py
@@ -0,0 +1,71 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Bases class for interfaces to provide 'light' interface handling.
+
+ TODO:
+ _ implements a check method which check that an object implements the
+ interface
+ _ Attribute objects
+
+ This module requires at least python 2.2
+"""
+__docformat__ = "restructuredtext en"
+
+
+class Interface(object):
+ """Base class for interfaces."""
+ def is_implemented_by(cls, instance):
+ return implements(instance, cls)
+ is_implemented_by = classmethod(is_implemented_by)
+
+
+def implements(obj, interface):
+ """Return true if the give object (maybe an instance or class) implements
+ the interface.
+ """
+ kimplements = getattr(obj, '__implements__', ())
+ if not isinstance(kimplements, (list, tuple)):
+ kimplements = (kimplements,)
+ for implementedinterface in kimplements:
+ if issubclass(implementedinterface, interface):
+ return True
+ return False
+
+
+def extend(klass, interface, _recurs=False):
+ """Add interface to klass'__implements__ if not already implemented in.
+
+ If klass is subclassed, ensure subclasses __implements__ it as well.
+
+ NOTE: klass should be e new class.
+ """
+ if not implements(klass, interface):
+ try:
+ kimplements = klass.__implements__
+ kimplementsklass = type(kimplements)
+ kimplements = list(kimplements)
+ except AttributeError:
+ kimplementsklass = tuple
+ kimplements = []
+ kimplements.append(interface)
+ klass.__implements__ = kimplementsklass(kimplements)
+ for subklass in klass.__subclasses__():
+ extend(subklass, interface, _recurs=True)
+ elif _recurs:
+ for subklass in klass.__subclasses__():
+ extend(subklass, interface, _recurs=True)
diff --git a/logilab/common/logging_ext.py b/logilab/common/logging_ext.py
new file mode 100644
index 0000000..3b6a580
--- /dev/null
+++ b/logilab/common/logging_ext.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Extends the logging module from the standard library."""
+
+__docformat__ = "restructuredtext en"
+
+import os
+import sys
+import logging
+
+from six import string_types
+
+from logilab.common.textutils import colorize_ansi
+
+
+def set_log_methods(cls, logger):
+ """bind standard logger's methods as methods on the class"""
+ cls.__logger = logger
+ for attr in ('debug', 'info', 'warning', 'error', 'critical', 'exception'):
+ setattr(cls, attr, getattr(logger, attr))
+
+
+def xxx_cyan(record):
+ if 'XXX' in record.message:
+ return 'cyan'
+
+class ColorFormatter(logging.Formatter):
+ """
+ A color Formatter for the logging standard module.
+
+ By default, colorize CRITICAL and ERROR in red, WARNING in orange, INFO in
+ green and DEBUG in yellow.
+
+ self.colors is customizable via the 'color' constructor argument (dictionary).
+
+ self.colorfilters is a list of functions that get the LogRecord
+ and return a color name or None.
+ """
+
+ def __init__(self, fmt=None, datefmt=None, colors=None):
+ logging.Formatter.__init__(self, fmt, datefmt)
+ self.colorfilters = []
+ self.colors = {'CRITICAL': 'red',
+ 'ERROR': 'red',
+ 'WARNING': 'magenta',
+ 'INFO': 'green',
+ 'DEBUG': 'yellow',
+ }
+ if colors is not None:
+ assert isinstance(colors, dict)
+ self.colors.update(colors)
+
+ def format(self, record):
+ msg = logging.Formatter.format(self, record)
+ if record.levelname in self.colors:
+ color = self.colors[record.levelname]
+ return colorize_ansi(msg, color)
+ else:
+ for cf in self.colorfilters:
+ color = cf(record)
+ if color:
+ return colorize_ansi(msg, color)
+ return msg
+
+def set_color_formatter(logger=None, **kw):
+ """
+ Install a color formatter on the 'logger'. If not given, it will
+ defaults to the default logger.
+
+ Any additional keyword will be passed as-is to the ColorFormatter
+ constructor.
+ """
+ if logger is None:
+ logger = logging.getLogger()
+ if not logger.handlers:
+ logging.basicConfig()
+ format_msg = logger.handlers[0].formatter._fmt
+ fmt = ColorFormatter(format_msg, **kw)
+ fmt.colorfilters.append(xxx_cyan)
+ logger.handlers[0].setFormatter(fmt)
+
+
+LOG_FORMAT = '%(asctime)s - (%(name)s) %(levelname)s: %(message)s'
+LOG_DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
+
+def get_handler(debug=False, syslog=False, logfile=None, rotation_parameters=None):
+ """get an apropriate handler according to given parameters"""
+ if os.environ.get('APYCOT_ROOT'):
+ handler = logging.StreamHandler(sys.stdout)
+ if debug:
+ handler = logging.StreamHandler()
+ elif logfile is None:
+ if syslog:
+ from logging import handlers
+ handler = handlers.SysLogHandler()
+ else:
+ handler = logging.StreamHandler()
+ else:
+ try:
+ if rotation_parameters is None:
+ if os.name == 'posix' and sys.version_info >= (2, 6):
+ from logging.handlers import WatchedFileHandler
+ handler = WatchedFileHandler(logfile)
+ else:
+ handler = logging.FileHandler(logfile)
+ else:
+ from logging.handlers import TimedRotatingFileHandler
+ handler = TimedRotatingFileHandler(
+ logfile, **rotation_parameters)
+ except IOError:
+ handler = logging.StreamHandler()
+ return handler
+
+def get_threshold(debug=False, logthreshold=None):
+ if logthreshold is None:
+ if debug:
+ logthreshold = logging.DEBUG
+ else:
+ logthreshold = logging.ERROR
+ elif isinstance(logthreshold, string_types):
+ logthreshold = getattr(logging, THRESHOLD_MAP.get(logthreshold,
+ logthreshold))
+ return logthreshold
+
+def _colorable_terminal():
+ isatty = hasattr(sys.__stdout__, 'isatty') and sys.__stdout__.isatty()
+ if not isatty:
+ return False
+ if os.name == 'nt':
+ try:
+ from colorama import init as init_win32_colors
+ except ImportError:
+ return False
+ init_win32_colors()
+ return True
+
+def get_formatter(logformat=LOG_FORMAT, logdateformat=LOG_DATE_FORMAT):
+ if _colorable_terminal():
+ fmt = ColorFormatter(logformat, logdateformat)
+ def col_fact(record):
+ if 'XXX' in record.message:
+ return 'cyan'
+ if 'kick' in record.message:
+ return 'red'
+ fmt.colorfilters.append(col_fact)
+ else:
+ fmt = logging.Formatter(logformat, logdateformat)
+ return fmt
+
+def init_log(debug=False, syslog=False, logthreshold=None, logfile=None,
+ logformat=LOG_FORMAT, logdateformat=LOG_DATE_FORMAT, fmt=None,
+ rotation_parameters=None, handler=None):
+ """init the log service"""
+ logger = logging.getLogger()
+ if handler is None:
+ handler = get_handler(debug, syslog, logfile, rotation_parameters)
+ # only addHandler and removeHandler method while I would like a setHandler
+ # method, so do it this way :$
+ logger.handlers = [handler]
+ logthreshold = get_threshold(debug, logthreshold)
+ logger.setLevel(logthreshold)
+ if fmt is None:
+ if debug:
+ fmt = get_formatter(logformat=logformat, logdateformat=logdateformat)
+ else:
+ fmt = logging.Formatter(logformat, logdateformat)
+ handler.setFormatter(fmt)
+ return handler
+
+# map logilab.common.logger thresholds to logging thresholds
+THRESHOLD_MAP = {'LOG_DEBUG': 'DEBUG',
+ 'LOG_INFO': 'INFO',
+ 'LOG_NOTICE': 'INFO',
+ 'LOG_WARN': 'WARNING',
+ 'LOG_WARNING': 'WARNING',
+ 'LOG_ERR': 'ERROR',
+ 'LOG_ERROR': 'ERROR',
+ 'LOG_CRIT': 'CRITICAL',
+ }
diff --git a/logilab/common/modutils.py b/logilab/common/modutils.py
new file mode 100644
index 0000000..a426a3a
--- /dev/null
+++ b/logilab/common/modutils.py
@@ -0,0 +1,702 @@
+# -*- coding: utf-8 -*-
+# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Python modules manipulation utility functions.
+
+:type PY_SOURCE_EXTS: tuple(str)
+:var PY_SOURCE_EXTS: list of possible python source file extension
+
+:type STD_LIB_DIR: str
+:var STD_LIB_DIR: directory where standard modules are located
+
+:type BUILTIN_MODULES: dict
+:var BUILTIN_MODULES: dictionary with builtin module names has key
+"""
+
+__docformat__ = "restructuredtext en"
+
+import sys
+import os
+from os.path import splitext, join, abspath, isdir, dirname, exists, basename
+from imp import find_module, load_module, C_BUILTIN, PY_COMPILED, PKG_DIRECTORY
+from distutils.sysconfig import get_config_var, get_python_lib, get_python_version
+from distutils.errors import DistutilsPlatformError
+
+from six.moves import range
+
+try:
+ import zipimport
+except ImportError:
+ zipimport = None
+
+ZIPFILE = object()
+
+from logilab.common import STD_BLACKLIST, _handle_blacklist
+
+# Notes about STD_LIB_DIR
+# Consider arch-specific installation for STD_LIB_DIR definition
+# :mod:`distutils.sysconfig` contains to much hardcoded values to rely on
+#
+# :see: `Problems with /usr/lib64 builds <http://bugs.python.org/issue1294959>`_
+# :see: `FHS <http://www.pathname.com/fhs/pub/fhs-2.3.html#LIBLTQUALGTALTERNATEFORMATESSENTIAL>`_
+if sys.platform.startswith('win'):
+ PY_SOURCE_EXTS = ('py', 'pyw')
+ PY_COMPILED_EXTS = ('dll', 'pyd')
+else:
+ PY_SOURCE_EXTS = ('py',)
+ PY_COMPILED_EXTS = ('so',)
+
+try:
+ STD_LIB_DIR = get_python_lib(standard_lib=1)
+# get_python_lib(standard_lib=1) is not available on pypy, set STD_LIB_DIR to
+# non-valid path, see https://bugs.pypy.org/issue1164
+except DistutilsPlatformError:
+ STD_LIB_DIR = '//'
+
+EXT_LIB_DIR = get_python_lib()
+
+BUILTIN_MODULES = dict(zip(sys.builtin_module_names,
+ [1]*len(sys.builtin_module_names)))
+
+
+class NoSourceFile(Exception):
+ """exception raised when we are not able to get a python
+ source file for a precompiled file
+ """
+
+class LazyObject(object):
+ def __init__(self, module, obj):
+ self.module = module
+ self.obj = obj
+ self._imported = None
+
+ def _getobj(self):
+ if self._imported is None:
+ self._imported = getattr(load_module_from_name(self.module),
+ self.obj)
+ return self._imported
+
+ def __getattribute__(self, attr):
+ try:
+ return super(LazyObject, self).__getattribute__(attr)
+ except AttributeError as ex:
+ return getattr(self._getobj(), attr)
+
+ def __call__(self, *args, **kwargs):
+ return self._getobj()(*args, **kwargs)
+
+
+def load_module_from_name(dotted_name, path=None, use_sys=1):
+ """Load a Python module from its name.
+
+ :type dotted_name: str
+ :param dotted_name: python name of a module or package
+
+ :type path: list or None
+ :param path:
+ optional list of path where the module or package should be
+ searched (use sys.path if nothing or None is given)
+
+ :type use_sys: bool
+ :param use_sys:
+ boolean indicating whether the sys.modules dictionary should be
+ used or not
+
+
+ :raise ImportError: if the module or package is not found
+
+ :rtype: module
+ :return: the loaded module
+ """
+ return load_module_from_modpath(dotted_name.split('.'), path, use_sys)
+
+
+def load_module_from_modpath(parts, path=None, use_sys=1):
+ """Load a python module from its splitted name.
+
+ :type parts: list(str) or tuple(str)
+ :param parts:
+ python name of a module or package splitted on '.'
+
+ :type path: list or None
+ :param path:
+ optional list of path where the module or package should be
+ searched (use sys.path if nothing or None is given)
+
+ :type use_sys: bool
+ :param use_sys:
+ boolean indicating whether the sys.modules dictionary should be used or not
+
+ :raise ImportError: if the module or package is not found
+
+ :rtype: module
+ :return: the loaded module
+ """
+ if use_sys:
+ try:
+ return sys.modules['.'.join(parts)]
+ except KeyError:
+ pass
+ modpath = []
+ prevmodule = None
+ for part in parts:
+ modpath.append(part)
+ curname = '.'.join(modpath)
+ module = None
+ if len(modpath) != len(parts):
+ # even with use_sys=False, should try to get outer packages from sys.modules
+ module = sys.modules.get(curname)
+ elif use_sys:
+ # because it may have been indirectly loaded through a parent
+ module = sys.modules.get(curname)
+ if module is None:
+ mp_file, mp_filename, mp_desc = find_module(part, path)
+ module = load_module(curname, mp_file, mp_filename, mp_desc)
+ if prevmodule:
+ setattr(prevmodule, part, module)
+ _file = getattr(module, '__file__', '')
+ if not _file and len(modpath) != len(parts):
+ raise ImportError('no module in %s' % '.'.join(parts[len(modpath):]) )
+ path = [dirname( _file )]
+ prevmodule = module
+ return module
+
+
+def load_module_from_file(filepath, path=None, use_sys=1, extrapath=None):
+ """Load a Python module from it's path.
+
+ :type filepath: str
+ :param filepath: path to the python module or package
+
+ :type path: list or None
+ :param path:
+ optional list of path where the module or package should be
+ searched (use sys.path if nothing or None is given)
+
+ :type use_sys: bool
+ :param use_sys:
+ boolean indicating whether the sys.modules dictionary should be
+ used or not
+
+
+ :raise ImportError: if the module or package is not found
+
+ :rtype: module
+ :return: the loaded module
+ """
+ modpath = modpath_from_file(filepath, extrapath)
+ return load_module_from_modpath(modpath, path, use_sys)
+
+
+def _check_init(path, mod_path):
+ """check there are some __init__.py all along the way"""
+ for part in mod_path:
+ path = join(path, part)
+ if not _has_init(path):
+ return False
+ return True
+
+
+def modpath_from_file(filename, extrapath=None):
+ """given a file path return the corresponding splitted module's name
+ (i.e name of a module or package splitted on '.')
+
+ :type filename: str
+ :param filename: file's path for which we want the module's name
+
+ :type extrapath: dict
+ :param extrapath:
+ optional extra search path, with path as key and package name for the path
+ as value. This is usually useful to handle package splitted in multiple
+ directories using __path__ trick.
+
+
+ :raise ImportError:
+ if the corresponding module's name has not been found
+
+ :rtype: list(str)
+ :return: the corresponding splitted module's name
+ """
+ base = splitext(abspath(filename))[0]
+ if extrapath is not None:
+ for path_ in extrapath:
+ path = abspath(path_)
+ if path and base[:len(path)] == path:
+ submodpath = [pkg for pkg in base[len(path):].split(os.sep)
+ if pkg]
+ if _check_init(path, submodpath[:-1]):
+ return extrapath[path_].split('.') + submodpath
+ for path in sys.path:
+ path = abspath(path)
+ if path and base.startswith(path):
+ modpath = [pkg for pkg in base[len(path):].split(os.sep) if pkg]
+ if _check_init(path, modpath[:-1]):
+ return modpath
+ raise ImportError('Unable to find module for %s in %s' % (
+ filename, ', \n'.join(sys.path)))
+
+
+
+def file_from_modpath(modpath, path=None, context_file=None):
+ """given a mod path (i.e. splitted module / package name), return the
+ corresponding file, giving priority to source file over precompiled
+ file if it exists
+
+ :type modpath: list or tuple
+ :param modpath:
+ splitted module's name (i.e name of a module or package splitted
+ on '.')
+ (this means explicit relative imports that start with dots have
+ empty strings in this list!)
+
+ :type path: list or None
+ :param path:
+ optional list of path where the module or package should be
+ searched (use sys.path if nothing or None is given)
+
+ :type context_file: str or None
+ :param context_file:
+ context file to consider, necessary if the identifier has been
+ introduced using a relative import unresolvable in the actual
+ context (i.e. modutils)
+
+ :raise ImportError: if there is no such module in the directory
+
+ :rtype: str or None
+ :return:
+ the path to the module's file or None if it's an integrated
+ builtin module such as 'sys'
+ """
+ if context_file is not None:
+ context = dirname(context_file)
+ else:
+ context = context_file
+ if modpath[0] == 'xml':
+ # handle _xmlplus
+ try:
+ return _file_from_modpath(['_xmlplus'] + modpath[1:], path, context)
+ except ImportError:
+ return _file_from_modpath(modpath, path, context)
+ elif modpath == ['os', 'path']:
+ # FIXME: currently ignoring search_path...
+ return os.path.__file__
+ return _file_from_modpath(modpath, path, context)
+
+
+
+def get_module_part(dotted_name, context_file=None):
+ """given a dotted name return the module part of the name :
+
+ >>> get_module_part('logilab.common.modutils.get_module_part')
+ 'logilab.common.modutils'
+
+ :type dotted_name: str
+ :param dotted_name: full name of the identifier we are interested in
+
+ :type context_file: str or None
+ :param context_file:
+ context file to consider, necessary if the identifier has been
+ introduced using a relative import unresolvable in the actual
+ context (i.e. modutils)
+
+
+ :raise ImportError: if there is no such module in the directory
+
+ :rtype: str or None
+ :return:
+ the module part of the name or None if we have not been able at
+ all to import the given name
+
+ XXX: deprecated, since it doesn't handle package precedence over module
+ (see #10066)
+ """
+ # os.path trick
+ if dotted_name.startswith('os.path'):
+ return 'os.path'
+ parts = dotted_name.split('.')
+ if context_file is not None:
+ # first check for builtin module which won't be considered latter
+ # in that case (path != None)
+ if parts[0] in BUILTIN_MODULES:
+ if len(parts) > 2:
+ raise ImportError(dotted_name)
+ return parts[0]
+ # don't use += or insert, we want a new list to be created !
+ path = None
+ starti = 0
+ if parts[0] == '':
+ assert context_file is not None, \
+ 'explicit relative import, but no context_file?'
+ path = [] # prevent resolving the import non-relatively
+ starti = 1
+ while parts[starti] == '': # for all further dots: change context
+ starti += 1
+ context_file = dirname(context_file)
+ for i in range(starti, len(parts)):
+ try:
+ file_from_modpath(parts[starti:i+1],
+ path=path, context_file=context_file)
+ except ImportError:
+ if not i >= max(1, len(parts) - 2):
+ raise
+ return '.'.join(parts[:i])
+ return dotted_name
+
+
+def get_modules(package, src_directory, blacklist=STD_BLACKLIST):
+ """given a package directory return a list of all available python
+ modules in the package and its subpackages
+
+ :type package: str
+ :param package: the python name for the package
+
+ :type src_directory: str
+ :param src_directory:
+ path of the directory corresponding to the package
+
+ :type blacklist: list or tuple
+ :param blacklist:
+ optional list of files or directory to ignore, default to
+ the value of `logilab.common.STD_BLACKLIST`
+
+ :rtype: list
+ :return:
+ the list of all available python modules in the package and its
+ subpackages
+ """
+ modules = []
+ for directory, dirnames, filenames in os.walk(src_directory):
+ _handle_blacklist(blacklist, dirnames, filenames)
+ # check for __init__.py
+ if not '__init__.py' in filenames:
+ dirnames[:] = ()
+ continue
+ if directory != src_directory:
+ dir_package = directory[len(src_directory):].replace(os.sep, '.')
+ modules.append(package + dir_package)
+ for filename in filenames:
+ if _is_python_file(filename) and filename != '__init__.py':
+ src = join(directory, filename)
+ module = package + src[len(src_directory):-3]
+ modules.append(module.replace(os.sep, '.'))
+ return modules
+
+
+
+def get_module_files(src_directory, blacklist=STD_BLACKLIST):
+ """given a package directory return a list of all available python
+ module's files in the package and its subpackages
+
+ :type src_directory: str
+ :param src_directory:
+ path of the directory corresponding to the package
+
+ :type blacklist: list or tuple
+ :param blacklist:
+ optional list of files or directory to ignore, default to the value of
+ `logilab.common.STD_BLACKLIST`
+
+ :rtype: list
+ :return:
+ the list of all available python module's files in the package and
+ its subpackages
+ """
+ files = []
+ for directory, dirnames, filenames in os.walk(src_directory):
+ _handle_blacklist(blacklist, dirnames, filenames)
+ # check for __init__.py
+ if not '__init__.py' in filenames:
+ dirnames[:] = ()
+ continue
+ for filename in filenames:
+ if _is_python_file(filename):
+ src = join(directory, filename)
+ files.append(src)
+ return files
+
+
+def get_source_file(filename, include_no_ext=False):
+ """given a python module's file name return the matching source file
+ name (the filename will be returned identically if it's a already an
+ absolute path to a python source file...)
+
+ :type filename: str
+ :param filename: python module's file name
+
+
+ :raise NoSourceFile: if no source file exists on the file system
+
+ :rtype: str
+ :return: the absolute path of the source file if it exists
+ """
+ base, orig_ext = splitext(abspath(filename))
+ for ext in PY_SOURCE_EXTS:
+ source_path = '%s.%s' % (base, ext)
+ if exists(source_path):
+ return source_path
+ if include_no_ext and not orig_ext and exists(base):
+ return base
+ raise NoSourceFile(filename)
+
+
+def cleanup_sys_modules(directories):
+ """remove submodules of `directories` from `sys.modules`"""
+ cleaned = []
+ for modname, module in list(sys.modules.items()):
+ modfile = getattr(module, '__file__', None)
+ if modfile:
+ for directory in directories:
+ if modfile.startswith(directory):
+ cleaned.append(modname)
+ del sys.modules[modname]
+ break
+ return cleaned
+
+
+def is_python_source(filename):
+ """
+ rtype: bool
+ return: True if the filename is a python source file
+ """
+ return splitext(filename)[1][1:] in PY_SOURCE_EXTS
+
+
+
+def is_standard_module(modname, std_path=(STD_LIB_DIR,)):
+ """try to guess if a module is a standard python module (by default,
+ see `std_path` parameter's description)
+
+ :type modname: str
+ :param modname: name of the module we are interested in
+
+ :type std_path: list(str) or tuple(str)
+ :param std_path: list of path considered has standard
+
+
+ :rtype: bool
+ :return:
+ true if the module:
+ - is located on the path listed in one of the directory in `std_path`
+ - is a built-in module
+ """
+ modname = modname.split('.')[0]
+ try:
+ filename = file_from_modpath([modname])
+ except ImportError as ex:
+ # import failed, i'm probably not so wrong by supposing it's
+ # not standard...
+ return 0
+ # modules which are not living in a file are considered standard
+ # (sys and __builtin__ for instance)
+ if filename is None:
+ return 1
+ filename = abspath(filename)
+ if filename.startswith(EXT_LIB_DIR):
+ return 0
+ for path in std_path:
+ if filename.startswith(abspath(path)):
+ return 1
+ return False
+
+
+
+def is_relative(modname, from_file):
+ """return true if the given module name is relative to the given
+ file name
+
+ :type modname: str
+ :param modname: name of the module we are interested in
+
+ :type from_file: str
+ :param from_file:
+ path of the module from which modname has been imported
+
+ :rtype: bool
+ :return:
+ true if the module has been imported relatively to `from_file`
+ """
+ if not isdir(from_file):
+ from_file = dirname(from_file)
+ if from_file in sys.path:
+ return False
+ try:
+ find_module(modname.split('.')[0], [from_file])
+ return True
+ except ImportError:
+ return False
+
+
+# internal only functions #####################################################
+
+def _file_from_modpath(modpath, path=None, context=None):
+ """given a mod path (i.e. splitted module / package name), return the
+ corresponding file
+
+ this function is used internally, see `file_from_modpath`'s
+ documentation for more information
+ """
+ assert len(modpath) > 0
+ if context is not None:
+ try:
+ mtype, mp_filename = _module_file(modpath, [context])
+ except ImportError:
+ mtype, mp_filename = _module_file(modpath, path)
+ else:
+ mtype, mp_filename = _module_file(modpath, path)
+ if mtype == PY_COMPILED:
+ try:
+ return get_source_file(mp_filename)
+ except NoSourceFile:
+ return mp_filename
+ elif mtype == C_BUILTIN:
+ # integrated builtin module
+ return None
+ elif mtype == PKG_DIRECTORY:
+ mp_filename = _has_init(mp_filename)
+ return mp_filename
+
+def _search_zip(modpath, pic):
+ for filepath, importer in pic.items():
+ if importer is not None:
+ if importer.find_module(modpath[0]):
+ if not importer.find_module('/'.join(modpath)):
+ raise ImportError('No module named %s in %s/%s' % (
+ '.'.join(modpath[1:]), filepath, modpath))
+ return ZIPFILE, abspath(filepath) + '/' + '/'.join(modpath), filepath
+ raise ImportError('No module named %s' % '.'.join(modpath))
+
+try:
+ import pkg_resources
+except ImportError:
+ pkg_resources = None
+
+def _module_file(modpath, path=None):
+ """get a module type / file path
+
+ :type modpath: list or tuple
+ :param modpath:
+ splitted module's name (i.e name of a module or package splitted
+ on '.'), with leading empty strings for explicit relative import
+
+ :type path: list or None
+ :param path:
+ optional list of path where the module or package should be
+ searched (use sys.path if nothing or None is given)
+
+
+ :rtype: tuple(int, str)
+ :return: the module type flag and the file path for a module
+ """
+ # egg support compat
+ try:
+ pic = sys.path_importer_cache
+ _path = (path is None and sys.path or path)
+ for __path in _path:
+ if not __path in pic:
+ try:
+ pic[__path] = zipimport.zipimporter(__path)
+ except zipimport.ZipImportError:
+ pic[__path] = None
+ checkeggs = True
+ except AttributeError:
+ checkeggs = False
+ # pkg_resources support (aka setuptools namespace packages)
+ if (pkg_resources is not None
+ and modpath[0] in pkg_resources._namespace_packages
+ and modpath[0] in sys.modules
+ and len(modpath) > 1):
+ # setuptools has added into sys.modules a module object with proper
+ # __path__, get back information from there
+ module = sys.modules[modpath.pop(0)]
+ path = module.__path__
+ imported = []
+ while modpath:
+ modname = modpath[0]
+ # take care to changes in find_module implementation wrt builtin modules
+ #
+ # Python 2.6.6 (r266:84292, Sep 11 2012, 08:34:23)
+ # >>> imp.find_module('posix')
+ # (None, 'posix', ('', '', 6))
+ #
+ # Python 3.3.1 (default, Apr 26 2013, 12:08:46)
+ # >>> imp.find_module('posix')
+ # (None, None, ('', '', 6))
+ try:
+ _, mp_filename, mp_desc = find_module(modname, path)
+ except ImportError:
+ if checkeggs:
+ return _search_zip(modpath, pic)[:2]
+ raise
+ else:
+ if checkeggs and mp_filename:
+ fullabspath = [abspath(x) for x in _path]
+ try:
+ pathindex = fullabspath.index(dirname(abspath(mp_filename)))
+ emtype, emp_filename, zippath = _search_zip(modpath, pic)
+ if pathindex > _path.index(zippath):
+ # an egg takes priority
+ return emtype, emp_filename
+ except ValueError:
+ # XXX not in _path
+ pass
+ except ImportError:
+ pass
+ checkeggs = False
+ imported.append(modpath.pop(0))
+ mtype = mp_desc[2]
+ if modpath:
+ if mtype != PKG_DIRECTORY:
+ raise ImportError('No module %s in %s' % ('.'.join(modpath),
+ '.'.join(imported)))
+ # XXX guess if package is using pkgutil.extend_path by looking for
+ # those keywords in the first four Kbytes
+ try:
+ with open(join(mp_filename, '__init__.py')) as stream:
+ data = stream.read(4096)
+ except IOError:
+ path = [mp_filename]
+ else:
+ if 'pkgutil' in data and 'extend_path' in data:
+ # extend_path is called, search sys.path for module/packages
+ # of this name see pkgutil.extend_path documentation
+ path = [join(p, *imported) for p in sys.path
+ if isdir(join(p, *imported))]
+ else:
+ path = [mp_filename]
+ return mtype, mp_filename
+
+def _is_python_file(filename):
+ """return true if the given filename should be considered as a python file
+
+ .pyc and .pyo are ignored
+ """
+ for ext in ('.py', '.so', '.pyd', '.pyw'):
+ if filename.endswith(ext):
+ return True
+ return False
+
+
+def _has_init(directory):
+ """if the given directory has a valid __init__ file, return its path,
+ else return None
+ """
+ mod_or_pack = join(directory, '__init__')
+ for ext in PY_SOURCE_EXTS + ('pyc', 'pyo'):
+ if exists(mod_or_pack + '.' + ext):
+ return mod_or_pack + '.' + ext
+ return None
diff --git a/logilab/common/optik_ext.py b/logilab/common/optik_ext.py
new file mode 100644
index 0000000..1fd2a7f
--- /dev/null
+++ b/logilab/common/optik_ext.py
@@ -0,0 +1,392 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Add an abstraction level to transparently import optik classes from optparse
+(python >= 2.3) or the optik package.
+
+It also defines three new types for optik/optparse command line parser :
+
+ * regexp
+ argument of this type will be converted using re.compile
+ * csv
+ argument of this type will be converted using split(',')
+ * yn
+ argument of this type will be true if 'y' or 'yes', false if 'n' or 'no'
+ * named
+ argument of this type are in the form <NAME>=<VALUE> or <NAME>:<VALUE>
+ * password
+ argument of this type wont be converted but this is used by other tools
+ such as interactive prompt for configuration to double check value and
+ use an invisible field
+ * multiple_choice
+ same as default "choice" type but multiple choices allowed
+ * file
+ argument of this type wont be converted but checked that the given file exists
+ * color
+ argument of this type wont be converted but checked its either a
+ named color or a color specified using hexadecimal notation (preceded by a #)
+ * time
+ argument of this type will be converted to a float value in seconds
+ according to time units (ms, s, min, h, d)
+ * bytes
+ argument of this type will be converted to a float value in bytes
+ according to byte units (b, kb, mb, gb, tb)
+"""
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+import re
+import sys
+import time
+from copy import copy
+from os.path import exists
+
+# python >= 2.3
+from optparse import OptionParser as BaseParser, Option as BaseOption, \
+ OptionGroup, OptionContainer, OptionValueError, OptionError, \
+ Values, HelpFormatter, NO_DEFAULT, SUPPRESS_HELP
+
+try:
+ from mx import DateTime
+ HAS_MX_DATETIME = True
+except ImportError:
+ HAS_MX_DATETIME = False
+
+from logilab.common.textutils import splitstrip, TIME_UNITS, BYTE_UNITS, \
+ apply_units
+
+
+def check_regexp(option, opt, value):
+ """check a regexp value by trying to compile it
+ return the compiled regexp
+ """
+ if hasattr(value, 'pattern'):
+ return value
+ try:
+ return re.compile(value)
+ except ValueError:
+ raise OptionValueError(
+ "option %s: invalid regexp value: %r" % (opt, value))
+
+def check_csv(option, opt, value):
+ """check a csv value by trying to split it
+ return the list of separated values
+ """
+ if isinstance(value, (list, tuple)):
+ return value
+ try:
+ return splitstrip(value)
+ except ValueError:
+ raise OptionValueError(
+ "option %s: invalid csv value: %r" % (opt, value))
+
+def check_yn(option, opt, value):
+ """check a yn value
+ return true for yes and false for no
+ """
+ if isinstance(value, int):
+ return bool(value)
+ if value in ('y', 'yes'):
+ return True
+ if value in ('n', 'no'):
+ return False
+ msg = "option %s: invalid yn value %r, should be in (y, yes, n, no)"
+ raise OptionValueError(msg % (opt, value))
+
+def check_named(option, opt, value):
+ """check a named value
+ return a dictionary containing (name, value) associations
+ """
+ if isinstance(value, dict):
+ return value
+ values = []
+ for value in check_csv(option, opt, value):
+ if value.find('=') != -1:
+ values.append(value.split('=', 1))
+ elif value.find(':') != -1:
+ values.append(value.split(':', 1))
+ if values:
+ return dict(values)
+ msg = "option %s: invalid named value %r, should be <NAME>=<VALUE> or \
+<NAME>:<VALUE>"
+ raise OptionValueError(msg % (opt, value))
+
+def check_password(option, opt, value):
+ """check a password value (can't be empty)
+ """
+ # no actual checking, monkey patch if you want more
+ return value
+
+def check_file(option, opt, value):
+ """check a file value
+ return the filepath
+ """
+ if exists(value):
+ return value
+ msg = "option %s: file %r does not exist"
+ raise OptionValueError(msg % (opt, value))
+
+# XXX use python datetime
+def check_date(option, opt, value):
+ """check a file value
+ return the filepath
+ """
+ try:
+ return DateTime.strptime(value, "%Y/%m/%d")
+ except DateTime.Error :
+ raise OptionValueError(
+ "expected format of %s is yyyy/mm/dd" % opt)
+
+def check_color(option, opt, value):
+ """check a color value and returns it
+ /!\ does *not* check color labels (like 'red', 'green'), only
+ checks hexadecimal forms
+ """
+ # Case (1) : color label, we trust the end-user
+ if re.match('[a-z0-9 ]+$', value, re.I):
+ return value
+ # Case (2) : only accepts hexadecimal forms
+ if re.match('#[a-f0-9]{6}', value, re.I):
+ return value
+ # Else : not a color label neither a valid hexadecimal form => error
+ msg = "option %s: invalid color : %r, should be either hexadecimal \
+ value or predefined color"
+ raise OptionValueError(msg % (opt, value))
+
+def check_time(option, opt, value):
+ if isinstance(value, (int, long, float)):
+ return value
+ return apply_units(value, TIME_UNITS)
+
+def check_bytes(option, opt, value):
+ if hasattr(value, '__int__'):
+ return value
+ return apply_units(value, BYTE_UNITS)
+
+
+class Option(BaseOption):
+ """override optik.Option to add some new option types
+ """
+ TYPES = BaseOption.TYPES + ('regexp', 'csv', 'yn', 'named', 'password',
+ 'multiple_choice', 'file', 'color',
+ 'time', 'bytes')
+ ATTRS = BaseOption.ATTRS + ['hide', 'level']
+ TYPE_CHECKER = copy(BaseOption.TYPE_CHECKER)
+ TYPE_CHECKER['regexp'] = check_regexp
+ TYPE_CHECKER['csv'] = check_csv
+ TYPE_CHECKER['yn'] = check_yn
+ TYPE_CHECKER['named'] = check_named
+ TYPE_CHECKER['multiple_choice'] = check_csv
+ TYPE_CHECKER['file'] = check_file
+ TYPE_CHECKER['color'] = check_color
+ TYPE_CHECKER['password'] = check_password
+ TYPE_CHECKER['time'] = check_time
+ TYPE_CHECKER['bytes'] = check_bytes
+ if HAS_MX_DATETIME:
+ TYPES += ('date',)
+ TYPE_CHECKER['date'] = check_date
+
+ def __init__(self, *opts, **attrs):
+ BaseOption.__init__(self, *opts, **attrs)
+ if hasattr(self, "hide") and self.hide:
+ self.help = SUPPRESS_HELP
+
+ def _check_choice(self):
+ """FIXME: need to override this due to optik misdesign"""
+ if self.type in ("choice", "multiple_choice"):
+ if self.choices is None:
+ raise OptionError(
+ "must supply a list of choices for type 'choice'", self)
+ elif not isinstance(self.choices, (tuple, list)):
+ raise OptionError(
+ "choices must be a list of strings ('%s' supplied)"
+ % str(type(self.choices)).split("'")[1], self)
+ elif self.choices is not None:
+ raise OptionError(
+ "must not supply choices for type %r" % self.type, self)
+ BaseOption.CHECK_METHODS[2] = _check_choice
+
+
+ def process(self, opt, value, values, parser):
+ # First, convert the value(s) to the right type. Howl if any
+ # value(s) are bogus.
+ value = self.convert_value(opt, value)
+ if self.type == 'named':
+ existant = getattr(values, self.dest)
+ if existant:
+ existant.update(value)
+ value = existant
+ # And then take whatever action is expected of us.
+ # This is a separate method to make life easier for
+ # subclasses to add new actions.
+ return self.take_action(
+ self.action, self.dest, opt, value, values, parser)
+
+
+class OptionParser(BaseParser):
+ """override optik.OptionParser to use our Option class
+ """
+ def __init__(self, option_class=Option, *args, **kwargs):
+ BaseParser.__init__(self, option_class=Option, *args, **kwargs)
+
+ def format_option_help(self, formatter=None):
+ if formatter is None:
+ formatter = self.formatter
+ outputlevel = getattr(formatter, 'output_level', 0)
+ formatter.store_option_strings(self)
+ result = []
+ result.append(formatter.format_heading("Options"))
+ formatter.indent()
+ if self.option_list:
+ result.append(OptionContainer.format_option_help(self, formatter))
+ result.append("\n")
+ for group in self.option_groups:
+ if group.level <= outputlevel and (
+ group.description or level_options(group, outputlevel)):
+ result.append(group.format_help(formatter))
+ result.append("\n")
+ formatter.dedent()
+ # Drop the last "\n", or the header if no options or option groups:
+ return "".join(result[:-1])
+
+
+OptionGroup.level = 0
+
+def level_options(group, outputlevel):
+ return [option for option in group.option_list
+ if (getattr(option, 'level', 0) or 0) <= outputlevel
+ and not option.help is SUPPRESS_HELP]
+
+def format_option_help(self, formatter):
+ result = []
+ outputlevel = getattr(formatter, 'output_level', 0) or 0
+ for option in level_options(self, outputlevel):
+ result.append(formatter.format_option(option))
+ return "".join(result)
+OptionContainer.format_option_help = format_option_help
+
+
+class ManHelpFormatter(HelpFormatter):
+ """Format help using man pages ROFF format"""
+
+ def __init__ (self,
+ indent_increment=0,
+ max_help_position=24,
+ width=79,
+ short_first=0):
+ HelpFormatter.__init__ (
+ self, indent_increment, max_help_position, width, short_first)
+
+ def format_heading(self, heading):
+ return '.SH %s\n' % heading.upper()
+
+ def format_description(self, description):
+ return description
+
+ def format_option(self, option):
+ try:
+ optstring = option.option_strings
+ except AttributeError:
+ optstring = self.format_option_strings(option)
+ if option.help:
+ help_text = self.expand_default(option)
+ help = ' '.join([l.strip() for l in help_text.splitlines()])
+ else:
+ help = ''
+ return '''.IP "%s"
+%s
+''' % (optstring, help)
+
+ def format_head(self, optparser, pkginfo, section=1):
+ long_desc = ""
+ try:
+ pgm = optparser._get_prog_name()
+ except AttributeError:
+ # py >= 2.4.X (dunno which X exactly, at least 2)
+ pgm = optparser.get_prog_name()
+ short_desc = self.format_short_description(pgm, pkginfo.description)
+ if hasattr(pkginfo, "long_desc"):
+ long_desc = self.format_long_description(pgm, pkginfo.long_desc)
+ return '%s\n%s\n%s\n%s' % (self.format_title(pgm, section),
+ short_desc, self.format_synopsis(pgm),
+ long_desc)
+
+ def format_title(self, pgm, section):
+ date = '-'.join([str(num) for num in time.localtime()[:3]])
+ return '.TH %s %s "%s" %s' % (pgm, section, date, pgm)
+
+ def format_short_description(self, pgm, short_desc):
+ return '''.SH NAME
+.B %s
+\- %s
+''' % (pgm, short_desc.strip())
+
+ def format_synopsis(self, pgm):
+ return '''.SH SYNOPSIS
+.B %s
+[
+.I OPTIONS
+] [
+.I <arguments>
+]
+''' % pgm
+
+ def format_long_description(self, pgm, long_desc):
+ long_desc = '\n'.join([line.lstrip()
+ for line in long_desc.splitlines()])
+ long_desc = long_desc.replace('\n.\n', '\n\n')
+ if long_desc.lower().startswith(pgm):
+ long_desc = long_desc[len(pgm):]
+ return '''.SH DESCRIPTION
+.B %s
+%s
+''' % (pgm, long_desc.strip())
+
+ def format_tail(self, pkginfo):
+ tail = '''.SH SEE ALSO
+/usr/share/doc/pythonX.Y-%s/
+
+.SH BUGS
+Please report bugs on the project\'s mailing list:
+%s
+
+.SH AUTHOR
+%s <%s>
+''' % (getattr(pkginfo, 'debian_name', pkginfo.modname),
+ pkginfo.mailinglist, pkginfo.author, pkginfo.author_email)
+
+ if hasattr(pkginfo, "copyright"):
+ tail += '''
+.SH COPYRIGHT
+%s
+''' % pkginfo.copyright
+
+ return tail
+
+def generate_manpage(optparser, pkginfo, section=1, stream=sys.stdout, level=0):
+ """generate a man page from an optik parser"""
+ formatter = ManHelpFormatter()
+ formatter.output_level = level
+ formatter.parser = optparser
+ print(formatter.format_head(optparser, pkginfo, section), file=stream)
+ print(optparser.format_option_help(formatter), file=stream)
+ print(formatter.format_tail(pkginfo), file=stream)
+
+
+__all__ = ('OptionParser', 'Option', 'OptionGroup', 'OptionValueError',
+ 'Values')
diff --git a/logilab/common/optparser.py b/logilab/common/optparser.py
new file mode 100644
index 0000000..aa17750
--- /dev/null
+++ b/logilab/common/optparser.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Extend OptionParser with commands.
+
+Example:
+
+>>> parser = OptionParser()
+>>> parser.usage = '%prog COMMAND [options] <arg> ...'
+>>> parser.add_command('build', 'mymod.build')
+>>> parser.add_command('clean', run_clean, add_opt_clean)
+>>> run, options, args = parser.parse_command(sys.argv[1:])
+>>> return run(options, args[1:])
+
+With mymod.build that defines two functions run and add_options
+"""
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+from warnings import warn
+warn('lgc.optparser module is deprecated, use lgc.clcommands instead', DeprecationWarning,
+ stacklevel=2)
+
+import sys
+import optparse
+
+class OptionParser(optparse.OptionParser):
+
+ def __init__(self, *args, **kwargs):
+ optparse.OptionParser.__init__(self, *args, **kwargs)
+ self._commands = {}
+ self.min_args, self.max_args = 0, 1
+
+ def add_command(self, name, mod_or_funcs, help=''):
+ """name of the command, name of module or tuple of functions
+ (run, add_options)
+ """
+ assert isinstance(mod_or_funcs, str) or isinstance(mod_or_funcs, tuple), \
+ "mod_or_funcs has to be a module name or a tuple of functions"
+ self._commands[name] = (mod_or_funcs, help)
+
+ def print_main_help(self):
+ optparse.OptionParser.print_help(self)
+ print('\ncommands:')
+ for cmdname, (_, help) in self._commands.items():
+ print('% 10s - %s' % (cmdname, help))
+
+ def parse_command(self, args):
+ if len(args) == 0:
+ self.print_main_help()
+ sys.exit(1)
+ cmd = args[0]
+ args = args[1:]
+ if cmd not in self._commands:
+ if cmd in ('-h', '--help'):
+ self.print_main_help()
+ sys.exit(0)
+ elif self.version is not None and cmd == "--version":
+ self.print_version()
+ sys.exit(0)
+ self.error('unknown command')
+ self.prog = '%s %s' % (self.prog, cmd)
+ mod_or_f, help = self._commands[cmd]
+ # optparse inserts self.description between usage and options help
+ self.description = help
+ if isinstance(mod_or_f, str):
+ exec('from %s import run, add_options' % mod_or_f)
+ else:
+ run, add_options = mod_or_f
+ add_options(self)
+ (options, args) = self.parse_args(args)
+ if not (self.min_args <= len(args) <= self.max_args):
+ self.error('incorrect number of arguments')
+ return run, options, args
+
+
diff --git a/logilab/common/proc.py b/logilab/common/proc.py
new file mode 100644
index 0000000..c27356c
--- /dev/null
+++ b/logilab/common/proc.py
@@ -0,0 +1,277 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""module providing:
+* process information (linux specific: rely on /proc)
+* a class for resource control (memory / time / cpu time)
+
+This module doesn't work on windows platforms (only tested on linux)
+
+:organization: Logilab
+
+
+
+"""
+__docformat__ = "restructuredtext en"
+
+import os
+import stat
+from resource import getrlimit, setrlimit, RLIMIT_CPU, RLIMIT_AS
+from signal import signal, SIGXCPU, SIGKILL, SIGUSR2, SIGUSR1
+from threading import Timer, currentThread, Thread, Event
+from time import time
+
+from logilab.common.tree import Node
+
+class NoSuchProcess(Exception): pass
+
+def proc_exists(pid):
+ """check the a pid is registered in /proc
+ raise NoSuchProcess exception if not
+ """
+ if not os.path.exists('/proc/%s' % pid):
+ raise NoSuchProcess()
+
+PPID = 3
+UTIME = 13
+STIME = 14
+CUTIME = 15
+CSTIME = 16
+VSIZE = 22
+
+class ProcInfo(Node):
+ """provide access to process information found in /proc"""
+
+ def __init__(self, pid):
+ self.pid = int(pid)
+ Node.__init__(self, self.pid)
+ proc_exists(self.pid)
+ self.file = '/proc/%s/stat' % self.pid
+ self.ppid = int(self.status()[PPID])
+
+ def memory_usage(self):
+ """return the memory usage of the process in Ko"""
+ try :
+ return int(self.status()[VSIZE])
+ except IOError:
+ return 0
+
+ def lineage_memory_usage(self):
+ return self.memory_usage() + sum([child.lineage_memory_usage()
+ for child in self.children])
+
+ def time(self, children=0):
+ """return the number of jiffies that this process has been scheduled
+ in user and kernel mode"""
+ status = self.status()
+ time = int(status[UTIME]) + int(status[STIME])
+ if children:
+ time += int(status[CUTIME]) + int(status[CSTIME])
+ return time
+
+ def status(self):
+ """return the list of fields found in /proc/<pid>/stat"""
+ return open(self.file).read().split()
+
+ def name(self):
+ """return the process name found in /proc/<pid>/stat
+ """
+ return self.status()[1].strip('()')
+
+ def age(self):
+ """return the age of the process
+ """
+ return os.stat(self.file)[stat.ST_MTIME]
+
+class ProcInfoLoader:
+ """manage process information"""
+
+ def __init__(self):
+ self._loaded = {}
+
+ def list_pids(self):
+ """return a list of existent process ids"""
+ for subdir in os.listdir('/proc'):
+ if subdir.isdigit():
+ yield int(subdir)
+
+ def load(self, pid):
+ """get a ProcInfo object for a given pid"""
+ pid = int(pid)
+ try:
+ return self._loaded[pid]
+ except KeyError:
+ procinfo = ProcInfo(pid)
+ procinfo.manager = self
+ self._loaded[pid] = procinfo
+ return procinfo
+
+
+ def load_all(self):
+ """load all processes information"""
+ for pid in self.list_pids():
+ try:
+ procinfo = self.load(pid)
+ if procinfo.parent is None and procinfo.ppid:
+ pprocinfo = self.load(procinfo.ppid)
+ pprocinfo.append(procinfo)
+ except NoSuchProcess:
+ pass
+
+
+try:
+ class ResourceError(BaseException):
+ """Error raise when resource limit is reached"""
+ limit = "Unknown Resource Limit"
+except NameError:
+ class ResourceError(Exception):
+ """Error raise when resource limit is reached"""
+ limit = "Unknown Resource Limit"
+
+
+class XCPUError(ResourceError):
+ """Error raised when CPU Time limit is reached"""
+ limit = "CPU Time"
+
+class LineageMemoryError(ResourceError):
+ """Error raised when the total amount of memory used by a process and
+ it's child is reached"""
+ limit = "Lineage total Memory"
+
+class TimeoutError(ResourceError):
+ """Error raised when the process is running for to much time"""
+ limit = "Real Time"
+
+# Can't use subclass because the StandardError MemoryError raised
+RESOURCE_LIMIT_EXCEPTION = (ResourceError, MemoryError)
+
+
+class MemorySentinel(Thread):
+ """A class checking a process don't use too much memory in a separated
+ daemonic thread
+ """
+ def __init__(self, interval, memory_limit, gpid=os.getpid()):
+ Thread.__init__(self, target=self._run, name="Test.Sentinel")
+ self.memory_limit = memory_limit
+ self._stop = Event()
+ self.interval = interval
+ self.setDaemon(True)
+ self.gpid = gpid
+
+ def stop(self):
+ """stop ap"""
+ self._stop.set()
+
+ def _run(self):
+ pil = ProcInfoLoader()
+ while not self._stop.isSet():
+ if self.memory_limit <= pil.load(self.gpid).lineage_memory_usage():
+ os.killpg(self.gpid, SIGUSR1)
+ self._stop.wait(self.interval)
+
+
+class ResourceController:
+
+ def __init__(self, max_cpu_time=None, max_time=None, max_memory=None,
+ max_reprieve=60):
+ if SIGXCPU == -1:
+ raise RuntimeError("Unsupported platform")
+ self.max_time = max_time
+ self.max_memory = max_memory
+ self.max_cpu_time = max_cpu_time
+ self._reprieve = max_reprieve
+ self._timer = None
+ self._msentinel = None
+ self._old_max_memory = None
+ self._old_usr1_hdlr = None
+ self._old_max_cpu_time = None
+ self._old_usr2_hdlr = None
+ self._old_sigxcpu_hdlr = None
+ self._limit_set = 0
+ self._abort_try = 0
+ self._start_time = None
+ self._elapse_time = 0
+
+ def _hangle_sig_timeout(self, sig, frame):
+ raise TimeoutError()
+
+ def _hangle_sig_memory(self, sig, frame):
+ if self._abort_try < self._reprieve:
+ self._abort_try += 1
+ raise LineageMemoryError("Memory limit reached")
+ else:
+ os.killpg(os.getpid(), SIGKILL)
+
+ def _handle_sigxcpu(self, sig, frame):
+ if self._abort_try < self._reprieve:
+ self._abort_try += 1
+ raise XCPUError("Soft CPU time limit reached")
+ else:
+ os.killpg(os.getpid(), SIGKILL)
+
+ def _time_out(self):
+ if self._abort_try < self._reprieve:
+ self._abort_try += 1
+ os.killpg(os.getpid(), SIGUSR2)
+ if self._limit_set > 0:
+ self._timer = Timer(1, self._time_out)
+ self._timer.start()
+ else:
+ os.killpg(os.getpid(), SIGKILL)
+
+ def setup_limit(self):
+ """set up the process limit"""
+ assert currentThread().getName() == 'MainThread'
+ os.setpgrp()
+ if self._limit_set <= 0:
+ if self.max_time is not None:
+ self._old_usr2_hdlr = signal(SIGUSR2, self._hangle_sig_timeout)
+ self._timer = Timer(max(1, int(self.max_time) - self._elapse_time),
+ self._time_out)
+ self._start_time = int(time())
+ self._timer.start()
+ if self.max_cpu_time is not None:
+ self._old_max_cpu_time = getrlimit(RLIMIT_CPU)
+ cpu_limit = (int(self.max_cpu_time), self._old_max_cpu_time[1])
+ self._old_sigxcpu_hdlr = signal(SIGXCPU, self._handle_sigxcpu)
+ setrlimit(RLIMIT_CPU, cpu_limit)
+ if self.max_memory is not None:
+ self._msentinel = MemorySentinel(1, int(self.max_memory) )
+ self._old_max_memory = getrlimit(RLIMIT_AS)
+ self._old_usr1_hdlr = signal(SIGUSR1, self._hangle_sig_memory)
+ as_limit = (int(self.max_memory), self._old_max_memory[1])
+ setrlimit(RLIMIT_AS, as_limit)
+ self._msentinel.start()
+ self._limit_set += 1
+
+ def clean_limit(self):
+ """reinstall the old process limit"""
+ if self._limit_set > 0:
+ if self.max_time is not None:
+ self._timer.cancel()
+ self._elapse_time += int(time())-self._start_time
+ self._timer = None
+ signal(SIGUSR2, self._old_usr2_hdlr)
+ if self.max_cpu_time is not None:
+ setrlimit(RLIMIT_CPU, self._old_max_cpu_time)
+ signal(SIGXCPU, self._old_sigxcpu_hdlr)
+ if self.max_memory is not None:
+ self._msentinel.stop()
+ self._msentinel = None
+ setrlimit(RLIMIT_AS, self._old_max_memory)
+ signal(SIGUSR1, self._old_usr1_hdlr)
+ self._limit_set -= 1
diff --git a/logilab/common/pyro_ext.py b/logilab/common/pyro_ext.py
new file mode 100644
index 0000000..5204b1b
--- /dev/null
+++ b/logilab/common/pyro_ext.py
@@ -0,0 +1,180 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Python Remote Object utilities
+
+Main functions available:
+
+* `register_object` to expose arbitrary object through pyro using delegation
+ approach and register it in the nameserver.
+* `ns_unregister` unregister an object identifier from the nameserver.
+* `ns_get_proxy` get a pyro proxy from a nameserver object identifier.
+"""
+
+__docformat__ = "restructuredtext en"
+
+import logging
+import tempfile
+
+from Pyro import core, naming, errors, util, config
+
+_LOGGER = logging.getLogger('pyro')
+_MARKER = object()
+
+config.PYRO_STORAGE = tempfile.gettempdir()
+
+def ns_group_and_id(idstr, defaultnsgroup=_MARKER):
+ try:
+ nsgroup, nsid = idstr.rsplit('.', 1)
+ except ValueError:
+ if defaultnsgroup is _MARKER:
+ nsgroup = config.PYRO_NS_DEFAULTGROUP
+ else:
+ nsgroup = defaultnsgroup
+ nsid = idstr
+ if nsgroup is not None and not nsgroup.startswith(':'):
+ nsgroup = ':' + nsgroup
+ return nsgroup, nsid
+
+def host_and_port(hoststr):
+ if not hoststr:
+ return None, None
+ try:
+ hoststr, port = hoststr.split(':')
+ except ValueError:
+ port = None
+ else:
+ port = int(port)
+ return hoststr, port
+
+_DAEMONS = {}
+_PYRO_OBJS = {}
+def _get_daemon(daemonhost, start=True):
+ if not daemonhost in _DAEMONS:
+ if not start:
+ raise Exception('no daemon for %s' % daemonhost)
+ if not _DAEMONS:
+ core.initServer(banner=0)
+ host, port = host_and_port(daemonhost)
+ daemon = core.Daemon(host=host, port=port)
+ _DAEMONS[daemonhost] = daemon
+ return _DAEMONS[daemonhost]
+
+
+def locate_ns(nshost):
+ """locate and return the pyro name server to the daemon"""
+ core.initClient(banner=False)
+ return naming.NameServerLocator().getNS(*host_and_port(nshost))
+
+
+def register_object(object, nsid, defaultnsgroup=_MARKER,
+ daemonhost=None, nshost=None, use_pyrons=True):
+ """expose the object as a pyro object and register it in the name-server
+
+ if use_pyrons is False, then the object is exposed, but no
+ attempt to register it to a pyro nameserver is made.
+
+ return the pyro daemon object
+ """
+ nsgroup, nsid = ns_group_and_id(nsid, defaultnsgroup)
+ daemon = _get_daemon(daemonhost)
+ if use_pyrons:
+ nsd = locate_ns(nshost)
+ # make sure our namespace group exists
+ try:
+ nsd.createGroup(nsgroup)
+ except errors.NamingError:
+ pass
+ daemon.useNameServer(nsd)
+ # use Delegation approach
+ impl = core.ObjBase()
+ impl.delegateTo(object)
+ qnsid = '%s.%s' % (nsgroup, nsid)
+ uri = daemon.connect(impl, qnsid)
+ _PYRO_OBJS[qnsid] = str(uri)
+ _LOGGER.info('registered %s a pyro object using group %s and id %s',
+ object, nsgroup, nsid)
+ return daemon
+
+def get_object_uri(qnsid):
+ return _PYRO_OBJS[qnsid]
+
+def ns_unregister(nsid, defaultnsgroup=_MARKER, nshost=None):
+ """unregister the object with the given nsid from the pyro name server"""
+ nsgroup, nsid = ns_group_and_id(nsid, defaultnsgroup)
+ try:
+ nsd = locate_ns(nshost)
+ except errors.PyroError as ex:
+ # name server not responding
+ _LOGGER.error('can\'t locate pyro name server: %s', ex)
+ else:
+ try:
+ nsd.unregister('%s.%s' % (nsgroup, nsid))
+ _LOGGER.info('%s unregistered from pyro name server', nsid)
+ except errors.NamingError:
+ _LOGGER.warning('%s not registered in pyro name server', nsid)
+
+
+def ns_reregister(nsid, defaultnsgroup=_MARKER, nshost=None):
+ """reregister a pyro object into the name server. You only have to specify
+ the name-server id of the object (though you MUST have gone through
+ `register_object` for the given object previously).
+
+ This is especially useful for long running server while the name server may
+ have been restarted, and its records lost.
+ """
+ nsgroup, nsid = ns_group_and_id(nsid, defaultnsgroup)
+ qnsid = '%s.%s' % (nsgroup, nsid)
+ nsd = locate_ns(nshost)
+ try:
+ nsd.unregister(qnsid)
+ except errors.NamingError:
+ # make sure our namespace group exists
+ try:
+ nsd.createGroup(nsgroup)
+ except errors.NamingError:
+ pass
+ nsd.register(qnsid, _PYRO_OBJS[qnsid])
+
+def ns_get_proxy(nsid, defaultnsgroup=_MARKER, nshost=None):
+ """
+ if nshost is None, the nameserver is found by a broadcast.
+ """
+ # resolve the Pyro object
+ nsgroup, nsid = ns_group_and_id(nsid, defaultnsgroup)
+ try:
+ nsd = locate_ns(nshost)
+ pyrouri = nsd.resolve('%s.%s' % (nsgroup, nsid))
+ except errors.ProtocolError as ex:
+ raise errors.PyroError(
+ 'Could not connect to the Pyro name server (host: %s)' % nshost)
+ except errors.NamingError:
+ raise errors.PyroError(
+ 'Could not get proxy for %s (not registered in Pyro), '
+ 'you may have to restart your server-side application' % nsid)
+ return core.getProxyForURI(pyrouri)
+
+def get_proxy(pyro_uri):
+ """get a proxy for the passed pyro uri without using a nameserver
+ """
+ return core.getProxyForURI(pyro_uri)
+
+def set_pyro_log_threshold(level):
+ pyrologger = logging.getLogger('Pyro.%s' % str(id(util.Log)))
+ # remove handlers so only the root handler is used
+ pyrologger.handlers = []
+ pyrologger.setLevel(level)
diff --git a/logilab/common/pytest.py b/logilab/common/pytest.py
new file mode 100644
index 0000000..58515a9
--- /dev/null
+++ b/logilab/common/pytest.py
@@ -0,0 +1,1199 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""pytest is a tool that eases test running and debugging.
+
+To be able to use pytest, you should either write tests using
+the logilab.common.testlib's framework or the unittest module of the
+Python's standard library.
+
+You can customize pytest's behaviour by defining a ``pytestconf.py`` file
+somewhere in your test directory. In this file, you can add options or
+change the way tests are run.
+
+To add command line options, you must define a ``update_parser`` function in
+your ``pytestconf.py`` file. The function must accept a single parameter
+that will be the OptionParser's instance to customize.
+
+If you wish to customize the tester, you'll have to define a class named
+``CustomPyTester``. This class should extend the default `PyTester` class
+defined in the pytest module. Take a look at the `PyTester` and `DjangoTester`
+classes for more information about what can be done.
+
+For instance, if you wish to add a custom -l option to specify a loglevel, you
+could define the following ``pytestconf.py`` file ::
+
+ import logging
+ from logilab.common.pytest import PyTester
+
+ def update_parser(parser):
+ parser.add_option('-l', '--loglevel', dest='loglevel', action='store',
+ choices=('debug', 'info', 'warning', 'error', 'critical'),
+ default='critical', help="the default log level possible choices are "
+ "('debug', 'info', 'warning', 'error', 'critical')")
+ return parser
+
+
+ class CustomPyTester(PyTester):
+ def __init__(self, cvg, options):
+ super(CustomPyTester, self).__init__(cvg, options)
+ loglevel = options.loglevel.upper()
+ logger = logging.getLogger('erudi')
+ logger.setLevel(logging.getLevelName(loglevel))
+
+
+In your TestCase class you can then get the value of a specific option with
+the ``optval`` method::
+
+ class MyTestCase(TestCase):
+ def test_foo(self):
+ loglevel = self.optval('loglevel')
+ # ...
+
+
+You can also tag your tag your test for fine filtering
+
+With those tag::
+
+ from logilab.common.testlib import tag, TestCase
+
+ class Exemple(TestCase):
+
+ @tag('rouge', 'carre')
+ def toto(self):
+ pass
+
+ @tag('carre', 'vert')
+ def tata(self):
+ pass
+
+ @tag('rouge')
+ def titi(test):
+ pass
+
+you can filter the function with a simple python expression
+
+ * ``toto`` and ``titi`` match ``rouge``
+ * ``toto``, ``tata`` and ``titi``, match ``rouge or carre``
+ * ``tata`` and ``titi`` match``rouge ^ carre``
+ * ``titi`` match ``rouge and not carre``
+"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+PYTEST_DOC = """%prog [OPTIONS] [testfile [testpattern]]
+
+examples:
+
+pytest path/to/mytests.py
+pytest path/to/mytests.py TheseTests
+pytest path/to/mytests.py TheseTests.test_thisone
+pytest path/to/mytests.py -m '(not long and database) or regr'
+
+pytest one (will run both test_thisone and test_thatone)
+pytest path/to/mytests.py -s not (will skip test_notthisone)
+"""
+
+ENABLE_DBC = False
+FILE_RESTART = ".pytest.restart"
+
+import os, sys, re
+import os.path as osp
+from time import time, clock
+import warnings
+import types
+from inspect import isgeneratorfunction, isclass
+from contextlib import contextmanager
+
+from logilab.common.fileutils import abspath_listdir
+from logilab.common import textutils
+from logilab.common import testlib, STD_BLACKLIST
+# use the same unittest module as testlib
+from logilab.common.testlib import unittest, start_interactive_mode
+from logilab.common.deprecation import deprecated
+import doctest
+
+import unittest as unittest_legacy
+if not getattr(unittest_legacy, "__package__", None):
+ try:
+ import unittest2.suite as unittest_suite
+ except ImportError:
+ sys.exit("You have to install python-unittest2 to use this module")
+else:
+ import unittest.suite as unittest_suite
+
+try:
+ import django
+ from logilab.common.modutils import modpath_from_file, load_module_from_modpath
+ DJANGO_FOUND = True
+except ImportError:
+ DJANGO_FOUND = False
+
+CONF_FILE = 'pytestconf.py'
+
+## coverage pausing tools
+
+@contextmanager
+def replace_trace(trace=None):
+ """A context manager that temporary replaces the trace function"""
+ oldtrace = sys.gettrace()
+ sys.settrace(trace)
+ try:
+ yield
+ finally:
+ # specific hack to work around a bug in pycoverage, see
+ # https://bitbucket.org/ned/coveragepy/issue/123
+ if (oldtrace is not None and not callable(oldtrace) and
+ hasattr(oldtrace, 'pytrace')):
+ oldtrace = oldtrace.pytrace
+ sys.settrace(oldtrace)
+
+
+def pause_trace():
+ """A context manager that temporary pauses any tracing"""
+ return replace_trace()
+
+class TraceController(object):
+ ctx_stack = []
+
+ @classmethod
+ @deprecated('[lgc 0.63.1] Use the pause_trace() context manager')
+ def pause_tracing(cls):
+ cls.ctx_stack.append(pause_trace())
+ cls.ctx_stack[-1].__enter__()
+
+ @classmethod
+ @deprecated('[lgc 0.63.1] Use the pause_trace() context manager')
+ def resume_tracing(cls):
+ cls.ctx_stack.pop().__exit__(None, None, None)
+
+
+pause_tracing = TraceController.pause_tracing
+resume_tracing = TraceController.resume_tracing
+
+
+def nocoverage(func):
+ """Function decorator that pauses tracing functions"""
+ if hasattr(func, 'uncovered'):
+ return func
+ func.uncovered = True
+
+ def not_covered(*args, **kwargs):
+ with pause_trace():
+ return func(*args, **kwargs)
+ not_covered.uncovered = True
+ return not_covered
+
+## end of coverage pausing tools
+
+
+TESTFILE_RE = re.compile("^((unit)?test.*|smoketest)\.py$")
+def this_is_a_testfile(filename):
+ """returns True if `filename` seems to be a test file"""
+ return TESTFILE_RE.match(osp.basename(filename))
+
+TESTDIR_RE = re.compile("^(unit)?tests?$")
+def this_is_a_testdir(dirpath):
+ """returns True if `filename` seems to be a test directory"""
+ return TESTDIR_RE.match(osp.basename(dirpath))
+
+
+def load_pytest_conf(path, parser):
+ """loads a ``pytestconf.py`` file and update default parser
+ and / or tester.
+ """
+ namespace = {}
+ exec(open(path, 'rb').read(), namespace)
+ if 'update_parser' in namespace:
+ namespace['update_parser'](parser)
+ return namespace.get('CustomPyTester', PyTester)
+
+
+def project_root(parser, projdir=os.getcwd()):
+ """try to find project's root and add it to sys.path"""
+ previousdir = curdir = osp.abspath(projdir)
+ testercls = PyTester
+ conf_file_path = osp.join(curdir, CONF_FILE)
+ if osp.isfile(conf_file_path):
+ testercls = load_pytest_conf(conf_file_path, parser)
+ while this_is_a_testdir(curdir) or \
+ osp.isfile(osp.join(curdir, '__init__.py')):
+ newdir = osp.normpath(osp.join(curdir, os.pardir))
+ if newdir == curdir:
+ break
+ previousdir = curdir
+ curdir = newdir
+ conf_file_path = osp.join(curdir, CONF_FILE)
+ if osp.isfile(conf_file_path):
+ testercls = load_pytest_conf(conf_file_path, parser)
+ return previousdir, testercls
+
+
+class GlobalTestReport(object):
+ """this class holds global test statistics"""
+ def __init__(self):
+ self.ran = 0
+ self.skipped = 0
+ self.failures = 0
+ self.errors = 0
+ self.ttime = 0
+ self.ctime = 0
+ self.modulescount = 0
+ self.errmodules = []
+
+ def feed(self, filename, testresult, ttime, ctime):
+ """integrates new test information into internal statistics"""
+ ran = testresult.testsRun
+ self.ran += ran
+ self.skipped += len(getattr(testresult, 'skipped', ()))
+ self.failures += len(testresult.failures)
+ self.errors += len(testresult.errors)
+ self.ttime += ttime
+ self.ctime += ctime
+ self.modulescount += 1
+ if not testresult.wasSuccessful():
+ problems = len(testresult.failures) + len(testresult.errors)
+ self.errmodules.append((filename[:-3], problems, ran))
+
+ def failed_to_test_module(self, filename):
+ """called when the test module could not be imported by unittest
+ """
+ self.errors += 1
+ self.modulescount += 1
+ self.ran += 1
+ self.errmodules.append((filename[:-3], 1, 1))
+
+ def skip_module(self, filename):
+ self.modulescount += 1
+ self.ran += 1
+ self.errmodules.append((filename[:-3], 0, 0))
+
+ def __str__(self):
+ """this is just presentation stuff"""
+ line1 = ['Ran %s test cases in %.2fs (%.2fs CPU)'
+ % (self.ran, self.ttime, self.ctime)]
+ if self.errors:
+ line1.append('%s errors' % self.errors)
+ if self.failures:
+ line1.append('%s failures' % self.failures)
+ if self.skipped:
+ line1.append('%s skipped' % self.skipped)
+ modulesok = self.modulescount - len(self.errmodules)
+ if self.errors or self.failures:
+ line2 = '%s modules OK (%s failed)' % (modulesok,
+ len(self.errmodules))
+ descr = ', '.join(['%s [%s/%s]' % info for info in self.errmodules])
+ line3 = '\nfailures: %s' % descr
+ elif modulesok:
+ line2 = 'All %s modules OK' % modulesok
+ line3 = ''
+ else:
+ return ''
+ return '%s\n%s%s' % (', '.join(line1), line2, line3)
+
+
+
+def remove_local_modules_from_sys(testdir):
+ """remove all modules from cache that come from `testdir`
+
+ This is used to avoid strange side-effects when using the
+ testall() mode of pytest.
+ For instance, if we run pytest on this tree::
+
+ A/test/test_utils.py
+ B/test/test_utils.py
+
+ we **have** to clean sys.modules to make sure the correct test_utils
+ module is ran in B
+ """
+ for modname, mod in list(sys.modules.items()):
+ if mod is None:
+ continue
+ if not hasattr(mod, '__file__'):
+ # this is the case of some built-in modules like sys, imp, marshal
+ continue
+ modfile = mod.__file__
+ # if modfile is not an absolute path, it was probably loaded locally
+ # during the tests
+ if not osp.isabs(modfile) or modfile.startswith(testdir):
+ del sys.modules[modname]
+
+
+
+class PyTester(object):
+ """encapsulates testrun logic"""
+
+ def __init__(self, cvg, options):
+ self.report = GlobalTestReport()
+ self.cvg = cvg
+ self.options = options
+ self.firstwrite = True
+ self._errcode = None
+
+ def show_report(self):
+ """prints the report and returns appropriate exitcode"""
+ # everything has been ran, print report
+ print("*" * 79)
+ print(self.report)
+
+ def get_errcode(self):
+ # errcode set explicitly
+ if self._errcode is not None:
+ return self._errcode
+ return self.report.failures + self.report.errors
+
+ def set_errcode(self, errcode):
+ self._errcode = errcode
+ errcode = property(get_errcode, set_errcode)
+
+ def testall(self, exitfirst=False):
+ """walks through current working directory, finds something
+ which can be considered as a testdir and runs every test there
+ """
+ here = os.getcwd()
+ for dirname, dirs, _ in os.walk(here):
+ for skipped in STD_BLACKLIST:
+ if skipped in dirs:
+ dirs.remove(skipped)
+ basename = osp.basename(dirname)
+ if this_is_a_testdir(basename):
+ print("going into", dirname)
+ # we found a testdir, let's explore it !
+ if not self.testonedir(dirname, exitfirst):
+ break
+ dirs[:] = []
+ if self.report.ran == 0:
+ print("no test dir found testing here:", here)
+ # if no test was found during the visit, consider
+ # the local directory as a test directory even if
+ # it doesn't have a traditional test directory name
+ self.testonedir(here)
+
+ def testonedir(self, testdir, exitfirst=False):
+ """finds each testfile in the `testdir` and runs it
+
+ return true when all tests has been executed, false if exitfirst and
+ some test has failed.
+ """
+ for filename in abspath_listdir(testdir):
+ if this_is_a_testfile(filename):
+ if self.options.exitfirst and not self.options.restart:
+ # overwrite restart file
+ try:
+ restartfile = open(FILE_RESTART, "w")
+ restartfile.close()
+ except Exception:
+ print("Error while overwriting succeeded test file :",
+ osp.join(os.getcwd(), FILE_RESTART),
+ file=sys.__stderr__)
+ raise
+ # run test and collect information
+ prog = self.testfile(filename, batchmode=True)
+ if exitfirst and (prog is None or not prog.result.wasSuccessful()):
+ return False
+ self.firstwrite = True
+ # clean local modules
+ remove_local_modules_from_sys(testdir)
+ return True
+
+ def testfile(self, filename, batchmode=False):
+ """runs every test in `filename`
+
+ :param filename: an absolute path pointing to a unittest file
+ """
+ here = os.getcwd()
+ dirname = osp.dirname(filename)
+ if dirname:
+ os.chdir(dirname)
+ # overwrite restart file if it has not been done already
+ if self.options.exitfirst and not self.options.restart and self.firstwrite:
+ try:
+ restartfile = open(FILE_RESTART, "w")
+ restartfile.close()
+ except Exception:
+ print("Error while overwriting succeeded test file :",
+ osp.join(os.getcwd(), FILE_RESTART), file=sys.__stderr__)
+ raise
+ modname = osp.basename(filename)[:-3]
+ print((' %s ' % osp.basename(filename)).center(70, '='),
+ file=sys.__stderr__)
+ try:
+ tstart, cstart = time(), clock()
+ try:
+ testprog = SkipAwareTestProgram(modname, batchmode=batchmode, cvg=self.cvg,
+ options=self.options, outstream=sys.stderr)
+ except KeyboardInterrupt:
+ raise
+ except SystemExit as exc:
+ self.errcode = exc.code
+ raise
+ except testlib.SkipTest:
+ print("Module skipped:", filename)
+ self.report.skip_module(filename)
+ return None
+ except Exception:
+ self.report.failed_to_test_module(filename)
+ print('unhandled exception occurred while testing', modname,
+ file=sys.stderr)
+ import traceback
+ traceback.print_exc(file=sys.stderr)
+ return None
+
+ tend, cend = time(), clock()
+ ttime, ctime = (tend - tstart), (cend - cstart)
+ self.report.feed(filename, testprog.result, ttime, ctime)
+ return testprog
+ finally:
+ if dirname:
+ os.chdir(here)
+
+
+
+class DjangoTester(PyTester):
+
+ def load_django_settings(self, dirname):
+ """try to find project's setting and load it"""
+ curdir = osp.abspath(dirname)
+ previousdir = curdir
+ while not osp.isfile(osp.join(curdir, 'settings.py')) and \
+ osp.isfile(osp.join(curdir, '__init__.py')):
+ newdir = osp.normpath(osp.join(curdir, os.pardir))
+ if newdir == curdir:
+ raise AssertionError('could not find settings.py')
+ previousdir = curdir
+ curdir = newdir
+ # late django initialization
+ settings = load_module_from_modpath(modpath_from_file(osp.join(curdir, 'settings.py')))
+ from django.core.management import setup_environ
+ setup_environ(settings)
+ settings.DEBUG = False
+ self.settings = settings
+ # add settings dir to pythonpath since it's the project's root
+ if curdir not in sys.path:
+ sys.path.insert(1, curdir)
+
+ def before_testfile(self):
+ # Those imports must be done **after** setup_environ was called
+ from django.test.utils import setup_test_environment
+ from django.test.utils import create_test_db
+ setup_test_environment()
+ create_test_db(verbosity=0)
+ self.dbname = self.settings.TEST_DATABASE_NAME
+
+ def after_testfile(self):
+ # Those imports must be done **after** setup_environ was called
+ from django.test.utils import teardown_test_environment
+ from django.test.utils import destroy_test_db
+ teardown_test_environment()
+ print('destroying', self.dbname)
+ destroy_test_db(self.dbname, verbosity=0)
+
+ def testall(self, exitfirst=False):
+ """walks through current working directory, finds something
+ which can be considered as a testdir and runs every test there
+ """
+ for dirname, dirs, files in os.walk(os.getcwd()):
+ for skipped in ('CVS', '.svn', '.hg'):
+ if skipped in dirs:
+ dirs.remove(skipped)
+ if 'tests.py' in files:
+ if not self.testonedir(dirname, exitfirst):
+ break
+ dirs[:] = []
+ else:
+ basename = osp.basename(dirname)
+ if basename in ('test', 'tests'):
+ print("going into", dirname)
+ # we found a testdir, let's explore it !
+ if not self.testonedir(dirname, exitfirst):
+ break
+ dirs[:] = []
+
+ def testonedir(self, testdir, exitfirst=False):
+ """finds each testfile in the `testdir` and runs it
+
+ return true when all tests has been executed, false if exitfirst and
+ some test has failed.
+ """
+ # special django behaviour : if tests are splitted in several files,
+ # remove the main tests.py file and tests each test file separately
+ testfiles = [fpath for fpath in abspath_listdir(testdir)
+ if this_is_a_testfile(fpath)]
+ if len(testfiles) > 1:
+ try:
+ testfiles.remove(osp.join(testdir, 'tests.py'))
+ except ValueError:
+ pass
+ for filename in testfiles:
+ # run test and collect information
+ prog = self.testfile(filename, batchmode=True)
+ if exitfirst and (prog is None or not prog.result.wasSuccessful()):
+ return False
+ # clean local modules
+ remove_local_modules_from_sys(testdir)
+ return True
+
+ def testfile(self, filename, batchmode=False):
+ """runs every test in `filename`
+
+ :param filename: an absolute path pointing to a unittest file
+ """
+ here = os.getcwd()
+ dirname = osp.dirname(filename)
+ if dirname:
+ os.chdir(dirname)
+ self.load_django_settings(dirname)
+ modname = osp.basename(filename)[:-3]
+ print((' %s ' % osp.basename(filename)).center(70, '='),
+ file=sys.stderr)
+ try:
+ try:
+ tstart, cstart = time(), clock()
+ self.before_testfile()
+ testprog = SkipAwareTestProgram(modname, batchmode=batchmode, cvg=self.cvg)
+ tend, cend = time(), clock()
+ ttime, ctime = (tend - tstart), (cend - cstart)
+ self.report.feed(filename, testprog.result, ttime, ctime)
+ return testprog
+ except SystemExit:
+ raise
+ except Exception as exc:
+ import traceback
+ traceback.print_exc()
+ self.report.failed_to_test_module(filename)
+ print('unhandled exception occurred while testing', modname)
+ print('error: %s' % exc)
+ return None
+ finally:
+ self.after_testfile()
+ if dirname:
+ os.chdir(here)
+
+
+def make_parser():
+ """creates the OptionParser instance
+ """
+ from optparse import OptionParser
+ parser = OptionParser(usage=PYTEST_DOC)
+
+ parser.newargs = []
+ def rebuild_cmdline(option, opt, value, parser):
+ """carry the option to unittest_main"""
+ parser.newargs.append(opt)
+
+ def rebuild_and_store(option, opt, value, parser):
+ """carry the option to unittest_main and store
+ the value on current parser
+ """
+ parser.newargs.append(opt)
+ setattr(parser.values, option.dest, True)
+
+ def capture_and_rebuild(option, opt, value, parser):
+ warnings.simplefilter('ignore', DeprecationWarning)
+ rebuild_cmdline(option, opt, value, parser)
+
+ # pytest options
+ parser.add_option('-t', dest='testdir', default=None,
+ help="directory where the tests will be found")
+ parser.add_option('-d', dest='dbc', default=False,
+ action="store_true", help="enable design-by-contract")
+ # unittest_main options provided and passed through pytest
+ parser.add_option('-v', '--verbose', callback=rebuild_cmdline,
+ action="callback", help="Verbose output")
+ parser.add_option('-i', '--pdb', callback=rebuild_and_store,
+ dest="pdb", action="callback",
+ help="Enable test failure inspection")
+ parser.add_option('-x', '--exitfirst', callback=rebuild_and_store,
+ dest="exitfirst", default=False,
+ action="callback", help="Exit on first failure "
+ "(only make sense when pytest run one test file)")
+ parser.add_option('-R', '--restart', callback=rebuild_and_store,
+ dest="restart", default=False,
+ action="callback",
+ help="Restart tests from where it failed (implies exitfirst) "
+ "(only make sense if tests previously ran with exitfirst only)")
+ parser.add_option('--color', callback=rebuild_cmdline,
+ action="callback",
+ help="colorize tracebacks")
+ parser.add_option('-s', '--skip',
+ # XXX: I wish I could use the callback action but it
+ # doesn't seem to be able to get the value
+ # associated to the option
+ action="store", dest="skipped", default=None,
+ help="test names matching this name will be skipped "
+ "to skip several patterns, use commas")
+ parser.add_option('-q', '--quiet', callback=rebuild_cmdline,
+ action="callback", help="Minimal output")
+ parser.add_option('-P', '--profile', default=None, dest='profile',
+ help="Profile execution and store data in the given file")
+ parser.add_option('-m', '--match', default=None, dest='tags_pattern',
+ help="only execute test whose tag match the current pattern")
+
+ if DJANGO_FOUND:
+ parser.add_option('-J', '--django', dest='django', default=False,
+ action="store_true",
+ help='use pytest for django test cases')
+ return parser
+
+
+def parseargs(parser):
+ """Parse the command line and return (options processed), (options to pass to
+ unittest_main()), (explicitfile or None).
+ """
+ # parse the command line
+ options, args = parser.parse_args()
+ filenames = [arg for arg in args if arg.endswith('.py')]
+ if filenames:
+ if len(filenames) > 1:
+ parser.error("only one filename is acceptable")
+ explicitfile = filenames[0]
+ args.remove(explicitfile)
+ else:
+ explicitfile = None
+ # someone wants DBC
+ testlib.ENABLE_DBC = options.dbc
+ newargs = parser.newargs
+ if options.skipped:
+ newargs.extend(['--skip', options.skipped])
+ # restart implies exitfirst
+ if options.restart:
+ options.exitfirst = True
+ # append additional args to the new sys.argv and let unittest_main
+ # do the rest
+ newargs += args
+ return options, explicitfile
+
+
+
+def run():
+ parser = make_parser()
+ rootdir, testercls = project_root(parser)
+ options, explicitfile = parseargs(parser)
+ # mock a new command line
+ sys.argv[1:] = parser.newargs
+ cvg = None
+ if not '' in sys.path:
+ sys.path.insert(0, '')
+ if DJANGO_FOUND and options.django:
+ tester = DjangoTester(cvg, options)
+ else:
+ tester = testercls(cvg, options)
+ if explicitfile:
+ cmd, args = tester.testfile, (explicitfile,)
+ elif options.testdir:
+ cmd, args = tester.testonedir, (options.testdir, options.exitfirst)
+ else:
+ cmd, args = tester.testall, (options.exitfirst,)
+ try:
+ try:
+ if options.profile:
+ import hotshot
+ prof = hotshot.Profile(options.profile)
+ prof.runcall(cmd, *args)
+ prof.close()
+ print('profile data saved in', options.profile)
+ else:
+ cmd(*args)
+ except SystemExit:
+ raise
+ except:
+ import traceback
+ traceback.print_exc()
+ finally:
+ tester.show_report()
+ sys.exit(tester.errcode)
+
+class SkipAwareTestProgram(unittest.TestProgram):
+ # XXX: don't try to stay close to unittest.py, use optparse
+ USAGE = """\
+Usage: %(progName)s [options] [test] [...]
+
+Options:
+ -h, --help Show this message
+ -v, --verbose Verbose output
+ -i, --pdb Enable test failure inspection
+ -x, --exitfirst Exit on first failure
+ -s, --skip skip test matching this pattern (no regexp for now)
+ -q, --quiet Minimal output
+ --color colorize tracebacks
+
+ -m, --match Run only test whose tag match this pattern
+
+ -P, --profile FILE: Run the tests using cProfile and saving results
+ in FILE
+
+Examples:
+ %(progName)s - run default set of tests
+ %(progName)s MyTestSuite - run suite 'MyTestSuite'
+ %(progName)s MyTestCase.testSomething - run MyTestCase.testSomething
+ %(progName)s MyTestCase - run all 'test*' test methods
+ in MyTestCase
+"""
+ def __init__(self, module='__main__', defaultTest=None, batchmode=False,
+ cvg=None, options=None, outstream=sys.stderr):
+ self.batchmode = batchmode
+ self.cvg = cvg
+ self.options = options
+ self.outstream = outstream
+ super(SkipAwareTestProgram, self).__init__(
+ module=module, defaultTest=defaultTest,
+ testLoader=NonStrictTestLoader())
+
+ def parseArgs(self, argv):
+ self.pdbmode = False
+ self.exitfirst = False
+ self.skipped_patterns = []
+ self.test_pattern = None
+ self.tags_pattern = None
+ self.colorize = False
+ self.profile_name = None
+ import getopt
+ try:
+ options, args = getopt.getopt(argv[1:], 'hHvixrqcp:s:m:P:',
+ ['help', 'verbose', 'quiet', 'pdb',
+ 'exitfirst', 'restart',
+ 'skip=', 'color', 'match=', 'profile='])
+ for opt, value in options:
+ if opt in ('-h', '-H', '--help'):
+ self.usageExit()
+ if opt in ('-i', '--pdb'):
+ self.pdbmode = True
+ if opt in ('-x', '--exitfirst'):
+ self.exitfirst = True
+ if opt in ('-r', '--restart'):
+ self.restart = True
+ self.exitfirst = True
+ if opt in ('-q', '--quiet'):
+ self.verbosity = 0
+ if opt in ('-v', '--verbose'):
+ self.verbosity = 2
+ if opt in ('-s', '--skip'):
+ self.skipped_patterns = [pat.strip() for pat in
+ value.split(', ')]
+ if opt == '--color':
+ self.colorize = True
+ if opt in ('-m', '--match'):
+ #self.tags_pattern = value
+ self.options["tag_pattern"] = value
+ if opt in ('-P', '--profile'):
+ self.profile_name = value
+ self.testLoader.skipped_patterns = self.skipped_patterns
+ if len(args) == 0 and self.defaultTest is None:
+ suitefunc = getattr(self.module, 'suite', None)
+ if isinstance(suitefunc, (types.FunctionType,
+ types.MethodType)):
+ self.test = self.module.suite()
+ else:
+ self.test = self.testLoader.loadTestsFromModule(self.module)
+ return
+ if len(args) > 0:
+ self.test_pattern = args[0]
+ self.testNames = args
+ else:
+ self.testNames = (self.defaultTest, )
+ self.createTests()
+ except getopt.error as msg:
+ self.usageExit(msg)
+
+ def runTests(self):
+ if self.profile_name:
+ import cProfile
+ cProfile.runctx('self._runTests()', globals(), locals(), self.profile_name )
+ else:
+ return self._runTests()
+
+ def _runTests(self):
+ self.testRunner = SkipAwareTextTestRunner(verbosity=self.verbosity,
+ stream=self.outstream,
+ exitfirst=self.exitfirst,
+ pdbmode=self.pdbmode,
+ cvg=self.cvg,
+ test_pattern=self.test_pattern,
+ skipped_patterns=self.skipped_patterns,
+ colorize=self.colorize,
+ batchmode=self.batchmode,
+ options=self.options)
+
+ def removeSucceededTests(obj, succTests):
+ """ Recursive function that removes succTests from
+ a TestSuite or TestCase
+ """
+ if isinstance(obj, unittest.TestSuite):
+ removeSucceededTests(obj._tests, succTests)
+ if isinstance(obj, list):
+ for el in obj[:]:
+ if isinstance(el, unittest.TestSuite):
+ removeSucceededTests(el, succTests)
+ elif isinstance(el, unittest.TestCase):
+ descr = '.'.join((el.__class__.__module__,
+ el.__class__.__name__,
+ el._testMethodName))
+ if descr in succTests:
+ obj.remove(el)
+ # take care, self.options may be None
+ if getattr(self.options, 'restart', False):
+ # retrieve succeeded tests from FILE_RESTART
+ try:
+ restartfile = open(FILE_RESTART, 'r')
+ try:
+ succeededtests = list(elem.rstrip('\n\r') for elem in
+ restartfile.readlines())
+ removeSucceededTests(self.test, succeededtests)
+ finally:
+ restartfile.close()
+ except Exception as ex:
+ raise Exception("Error while reading succeeded tests into %s: %s"
+ % (osp.join(os.getcwd(), FILE_RESTART), ex))
+
+ result = self.testRunner.run(self.test)
+ # help garbage collection: we want TestSuite, which hold refs to every
+ # executed TestCase, to be gc'ed
+ del self.test
+ if getattr(result, "debuggers", None) and \
+ getattr(self, "pdbmode", None):
+ start_interactive_mode(result)
+ if not getattr(self, "batchmode", None):
+ sys.exit(not result.wasSuccessful())
+ self.result = result
+
+
+class SkipAwareTextTestRunner(unittest.TextTestRunner):
+
+ def __init__(self, stream=sys.stderr, verbosity=1,
+ exitfirst=False, pdbmode=False, cvg=None, test_pattern=None,
+ skipped_patterns=(), colorize=False, batchmode=False,
+ options=None):
+ super(SkipAwareTextTestRunner, self).__init__(stream=stream,
+ verbosity=verbosity)
+ self.exitfirst = exitfirst
+ self.pdbmode = pdbmode
+ self.cvg = cvg
+ self.test_pattern = test_pattern
+ self.skipped_patterns = skipped_patterns
+ self.colorize = colorize
+ self.batchmode = batchmode
+ self.options = options
+
+ def _this_is_skipped(self, testedname):
+ return any([(pat in testedname) for pat in self.skipped_patterns])
+
+ def _runcondition(self, test, skipgenerator=True):
+ if isinstance(test, testlib.InnerTest):
+ testname = test.name
+ else:
+ if isinstance(test, testlib.TestCase):
+ meth = test._get_test_method()
+ testname = '%s.%s' % (test.__name__, meth.__name__)
+ elif isinstance(test, types.FunctionType):
+ func = test
+ testname = func.__name__
+ elif isinstance(test, types.MethodType):
+ cls = test.__self__.__class__
+ testname = '%s.%s' % (cls.__name__, test.__name__)
+ else:
+ return True # Not sure when this happens
+ if isgeneratorfunction(test) and skipgenerator:
+ return self.does_match_tags(test) # Let inner tests decide at run time
+ if self._this_is_skipped(testname):
+ return False # this was explicitly skipped
+ if self.test_pattern is not None:
+ try:
+ classpattern, testpattern = self.test_pattern.split('.')
+ klass, name = testname.split('.')
+ if classpattern not in klass or testpattern not in name:
+ return False
+ except ValueError:
+ if self.test_pattern not in testname:
+ return False
+
+ return self.does_match_tags(test)
+
+ def does_match_tags(self, test):
+ if self.options is not None:
+ tags_pattern = getattr(self.options, 'tags_pattern', None)
+ if tags_pattern is not None:
+ tags = getattr(test, 'tags', testlib.Tags())
+ if tags.inherit and isinstance(test, types.MethodType):
+ tags = tags | getattr(test.im_class, 'tags', testlib.Tags())
+ return tags.match(tags_pattern)
+ return True # no pattern
+
+ def _makeResult(self):
+ return testlib.SkipAwareTestResult(self.stream, self.descriptions,
+ self.verbosity, self.exitfirst,
+ self.pdbmode, self.cvg, self.colorize)
+
+ def run(self, test):
+ "Run the given test case or test suite."
+ result = self._makeResult()
+ startTime = time()
+ test(result, runcondition=self._runcondition, options=self.options)
+ stopTime = time()
+ timeTaken = stopTime - startTime
+ result.printErrors()
+ if not self.batchmode:
+ self.stream.writeln(result.separator2)
+ run = result.testsRun
+ self.stream.writeln("Ran %d test%s in %.3fs" %
+ (run, run != 1 and "s" or "", timeTaken))
+ self.stream.writeln()
+ if not result.wasSuccessful():
+ if self.colorize:
+ self.stream.write(textutils.colorize_ansi("FAILED", color='red'))
+ else:
+ self.stream.write("FAILED")
+ else:
+ if self.colorize:
+ self.stream.write(textutils.colorize_ansi("OK", color='green'))
+ else:
+ self.stream.write("OK")
+ failed, errored, skipped = map(len, (result.failures,
+ result.errors,
+ result.skipped))
+
+ det_results = []
+ for name, value in (("failures", result.failures),
+ ("errors",result.errors),
+ ("skipped", result.skipped)):
+ if value:
+ det_results.append("%s=%i" % (name, len(value)))
+ if det_results:
+ self.stream.write(" (")
+ self.stream.write(', '.join(det_results))
+ self.stream.write(")")
+ self.stream.writeln("")
+ return result
+
+class NonStrictTestLoader(unittest.TestLoader):
+ """
+ Overrides default testloader to be able to omit classname when
+ specifying tests to run on command line.
+
+ For example, if the file test_foo.py contains ::
+
+ class FooTC(TestCase):
+ def test_foo1(self): # ...
+ def test_foo2(self): # ...
+ def test_bar1(self): # ...
+
+ class BarTC(TestCase):
+ def test_bar2(self): # ...
+
+ 'python test_foo.py' will run the 3 tests in FooTC
+ 'python test_foo.py FooTC' will run the 3 tests in FooTC
+ 'python test_foo.py test_foo' will run test_foo1 and test_foo2
+ 'python test_foo.py test_foo1' will run test_foo1
+ 'python test_foo.py test_bar' will run FooTC.test_bar1 and BarTC.test_bar2
+ """
+
+ def __init__(self):
+ self.skipped_patterns = ()
+
+ # some magic here to accept empty list by extending
+ # and to provide callable capability
+ def loadTestsFromNames(self, names, module=None):
+ suites = []
+ for name in names:
+ suites.extend(self.loadTestsFromName(name, module))
+ return self.suiteClass(suites)
+
+ def _collect_tests(self, module):
+ tests = {}
+ for obj in vars(module).values():
+ if isclass(obj) and issubclass(obj, unittest.TestCase):
+ classname = obj.__name__
+ if classname[0] == '_' or self._this_is_skipped(classname):
+ continue
+ methodnames = []
+ # obj is a TestCase class
+ for attrname in dir(obj):
+ if attrname.startswith(self.testMethodPrefix):
+ attr = getattr(obj, attrname)
+ if callable(attr):
+ methodnames.append(attrname)
+ # keep track of class (obj) for convenience
+ tests[classname] = (obj, methodnames)
+ return tests
+
+ def loadTestsFromSuite(self, module, suitename):
+ try:
+ suite = getattr(module, suitename)()
+ except AttributeError:
+ return []
+ assert hasattr(suite, '_tests'), \
+ "%s.%s is not a valid TestSuite" % (module.__name__, suitename)
+ # python2.3 does not implement __iter__ on suites, we need to return
+ # _tests explicitly
+ return suite._tests
+
+ def loadTestsFromName(self, name, module=None):
+ parts = name.split('.')
+ if module is None or len(parts) > 2:
+ # let the base class do its job here
+ return [super(NonStrictTestLoader, self).loadTestsFromName(name)]
+ tests = self._collect_tests(module)
+ collected = []
+ if len(parts) == 1:
+ pattern = parts[0]
+ if callable(getattr(module, pattern, None)
+ ) and pattern not in tests:
+ # consider it as a suite
+ return self.loadTestsFromSuite(module, pattern)
+ if pattern in tests:
+ # case python unittest_foo.py MyTestTC
+ klass, methodnames = tests[pattern]
+ for methodname in methodnames:
+ collected = [klass(methodname)
+ for methodname in methodnames]
+ else:
+ # case python unittest_foo.py something
+ for klass, methodnames in tests.values():
+ # skip methodname if matched by skipped_patterns
+ for skip_pattern in self.skipped_patterns:
+ methodnames = [methodname
+ for methodname in methodnames
+ if skip_pattern not in methodname]
+ collected += [klass(methodname)
+ for methodname in methodnames
+ if pattern in methodname]
+ elif len(parts) == 2:
+ # case "MyClass.test_1"
+ classname, pattern = parts
+ klass, methodnames = tests.get(classname, (None, []))
+ for methodname in methodnames:
+ collected = [klass(methodname) for methodname in methodnames
+ if pattern in methodname]
+ return collected
+
+ def _this_is_skipped(self, testedname):
+ return any([(pat in testedname) for pat in self.skipped_patterns])
+
+ def getTestCaseNames(self, testCaseClass):
+ """Return a sorted sequence of method names found within testCaseClass
+ """
+ is_skipped = self._this_is_skipped
+ classname = testCaseClass.__name__
+ if classname[0] == '_' or is_skipped(classname):
+ return []
+ testnames = super(NonStrictTestLoader, self).getTestCaseNames(
+ testCaseClass)
+ return [testname for testname in testnames if not is_skipped(testname)]
+
+
+# The 2 functions below are modified versions of the TestSuite.run method
+# that is provided with unittest2 for python 2.6, in unittest2/suite.py
+# It is used to monkeypatch the original implementation to support
+# extra runcondition and options arguments (see in testlib.py)
+
+def _ts_run(self, result, runcondition=None, options=None):
+ self._wrapped_run(result, runcondition=runcondition, options=options)
+ self._tearDownPreviousClass(None, result)
+ self._handleModuleTearDown(result)
+ return result
+
+def _ts_wrapped_run(self, result, debug=False, runcondition=None, options=None):
+ for test in self:
+ if result.shouldStop:
+ break
+ if unittest_suite._isnotsuite(test):
+ self._tearDownPreviousClass(test, result)
+ self._handleModuleFixture(test, result)
+ self._handleClassSetUp(test, result)
+ result._previousTestClass = test.__class__
+ if (getattr(test.__class__, '_classSetupFailed', False) or
+ getattr(result, '_moduleSetUpFailed', False)):
+ continue
+
+ # --- modifications to deal with _wrapped_run ---
+ # original code is:
+ #
+ # if not debug:
+ # test(result)
+ # else:
+ # test.debug()
+ if hasattr(test, '_wrapped_run'):
+ try:
+ test._wrapped_run(result, debug, runcondition=runcondition, options=options)
+ except TypeError:
+ test._wrapped_run(result, debug)
+ elif not debug:
+ try:
+ test(result, runcondition, options)
+ except TypeError:
+ test(result)
+ else:
+ test.debug()
+ # --- end of modifications to deal with _wrapped_run ---
+ return result
+
+if sys.version_info >= (2, 7):
+ # The function below implements a modified version of the
+ # TestSuite.run method that is provided with python 2.7, in
+ # unittest/suite.py
+ def _ts_run(self, result, debug=False, runcondition=None, options=None):
+ topLevel = False
+ if getattr(result, '_testRunEntered', False) is False:
+ result._testRunEntered = topLevel = True
+
+ self._wrapped_run(result, debug, runcondition, options)
+
+ if topLevel:
+ self._tearDownPreviousClass(None, result)
+ self._handleModuleTearDown(result)
+ result._testRunEntered = False
+ return result
+
+
+def enable_dbc(*args):
+ """
+ Without arguments, return True if contracts can be enabled and should be
+ enabled (see option -d), return False otherwise.
+
+ With arguments, return False if contracts can't or shouldn't be enabled,
+ otherwise weave ContractAspect with items passed as arguments.
+ """
+ if not ENABLE_DBC:
+ return False
+ try:
+ from logilab.aspects.weaver import weaver
+ from logilab.aspects.lib.contracts import ContractAspect
+ except ImportError:
+ sys.stderr.write(
+ 'Warning: logilab.aspects is not available. Contracts disabled.')
+ return False
+ for arg in args:
+ weaver.weave_module(arg, ContractAspect)
+ return True
+
+
+# monkeypatch unittest and doctest (ouch !)
+unittest._TextTestResult = testlib.SkipAwareTestResult
+unittest.TextTestRunner = SkipAwareTextTestRunner
+unittest.TestLoader = NonStrictTestLoader
+unittest.TestProgram = SkipAwareTestProgram
+
+if sys.version_info >= (2, 4):
+ doctest.DocTestCase.__bases__ = (testlib.TestCase,)
+ # XXX check python2.6 compatibility
+ #doctest.DocTestCase._cleanups = []
+ #doctest.DocTestCase._out = []
+else:
+ unittest.FunctionTestCase.__bases__ = (testlib.TestCase,)
+unittest.TestSuite.run = _ts_run
+unittest.TestSuite._wrapped_run = _ts_wrapped_run
diff --git a/logilab/common/registry.py b/logilab/common/registry.py
new file mode 100644
index 0000000..a52b2eb
--- /dev/null
+++ b/logilab/common/registry.py
@@ -0,0 +1,1119 @@
+# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of Logilab-common.
+#
+# Logilab-common is free software: you can redistribute it and/or modify it
+# under the terms of the GNU Lesser General Public License as published by the
+# Free Software Foundation, either version 2.1 of the License, or (at your
+# option) any later version.
+#
+# Logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with Logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""This module provides bases for predicates dispatching (the pattern in use
+here is similar to what's refered as multi-dispatch or predicate-dispatch in the
+literature, though a bit different since the idea is to select across different
+implementation 'e.g. classes), not to dispatch a message to a function or
+method. It contains the following classes:
+
+* :class:`RegistryStore`, the top level object which loads implementation
+ objects and stores them into registries. You'll usually use it to access
+ registries and their contained objects;
+
+* :class:`Registry`, the base class which contains objects semantically grouped
+ (for instance, sharing a same API, hence the 'implementation' name). You'll
+ use it to select the proper implementation according to a context. Notice you
+ may use registries on their own without using the store.
+
+.. Note::
+
+ implementation objects are usually designed to be accessed through the
+ registry and not by direct instantiation, besides to use it as base classe.
+
+The selection procedure is delegated to a selector, which is responsible for
+scoring the object according to some context. At the end of the selection, if an
+implementation has been found, an instance of this class is returned. A selector
+is built from one or more predicates combined together using AND, OR, NOT
+operators (actually `&`, `|` and `~`). You'll thus find some base classes to
+build predicates:
+
+* :class:`Predicate`, the abstract base predicate class
+
+* :class:`AndPredicate`, :class:`OrPredicate`, :class:`NotPredicate`, which you
+ shouldn't have to use directly. You'll use `&`, `|` and '~' operators between
+ predicates directly
+
+* :func:`objectify_predicate`
+
+You'll eventually find one concrete predicate: :class:`yes`
+
+.. autoclass:: RegistryStore
+.. autoclass:: Registry
+
+Predicates
+----------
+.. autoclass:: Predicate
+.. autofunc:: objectify_predicate
+.. autoclass:: yes
+
+Debugging
+---------
+.. autoclass:: traced_selection
+
+Exceptions
+----------
+.. autoclass:: RegistryException
+.. autoclass:: RegistryNotFound
+.. autoclass:: ObjectNotFound
+.. autoclass:: NoSelectableObject
+"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+import sys
+import types
+import weakref
+import traceback as tb
+from os import listdir, stat
+from os.path import join, isdir, exists
+from logging import getLogger
+from warnings import warn
+
+from six import string_types, add_metaclass
+
+from logilab.common.modutils import modpath_from_file
+from logilab.common.logging_ext import set_log_methods
+from logilab.common.decorators import classproperty
+
+
+class RegistryException(Exception):
+ """Base class for registry exception."""
+
+class RegistryNotFound(RegistryException):
+ """Raised when an unknown registry is requested.
+
+ This is usually a programming/typo error.
+ """
+
+class ObjectNotFound(RegistryException):
+ """Raised when an unregistered object is requested.
+
+ This may be a programming/typo or a misconfiguration error.
+ """
+
+class NoSelectableObject(RegistryException):
+ """Raised when no object is selectable for a given context."""
+ def __init__(self, args, kwargs, objects):
+ self.args = args
+ self.kwargs = kwargs
+ self.objects = objects
+
+ def __str__(self):
+ return ('args: %s, kwargs: %s\ncandidates: %s'
+ % (self.args, self.kwargs.keys(), self.objects))
+
+
+def _modname_from_path(path, extrapath=None):
+ modpath = modpath_from_file(path, extrapath)
+ # omit '__init__' from package's name to avoid loading that module
+ # once for each name when it is imported by some other object
+ # module. This supposes import in modules are done as::
+ #
+ # from package import something
+ #
+ # not::
+ #
+ # from package.__init__ import something
+ #
+ # which seems quite correct.
+ if modpath[-1] == '__init__':
+ modpath.pop()
+ return '.'.join(modpath)
+
+
+def _toload_info(path, extrapath, _toload=None):
+ """Return a dictionary of <modname>: <modpath> and an ordered list of
+ (file, module name) to load
+ """
+ if _toload is None:
+ assert isinstance(path, list)
+ _toload = {}, []
+ for fileordir in path:
+ if isdir(fileordir) and exists(join(fileordir, '__init__.py')):
+ subfiles = [join(fileordir, fname) for fname in listdir(fileordir)]
+ _toload_info(subfiles, extrapath, _toload)
+ elif fileordir[-3:] == '.py':
+ modname = _modname_from_path(fileordir, extrapath)
+ _toload[0][modname] = fileordir
+ _toload[1].append((fileordir, modname))
+ return _toload
+
+
+class RegistrableObject(object):
+ """This is the base class for registrable objects which are selected
+ according to a context.
+
+ :attr:`__registry__`
+ name of the registry for this object (string like 'views',
+ 'templates'...). You may want to define `__registries__` directly if your
+ object should be registered in several registries.
+
+ :attr:`__regid__`
+ object's identifier in the registry (string like 'main',
+ 'primary', 'folder_box')
+
+ :attr:`__select__`
+ class'selector
+
+ Moreover, the `__abstract__` attribute may be set to True to indicate that a
+ class is abstract and should not be registered.
+
+ You don't have to inherit from this class to put it in a registry (having
+ `__regid__` and `__select__` is enough), though this is needed for classes
+ that should be automatically registered.
+ """
+
+ __registry__ = None
+ __regid__ = None
+ __select__ = None
+ __abstract__ = True # see doc snipppets below (in Registry class)
+
+ @classproperty
+ def __registries__(cls):
+ if cls.__registry__ is None:
+ return ()
+ return (cls.__registry__,)
+
+
+class RegistrableInstance(RegistrableObject):
+ """Inherit this class if you want instances of the classes to be
+ automatically registered.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ """Add a __module__ attribute telling the module where the instance was
+ created, for automatic registration.
+ """
+ obj = super(RegistrableInstance, cls).__new__(cls)
+ # XXX subclass must no override __new__
+ filepath = tb.extract_stack(limit=2)[0][0]
+ obj.__module__ = _modname_from_path(filepath)
+ return obj
+
+
+class Registry(dict):
+ """The registry store a set of implementations associated to identifier:
+
+ * to each identifier are associated a list of implementations
+
+ * to select an implementation of a given identifier, you should use one of the
+ :meth:`select` or :meth:`select_or_none` method
+
+ * to select a list of implementations for a context, you should use the
+ :meth:`possible_objects` method
+
+ * dictionary like access to an identifier will return the bare list of
+ implementations for this identifier.
+
+ To be usable in a registry, the only requirement is to have a `__select__`
+ attribute.
+
+ At the end of the registration process, the :meth:`__registered__`
+ method is called on each registered object which have them, given the
+ registry in which it's registered as argument.
+
+ Registration methods:
+
+ .. automethod: register
+ .. automethod: unregister
+
+ Selection methods:
+
+ .. automethod: select
+ .. automethod: select_or_none
+ .. automethod: possible_objects
+ .. automethod: object_by_id
+ """
+ def __init__(self, debugmode):
+ super(Registry, self).__init__()
+ self.debugmode = debugmode
+
+ def __getitem__(self, name):
+ """return the registry (list of implementation objects) associated to
+ this name
+ """
+ try:
+ return super(Registry, self).__getitem__(name)
+ except KeyError:
+ exc = ObjectNotFound(name)
+ exc.__traceback__ = sys.exc_info()[-1]
+ raise exc
+
+ @classmethod
+ def objid(cls, obj):
+ """returns a unique identifier for an object stored in the registry"""
+ return '%s.%s' % (obj.__module__, cls.objname(obj))
+
+ @classmethod
+ def objname(cls, obj):
+ """returns a readable name for an object stored in the registry"""
+ return getattr(obj, '__name__', id(obj))
+
+ def initialization_completed(self):
+ """call method __registered__() on registered objects when the callback
+ is defined"""
+ for objects in self.values():
+ for objectcls in objects:
+ registered = getattr(objectcls, '__registered__', None)
+ if registered:
+ registered(self)
+ if self.debugmode:
+ wrap_predicates(_lltrace)
+
+ def register(self, obj, oid=None, clear=False):
+ """base method to add an object in the registry"""
+ assert not '__abstract__' in obj.__dict__, obj
+ assert obj.__select__, obj
+ oid = oid or obj.__regid__
+ assert oid, ('no explicit name supplied to register object %s, '
+ 'which has no __regid__ set' % obj)
+ if clear:
+ objects = self[oid] = []
+ else:
+ objects = self.setdefault(oid, [])
+ assert not obj in objects, 'object %s is already registered' % obj
+ objects.append(obj)
+
+ def register_and_replace(self, obj, replaced):
+ """remove <replaced> and register <obj>"""
+ # XXXFIXME this is a duplication of unregister()
+ # remove register_and_replace in favor of unregister + register
+ # or simplify by calling unregister then register here
+ if not isinstance(replaced, string_types):
+ replaced = self.objid(replaced)
+ # prevent from misspelling
+ assert obj is not replaced, 'replacing an object by itself: %s' % obj
+ registered_objs = self.get(obj.__regid__, ())
+ for index, registered in enumerate(registered_objs):
+ if self.objid(registered) == replaced:
+ del registered_objs[index]
+ break
+ else:
+ self.warning('trying to replace %s that is not registered with %s',
+ replaced, obj)
+ self.register(obj)
+
+ def unregister(self, obj):
+ """remove object <obj> from this registry"""
+ objid = self.objid(obj)
+ oid = obj.__regid__
+ for registered in self.get(oid, ()):
+ # use self.objid() to compare objects because vreg will probably
+ # have its own version of the object, loaded through execfile
+ if self.objid(registered) == objid:
+ self[oid].remove(registered)
+ break
+ else:
+ self.warning('can\'t remove %s, no id %s in the registry',
+ objid, oid)
+
+ def all_objects(self):
+ """return a list containing all objects in this registry.
+ """
+ result = []
+ for objs in self.values():
+ result += objs
+ return result
+
+ # dynamic selection methods ################################################
+
+ def object_by_id(self, oid, *args, **kwargs):
+ """return object with the `oid` identifier. Only one object is expected
+ to be found.
+
+ raise :exc:`ObjectNotFound` if there are no object with id `oid` in this
+ registry
+
+ raise :exc:`AssertionError` if there is more than one object there
+ """
+ objects = self[oid]
+ assert len(objects) == 1, objects
+ return objects[0](*args, **kwargs)
+
+ def select(self, __oid, *args, **kwargs):
+ """return the most specific object among those with the given oid
+ according to the given context.
+
+ raise :exc:`ObjectNotFound` if there are no object with id `oid` in this
+ registry
+
+ raise :exc:`NoSelectableObject` if no object can be selected
+ """
+ obj = self._select_best(self[__oid], *args, **kwargs)
+ if obj is None:
+ raise NoSelectableObject(args, kwargs, self[__oid] )
+ return obj
+
+ def select_or_none(self, __oid, *args, **kwargs):
+ """return the most specific object among those with the given oid
+ according to the given context, or None if no object applies.
+ """
+ try:
+ return self._select_best(self[__oid], *args, **kwargs)
+ except ObjectNotFound:
+ return None
+
+ def possible_objects(self, *args, **kwargs):
+ """return an iterator on possible objects in this registry for the given
+ context
+ """
+ for objects in self.values():
+ obj = self._select_best(objects, *args, **kwargs)
+ if obj is None:
+ continue
+ yield obj
+
+ def _select_best(self, objects, *args, **kwargs):
+ """return an instance of the most specific object according
+ to parameters
+
+ return None if not object apply (don't raise `NoSelectableObject` since
+ it's costly when searching objects using `possible_objects`
+ (e.g. searching for hooks).
+ """
+ score, winners = 0, None
+ for obj in objects:
+ objectscore = obj.__select__(obj, *args, **kwargs)
+ if objectscore > score:
+ score, winners = objectscore, [obj]
+ elif objectscore > 0 and objectscore == score:
+ winners.append(obj)
+ if winners is None:
+ return None
+ if len(winners) > 1:
+ # log in production environement / test, error while debugging
+ msg = 'select ambiguity: %s\n(args: %s, kwargs: %s)'
+ if self.debugmode:
+ # raise bare exception in debug mode
+ raise Exception(msg % (winners, args, kwargs.keys()))
+ self.error(msg, winners, args, kwargs.keys())
+ # return the result of calling the object
+ return self.selected(winners[0], args, kwargs)
+
+ def selected(self, winner, args, kwargs):
+ """override here if for instance you don't want "instanciation"
+ """
+ return winner(*args, **kwargs)
+
+ # these are overridden by set_log_methods below
+ # only defining here to prevent pylint from complaining
+ info = warning = error = critical = exception = debug = lambda msg, *a, **kw: None
+
+
+def obj_registries(cls, registryname=None):
+ """return a tuple of registry names (see __registries__)"""
+ if registryname:
+ return (registryname,)
+ return cls.__registries__
+
+
+class RegistryStore(dict):
+ """This class is responsible for loading objects and storing them
+ in their registry which is created on the fly as needed.
+
+ It handles dynamic registration of objects and provides a
+ convenient api to access them. To be recognized as an object that
+ should be stored into one of the store's registry
+ (:class:`Registry`), an object must provide the following
+ attributes, used control how they interact with the registry:
+
+ :attr:`__registries__`
+ list of registry names (string like 'views', 'templates'...) into which
+ the object should be registered
+
+ :attr:`__regid__`
+ object identifier in the registry (string like 'main',
+ 'primary', 'folder_box')
+
+ :attr:`__select__`
+ the object predicate selectors
+
+ Moreover, the :attr:`__abstract__` attribute may be set to `True`
+ to indicate that an object is abstract and should not be registered
+ (such inherited attributes not considered).
+
+ .. Note::
+
+ When using the store to load objects dynamically, you *always* have
+ to use **super()** to get the methods and attributes of the
+ superclasses, and not use the class identifier. If not, you'll get into
+ trouble at reload time.
+
+ For example, instead of writing::
+
+ class Thing(Parent):
+ __regid__ = 'athing'
+ __select__ = yes()
+
+ def f(self, arg1):
+ Parent.f(self, arg1)
+
+ You must write::
+
+ class Thing(Parent):
+ __regid__ = 'athing'
+ __select__ = yes()
+
+ def f(self, arg1):
+ super(Thing, self).f(arg1)
+
+ Controlling object registration
+ -------------------------------
+
+ Dynamic loading is triggered by calling the
+ :meth:`register_objects` method, given a list of directories to
+ inspect for python modules.
+
+ .. automethod: register_objects
+
+ For each module, by default, all compatible objects are registered
+ automatically. However if some objects come as replacement of
+ other objects, or have to be included only if some condition is
+ met, you'll have to define a `registration_callback(vreg)`
+ function in the module and explicitly register **all objects** in
+ this module, using the api defined below.
+
+
+ .. automethod:: RegistryStore.register_all
+ .. automethod:: RegistryStore.register_and_replace
+ .. automethod:: RegistryStore.register
+ .. automethod:: RegistryStore.unregister
+
+ .. Note::
+ Once the function `registration_callback(vreg)` is implemented in a
+ module, all the objects from this module have to be explicitly
+ registered as it disables the automatic object registration.
+
+
+ Examples:
+
+ .. sourcecode:: python
+
+ def registration_callback(store):
+ # register everything in the module except BabarClass
+ store.register_all(globals().values(), __name__, (BabarClass,))
+
+ # conditionally register BabarClass
+ if 'babar_relation' in store.schema:
+ store.register(BabarClass)
+
+ In this example, we register all application object classes defined in the module
+ except `BabarClass`. This class is then registered only if the 'babar_relation'
+ relation type is defined in the instance schema.
+
+ .. sourcecode:: python
+
+ def registration_callback(store):
+ store.register(Elephant)
+ # replace Babar by Celeste
+ store.register_and_replace(Celeste, Babar)
+
+ In this example, we explicitly register classes one by one:
+
+ * the `Elephant` class
+ * the `Celeste` to replace `Babar`
+
+ If at some point we register a new appobject class in this module, it won't be
+ registered at all without modification to the `registration_callback`
+ implementation. The first example will register it though, thanks to the call
+ to the `register_all` method.
+
+ Controlling registry instantiation
+ ----------------------------------
+
+ The `REGISTRY_FACTORY` class dictionary allows to specify which class should
+ be instantiated for a given registry name. The class associated to `None`
+ key will be the class used when there is no specific class for a name.
+ """
+
+ def __init__(self, debugmode=False):
+ super(RegistryStore, self).__init__()
+ self.debugmode = debugmode
+
+ def reset(self):
+ """clear all registries managed by this store"""
+ # don't use self.clear, we want to keep existing subdictionaries
+ for subdict in self.values():
+ subdict.clear()
+ self._lastmodifs = {}
+
+ def __getitem__(self, name):
+ """return the registry (dictionary of class objects) associated to
+ this name
+ """
+ try:
+ return super(RegistryStore, self).__getitem__(name)
+ except KeyError:
+ exc = RegistryNotFound(name)
+ exc.__traceback__ = sys.exc_info()[-1]
+ raise exc
+
+ # methods for explicit (un)registration ###################################
+
+ # default class, when no specific class set
+ REGISTRY_FACTORY = {None: Registry}
+
+ def registry_class(self, regid):
+ """return existing registry named regid or use factory to create one and
+ return it"""
+ try:
+ return self.REGISTRY_FACTORY[regid]
+ except KeyError:
+ return self.REGISTRY_FACTORY[None]
+
+ def setdefault(self, regid):
+ try:
+ return self[regid]
+ except RegistryNotFound:
+ self[regid] = self.registry_class(regid)(self.debugmode)
+ return self[regid]
+
+ def register_all(self, objects, modname, butclasses=()):
+ """register registrable objects into `objects`.
+
+ Registrable objects are properly configured subclasses of
+ :class:`RegistrableObject`. Objects which are not defined in the module
+ `modname` or which are in `butclasses` won't be registered.
+
+ Typical usage is:
+
+ .. sourcecode:: python
+
+ store.register_all(globals().values(), __name__, (ClassIWantToRegisterExplicitly,))
+
+ So you get partially automatic registration, keeping manual registration
+ for some object (to use
+ :meth:`~logilab.common.registry.RegistryStore.register_and_replace` for
+ instance).
+ """
+ assert isinstance(modname, string_types), \
+ 'modname expected to be a module name (ie string), got %r' % modname
+ for obj in objects:
+ if self.is_registrable(obj) and obj.__module__ == modname and not obj in butclasses:
+ if isinstance(obj, type):
+ self._load_ancestors_then_object(modname, obj, butclasses)
+ else:
+ self.register(obj)
+
+ def register(self, obj, registryname=None, oid=None, clear=False):
+ """register `obj` implementation into `registryname` or
+ `obj.__registries__` if not specified, with identifier `oid` or
+ `obj.__regid__` if not specified.
+
+ If `clear` is true, all objects with the same identifier will be
+ previously unregistered.
+ """
+ assert not obj.__dict__.get('__abstract__'), obj
+ for registryname in obj_registries(obj, registryname):
+ registry = self.setdefault(registryname)
+ registry.register(obj, oid=oid, clear=clear)
+ self.debug("register %s in %s['%s']",
+ registry.objname(obj), registryname, oid or obj.__regid__)
+ self._loadedmods.setdefault(obj.__module__, {})[registry.objid(obj)] = obj
+
+ def unregister(self, obj, registryname=None):
+ """unregister `obj` object from the registry `registryname` or
+ `obj.__registries__` if not specified.
+ """
+ for registryname in obj_registries(obj, registryname):
+ registry = self[registryname]
+ registry.unregister(obj)
+ self.debug("unregister %s from %s['%s']",
+ registry.objname(obj), registryname, obj.__regid__)
+
+ def register_and_replace(self, obj, replaced, registryname=None):
+ """register `obj` object into `registryname` or
+ `obj.__registries__` if not specified. If found, the `replaced` object
+ will be unregistered first (else a warning will be issued as it is
+ generally unexpected).
+ """
+ for registryname in obj_registries(obj, registryname):
+ registry = self[registryname]
+ registry.register_and_replace(obj, replaced)
+ self.debug("register %s in %s['%s'] instead of %s",
+ registry.objname(obj), registryname, obj.__regid__,
+ registry.objname(replaced))
+
+ # initialization methods ###################################################
+
+ def init_registration(self, path, extrapath=None):
+ """reset registry and walk down path to return list of (path, name)
+ file modules to be loaded"""
+ # XXX make this private by renaming it to _init_registration ?
+ self.reset()
+ # compute list of all modules that have to be loaded
+ self._toloadmods, filemods = _toload_info(path, extrapath)
+ # XXX is _loadedmods still necessary ? It seems like it's useful
+ # to avoid loading same module twice, especially with the
+ # _load_ancestors_then_object logic but this needs to be checked
+ self._loadedmods = {}
+ return filemods
+
+ def register_objects(self, path, extrapath=None):
+ """register all objects found walking down <path>"""
+ # load views from each directory in the instance's path
+ # XXX inline init_registration ?
+ filemods = self.init_registration(path, extrapath)
+ for filepath, modname in filemods:
+ self.load_file(filepath, modname)
+ self.initialization_completed()
+
+ def initialization_completed(self):
+ """call initialization_completed() on all known registries"""
+ for reg in self.values():
+ reg.initialization_completed()
+
+ def _mdate(self, filepath):
+ """ return the modification date of a file path """
+ try:
+ return stat(filepath)[-2]
+ except OSError:
+ # this typically happens on emacs backup files (.#foo.py)
+ self.warning('Unable to load %s. It is likely to be a backup file',
+ filepath)
+ return None
+
+ def is_reload_needed(self, path):
+ """return True if something module changed and the registry should be
+ reloaded
+ """
+ lastmodifs = self._lastmodifs
+ for fileordir in path:
+ if isdir(fileordir) and exists(join(fileordir, '__init__.py')):
+ if self.is_reload_needed([join(fileordir, fname)
+ for fname in listdir(fileordir)]):
+ return True
+ elif fileordir[-3:] == '.py':
+ mdate = self._mdate(fileordir)
+ if mdate is None:
+ continue # backup file, see _mdate implementation
+ elif "flymake" in fileordir:
+ # flymake + pylint in use, don't consider these they will corrupt the registry
+ continue
+ if fileordir not in lastmodifs or lastmodifs[fileordir] < mdate:
+ self.info('File %s changed since last visit', fileordir)
+ return True
+ return False
+
+ def load_file(self, filepath, modname):
+ """ load registrable objects (if any) from a python file """
+ from logilab.common.modutils import load_module_from_name
+ if modname in self._loadedmods:
+ return
+ self._loadedmods[modname] = {}
+ mdate = self._mdate(filepath)
+ if mdate is None:
+ return # backup file, see _mdate implementation
+ elif "flymake" in filepath:
+ # flymake + pylint in use, don't consider these they will corrupt the registry
+ return
+ # set update time before module loading, else we get some reloading
+ # weirdness in case of syntax error or other error while importing the
+ # module
+ self._lastmodifs[filepath] = mdate
+ # load the module
+ module = load_module_from_name(modname)
+ self.load_module(module)
+
+ def load_module(self, module):
+ """Automatically handle module objects registration.
+
+ Instances are registered as soon as they are hashable and have the
+ following attributes:
+
+ * __regid__ (a string)
+ * __select__ (a callable)
+ * __registries__ (a tuple/list of string)
+
+ For classes this is a bit more complicated :
+
+ - first ensure parent classes are already registered
+
+ - class with __abstract__ == True in their local dictionary are skipped
+
+ - object class needs to have registries and identifier properly set to a
+ non empty string to be registered.
+ """
+ self.info('loading %s from %s', module.__name__, module.__file__)
+ if hasattr(module, 'registration_callback'):
+ module.registration_callback(self)
+ else:
+ self.register_all(vars(module).values(), module.__name__)
+
+ def _load_ancestors_then_object(self, modname, objectcls, butclasses=()):
+ """handle class registration according to rules defined in
+ :meth:`load_module`
+ """
+ # backward compat, we used to allow whatever else than classes
+ if not isinstance(objectcls, type):
+ if self.is_registrable(objectcls) and objectcls.__module__ == modname:
+ self.register(objectcls)
+ return
+ # imported classes
+ objmodname = objectcls.__module__
+ if objmodname != modname:
+ # The module of the object is not the same as the currently
+ # worked on module, or this is actually an instance, which
+ # has no module at all
+ if objmodname in self._toloadmods:
+ # if this is still scheduled for loading, let's proceed immediately,
+ # but using the object module
+ self.load_file(self._toloadmods[objmodname], objmodname)
+ return
+ # ensure object hasn't been already processed
+ clsid = '%s.%s' % (modname, objectcls.__name__)
+ if clsid in self._loadedmods[modname]:
+ return
+ self._loadedmods[modname][clsid] = objectcls
+ # ensure ancestors are registered
+ for parent in objectcls.__bases__:
+ self._load_ancestors_then_object(modname, parent, butclasses)
+ # ensure object is registrable
+ if objectcls in butclasses or not self.is_registrable(objectcls):
+ return
+ # backward compat
+ reg = self.setdefault(obj_registries(objectcls)[0])
+ if reg.objname(objectcls)[0] == '_':
+ warn("[lgc 0.59] object whose name start with '_' won't be "
+ "skipped anymore at some point, use __abstract__ = True "
+ "instead (%s)" % objectcls, DeprecationWarning)
+ return
+ # register, finally
+ self.register(objectcls)
+
+ @classmethod
+ def is_registrable(cls, obj):
+ """ensure `obj` should be registered
+
+ as arbitrary stuff may be registered, do a lot of check and warn about
+ weird cases (think to dumb proxy objects)
+ """
+ if isinstance(obj, type):
+ if not issubclass(obj, RegistrableObject):
+ # ducktyping backward compat
+ if not (getattr(obj, '__registries__', None)
+ and getattr(obj, '__regid__', None)
+ and getattr(obj, '__select__', None)):
+ return False
+ elif issubclass(obj, RegistrableInstance):
+ return False
+ elif not isinstance(obj, RegistrableInstance):
+ return False
+ if not obj.__regid__:
+ return False # no regid
+ registries = obj.__registries__
+ if not registries:
+ return False # no registries
+ selector = obj.__select__
+ if not selector:
+ return False # no selector
+ if obj.__dict__.get('__abstract__', False):
+ return False
+ # then detect potential problems that should be warned
+ if not isinstance(registries, (tuple, list)):
+ cls.warning('%s has __registries__ which is not a list or tuple', obj)
+ return False
+ if not callable(selector):
+ cls.warning('%s has not callable __select__', obj)
+ return False
+ return True
+
+ # these are overridden by set_log_methods below
+ # only defining here to prevent pylint from complaining
+ info = warning = error = critical = exception = debug = lambda msg, *a, **kw: None
+
+
+# init logging
+set_log_methods(RegistryStore, getLogger('registry.store'))
+set_log_methods(Registry, getLogger('registry'))
+
+
+# helpers for debugging selectors
+TRACED_OIDS = None
+
+def _trace_selector(cls, selector, args, ret):
+ vobj = args[0]
+ if TRACED_OIDS == 'all' or vobj.__regid__ in TRACED_OIDS:
+ print('%s -> %s for %s(%s)' % (cls, ret, vobj, vobj.__regid__))
+
+def _lltrace(selector):
+ """use this decorator on your predicates so they become traceable with
+ :class:`traced_selection`
+ """
+ def traced(cls, *args, **kwargs):
+ ret = selector(cls, *args, **kwargs)
+ if TRACED_OIDS is not None:
+ _trace_selector(cls, selector, args, ret)
+ return ret
+ traced.__name__ = selector.__name__
+ traced.__doc__ = selector.__doc__
+ return traced
+
+class traced_selection(object): # pylint: disable=C0103
+ """
+ Typical usage is :
+
+ .. sourcecode:: python
+
+ >>> from logilab.common.registry import traced_selection
+ >>> with traced_selection():
+ ... # some code in which you want to debug selectors
+ ... # for all objects
+
+ This will yield lines like this in the logs::
+
+ selector one_line_rset returned 0 for <class 'elephant.Babar'>
+
+ You can also give to :class:`traced_selection` the identifiers of objects on
+ which you want to debug selection ('oid1' and 'oid2' in the example above).
+
+ .. sourcecode:: python
+
+ >>> with traced_selection( ('regid1', 'regid2') ):
+ ... # some code in which you want to debug selectors
+ ... # for objects with __regid__ 'regid1' and 'regid2'
+
+ A potentially useful point to set up such a tracing function is
+ the `logilab.common.registry.Registry.select` method body.
+ """
+
+ def __init__(self, traced='all'):
+ self.traced = traced
+
+ def __enter__(self):
+ global TRACED_OIDS
+ TRACED_OIDS = self.traced
+
+ def __exit__(self, exctype, exc, traceback):
+ global TRACED_OIDS
+ TRACED_OIDS = None
+ return traceback is None
+
+# selector base classes and operations ########################################
+
+def objectify_predicate(selector_func):
+ """Most of the time, a simple score function is enough to build a selector.
+ The :func:`objectify_predicate` decorator turn it into a proper selector
+ class::
+
+ @objectify_predicate
+ def one(cls, req, rset=None, **kwargs):
+ return 1
+
+ class MyView(View):
+ __select__ = View.__select__ & one()
+
+ """
+ return type(selector_func.__name__, (Predicate,),
+ {'__doc__': selector_func.__doc__,
+ '__call__': lambda self, *a, **kw: selector_func(*a, **kw)})
+
+
+_PREDICATES = {}
+
+def wrap_predicates(decorator):
+ for predicate in _PREDICATES.values():
+ if not '_decorators' in predicate.__dict__:
+ predicate._decorators = set()
+ if decorator in predicate._decorators:
+ continue
+ predicate._decorators.add(decorator)
+ predicate.__call__ = decorator(predicate.__call__)
+
+class PredicateMetaClass(type):
+ def __new__(mcs, *args, **kwargs):
+ # use __new__ so subclasses doesn't have to call Predicate.__init__
+ inst = type.__new__(mcs, *args, **kwargs)
+ proxy = weakref.proxy(inst, lambda p: _PREDICATES.pop(id(p)))
+ _PREDICATES[id(proxy)] = proxy
+ return inst
+
+
+@add_metaclass(PredicateMetaClass)
+class Predicate(object):
+ """base class for selector classes providing implementation
+ for operators ``&``, ``|`` and ``~``
+
+ This class is only here to give access to binary operators, the selector
+ logic itself should be implemented in the :meth:`__call__` method. Notice it
+ should usually accept any arbitrary arguments (the context), though that may
+ vary depending on your usage of the registry.
+
+ a selector is called to help choosing the correct object for a
+ particular context by returning a score (`int`) telling how well
+ the implementation given as first argument fit to the given context.
+
+ 0 score means that the class doesn't apply.
+ """
+
+ @property
+ def func_name(self):
+ # backward compatibility
+ return self.__class__.__name__
+
+ def search_selector(self, selector):
+ """search for the given selector, selector instance or tuple of
+ selectors in the selectors tree. Return None if not found.
+ """
+ if self is selector:
+ return self
+ if (isinstance(selector, type) or isinstance(selector, tuple)) and \
+ isinstance(self, selector):
+ return self
+ return None
+
+ def __str__(self):
+ return self.__class__.__name__
+
+ def __and__(self, other):
+ return AndPredicate(self, other)
+ def __rand__(self, other):
+ return AndPredicate(other, self)
+ def __iand__(self, other):
+ return AndPredicate(self, other)
+ def __or__(self, other):
+ return OrPredicate(self, other)
+ def __ror__(self, other):
+ return OrPredicate(other, self)
+ def __ior__(self, other):
+ return OrPredicate(self, other)
+
+ def __invert__(self):
+ return NotPredicate(self)
+
+ # XXX (function | function) or (function & function) not managed yet
+
+ def __call__(self, cls, *args, **kwargs):
+ return NotImplementedError("selector %s must implement its logic "
+ "in its __call__ method" % self.__class__)
+
+ def __repr__(self):
+ return u'<Predicate %s at %x>' % (self.__class__.__name__, id(self))
+
+
+class MultiPredicate(Predicate):
+ """base class for compound selector classes"""
+
+ def __init__(self, *selectors):
+ self.selectors = self.merge_selectors(selectors)
+
+ def __str__(self):
+ return '%s(%s)' % (self.__class__.__name__,
+ ','.join(str(s) for s in self.selectors))
+
+ @classmethod
+ def merge_selectors(cls, selectors):
+ """deal with selector instanciation when necessary and merge
+ multi-selectors if possible:
+
+ AndPredicate(AndPredicate(sel1, sel2), AndPredicate(sel3, sel4))
+ ==> AndPredicate(sel1, sel2, sel3, sel4)
+ """
+ merged_selectors = []
+ for selector in selectors:
+ # XXX do we really want magic-transformations below?
+ # if so, wanna warn about them?
+ if isinstance(selector, types.FunctionType):
+ selector = objectify_predicate(selector)()
+ if isinstance(selector, type) and issubclass(selector, Predicate):
+ selector = selector()
+ assert isinstance(selector, Predicate), selector
+ if isinstance(selector, cls):
+ merged_selectors += selector.selectors
+ else:
+ merged_selectors.append(selector)
+ return merged_selectors
+
+ def search_selector(self, selector):
+ """search for the given selector or selector instance (or tuple of
+ selectors) in the selectors tree. Return None if not found
+ """
+ for childselector in self.selectors:
+ if childselector is selector:
+ return childselector
+ found = childselector.search_selector(selector)
+ if found is not None:
+ return found
+ # if not found in children, maybe we are looking for self?
+ return super(MultiPredicate, self).search_selector(selector)
+
+
+class AndPredicate(MultiPredicate):
+ """and-chained selectors"""
+ def __call__(self, cls, *args, **kwargs):
+ score = 0
+ for selector in self.selectors:
+ partscore = selector(cls, *args, **kwargs)
+ if not partscore:
+ return 0
+ score += partscore
+ return score
+
+
+class OrPredicate(MultiPredicate):
+ """or-chained selectors"""
+ def __call__(self, cls, *args, **kwargs):
+ for selector in self.selectors:
+ partscore = selector(cls, *args, **kwargs)
+ if partscore:
+ return partscore
+ return 0
+
+class NotPredicate(Predicate):
+ """negation selector"""
+ def __init__(self, selector):
+ self.selector = selector
+
+ def __call__(self, cls, *args, **kwargs):
+ score = self.selector(cls, *args, **kwargs)
+ return int(not score)
+
+ def __str__(self):
+ return 'NOT(%s)' % self.selector
+
+
+class yes(Predicate): # pylint: disable=C0103
+ """Return the score given as parameter, with a default score of 0.5 so any
+ other selector take precedence.
+
+ Usually used for objects which can be selected whatever the context, or
+ also sometimes to add arbitrary points to a score.
+
+ Take care, `yes(0)` could be named 'no'...
+ """
+ def __init__(self, score=0.5):
+ self.score = score
+
+ def __call__(self, *args, **kwargs):
+ return self.score
+
+
+# deprecated stuff #############################################################
+
+from logilab.common.deprecation import deprecated
+
+@deprecated('[lgc 0.59] use Registry.objid class method instead')
+def classid(cls):
+ return '%s.%s' % (cls.__module__, cls.__name__)
+
+@deprecated('[lgc 0.59] use obj_registries function instead')
+def class_registries(cls, registryname):
+ return obj_registries(cls, registryname)
+
diff --git a/logilab/common/shellutils.py b/logilab/common/shellutils.py
new file mode 100644
index 0000000..4e68956
--- /dev/null
+++ b/logilab/common/shellutils.py
@@ -0,0 +1,462 @@
+# copyright 2003-2014 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""shell/term utilities, useful to write some python scripts instead of shell
+scripts.
+"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+import os
+import glob
+import shutil
+import stat
+import sys
+import tempfile
+import time
+import fnmatch
+import errno
+import string
+import random
+import subprocess
+from os.path import exists, isdir, islink, basename, join
+
+from six import string_types
+from six.moves import range, input as raw_input
+
+from logilab.common import STD_BLACKLIST, _handle_blacklist
+from logilab.common.compat import str_to_bytes
+from logilab.common.deprecation import deprecated
+
+try:
+ from logilab.common.proc import ProcInfo, NoSuchProcess
+except ImportError:
+ # windows platform
+ class NoSuchProcess(Exception): pass
+
+ def ProcInfo(pid):
+ raise NoSuchProcess()
+
+
+class tempdir(object):
+
+ def __enter__(self):
+ self.path = tempfile.mkdtemp()
+ return self.path
+
+ def __exit__(self, exctype, value, traceback):
+ # rmtree in all cases
+ shutil.rmtree(self.path)
+ return traceback is None
+
+
+class pushd(object):
+ def __init__(self, directory):
+ self.directory = directory
+
+ def __enter__(self):
+ self.cwd = os.getcwd()
+ os.chdir(self.directory)
+ return self.directory
+
+ def __exit__(self, exctype, value, traceback):
+ os.chdir(self.cwd)
+
+
+def chown(path, login=None, group=None):
+ """Same as `os.chown` function but accepting user login or group name as
+ argument. If login or group is omitted, it's left unchanged.
+
+ Note: you must own the file to chown it (or be root). Otherwise OSError is raised.
+ """
+ if login is None:
+ uid = -1
+ else:
+ try:
+ uid = int(login)
+ except ValueError:
+ import pwd # Platforms: Unix
+ uid = pwd.getpwnam(login).pw_uid
+ if group is None:
+ gid = -1
+ else:
+ try:
+ gid = int(group)
+ except ValueError:
+ import grp
+ gid = grp.getgrnam(group).gr_gid
+ os.chown(path, uid, gid)
+
+def mv(source, destination, _action=shutil.move):
+ """A shell-like mv, supporting wildcards.
+ """
+ sources = glob.glob(source)
+ if len(sources) > 1:
+ assert isdir(destination)
+ for filename in sources:
+ _action(filename, join(destination, basename(filename)))
+ else:
+ try:
+ source = sources[0]
+ except IndexError:
+ raise OSError('No file matching %s' % source)
+ if isdir(destination) and exists(destination):
+ destination = join(destination, basename(source))
+ try:
+ _action(source, destination)
+ except OSError as ex:
+ raise OSError('Unable to move %r to %r (%s)' % (
+ source, destination, ex))
+
+def rm(*files):
+ """A shell-like rm, supporting wildcards.
+ """
+ for wfile in files:
+ for filename in glob.glob(wfile):
+ if islink(filename):
+ os.remove(filename)
+ elif isdir(filename):
+ shutil.rmtree(filename)
+ else:
+ os.remove(filename)
+
+def cp(source, destination):
+ """A shell-like cp, supporting wildcards.
+ """
+ mv(source, destination, _action=shutil.copy)
+
+def find(directory, exts, exclude=False, blacklist=STD_BLACKLIST):
+ """Recursively find files ending with the given extensions from the directory.
+
+ :type directory: str
+ :param directory:
+ directory where the search should start
+
+ :type exts: basestring or list or tuple
+ :param exts:
+ extensions or lists or extensions to search
+
+ :type exclude: boolean
+ :param exts:
+ if this argument is True, returning files NOT ending with the given
+ extensions
+
+ :type blacklist: list or tuple
+ :param blacklist:
+ optional list of files or directory to ignore, default to the value of
+ `logilab.common.STD_BLACKLIST`
+
+ :rtype: list
+ :return:
+ the list of all matching files
+ """
+ if isinstance(exts, string_types):
+ exts = (exts,)
+ if exclude:
+ def match(filename, exts):
+ for ext in exts:
+ if filename.endswith(ext):
+ return False
+ return True
+ else:
+ def match(filename, exts):
+ for ext in exts:
+ if filename.endswith(ext):
+ return True
+ return False
+ files = []
+ for dirpath, dirnames, filenames in os.walk(directory):
+ _handle_blacklist(blacklist, dirnames, filenames)
+ # don't append files if the directory is blacklisted
+ dirname = basename(dirpath)
+ if dirname in blacklist:
+ continue
+ files.extend([join(dirpath, f) for f in filenames if match(f, exts)])
+ return files
+
+
+def globfind(directory, pattern, blacklist=STD_BLACKLIST):
+ """Recursively finds files matching glob `pattern` under `directory`.
+
+ This is an alternative to `logilab.common.shellutils.find`.
+
+ :type directory: str
+ :param directory:
+ directory where the search should start
+
+ :type pattern: basestring
+ :param pattern:
+ the glob pattern (e.g *.py, foo*.py, etc.)
+
+ :type blacklist: list or tuple
+ :param blacklist:
+ optional list of files or directory to ignore, default to the value of
+ `logilab.common.STD_BLACKLIST`
+
+ :rtype: iterator
+ :return:
+ iterator over the list of all matching files
+ """
+ for curdir, dirnames, filenames in os.walk(directory):
+ _handle_blacklist(blacklist, dirnames, filenames)
+ for fname in fnmatch.filter(filenames, pattern):
+ yield join(curdir, fname)
+
+def unzip(archive, destdir):
+ import zipfile
+ if not exists(destdir):
+ os.mkdir(destdir)
+ zfobj = zipfile.ZipFile(archive)
+ for name in zfobj.namelist():
+ if name.endswith('/'):
+ os.mkdir(join(destdir, name))
+ else:
+ outfile = open(join(destdir, name), 'wb')
+ outfile.write(zfobj.read(name))
+ outfile.close()
+
+
+class Execute:
+ """This is a deadlock safe version of popen2 (no stdin), that returns
+ an object with errorlevel, out and err.
+ """
+
+ def __init__(self, command):
+ cmd = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ self.out, self.err = cmd.communicate()
+ self.status = os.WEXITSTATUS(cmd.returncode)
+
+Execute = deprecated('Use subprocess.Popen instead')(Execute)
+
+
+def acquire_lock(lock_file, max_try=10, delay=10, max_delay=3600):
+ """Acquire a lock represented by a file on the file system
+
+ If the process written in lock file doesn't exist anymore, we remove the
+ lock file immediately
+ If age of the lock_file is greater than max_delay, then we raise a UserWarning
+ """
+ count = abs(max_try)
+ while count:
+ try:
+ fd = os.open(lock_file, os.O_EXCL | os.O_RDWR | os.O_CREAT)
+ os.write(fd, str_to_bytes(str(os.getpid())) )
+ os.close(fd)
+ return True
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ try:
+ fd = open(lock_file, "r")
+ pid = int(fd.readline())
+ pi = ProcInfo(pid)
+ age = (time.time() - os.stat(lock_file)[stat.ST_MTIME])
+ if age / max_delay > 1 :
+ raise UserWarning("Command '%s' (pid %s) has locked the "
+ "file '%s' for %s minutes"
+ % (pi.name(), pid, lock_file, age/60))
+ except UserWarning:
+ raise
+ except NoSuchProcess:
+ os.remove(lock_file)
+ except Exception:
+ # The try block is not essential. can be skipped.
+ # Note: ProcInfo object is only available for linux
+ # process information are not accessible...
+ # or lock_file is no more present...
+ pass
+ else:
+ raise
+ count -= 1
+ time.sleep(delay)
+ else:
+ raise Exception('Unable to acquire %s' % lock_file)
+
+def release_lock(lock_file):
+ """Release a lock represented by a file on the file system."""
+ os.remove(lock_file)
+
+
+class ProgressBar(object):
+ """A simple text progression bar."""
+
+ def __init__(self, nbops, size=20, stream=sys.stdout, title=''):
+ if title:
+ self._fstr = '\r%s [%%-%ss]' % (title, int(size))
+ else:
+ self._fstr = '\r[%%-%ss]' % int(size)
+ self._stream = stream
+ self._total = nbops
+ self._size = size
+ self._current = 0
+ self._progress = 0
+ self._current_text = None
+ self._last_text_write_size = 0
+
+ def _get_text(self):
+ return self._current_text
+
+ def _set_text(self, text=None):
+ if text != self._current_text:
+ self._current_text = text
+ self.refresh()
+
+ def _del_text(self):
+ self.text = None
+
+ text = property(_get_text, _set_text, _del_text)
+
+ def update(self, offset=1, exact=False):
+ """Move FORWARD to new cursor position (cursor will never go backward).
+
+ :offset: fraction of ``size``
+
+ :exact:
+
+ - False: offset relative to current cursor position if True
+ - True: offset as an asbsolute position
+
+ """
+ if exact:
+ self._current = offset
+ else:
+ self._current += offset
+
+ progress = int((float(self._current)/float(self._total))*self._size)
+ if progress > self._progress:
+ self._progress = progress
+ self.refresh()
+
+ def refresh(self):
+ """Refresh the progression bar display."""
+ self._stream.write(self._fstr % ('=' * min(self._progress, self._size)) )
+ if self._last_text_write_size or self._current_text:
+ template = ' %%-%is' % (self._last_text_write_size)
+ text = self._current_text
+ if text is None:
+ text = ''
+ self._stream.write(template % text)
+ self._last_text_write_size = len(text.rstrip())
+ self._stream.flush()
+
+ def finish(self):
+ self._stream.write('\n')
+ self._stream.flush()
+
+
+class DummyProgressBar(object):
+ __slot__ = ('text',)
+
+ def refresh(self):
+ pass
+ def update(self):
+ pass
+ def finish(self):
+ pass
+
+
+_MARKER = object()
+class progress(object):
+
+ def __init__(self, nbops=_MARKER, size=_MARKER, stream=_MARKER, title=_MARKER, enabled=True):
+ self.nbops = nbops
+ self.size = size
+ self.stream = stream
+ self.title = title
+ self.enabled = enabled
+
+ def __enter__(self):
+ if self.enabled:
+ kwargs = {}
+ for attr in ('nbops', 'size', 'stream', 'title'):
+ value = getattr(self, attr)
+ if value is not _MARKER:
+ kwargs[attr] = value
+ self.pb = ProgressBar(**kwargs)
+ else:
+ self.pb = DummyProgressBar()
+ return self.pb
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.pb.finish()
+
+class RawInput(object):
+
+ def __init__(self, input=None, printer=None):
+ self._input = input or raw_input
+ self._print = printer
+
+ def ask(self, question, options, default):
+ assert default in options
+ choices = []
+ for option in options:
+ if option == default:
+ label = option[0].upper()
+ else:
+ label = option[0].lower()
+ if len(option) > 1:
+ label += '(%s)' % option[1:].lower()
+ choices.append((option, label))
+ prompt = "%s [%s]: " % (question,
+ '/'.join([opt[1] for opt in choices]))
+ tries = 3
+ while tries > 0:
+ answer = self._input(prompt).strip().lower()
+ if not answer:
+ return default
+ possible = [option for option, label in choices
+ if option.lower().startswith(answer)]
+ if len(possible) == 1:
+ return possible[0]
+ elif len(possible) == 0:
+ msg = '%s is not an option.' % answer
+ else:
+ msg = ('%s is an ambiguous answer, do you mean %s ?' % (
+ answer, ' or '.join(possible)))
+ if self._print:
+ self._print(msg)
+ else:
+ print(msg)
+ tries -= 1
+ raise Exception('unable to get a sensible answer')
+
+ def confirm(self, question, default_is_yes=True):
+ default = default_is_yes and 'y' or 'n'
+ answer = self.ask(question, ('y', 'n'), default)
+ return answer == 'y'
+
+ASK = RawInput()
+
+
+def getlogin():
+ """avoid using os.getlogin() because of strange tty / stdin problems
+ (man 3 getlogin)
+ Another solution would be to use $LOGNAME, $USER or $USERNAME
+ """
+ if sys.platform != 'win32':
+ import pwd # Platforms: Unix
+ return pwd.getpwuid(os.getuid())[0]
+ else:
+ return os.environ['USERNAME']
+
+def generate_password(length=8, vocab=string.ascii_letters + string.digits):
+ """dumb password generation function"""
+ pwd = ''
+ for i in range(length):
+ pwd += random.choice(vocab)
+ return pwd
diff --git a/logilab/common/sphinx_ext.py b/logilab/common/sphinx_ext.py
new file mode 100644
index 0000000..a24608c
--- /dev/null
+++ b/logilab/common/sphinx_ext.py
@@ -0,0 +1,87 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+from logilab.common.decorators import monkeypatch
+
+from sphinx.ext import autodoc
+
+class DocstringOnlyModuleDocumenter(autodoc.ModuleDocumenter):
+ objtype = 'docstring'
+ def format_signature(self):
+ pass
+ def add_directive_header(self, sig):
+ pass
+ def document_members(self, all_members=False):
+ pass
+
+ def resolve_name(self, modname, parents, path, base):
+ if modname is not None:
+ return modname, parents + [base]
+ return (path or '') + base, []
+
+
+#autodoc.add_documenter(DocstringOnlyModuleDocumenter)
+
+def setup(app):
+ app.add_autodocumenter(DocstringOnlyModuleDocumenter)
+
+
+
+from sphinx.ext.autodoc import (ViewList, Options, AutodocReporter, nodes,
+ assemble_option_dict, nested_parse_with_titles)
+
+@monkeypatch(autodoc.AutoDirective)
+def run(self):
+ self.filename_set = set() # a set of dependent filenames
+ self.reporter = self.state.document.reporter
+ self.env = self.state.document.settings.env
+ self.warnings = []
+ self.result = ViewList()
+
+ # find out what documenter to call
+ objtype = self.name[4:]
+ doc_class = self._registry[objtype]
+ # process the options with the selected documenter's option_spec
+ self.genopt = Options(assemble_option_dict(
+ self.options.items(), doc_class.option_spec))
+ # generate the output
+ documenter = doc_class(self, self.arguments[0])
+ documenter.generate(more_content=self.content)
+ if not self.result:
+ return self.warnings
+
+ # record all filenames as dependencies -- this will at least
+ # partially make automatic invalidation possible
+ for fn in self.filename_set:
+ self.env.note_dependency(fn)
+
+ # use a custom reporter that correctly assigns lines to source
+ # filename/description and lineno
+ old_reporter = self.state.memo.reporter
+ self.state.memo.reporter = AutodocReporter(self.result,
+ self.state.memo.reporter)
+ if self.name in ('automodule', 'autodocstring'):
+ node = nodes.section()
+ # necessary so that the child nodes get the right source/line set
+ node.document = self.state.document
+ nested_parse_with_titles(self.state, self.result, node)
+ else:
+ node = nodes.paragraph()
+ node.document = self.state.document
+ self.state.nested_parse(self.result, 0, node)
+ self.state.memo.reporter = old_reporter
+ return self.warnings + node.children
diff --git a/logilab/common/sphinxutils.py b/logilab/common/sphinxutils.py
new file mode 100644
index 0000000..ab6e8a1
--- /dev/null
+++ b/logilab/common/sphinxutils.py
@@ -0,0 +1,122 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Sphinx utils
+
+ModuleGenerator: Generate a file that lists all the modules of a list of
+packages in order to pull all the docstring.
+This should not be used in a makefile to systematically generate sphinx
+documentation!
+
+Typical usage:
+
+>>> from logilab.common.sphinxutils import ModuleGenerator
+>>> mgen = ModuleGenerator('logilab common', '/home/adim/src/logilab/common')
+>>> mgen.generate('api_logilab_common.rst', exclude_dirs=('test',))
+"""
+
+import os, sys
+import os.path as osp
+import inspect
+
+from logilab.common import STD_BLACKLIST
+from logilab.common.shellutils import globfind
+from logilab.common.modutils import load_module_from_file, modpath_from_file
+
+def module_members(module):
+ members = []
+ for name, value in inspect.getmembers(module):
+ if getattr(value, '__module__', None) == module.__name__:
+ members.append( (name, value) )
+ return sorted(members)
+
+
+def class_members(klass):
+ return sorted([name for name in vars(klass)
+ if name not in ('__doc__', '__module__',
+ '__dict__', '__weakref__')])
+
+class ModuleGenerator:
+ file_header = """.. -*- coding: utf-8 -*-\n\n%s\n"""
+ module_def = """
+:mod:`%s`
+=======%s
+
+.. automodule:: %s
+ :members: %s
+"""
+ class_def = """
+
+.. autoclass:: %s
+ :members: %s
+
+"""
+
+ def __init__(self, project_title, code_dir):
+ self.title = project_title
+ self.code_dir = osp.abspath(code_dir)
+
+ def generate(self, dest_file, exclude_dirs=STD_BLACKLIST):
+ """make the module file"""
+ self.fn = open(dest_file, 'w')
+ num = len(self.title) + 6
+ title = "=" * num + "\n %s API\n" % self.title + "=" * num
+ self.fn.write(self.file_header % title)
+ self.gen_modules(exclude_dirs=exclude_dirs)
+ self.fn.close()
+
+ def gen_modules(self, exclude_dirs):
+ """generate all modules"""
+ for module in self.find_modules(exclude_dirs):
+ modname = module.__name__
+ classes = []
+ modmembers = []
+ for objname, obj in module_members(module):
+ if inspect.isclass(obj):
+ classmembers = class_members(obj)
+ classes.append( (objname, classmembers) )
+ else:
+ modmembers.append(objname)
+ self.fn.write(self.module_def % (modname, '=' * len(modname),
+ modname,
+ ', '.join(modmembers)))
+ for klass, members in classes:
+ self.fn.write(self.class_def % (klass, ', '.join(members)))
+
+ def find_modules(self, exclude_dirs):
+ basepath = osp.dirname(self.code_dir)
+ basedir = osp.basename(basepath) + osp.sep
+ if basedir not in sys.path:
+ sys.path.insert(1, basedir)
+ for filepath in globfind(self.code_dir, '*.py', exclude_dirs):
+ if osp.basename(filepath) in ('setup.py', '__pkginfo__.py'):
+ continue
+ try:
+ module = load_module_from_file(filepath)
+ except: # module might be broken or magic
+ dotted_path = modpath_from_file(filepath)
+ module = type('.'.join(dotted_path), (), {}) # mock it
+ yield module
+
+
+if __name__ == '__main__':
+ # example :
+ title, code_dir, outfile = sys.argv[1:]
+ generator = ModuleGenerator(title, code_dir)
+ # XXX modnames = ['logilab']
+ generator.generate(outfile, ('test', 'tests', 'examples',
+ 'data', 'doc', '.hg', 'migration'))
diff --git a/logilab/common/table.py b/logilab/common/table.py
new file mode 100644
index 0000000..2f3df69
--- /dev/null
+++ b/logilab/common/table.py
@@ -0,0 +1,929 @@
+# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Table management module."""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+from six.moves import range
+
+class Table(object):
+ """Table defines a data table with column and row names.
+ inv:
+ len(self.data) <= len(self.row_names)
+ forall(self.data, lambda x: len(x) <= len(self.col_names))
+ """
+
+ def __init__(self, default_value=0, col_names=None, row_names=None):
+ self.col_names = []
+ self.row_names = []
+ self.data = []
+ self.default_value = default_value
+ if col_names:
+ self.create_columns(col_names)
+ if row_names:
+ self.create_rows(row_names)
+
+ def _next_row_name(self):
+ return 'row%s' % (len(self.row_names)+1)
+
+ def __iter__(self):
+ return iter(self.data)
+
+ def __eq__(self, other):
+ if other is None:
+ return False
+ else:
+ return list(self) == list(other)
+
+ __hash__ = object.__hash__
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __len__(self):
+ return len(self.row_names)
+
+ ## Rows / Columns creation #################################################
+ def create_rows(self, row_names):
+ """Appends row_names to the list of existing rows
+ """
+ self.row_names.extend(row_names)
+ for row_name in row_names:
+ self.data.append([self.default_value]*len(self.col_names))
+
+ def create_columns(self, col_names):
+ """Appends col_names to the list of existing columns
+ """
+ for col_name in col_names:
+ self.create_column(col_name)
+
+ def create_row(self, row_name=None):
+ """Creates a rowname to the row_names list
+ """
+ row_name = row_name or self._next_row_name()
+ self.row_names.append(row_name)
+ self.data.append([self.default_value]*len(self.col_names))
+
+
+ def create_column(self, col_name):
+ """Creates a colname to the col_names list
+ """
+ self.col_names.append(col_name)
+ for row in self.data:
+ row.append(self.default_value)
+
+ ## Sort by column ##########################################################
+ def sort_by_column_id(self, col_id, method = 'asc'):
+ """Sorts the table (in-place) according to data stored in col_id
+ """
+ try:
+ col_index = self.col_names.index(col_id)
+ self.sort_by_column_index(col_index, method)
+ except ValueError:
+ raise KeyError("Col (%s) not found in table" % (col_id))
+
+
+ def sort_by_column_index(self, col_index, method = 'asc'):
+ """Sorts the table 'in-place' according to data stored in col_index
+
+ method should be in ('asc', 'desc')
+ """
+ sort_list = sorted([(row[col_index], row, row_name)
+ for row, row_name in zip(self.data, self.row_names)])
+ # Sorting sort_list will sort according to col_index
+ # If we want reverse sort, then reverse list
+ if method.lower() == 'desc':
+ sort_list.reverse()
+
+ # Rebuild data / row names
+ self.data = []
+ self.row_names = []
+ for val, row, row_name in sort_list:
+ self.data.append(row)
+ self.row_names.append(row_name)
+
+ def groupby(self, colname, *others):
+ """builds indexes of data
+ :returns: nested dictionaries pointing to actual rows
+ """
+ groups = {}
+ colnames = (colname,) + others
+ col_indexes = [self.col_names.index(col_id) for col_id in colnames]
+ for row in self.data:
+ ptr = groups
+ for col_index in col_indexes[:-1]:
+ ptr = ptr.setdefault(row[col_index], {})
+ ptr = ptr.setdefault(row[col_indexes[-1]],
+ Table(default_value=self.default_value,
+ col_names=self.col_names))
+ ptr.append_row(tuple(row))
+ return groups
+
+ def select(self, colname, value):
+ grouped = self.groupby(colname)
+ try:
+ return grouped[value]
+ except KeyError:
+ return []
+
+ def remove(self, colname, value):
+ col_index = self.col_names.index(colname)
+ for row in self.data[:]:
+ if row[col_index] == value:
+ self.data.remove(row)
+
+
+ ## The 'setter' part #######################################################
+ def set_cell(self, row_index, col_index, data):
+ """sets value of cell 'row_indew', 'col_index' to data
+ """
+ self.data[row_index][col_index] = data
+
+
+ def set_cell_by_ids(self, row_id, col_id, data):
+ """sets value of cell mapped by row_id and col_id to data
+ Raises a KeyError if row_id or col_id are not found in the table
+ """
+ try:
+ row_index = self.row_names.index(row_id)
+ except ValueError:
+ raise KeyError("Row (%s) not found in table" % (row_id))
+ else:
+ try:
+ col_index = self.col_names.index(col_id)
+ self.data[row_index][col_index] = data
+ except ValueError:
+ raise KeyError("Column (%s) not found in table" % (col_id))
+
+
+ def set_row(self, row_index, row_data):
+ """sets the 'row_index' row
+ pre:
+ type(row_data) == types.ListType
+ len(row_data) == len(self.col_names)
+ """
+ self.data[row_index] = row_data
+
+
+ def set_row_by_id(self, row_id, row_data):
+ """sets the 'row_id' column
+ pre:
+ type(row_data) == types.ListType
+ len(row_data) == len(self.row_names)
+ Raises a KeyError if row_id is not found
+ """
+ try:
+ row_index = self.row_names.index(row_id)
+ self.set_row(row_index, row_data)
+ except ValueError:
+ raise KeyError('Row (%s) not found in table' % (row_id))
+
+
+ def append_row(self, row_data, row_name=None):
+ """Appends a row to the table
+ pre:
+ type(row_data) == types.ListType
+ len(row_data) == len(self.col_names)
+ """
+ row_name = row_name or self._next_row_name()
+ self.row_names.append(row_name)
+ self.data.append(row_data)
+ return len(self.data) - 1
+
+ def insert_row(self, index, row_data, row_name=None):
+ """Appends row_data before 'index' in the table. To make 'insert'
+ behave like 'list.insert', inserting in an out of range index will
+ insert row_data to the end of the list
+ pre:
+ type(row_data) == types.ListType
+ len(row_data) == len(self.col_names)
+ """
+ row_name = row_name or self._next_row_name()
+ self.row_names.insert(index, row_name)
+ self.data.insert(index, row_data)
+
+
+ def delete_row(self, index):
+ """Deletes the 'index' row in the table, and returns it.
+ Raises an IndexError if index is out of range
+ """
+ self.row_names.pop(index)
+ return self.data.pop(index)
+
+
+ def delete_row_by_id(self, row_id):
+ """Deletes the 'row_id' row in the table.
+ Raises a KeyError if row_id was not found.
+ """
+ try:
+ row_index = self.row_names.index(row_id)
+ self.delete_row(row_index)
+ except ValueError:
+ raise KeyError('Row (%s) not found in table' % (row_id))
+
+
+ def set_column(self, col_index, col_data):
+ """sets the 'col_index' column
+ pre:
+ type(col_data) == types.ListType
+ len(col_data) == len(self.row_names)
+ """
+
+ for row_index, cell_data in enumerate(col_data):
+ self.data[row_index][col_index] = cell_data
+
+
+ def set_column_by_id(self, col_id, col_data):
+ """sets the 'col_id' column
+ pre:
+ type(col_data) == types.ListType
+ len(col_data) == len(self.col_names)
+ Raises a KeyError if col_id is not found
+ """
+ try:
+ col_index = self.col_names.index(col_id)
+ self.set_column(col_index, col_data)
+ except ValueError:
+ raise KeyError('Column (%s) not found in table' % (col_id))
+
+
+ def append_column(self, col_data, col_name):
+ """Appends the 'col_index' column
+ pre:
+ type(col_data) == types.ListType
+ len(col_data) == len(self.row_names)
+ """
+ self.col_names.append(col_name)
+ for row_index, cell_data in enumerate(col_data):
+ self.data[row_index].append(cell_data)
+
+
+ def insert_column(self, index, col_data, col_name):
+ """Appends col_data before 'index' in the table. To make 'insert'
+ behave like 'list.insert', inserting in an out of range index will
+ insert col_data to the end of the list
+ pre:
+ type(col_data) == types.ListType
+ len(col_data) == len(self.row_names)
+ """
+ self.col_names.insert(index, col_name)
+ for row_index, cell_data in enumerate(col_data):
+ self.data[row_index].insert(index, cell_data)
+
+
+ def delete_column(self, index):
+ """Deletes the 'index' column in the table, and returns it.
+ Raises an IndexError if index is out of range
+ """
+ self.col_names.pop(index)
+ return [row.pop(index) for row in self.data]
+
+
+ def delete_column_by_id(self, col_id):
+ """Deletes the 'col_id' col in the table.
+ Raises a KeyError if col_id was not found.
+ """
+ try:
+ col_index = self.col_names.index(col_id)
+ self.delete_column(col_index)
+ except ValueError:
+ raise KeyError('Column (%s) not found in table' % (col_id))
+
+
+ ## The 'getter' part #######################################################
+
+ def get_shape(self):
+ """Returns a tuple which represents the table's shape
+ """
+ return len(self.row_names), len(self.col_names)
+ shape = property(get_shape)
+
+ def __getitem__(self, indices):
+ """provided for convenience"""
+ rows, multirows = None, False
+ cols, multicols = None, False
+ if isinstance(indices, tuple):
+ rows = indices[0]
+ if len(indices) > 1:
+ cols = indices[1]
+ else:
+ rows = indices
+ # define row slice
+ if isinstance(rows, str):
+ try:
+ rows = self.row_names.index(rows)
+ except ValueError:
+ raise KeyError("Row (%s) not found in table" % (rows))
+ if isinstance(rows, int):
+ rows = slice(rows, rows+1)
+ multirows = False
+ else:
+ rows = slice(None)
+ multirows = True
+ # define col slice
+ if isinstance(cols, str):
+ try:
+ cols = self.col_names.index(cols)
+ except ValueError:
+ raise KeyError("Column (%s) not found in table" % (cols))
+ if isinstance(cols, int):
+ cols = slice(cols, cols+1)
+ multicols = False
+ else:
+ cols = slice(None)
+ multicols = True
+ # get sub-table
+ tab = Table()
+ tab.default_value = self.default_value
+ tab.create_rows(self.row_names[rows])
+ tab.create_columns(self.col_names[cols])
+ for idx, row in enumerate(self.data[rows]):
+ tab.set_row(idx, row[cols])
+ if multirows :
+ if multicols:
+ return tab
+ else:
+ return [item[0] for item in tab.data]
+ else:
+ if multicols:
+ return tab.data[0]
+ else:
+ return tab.data[0][0]
+
+ def get_cell_by_ids(self, row_id, col_id):
+ """Returns the element at [row_id][col_id]
+ """
+ try:
+ row_index = self.row_names.index(row_id)
+ except ValueError:
+ raise KeyError("Row (%s) not found in table" % (row_id))
+ else:
+ try:
+ col_index = self.col_names.index(col_id)
+ except ValueError:
+ raise KeyError("Column (%s) not found in table" % (col_id))
+ return self.data[row_index][col_index]
+
+ def get_row_by_id(self, row_id):
+ """Returns the 'row_id' row
+ """
+ try:
+ row_index = self.row_names.index(row_id)
+ except ValueError:
+ raise KeyError("Row (%s) not found in table" % (row_id))
+ return self.data[row_index]
+
+ def get_column_by_id(self, col_id, distinct=False):
+ """Returns the 'col_id' col
+ """
+ try:
+ col_index = self.col_names.index(col_id)
+ except ValueError:
+ raise KeyError("Column (%s) not found in table" % (col_id))
+ return self.get_column(col_index, distinct)
+
+ def get_columns(self):
+ """Returns all the columns in the table
+ """
+ return [self[:, index] for index in range(len(self.col_names))]
+
+ def get_column(self, col_index, distinct=False):
+ """get a column by index"""
+ col = [row[col_index] for row in self.data]
+ if distinct:
+ col = list(set(col))
+ return col
+
+ def apply_stylesheet(self, stylesheet):
+ """Applies the stylesheet to this table
+ """
+ for instruction in stylesheet.instructions:
+ eval(instruction)
+
+
+ def transpose(self):
+ """Keeps the self object intact, and returns the transposed (rotated)
+ table.
+ """
+ transposed = Table()
+ transposed.create_rows(self.col_names)
+ transposed.create_columns(self.row_names)
+ for col_index, column in enumerate(self.get_columns()):
+ transposed.set_row(col_index, column)
+ return transposed
+
+
+ def pprint(self):
+ """returns a string representing the table in a pretty
+ printed 'text' format.
+ """
+ # The maximum row name (to know the start_index of the first col)
+ max_row_name = 0
+ for row_name in self.row_names:
+ if len(row_name) > max_row_name:
+ max_row_name = len(row_name)
+ col_start = max_row_name + 5
+
+ lines = []
+ # Build the 'first' line <=> the col_names one
+ # The first cell <=> an empty one
+ col_names_line = [' '*col_start]
+ for col_name in self.col_names:
+ col_names_line.append(col_name + ' '*5)
+ lines.append('|' + '|'.join(col_names_line) + '|')
+ max_line_length = len(lines[0])
+
+ # Build the table
+ for row_index, row in enumerate(self.data):
+ line = []
+ # First, build the row_name's cell
+ row_name = self.row_names[row_index]
+ line.append(row_name + ' '*(col_start-len(row_name)))
+
+ # Then, build all the table's cell for this line.
+ for col_index, cell in enumerate(row):
+ col_name_length = len(self.col_names[col_index]) + 5
+ data = str(cell)
+ line.append(data + ' '*(col_name_length - len(data)))
+ lines.append('|' + '|'.join(line) + '|')
+ if len(lines[-1]) > max_line_length:
+ max_line_length = len(lines[-1])
+
+ # Wrap the table with '-' to make a frame
+ lines.insert(0, '-'*max_line_length)
+ lines.append('-'*max_line_length)
+ return '\n'.join(lines)
+
+
+ def __repr__(self):
+ return repr(self.data)
+
+ def as_text(self):
+ data = []
+ # We must convert cells into strings before joining them
+ for row in self.data:
+ data.append([str(cell) for cell in row])
+ lines = ['\t'.join(row) for row in data]
+ return '\n'.join(lines)
+
+
+
+class TableStyle:
+ """Defines a table's style
+ """
+
+ def __init__(self, table):
+
+ self._table = table
+ self.size = dict([(col_name, '1*') for col_name in table.col_names])
+ # __row_column__ is a special key to define the first column which
+ # actually has no name (<=> left most column <=> row names column)
+ self.size['__row_column__'] = '1*'
+ self.alignment = dict([(col_name, 'right')
+ for col_name in table.col_names])
+ self.alignment['__row_column__'] = 'right'
+
+ # We shouldn't have to create an entry for
+ # the 1st col (the row_column one)
+ self.units = dict([(col_name, '') for col_name in table.col_names])
+ self.units['__row_column__'] = ''
+
+ # XXX FIXME : params order should be reversed for all set() methods
+ def set_size(self, value, col_id):
+ """sets the size of the specified col_id to value
+ """
+ self.size[col_id] = value
+
+ def set_size_by_index(self, value, col_index):
+ """Allows to set the size according to the column index rather than
+ using the column's id.
+ BE CAREFUL : the '0' column is the '__row_column__' one !
+ """
+ if col_index == 0:
+ col_id = '__row_column__'
+ else:
+ col_id = self._table.col_names[col_index-1]
+
+ self.size[col_id] = value
+
+
+ def set_alignment(self, value, col_id):
+ """sets the alignment of the specified col_id to value
+ """
+ self.alignment[col_id] = value
+
+
+ def set_alignment_by_index(self, value, col_index):
+ """Allows to set the alignment according to the column index rather than
+ using the column's id.
+ BE CAREFUL : the '0' column is the '__row_column__' one !
+ """
+ if col_index == 0:
+ col_id = '__row_column__'
+ else:
+ col_id = self._table.col_names[col_index-1]
+
+ self.alignment[col_id] = value
+
+
+ def set_unit(self, value, col_id):
+ """sets the unit of the specified col_id to value
+ """
+ self.units[col_id] = value
+
+
+ def set_unit_by_index(self, value, col_index):
+ """Allows to set the unit according to the column index rather than
+ using the column's id.
+ BE CAREFUL : the '0' column is the '__row_column__' one !
+ (Note that in the 'unit' case, you shouldn't have to set a unit
+ for the 1st column (the __row__column__ one))
+ """
+ if col_index == 0:
+ col_id = '__row_column__'
+ else:
+ col_id = self._table.col_names[col_index-1]
+
+ self.units[col_id] = value
+
+
+ def get_size(self, col_id):
+ """Returns the size of the specified col_id
+ """
+ return self.size[col_id]
+
+
+ def get_size_by_index(self, col_index):
+ """Allows to get the size according to the column index rather than
+ using the column's id.
+ BE CAREFUL : the '0' column is the '__row_column__' one !
+ """
+ if col_index == 0:
+ col_id = '__row_column__'
+ else:
+ col_id = self._table.col_names[col_index-1]
+
+ return self.size[col_id]
+
+
+ def get_alignment(self, col_id):
+ """Returns the alignment of the specified col_id
+ """
+ return self.alignment[col_id]
+
+
+ def get_alignment_by_index(self, col_index):
+ """Allors to get the alignment according to the column index rather than
+ using the column's id.
+ BE CAREFUL : the '0' column is the '__row_column__' one !
+ """
+ if col_index == 0:
+ col_id = '__row_column__'
+ else:
+ col_id = self._table.col_names[col_index-1]
+
+ return self.alignment[col_id]
+
+
+ def get_unit(self, col_id):
+ """Returns the unit of the specified col_id
+ """
+ return self.units[col_id]
+
+
+ def get_unit_by_index(self, col_index):
+ """Allors to get the unit according to the column index rather than
+ using the column's id.
+ BE CAREFUL : the '0' column is the '__row_column__' one !
+ """
+ if col_index == 0:
+ col_id = '__row_column__'
+ else:
+ col_id = self._table.col_names[col_index-1]
+
+ return self.units[col_id]
+
+
+import re
+CELL_PROG = re.compile("([0-9]+)_([0-9]+)")
+
+class TableStyleSheet:
+ """A simple Table stylesheet
+ Rules are expressions where cells are defined by the row_index
+ and col_index separated by an underscore ('_').
+ For example, suppose you want to say that the (2,5) cell must be
+ the sum of its two preceding cells in the row, you would create
+ the following rule :
+ 2_5 = 2_3 + 2_4
+ You can also use all the math.* operations you want. For example:
+ 2_5 = sqrt(2_3**2 + 2_4**2)
+ """
+
+ def __init__(self, rules = None):
+ rules = rules or []
+ self.rules = []
+ self.instructions = []
+ for rule in rules:
+ self.add_rule(rule)
+
+
+ def add_rule(self, rule):
+ """Adds a rule to the stylesheet rules
+ """
+ try:
+ source_code = ['from math import *']
+ source_code.append(CELL_PROG.sub(r'self.data[\1][\2]', rule))
+ self.instructions.append(compile('\n'.join(source_code),
+ 'table.py', 'exec'))
+ self.rules.append(rule)
+ except SyntaxError:
+ print("Bad Stylesheet Rule : %s [skipped]" % rule)
+
+
+ def add_rowsum_rule(self, dest_cell, row_index, start_col, end_col):
+ """Creates and adds a rule to sum over the row at row_index from
+ start_col to end_col.
+ dest_cell is a tuple of two elements (x,y) of the destination cell
+ No check is done for indexes ranges.
+ pre:
+ start_col >= 0
+ end_col > start_col
+ """
+ cell_list = ['%d_%d'%(row_index, index) for index in range(start_col,
+ end_col + 1)]
+ rule = '%d_%d=' % dest_cell + '+'.join(cell_list)
+ self.add_rule(rule)
+
+
+ def add_rowavg_rule(self, dest_cell, row_index, start_col, end_col):
+ """Creates and adds a rule to make the row average (from start_col
+ to end_col)
+ dest_cell is a tuple of two elements (x,y) of the destination cell
+ No check is done for indexes ranges.
+ pre:
+ start_col >= 0
+ end_col > start_col
+ """
+ cell_list = ['%d_%d'%(row_index, index) for index in range(start_col,
+ end_col + 1)]
+ num = (end_col - start_col + 1)
+ rule = '%d_%d=' % dest_cell + '('+'+'.join(cell_list)+')/%f'%num
+ self.add_rule(rule)
+
+
+ def add_colsum_rule(self, dest_cell, col_index, start_row, end_row):
+ """Creates and adds a rule to sum over the col at col_index from
+ start_row to end_row.
+ dest_cell is a tuple of two elements (x,y) of the destination cell
+ No check is done for indexes ranges.
+ pre:
+ start_row >= 0
+ end_row > start_row
+ """
+ cell_list = ['%d_%d'%(index, col_index) for index in range(start_row,
+ end_row + 1)]
+ rule = '%d_%d=' % dest_cell + '+'.join(cell_list)
+ self.add_rule(rule)
+
+
+ def add_colavg_rule(self, dest_cell, col_index, start_row, end_row):
+ """Creates and adds a rule to make the col average (from start_row
+ to end_row)
+ dest_cell is a tuple of two elements (x,y) of the destination cell
+ No check is done for indexes ranges.
+ pre:
+ start_row >= 0
+ end_row > start_row
+ """
+ cell_list = ['%d_%d'%(index, col_index) for index in range(start_row,
+ end_row + 1)]
+ num = (end_row - start_row + 1)
+ rule = '%d_%d=' % dest_cell + '('+'+'.join(cell_list)+')/%f'%num
+ self.add_rule(rule)
+
+
+
+class TableCellRenderer:
+ """Defines a simple text renderer
+ """
+
+ def __init__(self, **properties):
+ """keywords should be properties with an associated boolean as value.
+ For example :
+ renderer = TableCellRenderer(units = True, alignment = False)
+ An unspecified property will have a 'False' value by default.
+ Possible properties are :
+ alignment, unit
+ """
+ self.properties = properties
+
+
+ def render_cell(self, cell_coord, table, table_style):
+ """Renders the cell at 'cell_coord' in the table, using table_style
+ """
+ row_index, col_index = cell_coord
+ cell_value = table.data[row_index][col_index]
+ final_content = self._make_cell_content(cell_value,
+ table_style, col_index +1)
+ return self._render_cell_content(final_content,
+ table_style, col_index + 1)
+
+
+ def render_row_cell(self, row_name, table, table_style):
+ """Renders the cell for 'row_id' row
+ """
+ cell_value = row_name
+ return self._render_cell_content(cell_value, table_style, 0)
+
+
+ def render_col_cell(self, col_name, table, table_style):
+ """Renders the cell for 'col_id' row
+ """
+ cell_value = col_name
+ col_index = table.col_names.index(col_name)
+ return self._render_cell_content(cell_value, table_style, col_index +1)
+
+
+
+ def _render_cell_content(self, content, table_style, col_index):
+ """Makes the appropriate rendering for this cell content.
+ Rendering properties will be searched using the
+ *table_style.get_xxx_by_index(col_index)' methods
+
+ **This method should be overridden in the derived renderer classes.**
+ """
+ return content
+
+
+ def _make_cell_content(self, cell_content, table_style, col_index):
+ """Makes the cell content (adds decoration data, like units for
+ example)
+ """
+ final_content = cell_content
+ if 'skip_zero' in self.properties:
+ replacement_char = self.properties['skip_zero']
+ else:
+ replacement_char = 0
+ if replacement_char and final_content == 0:
+ return replacement_char
+
+ try:
+ units_on = self.properties['units']
+ if units_on:
+ final_content = self._add_unit(
+ cell_content, table_style, col_index)
+ except KeyError:
+ pass
+
+ return final_content
+
+
+ def _add_unit(self, cell_content, table_style, col_index):
+ """Adds unit to the cell_content if needed
+ """
+ unit = table_style.get_unit_by_index(col_index)
+ return str(cell_content) + " " + unit
+
+
+
+class DocbookRenderer(TableCellRenderer):
+ """Defines how to render a cell for a docboook table
+ """
+
+ def define_col_header(self, col_index, table_style):
+ """Computes the colspec element according to the style
+ """
+ size = table_style.get_size_by_index(col_index)
+ return '<colspec colname="c%d" colwidth="%s"/>\n' % \
+ (col_index, size)
+
+
+ def _render_cell_content(self, cell_content, table_style, col_index):
+ """Makes the appropriate rendering for this cell content.
+ Rendering properties will be searched using the
+ table_style.get_xxx_by_index(col_index)' methods.
+ """
+ try:
+ align_on = self.properties['alignment']
+ alignment = table_style.get_alignment_by_index(col_index)
+ if align_on:
+ return "<entry align='%s'>%s</entry>\n" % \
+ (alignment, cell_content)
+ except KeyError:
+ # KeyError <=> Default alignment
+ return "<entry>%s</entry>\n" % cell_content
+
+
+class TableWriter:
+ """A class to write tables
+ """
+
+ def __init__(self, stream, table, style, **properties):
+ self._stream = stream
+ self.style = style or TableStyle(table)
+ self._table = table
+ self.properties = properties
+ self.renderer = None
+
+
+ def set_style(self, style):
+ """sets the table's associated style
+ """
+ self.style = style
+
+
+ def set_renderer(self, renderer):
+ """sets the way to render cell
+ """
+ self.renderer = renderer
+
+
+ def update_properties(self, **properties):
+ """Updates writer's properties (for cell rendering)
+ """
+ self.properties.update(properties)
+
+
+ def write_table(self, title = ""):
+ """Writes the table
+ """
+ raise NotImplementedError("write_table must be implemented !")
+
+
+
+class DocbookTableWriter(TableWriter):
+ """Defines an implementation of TableWriter to write a table in Docbook
+ """
+
+ def _write_headers(self):
+ """Writes col headers
+ """
+ # Define col_headers (colstpec elements)
+ for col_index in range(len(self._table.col_names)+1):
+ self._stream.write(self.renderer.define_col_header(col_index,
+ self.style))
+
+ self._stream.write("<thead>\n<row>\n")
+ # XXX FIXME : write an empty entry <=> the first (__row_column) column
+ self._stream.write('<entry></entry>\n')
+ for col_name in self._table.col_names:
+ self._stream.write(self.renderer.render_col_cell(
+ col_name, self._table,
+ self.style))
+
+ self._stream.write("</row>\n</thead>\n")
+
+
+ def _write_body(self):
+ """Writes the table body
+ """
+ self._stream.write('<tbody>\n')
+
+ for row_index, row in enumerate(self._table.data):
+ self._stream.write('<row>\n')
+ row_name = self._table.row_names[row_index]
+ # Write the first entry (row_name)
+ self._stream.write(self.renderer.render_row_cell(row_name,
+ self._table,
+ self.style))
+
+ for col_index, cell in enumerate(row):
+ self._stream.write(self.renderer.render_cell(
+ (row_index, col_index),
+ self._table, self.style))
+
+ self._stream.write('</row>\n')
+
+ self._stream.write('</tbody>\n')
+
+
+ def write_table(self, title = ""):
+ """Writes the table
+ """
+ self._stream.write('<table>\n<title>%s></title>\n'%(title))
+ self._stream.write(
+ '<tgroup cols="%d" align="left" colsep="1" rowsep="1">\n'%
+ (len(self._table.col_names)+1))
+ self._write_headers()
+ self._write_body()
+
+ self._stream.write('</tgroup>\n</table>\n')
+
+
diff --git a/logilab/common/tasksqueue.py b/logilab/common/tasksqueue.py
new file mode 100644
index 0000000..ed74cf5
--- /dev/null
+++ b/logilab/common/tasksqueue.py
@@ -0,0 +1,101 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Prioritized tasks queue"""
+
+__docformat__ = "restructuredtext en"
+
+from bisect import insort_left
+
+from six.moves import queue
+
+LOW = 0
+MEDIUM = 10
+HIGH = 100
+
+PRIORITY = {
+ 'LOW': LOW,
+ 'MEDIUM': MEDIUM,
+ 'HIGH': HIGH,
+ }
+REVERSE_PRIORITY = dict((values, key) for key, values in PRIORITY.items())
+
+
+
+class PrioritizedTasksQueue(queue.Queue):
+
+ def _init(self, maxsize):
+ """Initialize the queue representation"""
+ self.maxsize = maxsize
+ # ordered list of task, from the lowest to the highest priority
+ self.queue = []
+
+ def _put(self, item):
+ """Put a new item in the queue"""
+ for i, task in enumerate(self.queue):
+ # equivalent task
+ if task == item:
+ # if new task has a higher priority, remove the one already
+ # queued so the new priority will be considered
+ if task < item:
+ item.merge(task)
+ del self.queue[i]
+ break
+ # else keep it so current order is kept
+ task.merge(item)
+ return
+ insort_left(self.queue, item)
+
+ def _get(self):
+ """Get an item from the queue"""
+ return self.queue.pop()
+
+ def __iter__(self):
+ return iter(self.queue)
+
+ def remove(self, tid):
+ """remove a specific task from the queue"""
+ # XXX acquire lock
+ for i, task in enumerate(self):
+ if task.id == tid:
+ self.queue.pop(i)
+ return
+ raise ValueError('not task of id %s in queue' % tid)
+
+class Task(object):
+ def __init__(self, tid, priority=LOW):
+ # task id
+ self.id = tid
+ # task priority
+ self.priority = priority
+
+ def __repr__(self):
+ return '<Task %s @%#x>' % (self.id, id(self))
+
+ def __cmp__(self, other):
+ return cmp(self.priority, other.priority)
+
+ def __lt__(self, other):
+ return self.priority < other.priority
+
+ def __eq__(self, other):
+ return self.id == other.id
+
+ __hash__ = object.__hash__
+
+ def merge(self, other):
+ pass
diff --git a/logilab/common/testlib.py b/logilab/common/testlib.py
new file mode 100644
index 0000000..31efe56
--- /dev/null
+++ b/logilab/common/testlib.py
@@ -0,0 +1,1392 @@
+# -*- coding: utf-8 -*-
+# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Run tests.
+
+This will find all modules whose name match a given prefix in the test
+directory, and run them. Various command line options provide
+additional facilities.
+
+Command line options:
+
+ -v verbose -- run tests in verbose mode with output to stdout
+ -q quiet -- don't print anything except if a test fails
+ -t testdir -- directory where the tests will be found
+ -x exclude -- add a test to exclude
+ -p profile -- profiled execution
+ -d dbc -- enable design-by-contract
+ -m match -- only run test matching the tag pattern which follow
+
+If no non-option arguments are present, prefixes used are 'test',
+'regrtest', 'smoketest' and 'unittest'.
+
+"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+# modified copy of some functions from test/regrtest.py from PyXml
+# disable camel case warning
+# pylint: disable=C0103
+
+import sys
+import os, os.path as osp
+import re
+import traceback
+import inspect
+import difflib
+import tempfile
+import math
+import warnings
+from shutil import rmtree
+from operator import itemgetter
+from itertools import dropwhile
+from inspect import isgeneratorfunction
+
+from six import string_types
+from six.moves import builtins, range, configparser, input
+
+from logilab.common.deprecation import deprecated
+
+import unittest as unittest_legacy
+if not getattr(unittest_legacy, "__package__", None):
+ try:
+ import unittest2 as unittest
+ from unittest2 import SkipTest
+ except ImportError:
+ raise ImportError("You have to install python-unittest2 to use %s" % __name__)
+else:
+ import unittest
+ from unittest import SkipTest
+
+from functools import wraps
+
+from logilab.common.debugger import Debugger, colorize_source
+from logilab.common.decorators import cached, classproperty
+from logilab.common import textutils
+
+
+__all__ = ['main', 'unittest_main', 'find_tests', 'run_test', 'spawn']
+
+DEFAULT_PREFIXES = ('test', 'regrtest', 'smoketest', 'unittest',
+ 'func', 'validation')
+
+is_generator = deprecated('[lgc 0.63] use inspect.isgeneratorfunction')(isgeneratorfunction)
+
+# used by unittest to count the number of relevant levels in the traceback
+__unittest = 1
+
+
+def with_tempdir(callable):
+ """A decorator ensuring no temporary file left when the function return
+ Work only for temporary file create with the tempfile module"""
+ if isgeneratorfunction(callable):
+ def proxy(*args, **kwargs):
+ old_tmpdir = tempfile.gettempdir()
+ new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-")
+ tempfile.tempdir = new_tmpdir
+ try:
+ for x in callable(*args, **kwargs):
+ yield x
+ finally:
+ try:
+ rmtree(new_tmpdir, ignore_errors=True)
+ finally:
+ tempfile.tempdir = old_tmpdir
+ return proxy
+
+ @wraps(callable)
+ def proxy(*args, **kargs):
+
+ old_tmpdir = tempfile.gettempdir()
+ new_tmpdir = tempfile.mkdtemp(prefix="temp-lgc-")
+ tempfile.tempdir = new_tmpdir
+ try:
+ return callable(*args, **kargs)
+ finally:
+ try:
+ rmtree(new_tmpdir, ignore_errors=True)
+ finally:
+ tempfile.tempdir = old_tmpdir
+ return proxy
+
+def in_tempdir(callable):
+ """A decorator moving the enclosed function inside the tempfile.tempfdir
+ """
+ @wraps(callable)
+ def proxy(*args, **kargs):
+
+ old_cwd = os.getcwd()
+ os.chdir(tempfile.tempdir)
+ try:
+ return callable(*args, **kargs)
+ finally:
+ os.chdir(old_cwd)
+ return proxy
+
+def within_tempdir(callable):
+ """A decorator run the enclosed function inside a tmpdir removed after execution
+ """
+ proxy = with_tempdir(in_tempdir(callable))
+ proxy.__name__ = callable.__name__
+ return proxy
+
+def find_tests(testdir,
+ prefixes=DEFAULT_PREFIXES, suffix=".py",
+ excludes=(),
+ remove_suffix=True):
+ """
+ Return a list of all applicable test modules.
+ """
+ tests = []
+ for name in os.listdir(testdir):
+ if not suffix or name.endswith(suffix):
+ for prefix in prefixes:
+ if name.startswith(prefix):
+ if remove_suffix and name.endswith(suffix):
+ name = name[:-len(suffix)]
+ if name not in excludes:
+ tests.append(name)
+ tests.sort()
+ return tests
+
+
+## PostMortem Debug facilities #####
+def start_interactive_mode(result):
+ """starts an interactive shell so that the user can inspect errors
+ """
+ debuggers = result.debuggers
+ descrs = result.error_descrs + result.fail_descrs
+ if len(debuggers) == 1:
+ # don't ask for test name if there's only one failure
+ debuggers[0].start()
+ else:
+ while True:
+ testindex = 0
+ print("Choose a test to debug:")
+ # order debuggers in the same way than errors were printed
+ print("\n".join(['\t%s : %s' % (i, descr) for i, (_, descr)
+ in enumerate(descrs)]))
+ print("Type 'exit' (or ^D) to quit")
+ print()
+ try:
+ todebug = input('Enter a test name: ')
+ if todebug.strip().lower() == 'exit':
+ print()
+ break
+ else:
+ try:
+ testindex = int(todebug)
+ debugger = debuggers[descrs[testindex][0]]
+ except (ValueError, IndexError):
+ print("ERROR: invalid test number %r" % (todebug, ))
+ else:
+ debugger.start()
+ except (EOFError, KeyboardInterrupt):
+ print()
+ break
+
+
+# test utils ##################################################################
+
+class SkipAwareTestResult(unittest._TextTestResult):
+
+ def __init__(self, stream, descriptions, verbosity,
+ exitfirst=False, pdbmode=False, cvg=None, colorize=False):
+ super(SkipAwareTestResult, self).__init__(stream,
+ descriptions, verbosity)
+ self.skipped = []
+ self.debuggers = []
+ self.fail_descrs = []
+ self.error_descrs = []
+ self.exitfirst = exitfirst
+ self.pdbmode = pdbmode
+ self.cvg = cvg
+ self.colorize = colorize
+ self.pdbclass = Debugger
+ self.verbose = verbosity > 1
+
+ def descrs_for(self, flavour):
+ return getattr(self, '%s_descrs' % flavour.lower())
+
+ def _create_pdb(self, test_descr, flavour):
+ self.descrs_for(flavour).append( (len(self.debuggers), test_descr) )
+ if self.pdbmode:
+ self.debuggers.append(self.pdbclass(sys.exc_info()[2]))
+
+ def _iter_valid_frames(self, frames):
+ """only consider non-testlib frames when formatting traceback"""
+ lgc_testlib = osp.abspath(__file__)
+ std_testlib = osp.abspath(unittest.__file__)
+ invalid = lambda fi: osp.abspath(fi[1]) in (lgc_testlib, std_testlib)
+ for frameinfo in dropwhile(invalid, frames):
+ yield frameinfo
+
+ def _exc_info_to_string(self, err, test):
+ """Converts a sys.exc_info()-style tuple of values into a string.
+
+ This method is overridden here because we want to colorize
+ lines if --color is passed, and display local variables if
+ --verbose is passed
+ """
+ exctype, exc, tb = err
+ output = ['Traceback (most recent call last)']
+ frames = inspect.getinnerframes(tb)
+ colorize = self.colorize
+ frames = enumerate(self._iter_valid_frames(frames))
+ for index, (frame, filename, lineno, funcname, ctx, ctxindex) in frames:
+ filename = osp.abspath(filename)
+ if ctx is None: # pyc files or C extensions for instance
+ source = '<no source available>'
+ else:
+ source = ''.join(ctx)
+ if colorize:
+ filename = textutils.colorize_ansi(filename, 'magenta')
+ source = colorize_source(source)
+ output.append(' File "%s", line %s, in %s' % (filename, lineno, funcname))
+ output.append(' %s' % source.strip())
+ if self.verbose:
+ output.append('%r == %r' % (dir(frame), test.__module__))
+ output.append('')
+ output.append(' ' + ' local variables '.center(66, '-'))
+ for varname, value in sorted(frame.f_locals.items()):
+ output.append(' %s: %r' % (varname, value))
+ if varname == 'self': # special handy processing for self
+ for varname, value in sorted(vars(value).items()):
+ output.append(' self.%s: %r' % (varname, value))
+ output.append(' ' + '-' * 66)
+ output.append('')
+ output.append(''.join(traceback.format_exception_only(exctype, exc)))
+ return '\n'.join(output)
+
+ def addError(self, test, err):
+ """err -> (exc_type, exc, tcbk)"""
+ exc_type, exc, _ = err
+ if isinstance(exc, SkipTest):
+ assert exc_type == SkipTest
+ self.addSkip(test, exc)
+ else:
+ if self.exitfirst:
+ self.shouldStop = True
+ descr = self.getDescription(test)
+ super(SkipAwareTestResult, self).addError(test, err)
+ self._create_pdb(descr, 'error')
+
+ def addFailure(self, test, err):
+ if self.exitfirst:
+ self.shouldStop = True
+ descr = self.getDescription(test)
+ super(SkipAwareTestResult, self).addFailure(test, err)
+ self._create_pdb(descr, 'fail')
+
+ def addSkip(self, test, reason):
+ self.skipped.append((test, reason))
+ if self.showAll:
+ self.stream.writeln("SKIPPED")
+ elif self.dots:
+ self.stream.write('S')
+
+ def printErrors(self):
+ super(SkipAwareTestResult, self).printErrors()
+ self.printSkippedList()
+
+ def printSkippedList(self):
+ # format (test, err) compatible with unittest2
+ for test, err in self.skipped:
+ descr = self.getDescription(test)
+ self.stream.writeln(self.separator1)
+ self.stream.writeln("%s: %s" % ('SKIPPED', descr))
+ self.stream.writeln("\t%s" % err)
+
+ def printErrorList(self, flavour, errors):
+ for (_, descr), (test, err) in zip(self.descrs_for(flavour), errors):
+ self.stream.writeln(self.separator1)
+ self.stream.writeln("%s: %s" % (flavour, descr))
+ self.stream.writeln(self.separator2)
+ self.stream.writeln(err)
+ self.stream.writeln('no stdout'.center(len(self.separator2)))
+ self.stream.writeln('no stderr'.center(len(self.separator2)))
+
+# Add deprecation warnings about new api used by module level fixtures in unittest2
+# http://www.voidspace.org.uk/python/articles/unittest2.shtml#setupmodule-and-teardownmodule
+class _DebugResult(object): # simplify import statement among unittest flavors..
+ "Used by the TestSuite to hold previous class when running in debug."
+ _previousTestClass = None
+ _moduleSetUpFailed = False
+ shouldStop = False
+
+from logilab.common.decorators import monkeypatch
+@monkeypatch(unittest.TestSuite)
+def _handleModuleTearDown(self, result):
+ previousModule = self._get_previous_module(result)
+ if previousModule is None:
+ return
+ if result._moduleSetUpFailed:
+ return
+ try:
+ module = sys.modules[previousModule]
+ except KeyError:
+ return
+ # add testlib specific deprecation warning and switch to new api
+ if hasattr(module, 'teardown_module'):
+ warnings.warn('Please rename teardown_module() to tearDownModule() instead.',
+ DeprecationWarning)
+ setattr(module, 'tearDownModule', module.teardown_module)
+ # end of monkey-patching
+ tearDownModule = getattr(module, 'tearDownModule', None)
+ if tearDownModule is not None:
+ try:
+ tearDownModule()
+ except Exception as e:
+ if isinstance(result, _DebugResult):
+ raise
+ errorName = 'tearDownModule (%s)' % previousModule
+ self._addClassOrModuleLevelException(result, e, errorName)
+
+@monkeypatch(unittest.TestSuite)
+def _handleModuleFixture(self, test, result):
+ previousModule = self._get_previous_module(result)
+ currentModule = test.__class__.__module__
+ if currentModule == previousModule:
+ return
+ self._handleModuleTearDown(result)
+ result._moduleSetUpFailed = False
+ try:
+ module = sys.modules[currentModule]
+ except KeyError:
+ return
+ # add testlib specific deprecation warning and switch to new api
+ if hasattr(module, 'setup_module'):
+ warnings.warn('Please rename setup_module() to setUpModule() instead.',
+ DeprecationWarning)
+ setattr(module, 'setUpModule', module.setup_module)
+ # end of monkey-patching
+ setUpModule = getattr(module, 'setUpModule', None)
+ if setUpModule is not None:
+ try:
+ setUpModule()
+ except Exception as e:
+ if isinstance(result, _DebugResult):
+ raise
+ result._moduleSetUpFailed = True
+ errorName = 'setUpModule (%s)' % currentModule
+ self._addClassOrModuleLevelException(result, e, errorName)
+
+# backward compatibility: TestSuite might be imported from lgc.testlib
+TestSuite = unittest.TestSuite
+
+class keywords(dict):
+ """Keyword args (**kwargs) support for generative tests."""
+
+class starargs(tuple):
+ """Variable arguments (*args) for generative tests."""
+ def __new__(cls, *args):
+ return tuple.__new__(cls, args)
+
+unittest_main = unittest.main
+
+
+class InnerTestSkipped(SkipTest):
+ """raised when a test is skipped"""
+ pass
+
+def parse_generative_args(params):
+ args = []
+ varargs = ()
+ kwargs = {}
+ flags = 0 # 2 <=> starargs, 4 <=> kwargs
+ for param in params:
+ if isinstance(param, starargs):
+ varargs = param
+ if flags:
+ raise TypeError('found starargs after keywords !')
+ flags |= 2
+ args += list(varargs)
+ elif isinstance(param, keywords):
+ kwargs = param
+ if flags & 4:
+ raise TypeError('got multiple keywords parameters')
+ flags |= 4
+ elif flags & 2 or flags & 4:
+ raise TypeError('found parameters after kwargs or args')
+ else:
+ args.append(param)
+
+ return args, kwargs
+
+
+class InnerTest(tuple):
+ def __new__(cls, name, *data):
+ instance = tuple.__new__(cls, data)
+ instance.name = name
+ return instance
+
+class Tags(set):
+ """A set of tag able validate an expression"""
+
+ def __init__(self, *tags, **kwargs):
+ self.inherit = kwargs.pop('inherit', True)
+ if kwargs:
+ raise TypeError("%s are an invalid keyword argument for this function" % kwargs.keys())
+
+ if len(tags) == 1 and not isinstance(tags[0], string_types):
+ tags = tags[0]
+ super(Tags, self).__init__(tags, **kwargs)
+
+ def __getitem__(self, key):
+ return key in self
+
+ def match(self, exp):
+ return eval(exp, {}, self)
+
+
+# duplicate definition from unittest2 of the _deprecate decorator
+def _deprecate(original_func):
+ def deprecated_func(*args, **kwargs):
+ warnings.warn(
+ ('Please use %s instead.' % original_func.__name__),
+ DeprecationWarning, 2)
+ return original_func(*args, **kwargs)
+ return deprecated_func
+
+class TestCase(unittest.TestCase):
+ """A unittest.TestCase extension with some additional methods."""
+ maxDiff = None
+ pdbclass = Debugger
+ tags = Tags()
+
+ def __init__(self, methodName='runTest'):
+ super(TestCase, self).__init__(methodName)
+ self.__exc_info = sys.exc_info
+ self.__testMethodName = self._testMethodName
+ self._current_test_descr = None
+ self._options_ = None
+
+ @classproperty
+ @cached
+ def datadir(cls): # pylint: disable=E0213
+ """helper attribute holding the standard test's data directory
+
+ NOTE: this is a logilab's standard
+ """
+ mod = __import__(cls.__module__)
+ return osp.join(osp.dirname(osp.abspath(mod.__file__)), 'data')
+ # cache it (use a class method to cache on class since TestCase is
+ # instantiated for each test run)
+
+ @classmethod
+ def datapath(cls, *fname):
+ """joins the object's datadir and `fname`"""
+ return osp.join(cls.datadir, *fname)
+
+ def set_description(self, descr):
+ """sets the current test's description.
+ This can be useful for generative tests because it allows to specify
+ a description per yield
+ """
+ self._current_test_descr = descr
+
+ # override default's unittest.py feature
+ def shortDescription(self):
+ """override default unittest shortDescription to handle correctly
+ generative tests
+ """
+ if self._current_test_descr is not None:
+ return self._current_test_descr
+ return super(TestCase, self).shortDescription()
+
+ def quiet_run(self, result, func, *args, **kwargs):
+ try:
+ func(*args, **kwargs)
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except unittest.SkipTest as e:
+ if hasattr(result, 'addSkip'):
+ result.addSkip(self, str(e))
+ else:
+ warnings.warn("TestResult has no addSkip method, skips not reported",
+ RuntimeWarning, 2)
+ result.addSuccess(self)
+ return False
+ except:
+ result.addError(self, self.__exc_info())
+ return False
+ return True
+
+ def _get_test_method(self):
+ """return the test method"""
+ return getattr(self, self._testMethodName)
+
+ def optval(self, option, default=None):
+ """return the option value or default if the option is not define"""
+ return getattr(self._options_, option, default)
+
+ def __call__(self, result=None, runcondition=None, options=None):
+ """rewrite TestCase.__call__ to support generative tests
+ This is mostly a copy/paste from unittest.py (i.e same
+ variable names, same logic, except for the generative tests part)
+ """
+ from logilab.common.pytest import FILE_RESTART
+ if result is None:
+ result = self.defaultTestResult()
+ result.pdbclass = self.pdbclass
+ self._options_ = options
+ # if result.cvg:
+ # result.cvg.start()
+ testMethod = self._get_test_method()
+ if (getattr(self.__class__, "__unittest_skip__", False) or
+ getattr(testMethod, "__unittest_skip__", False)):
+ # If the class or method was skipped.
+ try:
+ skip_why = (getattr(self.__class__, '__unittest_skip_why__', '')
+ or getattr(testMethod, '__unittest_skip_why__', ''))
+ self._addSkip(result, skip_why)
+ finally:
+ result.stopTest(self)
+ return
+ if runcondition and not runcondition(testMethod):
+ return # test is skipped
+ result.startTest(self)
+ try:
+ if not self.quiet_run(result, self.setUp):
+ return
+ generative = isgeneratorfunction(testMethod)
+ # generative tests
+ if generative:
+ self._proceed_generative(result, testMethod,
+ runcondition)
+ else:
+ status = self._proceed(result, testMethod)
+ success = (status == 0)
+ if not self.quiet_run(result, self.tearDown):
+ return
+ if not generative and success:
+ if hasattr(options, "exitfirst") and options.exitfirst:
+ # add this test to restart file
+ try:
+ restartfile = open(FILE_RESTART, 'a')
+ try:
+ descr = '.'.join((self.__class__.__module__,
+ self.__class__.__name__,
+ self._testMethodName))
+ restartfile.write(descr+os.linesep)
+ finally:
+ restartfile.close()
+ except Exception:
+ print("Error while saving succeeded test into",
+ osp.join(os.getcwd(), FILE_RESTART),
+ file=sys.__stderr__)
+ raise
+ result.addSuccess(self)
+ finally:
+ # if result.cvg:
+ # result.cvg.stop()
+ result.stopTest(self)
+
+ def _proceed_generative(self, result, testfunc, runcondition=None):
+ # cancel startTest()'s increment
+ result.testsRun -= 1
+ success = True
+ try:
+ for params in testfunc():
+ if runcondition and not runcondition(testfunc,
+ skipgenerator=False):
+ if not (isinstance(params, InnerTest)
+ and runcondition(params)):
+ continue
+ if not isinstance(params, (tuple, list)):
+ params = (params, )
+ func = params[0]
+ args, kwargs = parse_generative_args(params[1:])
+ # increment test counter manually
+ result.testsRun += 1
+ status = self._proceed(result, func, args, kwargs)
+ if status == 0:
+ result.addSuccess(self)
+ success = True
+ else:
+ success = False
+ # XXX Don't stop anymore if an error occured
+ #if status == 2:
+ # result.shouldStop = True
+ if result.shouldStop: # either on error or on exitfirst + error
+ break
+ except:
+ # if an error occurs between two yield
+ result.addError(self, self.__exc_info())
+ success = False
+ return success
+
+ def _proceed(self, result, testfunc, args=(), kwargs=None):
+ """proceed the actual test
+ returns 0 on success, 1 on failure, 2 on error
+
+ Note: addSuccess can't be called here because we have to wait
+ for tearDown to be successfully executed to declare the test as
+ successful
+ """
+ kwargs = kwargs or {}
+ try:
+ testfunc(*args, **kwargs)
+ except self.failureException:
+ result.addFailure(self, self.__exc_info())
+ return 1
+ except KeyboardInterrupt:
+ raise
+ except InnerTestSkipped as e:
+ result.addSkip(self, e)
+ return 1
+ except SkipTest as e:
+ result.addSkip(self, e)
+ return 0
+ except:
+ result.addError(self, self.__exc_info())
+ return 2
+ return 0
+
+ def defaultTestResult(self):
+ """return a new instance of the defaultTestResult"""
+ return SkipAwareTestResult()
+
+ skip = _deprecate(unittest.TestCase.skipTest)
+ assertEquals = _deprecate(unittest.TestCase.assertEqual)
+ assertNotEquals = _deprecate(unittest.TestCase.assertNotEqual)
+ assertAlmostEquals = _deprecate(unittest.TestCase.assertAlmostEqual)
+ assertNotAlmostEquals = _deprecate(unittest.TestCase.assertNotAlmostEqual)
+
+ def innerSkip(self, msg=None):
+ """mark a generative test as skipped for the <msg> reason"""
+ msg = msg or 'test was skipped'
+ raise InnerTestSkipped(msg)
+
+ @deprecated('Please use assertDictEqual instead.')
+ def assertDictEquals(self, dict1, dict2, msg=None, context=None):
+ """compares two dicts
+
+ If the two dict differ, the first difference is shown in the error
+ message
+ :param dict1: a Python Dictionary
+ :param dict2: a Python Dictionary
+ :param msg: custom message (String) in case of failure
+ """
+ dict1 = dict(dict1)
+ msgs = []
+ for key, value in dict2.items():
+ try:
+ if dict1[key] != value:
+ msgs.append('%r != %r for key %r' % (dict1[key], value,
+ key))
+ del dict1[key]
+ except KeyError:
+ msgs.append('missing %r key' % key)
+ if dict1:
+ msgs.append('dict2 is lacking %r' % dict1)
+ if msg:
+ self.failureException(msg)
+ elif msgs:
+ if context is not None:
+ base = '%s\n' % context
+ else:
+ base = ''
+ self.fail(base + '\n'.join(msgs))
+
+ @deprecated('Please use assertCountEqual instead.')
+ def assertUnorderedIterableEquals(self, got, expected, msg=None):
+ """compares two iterable and shows difference between both
+
+ :param got: the unordered Iterable that we found
+ :param expected: the expected unordered Iterable
+ :param msg: custom message (String) in case of failure
+ """
+ got, expected = list(got), list(expected)
+ self.assertSetEqual(set(got), set(expected), msg)
+ if len(got) != len(expected):
+ if msg is None:
+ msg = ['Iterable have the same elements but not the same number',
+ '\t<element>\t<expected>i\t<got>']
+ got_count = {}
+ expected_count = {}
+ for element in got:
+ got_count[element] = got_count.get(element, 0) + 1
+ for element in expected:
+ expected_count[element] = expected_count.get(element, 0) + 1
+ # we know that got_count.key() == expected_count.key()
+ # because of assertSetEqual
+ for element, count in got_count.iteritems():
+ other_count = expected_count[element]
+ if other_count != count:
+ msg.append('\t%s\t%s\t%s' % (element, other_count, count))
+
+ self.fail(msg)
+
+ assertUnorderedIterableEqual = assertUnorderedIterableEquals
+ assertUnordIterEquals = assertUnordIterEqual = assertUnorderedIterableEqual
+
+ @deprecated('Please use assertSetEqual instead.')
+ def assertSetEquals(self,got,expected, msg=None):
+ """compares two sets and shows difference between both
+
+ Don't use it for iterables other than sets.
+
+ :param got: the Set that we found
+ :param expected: the second Set to be compared to the first one
+ :param msg: custom message (String) in case of failure
+ """
+
+ if not(isinstance(got, set) and isinstance(expected, set)):
+ warnings.warn("the assertSetEquals function if now intended for set only."\
+ "use assertUnorderedIterableEquals instead.",
+ DeprecationWarning, 2)
+ return self.assertUnorderedIterableEquals(got, expected, msg)
+
+ items={}
+ items['missing'] = expected - got
+ items['unexpected'] = got - expected
+ if any(items.itervalues()):
+ if msg is None:
+ msg = '\n'.join('%s:\n\t%s' % (key, "\n\t".join(str(value) for value in values))
+ for key, values in items.iteritems() if values)
+ self.fail(msg)
+
+ @deprecated('Please use assertListEqual instead.')
+ def assertListEquals(self, list_1, list_2, msg=None):
+ """compares two lists
+
+ If the two list differ, the first difference is shown in the error
+ message
+
+ :param list_1: a Python List
+ :param list_2: a second Python List
+ :param msg: custom message (String) in case of failure
+ """
+ _l1 = list_1[:]
+ for i, value in enumerate(list_2):
+ try:
+ if _l1[0] != value:
+ from pprint import pprint
+ pprint(list_1)
+ pprint(list_2)
+ self.fail('%r != %r for index %d' % (_l1[0], value, i))
+ del _l1[0]
+ except IndexError:
+ if msg is None:
+ msg = 'list_1 has only %d elements, not %s '\
+ '(at least %r missing)'% (i, len(list_2), value)
+ self.fail(msg)
+ if _l1:
+ if msg is None:
+ msg = 'list_2 is lacking %r' % _l1
+ self.fail(msg)
+
+ @deprecated('Non-standard. Please use assertMultiLineEqual instead.')
+ def assertLinesEquals(self, string1, string2, msg=None, striplines=False):
+ """compare two strings and assert that the text lines of the strings
+ are equal.
+
+ :param string1: a String
+ :param string2: a String
+ :param msg: custom message (String) in case of failure
+ :param striplines: Boolean to trigger line stripping before comparing
+ """
+ lines1 = string1.splitlines()
+ lines2 = string2.splitlines()
+ if striplines:
+ lines1 = [l.strip() for l in lines1]
+ lines2 = [l.strip() for l in lines2]
+ self.assertListEqual(lines1, lines2, msg)
+ assertLineEqual = assertLinesEquals
+
+ @deprecated('Non-standard: please copy test method to your TestCase class')
+ def assertXMLWellFormed(self, stream, msg=None, context=2):
+ """asserts the XML stream is well-formed (no DTD conformance check)
+
+ :param context: number of context lines in standard message
+ (show all data if negative).
+ Only available with element tree
+ """
+ try:
+ from xml.etree.ElementTree import parse
+ self._assertETXMLWellFormed(stream, parse, msg)
+ except ImportError:
+ from xml.sax import make_parser, SAXParseException
+ parser = make_parser()
+ try:
+ parser.parse(stream)
+ except SAXParseException as ex:
+ if msg is None:
+ stream.seek(0)
+ for _ in range(ex.getLineNumber()):
+ line = stream.readline()
+ pointer = ('' * (ex.getLineNumber() - 1)) + '^'
+ msg = 'XML stream not well formed: %s\n%s%s' % (ex, line, pointer)
+ self.fail(msg)
+
+ @deprecated('Non-standard: please copy test method to your TestCase class')
+ def assertXMLStringWellFormed(self, xml_string, msg=None, context=2):
+ """asserts the XML string is well-formed (no DTD conformance check)
+
+ :param context: number of context lines in standard message
+ (show all data if negative).
+ Only available with element tree
+ """
+ try:
+ from xml.etree.ElementTree import fromstring
+ except ImportError:
+ from elementtree.ElementTree import fromstring
+ self._assertETXMLWellFormed(xml_string, fromstring, msg)
+
+ def _assertETXMLWellFormed(self, data, parse, msg=None, context=2):
+ """internal function used by /assertXML(String)?WellFormed/ functions
+
+ :param data: xml_data
+ :param parse: appropriate parser function for this data
+ :param msg: error message
+ :param context: number of context lines in standard message
+ (show all data if negative).
+ Only available with element tree
+ """
+ from xml.parsers.expat import ExpatError
+ try:
+ from xml.etree.ElementTree import ParseError
+ except ImportError:
+ # compatibility for <python2.7
+ ParseError = ExpatError
+ try:
+ parse(data)
+ except (ExpatError, ParseError) as ex:
+ if msg is None:
+ if hasattr(data, 'readlines'): #file like object
+ data.seek(0)
+ lines = data.readlines()
+ else:
+ lines = data.splitlines(True)
+ nb_lines = len(lines)
+ context_lines = []
+
+ # catch when ParseError doesn't set valid lineno
+ if ex.lineno is not None:
+ if context < 0:
+ start = 1
+ end = nb_lines
+ else:
+ start = max(ex.lineno-context, 1)
+ end = min(ex.lineno+context, nb_lines)
+ line_number_length = len('%i' % end)
+ line_pattern = " %%%ii: %%s" % line_number_length
+
+ for line_no in range(start, ex.lineno):
+ context_lines.append(line_pattern % (line_no, lines[line_no-1]))
+ context_lines.append(line_pattern % (ex.lineno, lines[ex.lineno-1]))
+ context_lines.append('%s^\n' % (' ' * (1 + line_number_length + 2 +ex.offset)))
+ for line_no in range(ex.lineno+1, end+1):
+ context_lines.append(line_pattern % (line_no, lines[line_no-1]))
+
+ rich_context = ''.join(context_lines)
+ msg = 'XML stream not well formed: %s\n%s' % (ex, rich_context)
+ self.fail(msg)
+
+ @deprecated('Non-standard: please copy test method to your TestCase class')
+ def assertXMLEqualsTuple(self, element, tup):
+ """compare an ElementTree Element to a tuple formatted as follow:
+ (tagname, [attrib[, children[, text[, tail]]]])"""
+ # check tag
+ self.assertTextEquals(element.tag, tup[0])
+ # check attrib
+ if len(element.attrib) or len(tup)>1:
+ if len(tup)<=1:
+ self.fail( "tuple %s has no attributes (%s expected)"%(tup,
+ dict(element.attrib)))
+ self.assertDictEqual(element.attrib, tup[1])
+ # check children
+ if len(element) or len(tup)>2:
+ if len(tup)<=2:
+ self.fail( "tuple %s has no children (%i expected)"%(tup,
+ len(element)))
+ if len(element) != len(tup[2]):
+ self.fail( "tuple %s has %i children%s (%i expected)"%(tup,
+ len(tup[2]),
+ ('', 's')[len(tup[2])>1], len(element)))
+ for index in range(len(tup[2])):
+ self.assertXMLEqualsTuple(element[index], tup[2][index])
+ #check text
+ if element.text or len(tup)>3:
+ if len(tup)<=3:
+ self.fail( "tuple %s has no text value (%r expected)"%(tup,
+ element.text))
+ self.assertTextEquals(element.text, tup[3])
+ #check tail
+ if element.tail or len(tup)>4:
+ if len(tup)<=4:
+ self.fail( "tuple %s has no tail value (%r expected)"%(tup,
+ element.tail))
+ self.assertTextEquals(element.tail, tup[4])
+
+ def _difftext(self, lines1, lines2, junk=None, msg_prefix='Texts differ'):
+ junk = junk or (' ', '\t')
+ # result is a generator
+ result = difflib.ndiff(lines1, lines2, charjunk=lambda x: x in junk)
+ read = []
+ for line in result:
+ read.append(line)
+ # lines that don't start with a ' ' are diff ones
+ if not line.startswith(' '):
+ self.fail('\n'.join(['%s\n'%msg_prefix]+read + list(result)))
+
+ @deprecated('Non-standard. Please use assertMultiLineEqual instead.')
+ def assertTextEquals(self, text1, text2, junk=None,
+ msg_prefix='Text differ', striplines=False):
+ """compare two multiline strings (using difflib and splitlines())
+
+ :param text1: a Python BaseString
+ :param text2: a second Python Basestring
+ :param junk: List of Caracters
+ :param msg_prefix: String (message prefix)
+ :param striplines: Boolean to trigger line stripping before comparing
+ """
+ msg = []
+ if not isinstance(text1, string_types):
+ msg.append('text1 is not a string (%s)'%(type(text1)))
+ if not isinstance(text2, string_types):
+ msg.append('text2 is not a string (%s)'%(type(text2)))
+ if msg:
+ self.fail('\n'.join(msg))
+ lines1 = text1.strip().splitlines(True)
+ lines2 = text2.strip().splitlines(True)
+ if striplines:
+ lines1 = [line.strip() for line in lines1]
+ lines2 = [line.strip() for line in lines2]
+ self._difftext(lines1, lines2, junk, msg_prefix)
+ assertTextEqual = assertTextEquals
+
+ @deprecated('Non-standard: please copy test method to your TestCase class')
+ def assertStreamEquals(self, stream1, stream2, junk=None,
+ msg_prefix='Stream differ'):
+ """compare two streams (using difflib and readlines())"""
+ # if stream2 is stream2, readlines() on stream1 will also read lines
+ # in stream2, so they'll appear different, although they're not
+ if stream1 is stream2:
+ return
+ # make sure we compare from the beginning of the stream
+ stream1.seek(0)
+ stream2.seek(0)
+ # compare
+ self._difftext(stream1.readlines(), stream2.readlines(), junk,
+ msg_prefix)
+
+ assertStreamEqual = assertStreamEquals
+
+ @deprecated('Non-standard: please copy test method to your TestCase class')
+ def assertFileEquals(self, fname1, fname2, junk=(' ', '\t')):
+ """compares two files using difflib"""
+ self.assertStreamEqual(open(fname1), open(fname2), junk,
+ msg_prefix='Files differs\n-:%s\n+:%s\n'%(fname1, fname2))
+
+ assertFileEqual = assertFileEquals
+
+ @deprecated('Non-standard: please copy test method to your TestCase class')
+ def assertDirEquals(self, path_a, path_b):
+ """compares two files using difflib"""
+ assert osp.exists(path_a), "%s doesn't exists" % path_a
+ assert osp.exists(path_b), "%s doesn't exists" % path_b
+
+ all_a = [ (ipath[len(path_a):].lstrip('/'), idirs, ifiles)
+ for ipath, idirs, ifiles in os.walk(path_a)]
+ all_a.sort(key=itemgetter(0))
+
+ all_b = [ (ipath[len(path_b):].lstrip('/'), idirs, ifiles)
+ for ipath, idirs, ifiles in os.walk(path_b)]
+ all_b.sort(key=itemgetter(0))
+
+ iter_a, iter_b = iter(all_a), iter(all_b)
+ partial_iter = True
+ ipath_a, idirs_a, ifiles_a = data_a = None, None, None
+ while True:
+ try:
+ ipath_a, idirs_a, ifiles_a = datas_a = next(iter_a)
+ partial_iter = False
+ ipath_b, idirs_b, ifiles_b = datas_b = next(iter_b)
+ partial_iter = True
+
+
+ self.assertTrue(ipath_a == ipath_b,
+ "unexpected %s in %s while looking %s from %s" %
+ (ipath_a, path_a, ipath_b, path_b))
+
+
+ errors = {}
+ sdirs_a = set(idirs_a)
+ sdirs_b = set(idirs_b)
+ errors["unexpected directories"] = sdirs_a - sdirs_b
+ errors["missing directories"] = sdirs_b - sdirs_a
+
+ sfiles_a = set(ifiles_a)
+ sfiles_b = set(ifiles_b)
+ errors["unexpected files"] = sfiles_a - sfiles_b
+ errors["missing files"] = sfiles_b - sfiles_a
+
+
+ msgs = [ "%s: %s"% (name, items)
+ for name, items in errors.items() if items]
+
+ if msgs:
+ msgs.insert(0, "%s and %s differ :" % (
+ osp.join(path_a, ipath_a),
+ osp.join(path_b, ipath_b),
+ ))
+ self.fail("\n".join(msgs))
+
+ for files in (ifiles_a, ifiles_b):
+ files.sort()
+
+ for index, path in enumerate(ifiles_a):
+ self.assertFileEquals(osp.join(path_a, ipath_a, path),
+ osp.join(path_b, ipath_b, ifiles_b[index]))
+
+ except StopIteration:
+ break
+
+ assertDirEqual = assertDirEquals
+
+ def assertIsInstance(self, obj, klass, msg=None, strict=False):
+ """check if an object is an instance of a class
+
+ :param obj: the Python Object to be checked
+ :param klass: the target class
+ :param msg: a String for a custom message
+ :param strict: if True, check that the class of <obj> is <klass>;
+ else check with 'isinstance'
+ """
+ if strict:
+ warnings.warn('[API] Non-standard. Strict parameter has vanished',
+ DeprecationWarning, stacklevel=2)
+ if msg is None:
+ if strict:
+ msg = '%r is not of class %s but of %s'
+ else:
+ msg = '%r is not an instance of %s but of %s'
+ msg = msg % (obj, klass, type(obj))
+ if strict:
+ self.assertTrue(obj.__class__ is klass, msg)
+ else:
+ self.assertTrue(isinstance(obj, klass), msg)
+
+ @deprecated('Please use assertIsNone instead.')
+ def assertNone(self, obj, msg=None):
+ """assert obj is None
+
+ :param obj: Python Object to be tested
+ """
+ if msg is None:
+ msg = "reference to %r when None expected"%(obj,)
+ self.assertTrue( obj is None, msg )
+
+ @deprecated('Please use assertIsNotNone instead.')
+ def assertNotNone(self, obj, msg=None):
+ """assert obj is not None"""
+ if msg is None:
+ msg = "unexpected reference to None"
+ self.assertTrue( obj is not None, msg )
+
+ @deprecated('Non-standard. Please use assertAlmostEqual instead.')
+ def assertFloatAlmostEquals(self, obj, other, prec=1e-5,
+ relative=False, msg=None):
+ """compares if two floats have a distance smaller than expected
+ precision.
+
+ :param obj: a Float
+ :param other: another Float to be comparted to <obj>
+ :param prec: a Float describing the precision
+ :param relative: boolean switching to relative/absolute precision
+ :param msg: a String for a custom message
+ """
+ if msg is None:
+ msg = "%r != %r" % (obj, other)
+ if relative:
+ prec = prec*math.fabs(obj)
+ self.assertTrue(math.fabs(obj - other) < prec, msg)
+
+ def failUnlessRaises(self, excClass, callableObj=None, *args, **kwargs):
+ """override default failUnlessRaises method to return the raised
+ exception instance.
+
+ Fail unless an exception of class excClass is thrown
+ by callableObj when invoked with arguments args and keyword
+ arguments kwargs. If a different type of exception is
+ thrown, it will not be caught, and the test case will be
+ deemed to have suffered an error, exactly as for an
+ unexpected exception.
+
+ CAUTION! There are subtle differences between Logilab and unittest2
+ - exc is not returned in standard version
+ - context capabilities in standard version
+ - try/except/else construction (minor)
+
+ :param excClass: the Exception to be raised
+ :param callableObj: a callable Object which should raise <excClass>
+ :param args: a List of arguments for <callableObj>
+ :param kwargs: a List of keyword arguments for <callableObj>
+ """
+ # XXX cube vcslib : test_branches_from_app
+ if callableObj is None:
+ _assert = super(TestCase, self).assertRaises
+ return _assert(excClass, callableObj, *args, **kwargs)
+ try:
+ callableObj(*args, **kwargs)
+ except excClass as exc:
+ class ProxyException:
+ def __init__(self, obj):
+ self._obj = obj
+ def __getattr__(self, attr):
+ warn_msg = ("This exception was retrieved with the old testlib way "
+ "`exc = self.assertRaises(Exc, callable)`, please use "
+ "the context manager instead'")
+ warnings.warn(warn_msg, DeprecationWarning, 2)
+ return self._obj.__getattribute__(attr)
+ return ProxyException(exc)
+ else:
+ if hasattr(excClass, '__name__'):
+ excName = excClass.__name__
+ else:
+ excName = str(excClass)
+ raise self.failureException("%s not raised" % excName)
+
+ assertRaises = failUnlessRaises
+
+ if sys.version_info >= (3,2):
+ assertItemsEqual = unittest.TestCase.assertCountEqual
+ else:
+ assertCountEqual = unittest.TestCase.assertItemsEqual
+ if sys.version_info < (2,7):
+ def assertIsNotNone(self, value, *args, **kwargs):
+ self.assertNotEqual(None, value, *args, **kwargs)
+
+TestCase.assertItemsEqual = deprecated('assertItemsEqual is deprecated, use assertCountEqual')(
+ TestCase.assertItemsEqual)
+
+import doctest
+
+class SkippedSuite(unittest.TestSuite):
+ def test(self):
+ """just there to trigger test execution"""
+ self.skipped_test('doctest module has no DocTestSuite class')
+
+
+class DocTestFinder(doctest.DocTestFinder):
+
+ def __init__(self, *args, **kwargs):
+ self.skipped = kwargs.pop('skipped', ())
+ doctest.DocTestFinder.__init__(self, *args, **kwargs)
+
+ def _get_test(self, obj, name, module, globs, source_lines):
+ """override default _get_test method to be able to skip tests
+ according to skipped attribute's value
+ """
+ if getattr(obj, '__name__', '') in self.skipped:
+ return None
+ return doctest.DocTestFinder._get_test(self, obj, name, module,
+ globs, source_lines)
+
+
+class DocTest(TestCase):
+ """trigger module doctest
+ I don't know how to make unittest.main consider the DocTestSuite instance
+ without this hack
+ """
+ skipped = ()
+ def __call__(self, result=None, runcondition=None, options=None):\
+ # pylint: disable=W0613
+ try:
+ finder = DocTestFinder(skipped=self.skipped)
+ suite = doctest.DocTestSuite(self.module, test_finder=finder)
+ # XXX iirk
+ doctest.DocTestCase._TestCase__exc_info = sys.exc_info
+ except AttributeError:
+ suite = SkippedSuite()
+ # doctest may gork the builtins dictionnary
+ # This happen to the "_" entry used by gettext
+ old_builtins = builtins.__dict__.copy()
+ try:
+ return suite.run(result)
+ finally:
+ builtins.__dict__.clear()
+ builtins.__dict__.update(old_builtins)
+ run = __call__
+
+ def test(self):
+ """just there to trigger test execution"""
+
+MAILBOX = None
+
+class MockSMTP:
+ """fake smtplib.SMTP"""
+
+ def __init__(self, host, port):
+ self.host = host
+ self.port = port
+ global MAILBOX
+ self.reveived = MAILBOX = []
+
+ def set_debuglevel(self, debuglevel):
+ """ignore debug level"""
+
+ def sendmail(self, fromaddr, toaddres, body):
+ """push sent mail in the mailbox"""
+ self.reveived.append((fromaddr, toaddres, body))
+
+ def quit(self):
+ """ignore quit"""
+
+
+class MockConfigParser(configparser.ConfigParser):
+ """fake ConfigParser.ConfigParser"""
+
+ def __init__(self, options):
+ configparser.ConfigParser.__init__(self)
+ for section, pairs in options.iteritems():
+ self.add_section(section)
+ for key, value in pairs.iteritems():
+ self.set(section, key, value)
+ def write(self, _):
+ raise NotImplementedError()
+
+
+class MockConnection:
+ """fake DB-API 2.0 connexion AND cursor (i.e. cursor() return self)"""
+
+ def __init__(self, results):
+ self.received = []
+ self.states = []
+ self.results = results
+
+ def cursor(self):
+ """Mock cursor method"""
+ return self
+ def execute(self, query, args=None):
+ """Mock execute method"""
+ self.received.append( (query, args) )
+ def fetchone(self):
+ """Mock fetchone method"""
+ return self.results[0]
+ def fetchall(self):
+ """Mock fetchall method"""
+ return self.results
+ def commit(self):
+ """Mock commiy method"""
+ self.states.append( ('commit', len(self.received)) )
+ def rollback(self):
+ """Mock rollback method"""
+ self.states.append( ('rollback', len(self.received)) )
+ def close(self):
+ """Mock close method"""
+ pass
+
+
+def mock_object(**params):
+ """creates an object using params to set attributes
+ >>> option = mock_object(verbose=False, index=range(5))
+ >>> option.verbose
+ False
+ >>> option.index
+ [0, 1, 2, 3, 4]
+ """
+ return type('Mock', (), params)()
+
+
+def create_files(paths, chroot):
+ """Creates directories and files found in <path>.
+
+ :param paths: list of relative paths to files or directories
+ :param chroot: the root directory in which paths will be created
+
+ >>> from os.path import isdir, isfile
+ >>> isdir('/tmp/a')
+ False
+ >>> create_files(['a/b/foo.py', 'a/b/c/', 'a/b/c/d/e.py'], '/tmp')
+ >>> isdir('/tmp/a')
+ True
+ >>> isdir('/tmp/a/b/c')
+ True
+ >>> isfile('/tmp/a/b/c/d/e.py')
+ True
+ >>> isfile('/tmp/a/b/foo.py')
+ True
+ """
+ dirs, files = set(), set()
+ for path in paths:
+ path = osp.join(chroot, path)
+ filename = osp.basename(path)
+ # path is a directory path
+ if filename == '':
+ dirs.add(path)
+ # path is a filename path
+ else:
+ dirs.add(osp.dirname(path))
+ files.add(path)
+ for dirpath in dirs:
+ if not osp.isdir(dirpath):
+ os.makedirs(dirpath)
+ for filepath in files:
+ open(filepath, 'w').close()
+
+
+class AttrObject: # XXX cf mock_object
+ def __init__(self, **kwargs):
+ self.__dict__.update(kwargs)
+
+def tag(*args, **kwargs):
+ """descriptor adding tag to a function"""
+ def desc(func):
+ assert not hasattr(func, 'tags')
+ func.tags = Tags(*args, **kwargs)
+ return func
+ return desc
+
+def require_version(version):
+ """ Compare version of python interpreter to the given one. Skip the test
+ if older.
+ """
+ def check_require_version(f):
+ version_elements = version.split('.')
+ try:
+ compare = tuple([int(v) for v in version_elements])
+ except ValueError:
+ raise ValueError('%s is not a correct version : should be X.Y[.Z].' % version)
+ current = sys.version_info[:3]
+ if current < compare:
+ def new_f(self, *args, **kwargs):
+ self.skipTest('Need at least %s version of python. Current version is %s.' % (version, '.'.join([str(element) for element in current])))
+ new_f.__name__ = f.__name__
+ return new_f
+ else:
+ return f
+ return check_require_version
+
+def require_module(module):
+ """ Check if the given module is loaded. Skip the test if not.
+ """
+ def check_require_module(f):
+ try:
+ __import__(module)
+ return f
+ except ImportError:
+ def new_f(self, *args, **kwargs):
+ self.skipTest('%s can not be imported.' % module)
+ new_f.__name__ = f.__name__
+ return new_f
+ return check_require_module
+
diff --git a/logilab/common/textutils.py b/logilab/common/textutils.py
new file mode 100644
index 0000000..9046f97
--- /dev/null
+++ b/logilab/common/textutils.py
@@ -0,0 +1,537 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Some text manipulation utility functions.
+
+
+:group text formatting: normalize_text, normalize_paragraph, pretty_match,\
+unquote, colorize_ansi
+:group text manipulation: searchall, splitstrip
+:sort: text formatting, text manipulation
+
+:type ANSI_STYLES: dict(str)
+:var ANSI_STYLES: dictionary mapping style identifier to ANSI terminal code
+
+:type ANSI_COLORS: dict(str)
+:var ANSI_COLORS: dictionary mapping color identifier to ANSI terminal code
+
+:type ANSI_PREFIX: str
+:var ANSI_PREFIX:
+ ANSI terminal code notifying the start of an ANSI escape sequence
+
+:type ANSI_END: str
+:var ANSI_END:
+ ANSI terminal code notifying the end of an ANSI escape sequence
+
+:type ANSI_RESET: str
+:var ANSI_RESET:
+ ANSI terminal code resetting format defined by a previous ANSI escape sequence
+"""
+__docformat__ = "restructuredtext en"
+
+import sys
+import re
+import os.path as osp
+from warnings import warn
+from unicodedata import normalize as _uninormalize
+try:
+ from os import linesep
+except ImportError:
+ linesep = '\n' # gae
+
+from logilab.common.deprecation import deprecated
+
+MANUAL_UNICODE_MAP = {
+ u'\xa1': u'!', # INVERTED EXCLAMATION MARK
+ u'\u0142': u'l', # LATIN SMALL LETTER L WITH STROKE
+ u'\u2044': u'/', # FRACTION SLASH
+ u'\xc6': u'AE', # LATIN CAPITAL LETTER AE
+ u'\xa9': u'(c)', # COPYRIGHT SIGN
+ u'\xab': u'"', # LEFT-POINTING DOUBLE ANGLE QUOTATION MARK
+ u'\xe6': u'ae', # LATIN SMALL LETTER AE
+ u'\xae': u'(r)', # REGISTERED SIGN
+ u'\u0153': u'oe', # LATIN SMALL LIGATURE OE
+ u'\u0152': u'OE', # LATIN CAPITAL LIGATURE OE
+ u'\xd8': u'O', # LATIN CAPITAL LETTER O WITH STROKE
+ u'\xf8': u'o', # LATIN SMALL LETTER O WITH STROKE
+ u'\xbb': u'"', # RIGHT-POINTING DOUBLE ANGLE QUOTATION MARK
+ u'\xdf': u'ss', # LATIN SMALL LETTER SHARP S
+ }
+
+def unormalize(ustring, ignorenonascii=None, substitute=None):
+ """replace diacritical characters with their corresponding ascii characters
+
+ Convert the unicode string to its long normalized form (unicode character
+ will be transform into several characters) and keep the first one only.
+ The normal form KD (NFKD) will apply the compatibility decomposition, i.e.
+ replace all compatibility characters with their equivalents.
+
+ :type substitute: str
+ :param substitute: replacement character to use if decomposition fails
+
+ :see: Another project about ASCII transliterations of Unicode text
+ http://pypi.python.org/pypi/Unidecode
+ """
+ # backward compatibility, ignorenonascii was a boolean
+ if ignorenonascii is not None:
+ warn("ignorenonascii is deprecated, use substitute named parameter instead",
+ DeprecationWarning, stacklevel=2)
+ if ignorenonascii:
+ substitute = ''
+ res = []
+ for letter in ustring[:]:
+ try:
+ replacement = MANUAL_UNICODE_MAP[letter]
+ except KeyError:
+ replacement = _uninormalize('NFKD', letter)[0]
+ if ord(replacement) >= 2 ** 7:
+ if substitute is None:
+ raise ValueError("can't deal with non-ascii based characters")
+ replacement = substitute
+ res.append(replacement)
+ return u''.join(res)
+
+def unquote(string):
+ """remove optional quotes (simple or double) from the string
+
+ :type string: str or unicode
+ :param string: an optionally quoted string
+
+ :rtype: str or unicode
+ :return: the unquoted string (or the input string if it wasn't quoted)
+ """
+ if not string:
+ return string
+ if string[0] in '"\'':
+ string = string[1:]
+ if string[-1] in '"\'':
+ string = string[:-1]
+ return string
+
+
+_BLANKLINES_RGX = re.compile('\r?\n\r?\n')
+_NORM_SPACES_RGX = re.compile('\s+')
+
+def normalize_text(text, line_len=80, indent='', rest=False):
+ """normalize a text to display it with a maximum line size and
+ optionally arbitrary indentation. Line jumps are normalized but blank
+ lines are kept. The indentation string may be used to insert a
+ comment (#) or a quoting (>) mark for instance.
+
+ :type text: str or unicode
+ :param text: the input text to normalize
+
+ :type line_len: int
+ :param line_len: expected maximum line's length, default to 80
+
+ :type indent: str or unicode
+ :param indent: optional string to use as indentation
+
+ :rtype: str or unicode
+ :return:
+ the input text normalized to fit on lines with a maximized size
+ inferior to `line_len`, and optionally prefixed by an
+ indentation string
+ """
+ if rest:
+ normp = normalize_rest_paragraph
+ else:
+ normp = normalize_paragraph
+ result = []
+ for text in _BLANKLINES_RGX.split(text):
+ result.append(normp(text, line_len, indent))
+ return ('%s%s%s' % (linesep, indent, linesep)).join(result)
+
+
+def normalize_paragraph(text, line_len=80, indent=''):
+ """normalize a text to display it with a maximum line size and
+ optionally arbitrary indentation. Line jumps are normalized. The
+ indentation string may be used top insert a comment mark for
+ instance.
+
+ :type text: str or unicode
+ :param text: the input text to normalize
+
+ :type line_len: int
+ :param line_len: expected maximum line's length, default to 80
+
+ :type indent: str or unicode
+ :param indent: optional string to use as indentation
+
+ :rtype: str or unicode
+ :return:
+ the input text normalized to fit on lines with a maximized size
+ inferior to `line_len`, and optionally prefixed by an
+ indentation string
+ """
+ text = _NORM_SPACES_RGX.sub(' ', text)
+ line_len = line_len - len(indent)
+ lines = []
+ while text:
+ aline, text = splittext(text.strip(), line_len)
+ lines.append(indent + aline)
+ return linesep.join(lines)
+
+def normalize_rest_paragraph(text, line_len=80, indent=''):
+ """normalize a ReST text to display it with a maximum line size and
+ optionally arbitrary indentation. Line jumps are normalized. The
+ indentation string may be used top insert a comment mark for
+ instance.
+
+ :type text: str or unicode
+ :param text: the input text to normalize
+
+ :type line_len: int
+ :param line_len: expected maximum line's length, default to 80
+
+ :type indent: str or unicode
+ :param indent: optional string to use as indentation
+
+ :rtype: str or unicode
+ :return:
+ the input text normalized to fit on lines with a maximized size
+ inferior to `line_len`, and optionally prefixed by an
+ indentation string
+ """
+ toreport = ''
+ lines = []
+ line_len = line_len - len(indent)
+ for line in text.splitlines():
+ line = toreport + _NORM_SPACES_RGX.sub(' ', line.strip())
+ toreport = ''
+ while len(line) > line_len:
+ # too long line, need split
+ line, toreport = splittext(line, line_len)
+ lines.append(indent + line)
+ if toreport:
+ line = toreport + ' '
+ toreport = ''
+ else:
+ line = ''
+ if line:
+ lines.append(indent + line.strip())
+ return linesep.join(lines)
+
+
+def splittext(text, line_len):
+ """split the given text on space according to the given max line size
+
+ return a 2-uple:
+ * a line <= line_len if possible
+ * the rest of the text which has to be reported on another line
+ """
+ if len(text) <= line_len:
+ return text, ''
+ pos = min(len(text)-1, line_len)
+ while pos > 0 and text[pos] != ' ':
+ pos -= 1
+ if pos == 0:
+ pos = min(len(text), line_len)
+ while len(text) > pos and text[pos] != ' ':
+ pos += 1
+ return text[:pos], text[pos+1:].strip()
+
+
+def splitstrip(string, sep=','):
+ """return a list of stripped string by splitting the string given as
+ argument on `sep` (',' by default). Empty string are discarded.
+
+ >>> splitstrip('a, b, c , 4,,')
+ ['a', 'b', 'c', '4']
+ >>> splitstrip('a')
+ ['a']
+ >>>
+
+ :type string: str or unicode
+ :param string: a csv line
+
+ :type sep: str or unicode
+ :param sep: field separator, default to the comma (',')
+
+ :rtype: str or unicode
+ :return: the unquoted string (or the input string if it wasn't quoted)
+ """
+ return [word.strip() for word in string.split(sep) if word.strip()]
+
+get_csv = deprecated('get_csv is deprecated, use splitstrip')(splitstrip)
+
+
+def split_url_or_path(url_or_path):
+ """return the latest component of a string containing either an url of the
+ form <scheme>://<path> or a local file system path
+ """
+ if '://' in url_or_path:
+ return url_or_path.rstrip('/').rsplit('/', 1)
+ return osp.split(url_or_path.rstrip(osp.sep))
+
+
+def text_to_dict(text):
+ """parse multilines text containing simple 'key=value' lines and return a
+ dict of {'key': 'value'}. When the same key is encountered multiple time,
+ value is turned into a list containing all values.
+
+ >>> d = text_to_dict('''multiple=1
+ ... multiple= 2
+ ... single =3
+ ... ''')
+ >>> d['single']
+ '3'
+ >>> d['multiple']
+ ['1', '2']
+
+ """
+ res = {}
+ if not text:
+ return res
+ for line in text.splitlines():
+ line = line.strip()
+ if line and not line.startswith('#'):
+ key, value = [w.strip() for w in line.split('=', 1)]
+ if key in res:
+ try:
+ res[key].append(value)
+ except AttributeError:
+ res[key] = [res[key], value]
+ else:
+ res[key] = value
+ return res
+
+
+_BLANK_URE = r'(\s|,)+'
+_BLANK_RE = re.compile(_BLANK_URE)
+__VALUE_URE = r'-?(([0-9]+\.[0-9]*)|((0x?)?[0-9]+))'
+__UNITS_URE = r'[a-zA-Z]+'
+_VALUE_RE = re.compile(r'(?P<value>%s)(?P<unit>%s)?'%(__VALUE_URE, __UNITS_URE))
+_VALIDATION_RE = re.compile(r'^((%s)(%s))*(%s)?$' % (__VALUE_URE, __UNITS_URE,
+ __VALUE_URE))
+
+BYTE_UNITS = {
+ "b": 1,
+ "kb": 1024,
+ "mb": 1024 ** 2,
+ "gb": 1024 ** 3,
+ "tb": 1024 ** 4,
+}
+
+TIME_UNITS = {
+ "ms": 0.0001,
+ "s": 1,
+ "min": 60,
+ "h": 60 * 60,
+ "d": 60 * 60 *24,
+}
+
+def apply_units(string, units, inter=None, final=float, blank_reg=_BLANK_RE,
+ value_reg=_VALUE_RE):
+ """Parse the string applying the units defined in units
+ (e.g.: "1.5m",{'m',60} -> 80).
+
+ :type string: str or unicode
+ :param string: the string to parse
+
+ :type units: dict (or any object with __getitem__ using basestring key)
+ :param units: a dict mapping a unit string repr to its value
+
+ :type inter: type
+ :param inter: used to parse every intermediate value (need __sum__)
+
+ :type blank_reg: regexp
+ :param blank_reg: should match every blank char to ignore.
+
+ :type value_reg: regexp with "value" and optional "unit" group
+ :param value_reg: match a value and it's unit into the
+ """
+ if inter is None:
+ inter = final
+ fstring = _BLANK_RE.sub('', string)
+ if not (fstring and _VALIDATION_RE.match(fstring)):
+ raise ValueError("Invalid unit string: %r." % string)
+ values = []
+ for match in value_reg.finditer(fstring):
+ dic = match.groupdict()
+ lit, unit = dic["value"], dic.get("unit")
+ value = inter(lit)
+ if unit is not None:
+ try:
+ value *= units[unit.lower()]
+ except KeyError:
+ raise KeyError('invalid unit %s. valid units are %s' %
+ (unit, units.keys()))
+ values.append(value)
+ return final(sum(values))
+
+
+_LINE_RGX = re.compile('\r\n|\r+|\n')
+
+def pretty_match(match, string, underline_char='^'):
+ """return a string with the match location underlined:
+
+ >>> import re
+ >>> print(pretty_match(re.search('mange', 'il mange du bacon'), 'il mange du bacon'))
+ il mange du bacon
+ ^^^^^
+ >>>
+
+ :type match: _sre.SRE_match
+ :param match: object returned by re.match, re.search or re.finditer
+
+ :type string: str or unicode
+ :param string:
+ the string on which the regular expression has been applied to
+ obtain the `match` object
+
+ :type underline_char: str or unicode
+ :param underline_char:
+ character to use to underline the matched section, default to the
+ carret '^'
+
+ :rtype: str or unicode
+ :return:
+ the original string with an inserted line to underline the match
+ location
+ """
+ start = match.start()
+ end = match.end()
+ string = _LINE_RGX.sub(linesep, string)
+ start_line_pos = string.rfind(linesep, 0, start)
+ if start_line_pos == -1:
+ start_line_pos = 0
+ result = []
+ else:
+ result = [string[:start_line_pos]]
+ start_line_pos += len(linesep)
+ offset = start - start_line_pos
+ underline = ' ' * offset + underline_char * (end - start)
+ end_line_pos = string.find(linesep, end)
+ if end_line_pos == -1:
+ string = string[start_line_pos:]
+ result.append(string)
+ result.append(underline)
+ else:
+ end = string[end_line_pos + len(linesep):]
+ string = string[start_line_pos:end_line_pos]
+ result.append(string)
+ result.append(underline)
+ result.append(end)
+ return linesep.join(result).rstrip()
+
+
+# Ansi colorization ###########################################################
+
+ANSI_PREFIX = '\033['
+ANSI_END = 'm'
+ANSI_RESET = '\033[0m'
+ANSI_STYLES = {
+ 'reset': "0",
+ 'bold': "1",
+ 'italic': "3",
+ 'underline': "4",
+ 'blink': "5",
+ 'inverse': "7",
+ 'strike': "9",
+}
+ANSI_COLORS = {
+ 'reset': "0",
+ 'black': "30",
+ 'red': "31",
+ 'green': "32",
+ 'yellow': "33",
+ 'blue': "34",
+ 'magenta': "35",
+ 'cyan': "36",
+ 'white': "37",
+}
+
+def _get_ansi_code(color=None, style=None):
+ """return ansi escape code corresponding to color and style
+
+ :type color: str or None
+ :param color:
+ the color name (see `ANSI_COLORS` for available values)
+ or the color number when 256 colors are available
+
+ :type style: str or None
+ :param style:
+ style string (see `ANSI_COLORS` for available values). To get
+ several style effects at the same time, use a coma as separator.
+
+ :raise KeyError: if an unexistent color or style identifier is given
+
+ :rtype: str
+ :return: the built escape code
+ """
+ ansi_code = []
+ if style:
+ style_attrs = splitstrip(style)
+ for effect in style_attrs:
+ ansi_code.append(ANSI_STYLES[effect])
+ if color:
+ if color.isdigit():
+ ansi_code.extend(['38', '5'])
+ ansi_code.append(color)
+ else:
+ ansi_code.append(ANSI_COLORS[color])
+ if ansi_code:
+ return ANSI_PREFIX + ';'.join(ansi_code) + ANSI_END
+ return ''
+
+def colorize_ansi(msg, color=None, style=None):
+ """colorize message by wrapping it with ansi escape codes
+
+ :type msg: str or unicode
+ :param msg: the message string to colorize
+
+ :type color: str or None
+ :param color:
+ the color identifier (see `ANSI_COLORS` for available values)
+
+ :type style: str or None
+ :param style:
+ style string (see `ANSI_COLORS` for available values). To get
+ several style effects at the same time, use a coma as separator.
+
+ :raise KeyError: if an unexistent color or style identifier is given
+
+ :rtype: str or unicode
+ :return: the ansi escaped string
+ """
+ # If both color and style are not defined, then leave the text as is
+ if color is None and style is None:
+ return msg
+ escape_code = _get_ansi_code(color, style)
+ # If invalid (or unknown) color, don't wrap msg with ansi codes
+ if escape_code:
+ return '%s%s%s' % (escape_code, msg, ANSI_RESET)
+ return msg
+
+DIFF_STYLE = {'separator': 'cyan', 'remove': 'red', 'add': 'green'}
+
+def diff_colorize_ansi(lines, out=sys.stdout, style=DIFF_STYLE):
+ for line in lines:
+ if line[:4] in ('--- ', '+++ '):
+ out.write(colorize_ansi(line, style['separator']))
+ elif line[0] == '-':
+ out.write(colorize_ansi(line, style['remove']))
+ elif line[0] == '+':
+ out.write(colorize_ansi(line, style['add']))
+ elif line[:4] == '--- ':
+ out.write(colorize_ansi(line, style['separator']))
+ elif line[:4] == '+++ ':
+ out.write(colorize_ansi(line, style['separator']))
+ else:
+ out.write(line)
+
diff --git a/logilab/common/tree.py b/logilab/common/tree.py
new file mode 100644
index 0000000..885eb0f
--- /dev/null
+++ b/logilab/common/tree.py
@@ -0,0 +1,369 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Base class to represent a tree structure.
+
+
+
+
+"""
+__docformat__ = "restructuredtext en"
+
+import sys
+
+from logilab.common import flatten
+from logilab.common.visitor import VisitedMixIn, FilteredIterator, no_filter
+
+## Exceptions #################################################################
+
+class NodeNotFound(Exception):
+ """raised when a node has not been found"""
+
+EX_SIBLING_NOT_FOUND = "No such sibling as '%s'"
+EX_CHILD_NOT_FOUND = "No such child as '%s'"
+EX_NODE_NOT_FOUND = "No such node as '%s'"
+
+
+# Base node ###################################################################
+
+class Node(object):
+ """a basic tree node, characterized by an id"""
+
+ def __init__(self, nid=None) :
+ self.id = nid
+ # navigation
+ self.parent = None
+ self.children = []
+
+ def __iter__(self):
+ return iter(self.children)
+
+ def __str__(self, indent=0):
+ s = ['%s%s %s' % (' '*indent, self.__class__.__name__, self.id)]
+ indent += 2
+ for child in self.children:
+ try:
+ s.append(child.__str__(indent))
+ except TypeError:
+ s.append(child.__str__())
+ return '\n'.join(s)
+
+ def is_leaf(self):
+ return not self.children
+
+ def append(self, child):
+ """add a node to children"""
+ self.children.append(child)
+ child.parent = self
+
+ def remove(self, child):
+ """remove a child node"""
+ self.children.remove(child)
+ child.parent = None
+
+ def insert(self, index, child):
+ """insert a child node"""
+ self.children.insert(index, child)
+ child.parent = self
+
+ def replace(self, old_child, new_child):
+ """replace a child node with another"""
+ i = self.children.index(old_child)
+ self.children.pop(i)
+ self.children.insert(i, new_child)
+ new_child.parent = self
+
+ def get_sibling(self, nid):
+ """return the sibling node that has given id"""
+ try:
+ return self.parent.get_child_by_id(nid)
+ except NodeNotFound :
+ raise NodeNotFound(EX_SIBLING_NOT_FOUND % nid)
+
+ def next_sibling(self):
+ """
+ return the next sibling for this node if any
+ """
+ parent = self.parent
+ if parent is None:
+ # root node has no sibling
+ return None
+ index = parent.children.index(self)
+ try:
+ return parent.children[index+1]
+ except IndexError:
+ return None
+
+ def previous_sibling(self):
+ """
+ return the previous sibling for this node if any
+ """
+ parent = self.parent
+ if parent is None:
+ # root node has no sibling
+ return None
+ index = parent.children.index(self)
+ if index > 0:
+ return parent.children[index-1]
+ return None
+
+ def get_node_by_id(self, nid):
+ """
+ return node in whole hierarchy that has given id
+ """
+ root = self.root()
+ try:
+ return root.get_child_by_id(nid, 1)
+ except NodeNotFound :
+ raise NodeNotFound(EX_NODE_NOT_FOUND % nid)
+
+ def get_child_by_id(self, nid, recurse=None):
+ """
+ return child of given id
+ """
+ if self.id == nid:
+ return self
+ for c in self.children :
+ if recurse:
+ try:
+ return c.get_child_by_id(nid, 1)
+ except NodeNotFound :
+ continue
+ if c.id == nid :
+ return c
+ raise NodeNotFound(EX_CHILD_NOT_FOUND % nid)
+
+ def get_child_by_path(self, path):
+ """
+ return child of given path (path is a list of ids)
+ """
+ if len(path) > 0 and path[0] == self.id:
+ if len(path) == 1 :
+ return self
+ else :
+ for c in self.children :
+ try:
+ return c.get_child_by_path(path[1:])
+ except NodeNotFound :
+ pass
+ raise NodeNotFound(EX_CHILD_NOT_FOUND % path)
+
+ def depth(self):
+ """
+ return depth of this node in the tree
+ """
+ if self.parent is not None:
+ return 1 + self.parent.depth()
+ else :
+ return 0
+
+ def depth_down(self):
+ """
+ return depth of the tree from this node
+ """
+ if self.children:
+ return 1 + max([c.depth_down() for c in self.children])
+ return 1
+
+ def width(self):
+ """
+ return the width of the tree from this node
+ """
+ return len(self.leaves())
+
+ def root(self):
+ """
+ return the root node of the tree
+ """
+ if self.parent is not None:
+ return self.parent.root()
+ return self
+
+ def leaves(self):
+ """
+ return a list with all the leaves nodes descendant from this node
+ """
+ leaves = []
+ if self.children:
+ for child in self.children:
+ leaves += child.leaves()
+ return leaves
+ else:
+ return [self]
+
+ def flatten(self, _list=None):
+ """
+ return a list with all the nodes descendant from this node
+ """
+ if _list is None:
+ _list = []
+ _list.append(self)
+ for c in self.children:
+ c.flatten(_list)
+ return _list
+
+ def lineage(self):
+ """
+ return list of parents up to root node
+ """
+ lst = [self]
+ if self.parent is not None:
+ lst.extend(self.parent.lineage())
+ return lst
+
+class VNode(Node, VisitedMixIn):
+ """a visitable node
+ """
+ pass
+
+
+class BinaryNode(VNode):
+ """a binary node (i.e. only two children
+ """
+ def __init__(self, lhs=None, rhs=None) :
+ VNode.__init__(self)
+ if lhs is not None or rhs is not None:
+ assert lhs and rhs
+ self.append(lhs)
+ self.append(rhs)
+
+ def remove(self, child):
+ """remove the child and replace this node with the other child
+ """
+ self.children.remove(child)
+ self.parent.replace(self, self.children[0])
+
+ def get_parts(self):
+ """
+ return the left hand side and the right hand side of this node
+ """
+ return self.children[0], self.children[1]
+
+
+
+if sys.version_info[0:2] >= (2, 2):
+ list_class = list
+else:
+ from UserList import UserList
+ list_class = UserList
+
+class ListNode(VNode, list_class):
+ """Used to manipulate Nodes as Lists
+ """
+ def __init__(self):
+ list_class.__init__(self)
+ VNode.__init__(self)
+ self.children = self
+
+ def __str__(self, indent=0):
+ return '%s%s %s' % (indent*' ', self.__class__.__name__,
+ ', '.join([str(v) for v in self]))
+
+ def append(self, child):
+ """add a node to children"""
+ list_class.append(self, child)
+ child.parent = self
+
+ def insert(self, index, child):
+ """add a node to children"""
+ list_class.insert(self, index, child)
+ child.parent = self
+
+ def remove(self, child):
+ """add a node to children"""
+ list_class.remove(self, child)
+ child.parent = None
+
+ def pop(self, index):
+ """add a node to children"""
+ child = list_class.pop(self, index)
+ child.parent = None
+
+ def __iter__(self):
+ return list_class.__iter__(self)
+
+# construct list from tree ####################################################
+
+def post_order_list(node, filter_func=no_filter):
+ """
+ create a list with tree nodes for which the <filter> function returned true
+ in a post order fashion
+ """
+ l, stack = [], []
+ poped, index = 0, 0
+ while node:
+ if filter_func(node):
+ if node.children and not poped:
+ stack.append((node, index))
+ index = 0
+ node = node.children[0]
+ else:
+ l.append(node)
+ index += 1
+ try:
+ node = stack[-1][0].children[index]
+ except IndexError:
+ node = None
+ else:
+ node = None
+ poped = 0
+ if node is None and stack:
+ node, index = stack.pop()
+ poped = 1
+ return l
+
+def pre_order_list(node, filter_func=no_filter):
+ """
+ create a list with tree nodes for which the <filter> function returned true
+ in a pre order fashion
+ """
+ l, stack = [], []
+ poped, index = 0, 0
+ while node:
+ if filter_func(node):
+ if not poped:
+ l.append(node)
+ if node.children and not poped:
+ stack.append((node, index))
+ index = 0
+ node = node.children[0]
+ else:
+ index += 1
+ try:
+ node = stack[-1][0].children[index]
+ except IndexError:
+ node = None
+ else:
+ node = None
+ poped = 0
+ if node is None and len(stack) > 1:
+ node, index = stack.pop()
+ poped = 1
+ return l
+
+class PostfixedDepthFirstIterator(FilteredIterator):
+ """a postfixed depth first iterator, designed to be used with visitors
+ """
+ def __init__(self, node, filter_func=None):
+ FilteredIterator.__init__(self, node, post_order_list, filter_func)
+
+class PrefixedDepthFirstIterator(FilteredIterator):
+ """a prefixed depth first iterator, designed to be used with visitors
+ """
+ def __init__(self, node, filter_func=None):
+ FilteredIterator.__init__(self, node, pre_order_list, filter_func)
+
diff --git a/logilab/common/umessage.py b/logilab/common/umessage.py
new file mode 100644
index 0000000..a5e4799
--- /dev/null
+++ b/logilab/common/umessage.py
@@ -0,0 +1,194 @@
+# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Unicode email support (extends email from stdlib)"""
+
+__docformat__ = "restructuredtext en"
+
+import email
+from encodings import search_function
+import sys
+if sys.version_info >= (2, 5):
+ from email.utils import parseaddr, parsedate
+ from email.header import decode_header
+else:
+ from email.Utils import parseaddr, parsedate
+ from email.Header import decode_header
+
+from datetime import datetime
+
+try:
+ from mx.DateTime import DateTime
+except ImportError:
+ DateTime = datetime
+
+import logilab.common as lgc
+
+
+def decode_QP(string):
+ parts = []
+ for decoded, charset in decode_header(string):
+ if not charset :
+ charset = 'iso-8859-15'
+ parts.append(decoded.decode(charset, 'replace'))
+
+ if sys.version_info < (3, 3):
+ # decoding was non-RFC compliant wrt to whitespace handling
+ # see http://bugs.python.org/issue1079
+ return u' '.join(parts)
+ return u''.join(parts)
+
+def message_from_file(fd):
+ try:
+ return UMessage(email.message_from_file(fd))
+ except email.Errors.MessageParseError:
+ return ''
+
+def message_from_string(string):
+ try:
+ return UMessage(email.message_from_string(string))
+ except email.Errors.MessageParseError:
+ return ''
+
+class UMessage:
+ """Encapsulates an email.Message instance and returns only unicode objects.
+ """
+
+ def __init__(self, message):
+ self.message = message
+
+ # email.Message interface #################################################
+
+ def get(self, header, default=None):
+ value = self.message.get(header, default)
+ if value:
+ return decode_QP(value)
+ return value
+
+ def __getitem__(self, header):
+ return self.get(header)
+
+ def get_all(self, header, default=()):
+ return [decode_QP(val) for val in self.message.get_all(header, default)
+ if val is not None]
+
+ def is_multipart(self):
+ return self.message.is_multipart()
+
+ def get_boundary(self):
+ return self.message.get_boundary()
+
+ def walk(self):
+ for part in self.message.walk():
+ yield UMessage(part)
+
+ if sys.version_info < (3, 0):
+
+ def get_payload(self, index=None, decode=False):
+ message = self.message
+ if index is None:
+ payload = message.get_payload(index, decode)
+ if isinstance(payload, list):
+ return [UMessage(msg) for msg in payload]
+ if message.get_content_maintype() != 'text':
+ return payload
+
+ charset = message.get_content_charset() or 'iso-8859-1'
+ if search_function(charset) is None:
+ charset = 'iso-8859-1'
+ return unicode(payload or '', charset, "replace")
+ else:
+ payload = UMessage(message.get_payload(index, decode))
+ return payload
+
+ def get_content_maintype(self):
+ return unicode(self.message.get_content_maintype())
+
+ def get_content_type(self):
+ return unicode(self.message.get_content_type())
+
+ def get_filename(self, failobj=None):
+ value = self.message.get_filename(failobj)
+ if value is failobj:
+ return value
+ try:
+ return unicode(value)
+ except UnicodeDecodeError:
+ return u'error decoding filename'
+
+ else:
+
+ def get_payload(self, index=None, decode=False):
+ message = self.message
+ if index is None:
+ payload = message.get_payload(index, decode)
+ if isinstance(payload, list):
+ return [UMessage(msg) for msg in payload]
+ return payload
+ else:
+ payload = UMessage(message.get_payload(index, decode))
+ return payload
+
+ def get_content_maintype(self):
+ return self.message.get_content_maintype()
+
+ def get_content_type(self):
+ return self.message.get_content_type()
+
+ def get_filename(self, failobj=None):
+ return self.message.get_filename(failobj)
+
+ # other convenience methods ###############################################
+
+ def headers(self):
+ """return an unicode string containing all the message's headers"""
+ values = []
+ for header in self.message.keys():
+ values.append(u'%s: %s' % (header, self.get(header)))
+ return '\n'.join(values)
+
+ def multi_addrs(self, header):
+ """return a list of 2-uple (name, address) for the given address (which
+ is expected to be an header containing address such as from, to, cc...)
+ """
+ persons = []
+ for person in self.get_all(header, ()):
+ name, mail = parseaddr(person)
+ persons.append((name, mail))
+ return persons
+
+ def date(self, alternative_source=False, return_str=False):
+ """return a datetime object for the email's date or None if no date is
+ set or if it can't be parsed
+ """
+ value = self.get('date')
+ if value is None and alternative_source:
+ unix_from = self.message.get_unixfrom()
+ if unix_from is not None:
+ try:
+ value = unix_from.split(" ", 2)[2]
+ except IndexError:
+ pass
+ if value is not None:
+ datetuple = parsedate(value)
+ if datetuple:
+ if lgc.USE_MX_DATETIME:
+ return DateTime(*datetuple[:6])
+ return datetime(*datetuple[:6])
+ elif not return_str:
+ return None
+ return value
diff --git a/logilab/common/ureports/__init__.py b/logilab/common/ureports/__init__.py
new file mode 100644
index 0000000..d76ebe5
--- /dev/null
+++ b/logilab/common/ureports/__init__.py
@@ -0,0 +1,172 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Universal report objects and some formatting drivers.
+
+A way to create simple reports using python objects, primarily designed to be
+formatted as text and html.
+"""
+__docformat__ = "restructuredtext en"
+
+import sys
+
+from logilab.common.compat import StringIO
+from logilab.common.textutils import linesep
+
+
+def get_nodes(node, klass):
+ """return an iterator on all children node of the given klass"""
+ for child in node.children:
+ if isinstance(child, klass):
+ yield child
+ # recurse (FIXME: recursion controled by an option)
+ for grandchild in get_nodes(child, klass):
+ yield grandchild
+
+def layout_title(layout):
+ """try to return the layout's title as string, return None if not found
+ """
+ for child in layout.children:
+ if isinstance(child, Title):
+ return u' '.join([node.data for node in get_nodes(child, Text)])
+
+def build_summary(layout, level=1):
+ """make a summary for the report, including X level"""
+ assert level > 0
+ level -= 1
+ summary = List(klass=u'summary')
+ for child in layout.children:
+ if not isinstance(child, Section):
+ continue
+ label = layout_title(child)
+ if not label and not child.id:
+ continue
+ if not child.id:
+ child.id = label.replace(' ', '-')
+ node = Link(u'#'+child.id, label=label or child.id)
+ # FIXME: Three following lines produce not very compliant
+ # docbook: there are some useless <para><para>. They might be
+ # replaced by the three commented lines but this then produces
+ # a bug in html display...
+ if level and [n for n in child.children if isinstance(n, Section)]:
+ node = Paragraph([node, build_summary(child, level)])
+ summary.append(node)
+# summary.append(node)
+# if level and [n for n in child.children if isinstance(n, Section)]:
+# summary.append(build_summary(child, level))
+ return summary
+
+
+class BaseWriter(object):
+ """base class for ureport writers"""
+
+ def format(self, layout, stream=None, encoding=None):
+ """format and write the given layout into the stream object
+
+ unicode policy: unicode strings may be found in the layout;
+ try to call stream.write with it, but give it back encoded using
+ the given encoding if it fails
+ """
+ if stream is None:
+ stream = sys.stdout
+ if not encoding:
+ encoding = getattr(stream, 'encoding', 'UTF-8')
+ self.encoding = encoding or 'UTF-8'
+ self.__compute_funcs = []
+ self.out = stream
+ self.begin_format(layout)
+ layout.accept(self)
+ self.end_format(layout)
+
+ def format_children(self, layout):
+ """recurse on the layout children and call their accept method
+ (see the Visitor pattern)
+ """
+ for child in getattr(layout, 'children', ()):
+ child.accept(self)
+
+ def writeln(self, string=u''):
+ """write a line in the output buffer"""
+ self.write(string + linesep)
+
+ def write(self, string):
+ """write a string in the output buffer"""
+ try:
+ self.out.write(string)
+ except UnicodeEncodeError:
+ self.out.write(string.encode(self.encoding))
+
+ def begin_format(self, layout):
+ """begin to format a layout"""
+ self.section = 0
+
+ def end_format(self, layout):
+ """finished to format a layout"""
+
+ def get_table_content(self, table):
+ """trick to get table content without actually writing it
+
+ return an aligned list of lists containing table cells values as string
+ """
+ result = [[]]
+ cols = table.cols
+ for cell in self.compute_content(table):
+ if cols == 0:
+ result.append([])
+ cols = table.cols
+ cols -= 1
+ result[-1].append(cell)
+ # fill missing cells
+ while len(result[-1]) < cols:
+ result[-1].append(u'')
+ return result
+
+ def compute_content(self, layout):
+ """trick to compute the formatting of children layout before actually
+ writing it
+
+ return an iterator on strings (one for each child element)
+ """
+ # use cells !
+ def write(data):
+ try:
+ stream.write(data)
+ except UnicodeEncodeError:
+ stream.write(data.encode(self.encoding))
+ def writeln(data=u''):
+ try:
+ stream.write(data+linesep)
+ except UnicodeEncodeError:
+ stream.write(data.encode(self.encoding)+linesep)
+ self.write = write
+ self.writeln = writeln
+ self.__compute_funcs.append((write, writeln))
+ for child in layout.children:
+ stream = StringIO()
+ child.accept(self)
+ yield stream.getvalue()
+ self.__compute_funcs.pop()
+ try:
+ self.write, self.writeln = self.__compute_funcs[-1]
+ except IndexError:
+ del self.write
+ del self.writeln
+
+
+from logilab.common.ureports.nodes import *
+from logilab.common.ureports.text_writer import TextWriter
+from logilab.common.ureports.html_writer import HTMLWriter
diff --git a/logilab/common/ureports/docbook_writer.py b/logilab/common/ureports/docbook_writer.py
new file mode 100644
index 0000000..857068c
--- /dev/null
+++ b/logilab/common/ureports/docbook_writer.py
@@ -0,0 +1,140 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""HTML formatting drivers for ureports"""
+__docformat__ = "restructuredtext en"
+
+from six.moves import range
+
+from logilab.common.ureports import HTMLWriter
+
+class DocbookWriter(HTMLWriter):
+ """format layouts as HTML"""
+
+ def begin_format(self, layout):
+ """begin to format a layout"""
+ super(HTMLWriter, self).begin_format(layout)
+ if self.snippet is None:
+ self.writeln('<?xml version="1.0" encoding="ISO-8859-1"?>')
+ self.writeln("""
+<book xmlns:xi='http://www.w3.org/2001/XInclude'
+ lang='fr'>
+""")
+
+ def end_format(self, layout):
+ """finished to format a layout"""
+ if self.snippet is None:
+ self.writeln('</book>')
+
+ def visit_section(self, layout):
+ """display a section (using <chapter> (level 0) or <section>)"""
+ if self.section == 0:
+ tag = "chapter"
+ else:
+ tag = "section"
+ self.section += 1
+ self.writeln(self._indent('<%s%s>' % (tag, self.handle_attrs(layout))))
+ self.format_children(layout)
+ self.writeln(self._indent('</%s>'% tag))
+ self.section -= 1
+
+ def visit_title(self, layout):
+ """display a title using <title>"""
+ self.write(self._indent(' <title%s>' % self.handle_attrs(layout)))
+ self.format_children(layout)
+ self.writeln('</title>')
+
+ def visit_table(self, layout):
+ """display a table as html"""
+ self.writeln(self._indent(' <table%s><title>%s</title>' \
+ % (self.handle_attrs(layout), layout.title)))
+ self.writeln(self._indent(' <tgroup cols="%s">'% layout.cols))
+ for i in range(layout.cols):
+ self.writeln(self._indent(' <colspec colname="c%s" colwidth="1*"/>' % i))
+
+ table_content = self.get_table_content(layout)
+ # write headers
+ if layout.cheaders:
+ self.writeln(self._indent(' <thead>'))
+ self._write_row(table_content[0])
+ self.writeln(self._indent(' </thead>'))
+ table_content = table_content[1:]
+ elif layout.rcheaders:
+ self.writeln(self._indent(' <thead>'))
+ self._write_row(table_content[-1])
+ self.writeln(self._indent(' </thead>'))
+ table_content = table_content[:-1]
+ # write body
+ self.writeln(self._indent(' <tbody>'))
+ for i in range(len(table_content)):
+ row = table_content[i]
+ self.writeln(self._indent(' <row>'))
+ for j in range(len(row)):
+ cell = row[j] or '&#160;'
+ self.writeln(self._indent(' <entry>%s</entry>' % cell))
+ self.writeln(self._indent(' </row>'))
+ self.writeln(self._indent(' </tbody>'))
+ self.writeln(self._indent(' </tgroup>'))
+ self.writeln(self._indent(' </table>'))
+
+ def _write_row(self, row):
+ """write content of row (using <row> <entry>)"""
+ self.writeln(' <row>')
+ for j in range(len(row)):
+ cell = row[j] or '&#160;'
+ self.writeln(' <entry>%s</entry>' % cell)
+ self.writeln(self._indent(' </row>'))
+
+ def visit_list(self, layout):
+ """display a list (using <itemizedlist>)"""
+ self.writeln(self._indent(' <itemizedlist%s>' % self.handle_attrs(layout)))
+ for row in list(self.compute_content(layout)):
+ self.writeln(' <listitem><para>%s</para></listitem>' % row)
+ self.writeln(self._indent(' </itemizedlist>'))
+
+ def visit_paragraph(self, layout):
+ """display links (using <para>)"""
+ self.write(self._indent(' <para>'))
+ self.format_children(layout)
+ self.writeln('</para>')
+
+ def visit_span(self, layout):
+ """display links (using <p>)"""
+ #TODO: translate in docbook
+ self.write('<literal %s>' % self.handle_attrs(layout))
+ self.format_children(layout)
+ self.write('</literal>')
+
+ def visit_link(self, layout):
+ """display links (using <ulink>)"""
+ self.write('<ulink url="%s"%s>%s</ulink>' % (layout.url,
+ self.handle_attrs(layout),
+ layout.label))
+
+ def visit_verbatimtext(self, layout):
+ """display verbatim text (using <programlisting>)"""
+ self.writeln(self._indent(' <programlisting>'))
+ self.write(layout.data.replace('&', '&amp;').replace('<', '&lt;'))
+ self.writeln(self._indent(' </programlisting>'))
+
+ def visit_text(self, layout):
+ """add some text"""
+ self.write(layout.data.replace('&', '&amp;').replace('<', '&lt;'))
+
+ def _indent(self, string):
+ """correctly indent string according to section"""
+ return ' ' * 2*(self.section) + string
diff --git a/logilab/common/ureports/html_writer.py b/logilab/common/ureports/html_writer.py
new file mode 100644
index 0000000..eba34ea
--- /dev/null
+++ b/logilab/common/ureports/html_writer.py
@@ -0,0 +1,133 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""HTML formatting drivers for ureports"""
+__docformat__ = "restructuredtext en"
+
+from cgi import escape
+
+from six.moves import range
+
+from logilab.common.ureports import BaseWriter
+
+
+class HTMLWriter(BaseWriter):
+ """format layouts as HTML"""
+
+ def __init__(self, snippet=None):
+ super(HTMLWriter, self).__init__()
+ self.snippet = snippet
+
+ def handle_attrs(self, layout):
+ """get an attribute string from layout member attributes"""
+ attrs = u''
+ klass = getattr(layout, 'klass', None)
+ if klass:
+ attrs += u' class="%s"' % klass
+ nid = getattr(layout, 'id', None)
+ if nid:
+ attrs += u' id="%s"' % nid
+ return attrs
+
+ def begin_format(self, layout):
+ """begin to format a layout"""
+ super(HTMLWriter, self).begin_format(layout)
+ if self.snippet is None:
+ self.writeln(u'<html>')
+ self.writeln(u'<body>')
+
+ def end_format(self, layout):
+ """finished to format a layout"""
+ if self.snippet is None:
+ self.writeln(u'</body>')
+ self.writeln(u'</html>')
+
+
+ def visit_section(self, layout):
+ """display a section as html, using div + h[section level]"""
+ self.section += 1
+ self.writeln(u'<div%s>' % self.handle_attrs(layout))
+ self.format_children(layout)
+ self.writeln(u'</div>')
+ self.section -= 1
+
+ def visit_title(self, layout):
+ """display a title using <hX>"""
+ self.write(u'<h%s%s>' % (self.section, self.handle_attrs(layout)))
+ self.format_children(layout)
+ self.writeln(u'</h%s>' % self.section)
+
+ def visit_table(self, layout):
+ """display a table as html"""
+ self.writeln(u'<table%s>' % self.handle_attrs(layout))
+ table_content = self.get_table_content(layout)
+ for i in range(len(table_content)):
+ row = table_content[i]
+ if i == 0 and layout.rheaders:
+ self.writeln(u'<tr class="header">')
+ elif i+1 == len(table_content) and layout.rrheaders:
+ self.writeln(u'<tr class="header">')
+ else:
+ self.writeln(u'<tr class="%s">' % (i%2 and 'even' or 'odd'))
+ for j in range(len(row)):
+ cell = row[j] or u'&#160;'
+ if (layout.rheaders and i == 0) or \
+ (layout.cheaders and j == 0) or \
+ (layout.rrheaders and i+1 == len(table_content)) or \
+ (layout.rcheaders and j+1 == len(row)):
+ self.writeln(u'<th>%s</th>' % cell)
+ else:
+ self.writeln(u'<td>%s</td>' % cell)
+ self.writeln(u'</tr>')
+ self.writeln(u'</table>')
+
+ def visit_list(self, layout):
+ """display a list as html"""
+ self.writeln(u'<ul%s>' % self.handle_attrs(layout))
+ for row in list(self.compute_content(layout)):
+ self.writeln(u'<li>%s</li>' % row)
+ self.writeln(u'</ul>')
+
+ def visit_paragraph(self, layout):
+ """display links (using <p>)"""
+ self.write(u'<p>')
+ self.format_children(layout)
+ self.write(u'</p>')
+
+ def visit_span(self, layout):
+ """display links (using <p>)"""
+ self.write(u'<span%s>' % self.handle_attrs(layout))
+ self.format_children(layout)
+ self.write(u'</span>')
+
+ def visit_link(self, layout):
+ """display links (using <a>)"""
+ self.write(u' <a href="%s"%s>%s</a>' % (layout.url,
+ self.handle_attrs(layout),
+ layout.label))
+ def visit_verbatimtext(self, layout):
+ """display verbatim text (using <pre>)"""
+ self.write(u'<pre>')
+ self.write(layout.data.replace(u'&', u'&amp;').replace(u'<', u'&lt;'))
+ self.write(u'</pre>')
+
+ def visit_text(self, layout):
+ """add some text"""
+ data = layout.data
+ if layout.escaped:
+ data = data.replace(u'&', u'&amp;').replace(u'<', u'&lt;')
+ self.write(data)
diff --git a/logilab/common/ureports/nodes.py b/logilab/common/ureports/nodes.py
new file mode 100644
index 0000000..a9585b3
--- /dev/null
+++ b/logilab/common/ureports/nodes.py
@@ -0,0 +1,203 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Micro reports objects.
+
+A micro report is a tree of layout and content objects.
+"""
+__docformat__ = "restructuredtext en"
+
+from logilab.common.tree import VNode
+
+from six import string_types
+
+class BaseComponent(VNode):
+ """base report component
+
+ attributes
+ * id : the component's optional id
+ * klass : the component's optional klass
+ """
+ def __init__(self, id=None, klass=None):
+ VNode.__init__(self, id)
+ self.klass = klass
+
+class BaseLayout(BaseComponent):
+ """base container node
+
+ attributes
+ * BaseComponent attributes
+ * children : components in this table (i.e. the table's cells)
+ """
+ def __init__(self, children=(), **kwargs):
+ super(BaseLayout, self).__init__(**kwargs)
+ for child in children:
+ if isinstance(child, BaseComponent):
+ self.append(child)
+ else:
+ self.add_text(child)
+
+ def append(self, child):
+ """overridden to detect problems easily"""
+ assert child not in self.parents()
+ VNode.append(self, child)
+
+ def parents(self):
+ """return the ancestor nodes"""
+ assert self.parent is not self
+ if self.parent is None:
+ return []
+ return [self.parent] + self.parent.parents()
+
+ def add_text(self, text):
+ """shortcut to add text data"""
+ self.children.append(Text(text))
+
+
+# non container nodes #########################################################
+
+class Text(BaseComponent):
+ """a text portion
+
+ attributes :
+ * BaseComponent attributes
+ * data : the text value as an encoded or unicode string
+ """
+ def __init__(self, data, escaped=True, **kwargs):
+ super(Text, self).__init__(**kwargs)
+ #if isinstance(data, unicode):
+ # data = data.encode('ascii')
+ assert isinstance(data, string_types), data.__class__
+ self.escaped = escaped
+ self.data = data
+
+class VerbatimText(Text):
+ """a verbatim text, display the raw data
+
+ attributes :
+ * BaseComponent attributes
+ * data : the text value as an encoded or unicode string
+ """
+
+class Link(BaseComponent):
+ """a labelled link
+
+ attributes :
+ * BaseComponent attributes
+ * url : the link's target (REQUIRED)
+ * label : the link's label as a string (use the url by default)
+ """
+ def __init__(self, url, label=None, **kwargs):
+ super(Link, self).__init__(**kwargs)
+ assert url
+ self.url = url
+ self.label = label or url
+
+
+class Image(BaseComponent):
+ """an embedded or a single image
+
+ attributes :
+ * BaseComponent attributes
+ * filename : the image's filename (REQUIRED)
+ * stream : the stream object containing the image data (REQUIRED)
+ * title : the image's optional title
+ """
+ def __init__(self, filename, stream, title=None, **kwargs):
+ super(Image, self).__init__(**kwargs)
+ assert filename
+ assert stream
+ self.filename = filename
+ self.stream = stream
+ self.title = title
+
+
+# container nodes #############################################################
+
+class Section(BaseLayout):
+ """a section
+
+ attributes :
+ * BaseLayout attributes
+
+ a title may also be given to the constructor, it'll be added
+ as a first element
+ a description may also be given to the constructor, it'll be added
+ as a first paragraph
+ """
+ def __init__(self, title=None, description=None, **kwargs):
+ super(Section, self).__init__(**kwargs)
+ if description:
+ self.insert(0, Paragraph([Text(description)]))
+ if title:
+ self.insert(0, Title(children=(title,)))
+
+class Title(BaseLayout):
+ """a title
+
+ attributes :
+ * BaseLayout attributes
+
+ A title must not contains a section nor a paragraph!
+ """
+
+class Span(BaseLayout):
+ """a title
+
+ attributes :
+ * BaseLayout attributes
+
+ A span should only contains Text and Link nodes (in-line elements)
+ """
+
+class Paragraph(BaseLayout):
+ """a simple text paragraph
+
+ attributes :
+ * BaseLayout attributes
+
+ A paragraph must not contains a section !
+ """
+
+class Table(BaseLayout):
+ """some tabular data
+
+ attributes :
+ * BaseLayout attributes
+ * cols : the number of columns of the table (REQUIRED)
+ * rheaders : the first row's elements are table's header
+ * cheaders : the first col's elements are table's header
+ * title : the table's optional title
+ """
+ def __init__(self, cols, title=None,
+ rheaders=0, cheaders=0, rrheaders=0, rcheaders=0,
+ **kwargs):
+ super(Table, self).__init__(**kwargs)
+ assert isinstance(cols, int)
+ self.cols = cols
+ self.title = title
+ self.rheaders = rheaders
+ self.cheaders = cheaders
+ self.rrheaders = rrheaders
+ self.rcheaders = rcheaders
+
+class List(BaseLayout):
+ """some list data
+
+ attributes :
+ * BaseLayout attributes
+ """
diff --git a/logilab/common/ureports/text_writer.py b/logilab/common/ureports/text_writer.py
new file mode 100644
index 0000000..c87613c
--- /dev/null
+++ b/logilab/common/ureports/text_writer.py
@@ -0,0 +1,145 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Text formatting drivers for ureports"""
+
+from __future__ import print_function
+
+__docformat__ = "restructuredtext en"
+
+from six.moves import range
+
+from logilab.common.textutils import linesep
+from logilab.common.ureports import BaseWriter
+
+
+TITLE_UNDERLINES = [u'', u'=', u'-', u'`', u'.', u'~', u'^']
+BULLETS = [u'*', u'-']
+
+class TextWriter(BaseWriter):
+ """format layouts as text
+ (ReStructured inspiration but not totally handled yet)
+ """
+ def begin_format(self, layout):
+ super(TextWriter, self).begin_format(layout)
+ self.list_level = 0
+ self.pending_urls = []
+
+ def visit_section(self, layout):
+ """display a section as text
+ """
+ self.section += 1
+ self.writeln()
+ self.format_children(layout)
+ if self.pending_urls:
+ self.writeln()
+ for label, url in self.pending_urls:
+ self.writeln(u'.. _`%s`: %s' % (label, url))
+ self.pending_urls = []
+ self.section -= 1
+ self.writeln()
+
+ def visit_title(self, layout):
+ title = u''.join(list(self.compute_content(layout)))
+ self.writeln(title)
+ try:
+ self.writeln(TITLE_UNDERLINES[self.section] * len(title))
+ except IndexError:
+ print("FIXME TITLE TOO DEEP. TURNING TITLE INTO TEXT")
+
+ def visit_paragraph(self, layout):
+ """enter a paragraph"""
+ self.format_children(layout)
+ self.writeln()
+
+ def visit_span(self, layout):
+ """enter a span"""
+ self.format_children(layout)
+
+ def visit_table(self, layout):
+ """display a table as text"""
+ table_content = self.get_table_content(layout)
+ # get columns width
+ cols_width = [0]*len(table_content[0])
+ for row in table_content:
+ for index in range(len(row)):
+ col = row[index]
+ cols_width[index] = max(cols_width[index], len(col))
+ if layout.klass == 'field':
+ self.field_table(layout, table_content, cols_width)
+ else:
+ self.default_table(layout, table_content, cols_width)
+ self.writeln()
+
+ def default_table(self, layout, table_content, cols_width):
+ """format a table"""
+ cols_width = [size+1 for size in cols_width]
+ format_strings = u' '.join([u'%%-%ss'] * len(cols_width))
+ format_strings = format_strings % tuple(cols_width)
+ format_strings = format_strings.split(' ')
+ table_linesep = u'\n+' + u'+'.join([u'-'*w for w in cols_width]) + u'+\n'
+ headsep = u'\n+' + u'+'.join([u'='*w for w in cols_width]) + u'+\n'
+ # FIXME: layout.cheaders
+ self.write(table_linesep)
+ for i in range(len(table_content)):
+ self.write(u'|')
+ line = table_content[i]
+ for j in range(len(line)):
+ self.write(format_strings[j] % line[j])
+ self.write(u'|')
+ if i == 0 and layout.rheaders:
+ self.write(headsep)
+ else:
+ self.write(table_linesep)
+
+ def field_table(self, layout, table_content, cols_width):
+ """special case for field table"""
+ assert layout.cols == 2
+ format_string = u'%s%%-%ss: %%s' % (linesep, cols_width[0])
+ for field, value in table_content:
+ self.write(format_string % (field, value))
+
+
+ def visit_list(self, layout):
+ """display a list layout as text"""
+ bullet = BULLETS[self.list_level % len(BULLETS)]
+ indent = ' ' * self.list_level
+ self.list_level += 1
+ for child in layout.children:
+ self.write(u'%s%s%s ' % (linesep, indent, bullet))
+ child.accept(self)
+ self.list_level -= 1
+
+ def visit_link(self, layout):
+ """add a hyperlink"""
+ if layout.label != layout.url:
+ self.write(u'`%s`_' % layout.label)
+ self.pending_urls.append( (layout.label, layout.url) )
+ else:
+ self.write(layout.url)
+
+ def visit_verbatimtext(self, layout):
+ """display a verbatim layout as text (so difficult ;)
+ """
+ self.writeln(u'::\n')
+ for line in layout.data.splitlines():
+ self.writeln(u' ' + line)
+ self.writeln()
+
+ def visit_text(self, layout):
+ """add some text"""
+ self.write(u'%s' % layout.data)
diff --git a/logilab/common/urllib2ext.py b/logilab/common/urllib2ext.py
new file mode 100644
index 0000000..339aec0
--- /dev/null
+++ b/logilab/common/urllib2ext.py
@@ -0,0 +1,89 @@
+from __future__ import print_function
+
+import logging
+import urllib2
+
+import kerberos as krb
+
+class GssapiAuthError(Exception):
+ """raised on error during authentication process"""
+
+import re
+RGX = re.compile('(?:.*,)*\s*Negotiate\s*([^,]*),?', re.I)
+
+def get_negociate_value(headers):
+ for authreq in headers.getheaders('www-authenticate'):
+ match = RGX.search(authreq)
+ if match:
+ return match.group(1)
+
+class HTTPGssapiAuthHandler(urllib2.BaseHandler):
+ """Negotiate HTTP authentication using context from GSSAPI"""
+
+ handler_order = 400 # before Digest Auth
+
+ def __init__(self):
+ self._reset()
+
+ def _reset(self):
+ self._retried = 0
+ self._context = None
+
+ def clean_context(self):
+ if self._context is not None:
+ krb.authGSSClientClean(self._context)
+
+ def http_error_401(self, req, fp, code, msg, headers):
+ try:
+ if self._retried > 5:
+ raise urllib2.HTTPError(req.get_full_url(), 401,
+ "negotiate auth failed", headers, None)
+ self._retried += 1
+ logging.debug('gssapi handler, try %s' % self._retried)
+ negotiate = get_negociate_value(headers)
+ if negotiate is None:
+ logging.debug('no negociate found in a www-authenticate header')
+ return None
+ logging.debug('HTTPGssapiAuthHandler: negotiate 1 is %r' % negotiate)
+ result, self._context = krb.authGSSClientInit("HTTP@%s" % req.get_host())
+ if result < 1:
+ raise GssapiAuthError("HTTPGssapiAuthHandler: init failed with %d" % result)
+ result = krb.authGSSClientStep(self._context, negotiate)
+ if result < 0:
+ raise GssapiAuthError("HTTPGssapiAuthHandler: step 1 failed with %d" % result)
+ client_response = krb.authGSSClientResponse(self._context)
+ logging.debug('HTTPGssapiAuthHandler: client response is %s...' % client_response[:10])
+ req.add_unredirected_header('Authorization', "Negotiate %s" % client_response)
+ server_response = self.parent.open(req)
+ negotiate = get_negociate_value(server_response.info())
+ if negotiate is None:
+ logging.warning('HTTPGssapiAuthHandler: failed to authenticate server')
+ else:
+ logging.debug('HTTPGssapiAuthHandler negotiate 2: %s' % negotiate)
+ result = krb.authGSSClientStep(self._context, negotiate)
+ if result < 1:
+ raise GssapiAuthError("HTTPGssapiAuthHandler: step 2 failed with %d" % result)
+ return server_response
+ except GssapiAuthError as exc:
+ logging.error(repr(exc))
+ finally:
+ self.clean_context()
+ self._reset()
+
+if __name__ == '__main__':
+ import sys
+ # debug
+ import httplib
+ httplib.HTTPConnection.debuglevel = 1
+ httplib.HTTPSConnection.debuglevel = 1
+ # debug
+ import logging
+ logging.basicConfig(level=logging.DEBUG)
+ # handle cookies
+ import cookielib
+ cj = cookielib.CookieJar()
+ ch = urllib2.HTTPCookieProcessor(cj)
+ # test with url sys.argv[1]
+ h = HTTPGssapiAuthHandler()
+ response = urllib2.build_opener(h, ch).open(sys.argv[1])
+ print('\nresponse: %s\n--------------\n' % response.code, response.info())
diff --git a/logilab/common/vcgutils.py b/logilab/common/vcgutils.py
new file mode 100644
index 0000000..9cd2acd
--- /dev/null
+++ b/logilab/common/vcgutils.py
@@ -0,0 +1,216 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""Functions to generate files readable with Georg Sander's vcg
+(Visualization of Compiler Graphs).
+
+You can download vcg at http://rw4.cs.uni-sb.de/~sander/html/gshome.html
+Note that vcg exists as a debian package.
+
+See vcg's documentation for explanation about the different values that
+maybe used for the functions parameters.
+
+
+
+
+"""
+__docformat__ = "restructuredtext en"
+
+import string
+
+ATTRS_VAL = {
+ 'algos': ('dfs', 'tree', 'minbackward',
+ 'left_to_right', 'right_to_left',
+ 'top_to_bottom', 'bottom_to_top',
+ 'maxdepth', 'maxdepthslow', 'mindepth', 'mindepthslow',
+ 'mindegree', 'minindegree', 'minoutdegree',
+ 'maxdegree', 'maxindegree', 'maxoutdegree'),
+ 'booleans': ('yes', 'no'),
+ 'colors': ('black', 'white', 'blue', 'red', 'green', 'yellow',
+ 'magenta', 'lightgrey',
+ 'cyan', 'darkgrey', 'darkblue', 'darkred', 'darkgreen',
+ 'darkyellow', 'darkmagenta', 'darkcyan', 'gold',
+ 'lightblue', 'lightred', 'lightgreen', 'lightyellow',
+ 'lightmagenta', 'lightcyan', 'lilac', 'turquoise',
+ 'aquamarine', 'khaki', 'purple', 'yellowgreen', 'pink',
+ 'orange', 'orchid'),
+ 'shapes': ('box', 'ellipse', 'rhomb', 'triangle'),
+ 'textmodes': ('center', 'left_justify', 'right_justify'),
+ 'arrowstyles': ('solid', 'line', 'none'),
+ 'linestyles': ('continuous', 'dashed', 'dotted', 'invisible'),
+ }
+
+# meaning of possible values:
+# O -> string
+# 1 -> int
+# list -> value in list
+GRAPH_ATTRS = {
+ 'title': 0,
+ 'label': 0,
+ 'color': ATTRS_VAL['colors'],
+ 'textcolor': ATTRS_VAL['colors'],
+ 'bordercolor': ATTRS_VAL['colors'],
+ 'width': 1,
+ 'height': 1,
+ 'borderwidth': 1,
+ 'textmode': ATTRS_VAL['textmodes'],
+ 'shape': ATTRS_VAL['shapes'],
+ 'shrink': 1,
+ 'stretch': 1,
+ 'orientation': ATTRS_VAL['algos'],
+ 'vertical_order': 1,
+ 'horizontal_order': 1,
+ 'xspace': 1,
+ 'yspace': 1,
+ 'layoutalgorithm': ATTRS_VAL['algos'],
+ 'late_edge_labels': ATTRS_VAL['booleans'],
+ 'display_edge_labels': ATTRS_VAL['booleans'],
+ 'dirty_edge_labels': ATTRS_VAL['booleans'],
+ 'finetuning': ATTRS_VAL['booleans'],
+ 'manhattan_edges': ATTRS_VAL['booleans'],
+ 'smanhattan_edges': ATTRS_VAL['booleans'],
+ 'port_sharing': ATTRS_VAL['booleans'],
+ 'edges': ATTRS_VAL['booleans'],
+ 'nodes': ATTRS_VAL['booleans'],
+ 'splines': ATTRS_VAL['booleans'],
+ }
+NODE_ATTRS = {
+ 'title': 0,
+ 'label': 0,
+ 'color': ATTRS_VAL['colors'],
+ 'textcolor': ATTRS_VAL['colors'],
+ 'bordercolor': ATTRS_VAL['colors'],
+ 'width': 1,
+ 'height': 1,
+ 'borderwidth': 1,
+ 'textmode': ATTRS_VAL['textmodes'],
+ 'shape': ATTRS_VAL['shapes'],
+ 'shrink': 1,
+ 'stretch': 1,
+ 'vertical_order': 1,
+ 'horizontal_order': 1,
+ }
+EDGE_ATTRS = {
+ 'sourcename': 0,
+ 'targetname': 0,
+ 'label': 0,
+ 'linestyle': ATTRS_VAL['linestyles'],
+ 'class': 1,
+ 'thickness': 0,
+ 'color': ATTRS_VAL['colors'],
+ 'textcolor': ATTRS_VAL['colors'],
+ 'arrowcolor': ATTRS_VAL['colors'],
+ 'backarrowcolor': ATTRS_VAL['colors'],
+ 'arrowsize': 1,
+ 'backarrowsize': 1,
+ 'arrowstyle': ATTRS_VAL['arrowstyles'],
+ 'backarrowstyle': ATTRS_VAL['arrowstyles'],
+ 'textmode': ATTRS_VAL['textmodes'],
+ 'priority': 1,
+ 'anchor': 1,
+ 'horizontal_order': 1,
+ }
+
+
+# Misc utilities ###############################################################
+
+def latin_to_vcg(st):
+ """Convert latin characters using vcg escape sequence.
+ """
+ for char in st:
+ if char not in string.ascii_letters:
+ try:
+ num = ord(char)
+ if num >= 192:
+ st = st.replace(char, r'\fi%d'%ord(char))
+ except:
+ pass
+ return st
+
+
+class VCGPrinter:
+ """A vcg graph writer.
+ """
+
+ def __init__(self, output_stream):
+ self._stream = output_stream
+ self._indent = ''
+
+ def open_graph(self, **args):
+ """open a vcg graph
+ """
+ self._stream.write('%sgraph:{\n'%self._indent)
+ self._inc_indent()
+ self._write_attributes(GRAPH_ATTRS, **args)
+
+ def close_graph(self):
+ """close a vcg graph
+ """
+ self._dec_indent()
+ self._stream.write('%s}\n'%self._indent)
+
+
+ def node(self, title, **args):
+ """draw a node
+ """
+ self._stream.write('%snode: {title:"%s"' % (self._indent, title))
+ self._write_attributes(NODE_ATTRS, **args)
+ self._stream.write('}\n')
+
+
+ def edge(self, from_node, to_node, edge_type='', **args):
+ """draw an edge from a node to another.
+ """
+ self._stream.write(
+ '%s%sedge: {sourcename:"%s" targetname:"%s"' % (
+ self._indent, edge_type, from_node, to_node))
+ self._write_attributes(EDGE_ATTRS, **args)
+ self._stream.write('}\n')
+
+
+ # private ##################################################################
+
+ def _write_attributes(self, attributes_dict, **args):
+ """write graph, node or edge attributes
+ """
+ for key, value in args.items():
+ try:
+ _type = attributes_dict[key]
+ except KeyError:
+ raise Exception('''no such attribute %s
+possible attributes are %s''' % (key, attributes_dict.keys()))
+
+ if not _type:
+ self._stream.write('%s%s:"%s"\n' % (self._indent, key, value))
+ elif _type == 1:
+ self._stream.write('%s%s:%s\n' % (self._indent, key,
+ int(value)))
+ elif value in _type:
+ self._stream.write('%s%s:%s\n' % (self._indent, key, value))
+ else:
+ raise Exception('''value %s isn\'t correct for attribute %s
+correct values are %s''' % (value, key, _type))
+
+ def _inc_indent(self):
+ """increment indentation
+ """
+ self._indent = ' %s' % self._indent
+
+ def _dec_indent(self):
+ """decrement indentation
+ """
+ self._indent = self._indent[:-2]
diff --git a/logilab/common/visitor.py b/logilab/common/visitor.py
new file mode 100644
index 0000000..ed2b70f
--- /dev/null
+++ b/logilab/common/visitor.py
@@ -0,0 +1,109 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""A generic visitor abstract implementation.
+
+
+
+
+"""
+__docformat__ = "restructuredtext en"
+
+def no_filter(_):
+ return 1
+
+# Iterators ###################################################################
+class FilteredIterator(object):
+
+ def __init__(self, node, list_func, filter_func=None):
+ self._next = [(node, 0)]
+ if filter_func is None:
+ filter_func = no_filter
+ self._list = list_func(node, filter_func)
+
+ def __next__(self):
+ try:
+ return self._list.pop(0)
+ except :
+ return None
+
+ next = __next__
+
+# Base Visitor ################################################################
+class Visitor(object):
+
+ def __init__(self, iterator_class, filter_func=None):
+ self._iter_class = iterator_class
+ self.filter = filter_func
+
+ def visit(self, node, *args, **kargs):
+ """
+ launch the visit on a given node
+
+ call 'open_visit' before the beginning of the visit, with extra args
+ given
+ when all nodes have been visited, call the 'close_visit' method
+ """
+ self.open_visit(node, *args, **kargs)
+ return self.close_visit(self._visit(node))
+
+ def _visit(self, node):
+ iterator = self._get_iterator(node)
+ n = next(iterator)
+ while n:
+ result = n.accept(self)
+ n = next(iterator)
+ return result
+
+ def _get_iterator(self, node):
+ return self._iter_class(node, self.filter)
+
+ def open_visit(self, *args, **kargs):
+ """
+ method called at the beginning of the visit
+ """
+ pass
+
+ def close_visit(self, result):
+ """
+ method called at the end of the visit
+ """
+ return result
+
+# standard visited mixin ######################################################
+class VisitedMixIn(object):
+ """
+ Visited interface allow node visitors to use the node
+ """
+ def get_visit_name(self):
+ """
+ return the visit name for the mixed class. When calling 'accept', the
+ method <'visit_' + name returned by this method> will be called on the
+ visitor
+ """
+ try:
+ return self.TYPE.replace('-', '_')
+ except:
+ return self.__class__.__name__.lower()
+
+ def accept(self, visitor, *args, **kwargs):
+ func = getattr(visitor, 'visit_%s' % self.get_visit_name())
+ return func(self, *args, **kwargs)
+
+ def leave(self, visitor, *args, **kwargs):
+ func = getattr(visitor, 'leave_%s' % self.get_visit_name())
+ return func(self, *args, **kwargs)
diff --git a/logilab/common/xmlrpcutils.py b/logilab/common/xmlrpcutils.py
new file mode 100644
index 0000000..1d30d82
--- /dev/null
+++ b/logilab/common/xmlrpcutils.py
@@ -0,0 +1,131 @@
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""XML-RPC utilities."""
+__docformat__ = "restructuredtext en"
+
+import xmlrpclib
+from base64 import encodestring
+#from cStringIO import StringIO
+
+ProtocolError = xmlrpclib.ProtocolError
+
+## class BasicAuthTransport(xmlrpclib.Transport):
+## def __init__(self, username=None, password=None):
+## self.username = username
+## self.password = password
+## self.verbose = None
+## self.has_ssl = httplib.__dict__.has_key("HTTPConnection")
+
+## def request(self, host, handler, request_body, verbose=None):
+## # issue XML-RPC request
+## if self.has_ssl:
+## if host.startswith("https:"): h = httplib.HTTPSConnection(host)
+## else: h = httplib.HTTPConnection(host)
+## else: h = httplib.HTTP(host)
+
+## h.putrequest("POST", handler)
+
+## # required by HTTP/1.1
+## if not self.has_ssl: # HTTPConnection already does 1.1
+## h.putheader("Host", host)
+## h.putheader("Connection", "close")
+
+## if request_body: h.send(request_body)
+## if self.has_ssl:
+## response = h.getresponse()
+## if response.status != 200:
+## raise xmlrpclib.ProtocolError(host + handler,
+## response.status,
+## response.reason,
+## response.msg)
+## file = response.fp
+## else:
+## errcode, errmsg, headers = h.getreply()
+## if errcode != 200:
+## raise xmlrpclib.ProtocolError(host + handler, errcode,
+## errmsg, headers)
+
+## file = h.getfile()
+
+## return self.parse_response(file)
+
+
+
+class AuthMixin:
+ """basic http authentication mixin for xmlrpc transports"""
+
+ def __init__(self, username, password, encoding):
+ self.verbose = 0
+ self.username = username
+ self.password = password
+ self.encoding = encoding
+
+ def request(self, host, handler, request_body, verbose=0):
+ """issue XML-RPC request"""
+ h = self.make_connection(host)
+ h.putrequest("POST", handler)
+ # required by XML-RPC
+ h.putheader("User-Agent", self.user_agent)
+ h.putheader("Content-Type", "text/xml")
+ h.putheader("Content-Length", str(len(request_body)))
+ h.putheader("Host", host)
+ h.putheader("Connection", "close")
+ # basic auth
+ if self.username is not None and self.password is not None:
+ h.putheader("AUTHORIZATION", "Basic %s" % encodestring(
+ "%s:%s" % (self.username, self.password)).replace("\012", ""))
+ h.endheaders()
+ # send body
+ if request_body:
+ h.send(request_body)
+ # get and check reply
+ errcode, errmsg, headers = h.getreply()
+ if errcode != 200:
+ raise ProtocolError(host + handler, errcode, errmsg, headers)
+ file = h.getfile()
+## # FIXME: encoding ??? iirc, this fix a bug in xmlrpclib but...
+## data = h.getfile().read()
+## if self.encoding != 'UTF-8':
+## data = data.replace("version='1.0'",
+## "version='1.0' encoding='%s'" % self.encoding)
+## result = StringIO()
+## result.write(data)
+## result.seek(0)
+## return self.parse_response(result)
+ return self.parse_response(file)
+
+class BasicAuthTransport(AuthMixin, xmlrpclib.Transport):
+ """basic http authentication transport"""
+
+class BasicAuthSafeTransport(AuthMixin, xmlrpclib.SafeTransport):
+ """basic https authentication transport"""
+
+
+def connect(url, user=None, passwd=None, encoding='ISO-8859-1'):
+ """return an xml rpc server on <url>, using user / password if specified
+ """
+ if user or passwd:
+ assert user and passwd is not None
+ if url.startswith('https://'):
+ transport = BasicAuthSafeTransport(user, passwd, encoding)
+ else:
+ transport = BasicAuthTransport(user, passwd, encoding)
+ else:
+ transport = None
+ server = xmlrpclib.ServerProxy(url, transport, encoding=encoding)
+ return server
diff --git a/logilab/common/xmlutils.py b/logilab/common/xmlutils.py
new file mode 100644
index 0000000..d383b9d
--- /dev/null
+++ b/logilab/common/xmlutils.py
@@ -0,0 +1,61 @@
+# -*- coding: utf-8 -*-
+# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
+# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This file is part of logilab-common.
+#
+# logilab-common is free software: you can redistribute it and/or modify it under
+# the terms of the GNU Lesser General Public License as published by the Free
+# Software Foundation, either version 2.1 of the License, or (at your option) any
+# later version.
+#
+# logilab-common is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
+# details.
+#
+# You should have received a copy of the GNU Lesser General Public License along
+# with logilab-common. If not, see <http://www.gnu.org/licenses/>.
+"""XML utilities.
+
+This module contains useful functions for parsing and using XML data. For the
+moment, there is only one function that can parse the data inside a processing
+instruction and return a Python dictionary.
+
+
+
+
+"""
+__docformat__ = "restructuredtext en"
+
+import re
+
+RE_DOUBLE_QUOTE = re.compile('([\w\-\.]+)="([^"]+)"')
+RE_SIMPLE_QUOTE = re.compile("([\w\-\.]+)='([^']+)'")
+
+def parse_pi_data(pi_data):
+ """
+ Utility function that parses the data contained in an XML
+ processing instruction and returns a dictionary of keywords and their
+ associated values (most of the time, the processing instructions contain
+ data like ``keyword="value"``, if a keyword is not associated to a value,
+ for example ``keyword``, it will be associated to ``None``).
+
+ :param pi_data: data contained in an XML processing instruction.
+ :type pi_data: unicode
+
+ :returns: Dictionary of the keywords (Unicode strings) associated to
+ their values (Unicode strings) as they were defined in the
+ data.
+ :rtype: dict
+ """
+ results = {}
+ for elt in pi_data.split():
+ if RE_DOUBLE_QUOTE.match(elt):
+ kwd, val = RE_DOUBLE_QUOTE.match(elt).groups()
+ elif RE_SIMPLE_QUOTE.match(elt):
+ kwd, val = RE_SIMPLE_QUOTE.match(elt).groups()
+ else:
+ kwd, val = elt, None
+ results[kwd] = val
+ return results